Merge pull request #178 from SuanmoSuanyangTechnology/fix/workflow-cycle
fix(workflow): fix loop node scheduling and I/O issues
This commit is contained in:
@@ -9,7 +9,6 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class IterationRuntime:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -38,6 +39,7 @@ class IterationRuntime:
|
||||
config: Dictionary containing iteration node configuration.
|
||||
state: Current workflow state at the point of iteration.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -70,6 +72,7 @@ class IterationRuntime:
|
||||
"index": idx,
|
||||
}
|
||||
loopstate["looping"] = True
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
|
||||
@@ -26,6 +26,7 @@ class LoopRuntime:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -40,6 +41,7 @@ class LoopRuntime:
|
||||
config: Raw configuration dictionary for the loop node.
|
||||
state: The current workflow state before entering the loop.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -87,6 +89,7 @@ class LoopRuntime:
|
||||
**self.state
|
||||
)
|
||||
loopstate["looping"] = True
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -34,7 +34,6 @@ class CycleGraphNode(BaseNode):
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
self.end_node_ids = [] # IDs of end nodes within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.build_graph()
|
||||
@@ -105,13 +104,15 @@ class CycleGraphNode(BaseNode):
|
||||
"""
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.graph = GraphBuilder(
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
"nodes": self.cycle_nodes,
|
||||
"edges": self.cycle_edges,
|
||||
},
|
||||
subgraph=True
|
||||
).build()
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.graph = builder.build()
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
@@ -132,6 +133,7 @@ class CycleGraphNode(BaseNode):
|
||||
"""
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
@@ -139,6 +141,7 @@ class CycleGraphNode(BaseNode):
|
||||
).run()
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
|
||||
Reference in New Issue
Block a user