Merge pull request #638 from SuanmoSuanyangTechnology/feature/multi-end-stream

fix(workflow): unify streaming and non-stream execution outputs
This commit is contained in:
Mark
2026-03-20 15:19:55 +08:00
committed by GitHub
10 changed files with 353 additions and 223 deletions

View File

@@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import NodeFactory from app.core.workflow.nodes import NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
from app.core.workflow.utils.expression_evaluator import evaluate_condition from app.core.workflow.utils.expression_evaluator import evaluate_condition
from app.core.workflow.validator import WorkflowValidator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
# Strict variable format: {{ node_id.field_name }}
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
class GraphBuilder: class GraphBuilder:
def __init__( def __init__(
@@ -37,13 +49,13 @@ class GraphBuilder:
self.stream = stream self.stream = stream
self.subgraph = subgraph self.subgraph = subgraph
self.start_node_id = None self.start_node_id: str | None = None
self.end_node_ids = []
self.node_map = {node["id"]: node for node in self.nodes} self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {} self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_branch_node = lru_cache( self._find_upstream_activation_dep = lru_cache(
maxsize=len(self.nodes) * 2 maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node) )(self._find_upstream_activation_dep)
if variable_pool: if variable_pool:
self.variable_pool = variable_pool self.variable_pool = variable_pool
else: else:
@@ -51,10 +63,19 @@ class GraphBuilder:
self.graph = StateGraph(WorkflowState) self.graph = StateGraph(WorkflowState)
self.add_nodes() self.add_nodes()
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self.add_edges() self.add_edges()
self._analyze_end_node_output()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED. # EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj()
self._analyze_end_node_output()
@property @property
def nodes(self) -> list[dict[str, Any]]: def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", []) return self.workflow_config.get("nodes", [])
@@ -87,60 +108,50 @@ class GraphBuilder:
result[node[0]].append(node[1]) result[node[0]].append(node[1])
return result return result
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]: def _build_reverse_adj(self):
""" for edge in self.edges:
Recursively find all upstream branch (control) nodes that influence the execution if edge["source"] not in self.reachable_nodes:
of the given target node. continue
self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label")
})
This method walks upstream along the workflow graph starting from `target_node`. def _find_upstream_activation_dep(
It distinguishes between: self,
- branch nodes (node types listed in `BRANCH_NODES`) target_node: str
- non-branch nodes (ordinary processing nodes) ) -> tuple[tuple[tuple[str, str]], tuple[str]]:
"""Find upstream dependencies that affect the activation of a target node.
Traversal rules: Walks upstream along the workflow graph from the target node, collecting
1. For each immediate upstream node: two types of dependencies:
- If it is a branch node, it is recorded as an affecting control node. - Branch control nodes: upstream branch nodes (e.g. if-else) whose
- If it is a non-branch node, the traversal continues recursively upstream. routing outcome determines whether the target node executes.
2. If ANY upstream path reaches a START / CYCLE_START node without encountering - Output nodes: upstream END nodes that must complete their output
a branch node, the traversal is considered invalid: before the target node can activate.
- `has_branch` will be False
- no branch nodes are returned.
3. Only when ALL upstream non-branch paths eventually lead to at least one
branch node will `has_branch` be True.
Special case: The traversal terminates early and returns empty tuples if any upstream
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START, path reaches START/CYCLE_START without encountering a branch or output
it is considered directly reachable from the workflow entry, and therefore node, indicating the target node is directly reachable and should be
has no controlling branch nodes. activated immediately.
Args: Args:
target_node (str): target_node: The ID of the node whose upstream activation
The identifier of the node whose upstream control branches dependencies are to be resolved.
are to be resolved.
Returns: Returns:
tuple[bool, tuple[tuple[str, str]]]: A tuple of two elements:
- has_branch (bool): - A deduplicated tuple of (branch_node_id, branch_label) pairs
True if every upstream path from `target_node` encounters representing upstream branch control dependencies. Empty if
at least one branch node. any clean path to START exists.
False if any path reaches a start node without a branch. - A deduplicated tuple of upstream output node IDs that must
- branch_nodes (tuple[tuple[str, str]]): complete before this node activates.
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
representing all branch nodes that can influence `target_node`.
Returns an empty tuple if `has_branch` is False.
""" """
source_nodes = [ source_nodes = self._reverse_adj[target_node]
{
"id": edge.get("source"),
"branch": edge.get("label")
}
for edge in self.edges
if edge.get("target") == target_node
]
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return False, tuple() return tuple(), tuple()
branch_nodes = [] branch_nodes = []
output_nodes = []
non_branch_nodes = [] non_branch_nodes = []
for node_info in source_nodes: for node_info in source_nodes:
@@ -149,19 +160,23 @@ class GraphBuilder:
(node_info["id"], node_info["branch"]) (node_info["id"], node_info["branch"])
) )
else: else:
if self.get_node_type(node_info["id"]) == NodeType.END:
output_nodes.append(node_info["id"])
non_branch_nodes.append(node_info["id"]) non_branch_nodes.append(node_info["id"])
has_branch = True has_branch = True
for node_id in non_branch_nodes: for node_id in non_branch_nodes:
node_has_branch, nodes = self._find_upstream_branch_node(node_id) upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
has_branch = has_branch and node_has_branch if not upstream_control_nodes:
if not has_branch: if not upstream_output_nodes and node_id not in output_nodes:
break return tuple(), tuple()
branch_nodes.extend(nodes) branch_nodes = []
if not has_branch: has_branch = False
branch_nodes = [] if has_branch:
branch_nodes.extend(upstream_control_nodes)
output_nodes.extend(upstream_output_nodes)
return has_branch, tuple(set(branch_nodes)) return tuple(set(branch_nodes)), tuple(set(output_nodes))
def _analyze_end_node_output(self): def _analyze_end_node_output(self):
""" """
@@ -182,11 +197,10 @@ class GraphBuilder:
""" """
# Collect all End nodes in the workflow # Collect all End nodes in the workflow
end_nodes = [node for node in self.nodes if node.get("type") == "end"] logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
# Iterate through each End node to analyze its output # Iterate through each End node to analyze its output
for end_node in end_nodes: for end_node in self.end_nodes:
end_node_id = end_node.get("id") end_node_id = end_node.get("id")
config = end_node.get("config", {}) config = end_node.get("config", {})
output = config.get("output") output = config.get("output")
@@ -195,42 +209,33 @@ class GraphBuilder:
if not output: if not output:
continue continue
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
pattern = r'\{\{.*?\}\}|[^{}]+'
# Strict variable format: {{ node_id.field_name }}
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
# Split output into ordered segments # Split output into ordered segments
output_template = list(re.findall(pattern, output)) output_template = list(_OUTPUT_PATTERN.findall(output))
# Determine whether each segment is literal text # Determine whether each segment is literal text
# True -> literal (can be directly output) # True -> literal (can be directly output)
# False -> variable placeholder (needs runtime value) # False -> variable placeholder (needs runtime value)
output_flag = [ output_flag = [
not bool(variable_pattern.match(item)) not bool(_VARIABLE_PATTERN.match(item))
for item in output_template for item in output_template
] ]
# Stream mode: output activation depends on upstream branch nodes # Stream mode: output activation depends on upstream branch nodes
if self.stream: if self.stream:
# Find upstream branch nodes that can control this End node # Find upstream branch nodes that can control this End node
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id) upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
# Build StreamOutputConfig for this End node # Build StreamOutputConfig for this End node
self.end_node_map[end_node_id] = StreamOutputConfig( self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
# If there is no upstream branch, output is active immediately # If there is no upstream branch, output is active immediately
activate=not has_branch, activate=activate,
# Branch nodes that control activation of this End node # Branch nodes that control activation of this End node
control_nodes=self._merge_control_nodes(control_nodes), control_nodes=self._merge_control_nodes(upstream_control_nodes),
upstream_output_nodes=list(upstream_output_nodes),
control_resolved=not bool(upstream_control_nodes),
output_resolved=not bool(upstream_output_nodes),
# Convert output segments into OutputContent objects # Convert output segments into OutputContent objects
outputs=list( outputs=list(
@@ -249,14 +254,16 @@ class GraphBuilder:
cursor=0 cursor=0
) )
logger.info(f"[Stream Analysis] end_id: {end_node_id}, " logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
f"activate: {not has_branch}, " f"activate: {activate}, "
f"control_nodes: {control_nodes}," f"control_nodes: {upstream_control_nodes},"
f"ref_outputs: {upstream_output_nodes},"
f"output: {output_template}," f"output: {output_template},"
f"output_activate: {output_flag}") f"output_activate: {output_flag}")
# Non-stream mode: all outputs are activated by default # Non-stream mode: all outputs are activated by default
else: else:
self.end_node_map[end_node_id] = StreamOutputConfig( self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
activate=True, activate=True,
control_nodes={}, control_nodes={},
outputs=list( outputs=list(
@@ -269,7 +276,10 @@ class GraphBuilder:
for output_string, activate in zip(output_template, output_flag) for output_string, activate in zip(output_template, output_flag)
] ]
), ),
cursor=0 cursor=0,
upstream_output_nodes=[],
control_resolved=True,
output_resolved=True,
) )
def add_nodes(self): def add_nodes(self):
@@ -304,8 +314,6 @@ class GraphBuilder:
# Record start and end node IDs # Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]: if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# Create node instance (start and end nodes are also created) # Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
@@ -494,9 +502,11 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}") logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node # Connect End nodes to the global END node
for end_node_id in self.end_node_ids: for end_node in self.end_nodes:
self.graph.add_edge(end_node_id, END) end_node_id = end_node.get("id")
logger.debug(f"Added edge: {end_node_id} -> END") if end_node_id:
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:

View File

@@ -12,6 +12,7 @@ class WorkflowResultBuilder:
variable_pool: VariablePool, variable_pool: VariablePool,
elapsed_time: float, elapsed_time: float,
final_output: str, final_output: str,
success: bool
): ):
"""Construct the final standardized output of the workflow execution. """Construct the final standardized output of the workflow execution.
@@ -29,6 +30,7 @@ class WorkflowResultBuilder:
elapsed_time (float): Total execution time in seconds. elapsed_time (float): Total execution time in seconds.
final_output (Any): The aggregated or final output content of the workflow final_output (Any): The aggregated or final output content of the workflow
(e.g., combined messages from all End nodes). (e.g., combined messages from all End nodes).
success (bool): Whether the execution was successful.
Returns: Returns:
dict: A dictionary containing the final workflow execution result with keys: dict: A dictionary containing the final workflow execution result with keys:
@@ -49,7 +51,7 @@ class WorkflowResultBuilder:
conversation_id = variable_pool.get_value("sys.conversation_id") conversation_id = variable_pool.get_value("sys.conversation_id")
return { return {
"status": "completed", "status": "completed" if success else "failed",
"output": final_output, "output": final_output,
"variables": { "variables": {
"conv": variable_pool.get_all_conversation_vars(), "conv": variable_pool.get_all_conversation_vars(),

View File

@@ -3,6 +3,7 @@
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/9 15:11 # @Time : 2026/2/9 15:11
import re import re
from queue import Queue
from typing import AsyncGenerator from typing import AsyncGenerator
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
activate: bool = Field( activate: bool = Field(
..., ...,
description=( description=(
"Whether this output segment is currently active.\n" "Whether this output segment is currently active."
"- True: allowed to be emitted/output\n" "- True: allowed to be emitted/output"
"- False: blocked until activated by branch control" "- False: blocked until activated by branch control"
) )
) )
@@ -46,8 +47,8 @@ class OutputContent(BaseModel):
is_variable: bool = Field( is_variable: bool = Field(
..., ...,
description=( description=(
"Whether this segment represents a variable placeholder.\n" "Whether this segment represents a variable placeholder."
"True -> variable (e.g. {{ node.field }})\n" "True -> variable (e.g. {{ node.field }})"
"False -> literal text" "False -> literal text"
) )
) )
@@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel):
- which upstream branch/control nodes gate the activation - which upstream branch/control nodes gate the activation
- how each parsed output segment is streamed and activated - how each parsed output segment is streamed and activated
""" """
id: str = Field(
...,
description="ID of the End node this configuration belongs to."
)
activate: bool = Field( activate: bool = Field(
..., ...,
description=( description=(
"Global activation flag for the End node output.\n" "Global activation flag for the End node output."
"When False, output segments should not be emitted even if available.\n" "When False, output segments should not be emitted even if available."
"This flag typically becomes True once required control branch conditions " "This flag typically becomes True once required control branch conditions "
"are satisfied." "are satisfied."
) )
@@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel):
control_nodes: dict[str, list[str]] = Field( control_nodes: dict[str, list[str]] = Field(
..., ...,
description=( description=(
"Control branch conditions for this End node output.\n" "Control branch conditions for this End node output."
"Mapping of `branch_node_id -> expected_branch_label`.\n" "Mapping of `branch_node_id -> expected_branch_label`."
"The End node output becomes globally active when a controlling branch node " "The End node output becomes globally active when a controlling branch node "
"reports a matching completion status." "reports a matching completion status."
) )
) )
upstream_output_nodes: list[str] = Field(
...,
description=(
"Upstream output node dependencies (data flow)."
"Represents END/output nodes that this output depends on."
"These nodes provide data sources required before this output can be activated "
"or streamed."
"Used to ensure correct ordering and dependency resolution in streaming mode."
)
)
control_resolved: bool = Field(
...,
description=(
"Whether all upstream branch control dependencies have been satisfied."
"True if no upstream branch nodes exist or the required branch "
"conditions have been met."
)
)
output_resolved: bool = Field(
...,
description=(
"Whether all upstream output node dependencies have been completed."
"True if no upstream output nodes exist or all upstream output "
"nodes have finished their output."
)
)
outputs: list[OutputContent] = Field( outputs: list[OutputContent] = Field(
..., ...,
description=( description=(
"Ordered list of output segments parsed from the output template.\n" "Ordered list of output segments parsed from the output template."
"Each segment represents either a literal text block or a variable placeholder " "Each segment represents either a literal text block or a variable placeholder "
"that may be activated independently." "that may be activated independently."
) )
@@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel):
cursor: int = Field( cursor: int = Field(
..., ...,
description=( description=(
"Streaming cursor index.\n" "Streaming cursor index."
"Indicates the next output segment index to be emitted.\n" "Indicates the next output segment index to be emitted."
"Segments with index < cursor are considered already streamed." "Segments with index < cursor are considered already streamed."
) )
) )
force: bool = Field(
default=False,
description=(
"Force flag for output emission."
"When True, all output segments are emitted regardless of activation state."
"Triggered when this output node has finished execution."
)
)
def update_activate(self, scope: str, status=None): def update_activate(self, scope: str, status=None):
""" """
Update streaming activation state based on an upstream node or special variable. Update streaming activation state based on upstream events.
Args: Args:
scope (str): scope (str):
Identifier of the completed upstream entity. Identifier of the completed upstream entity.
- If a control branch node, it should match a key in `control_nodes`. - If a control branch node, it should match a key in `control_nodes`.
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments. - If an upstream output node, it should match an entry in `upstream_output_nodes`.
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
it may appear in output segments.
status (optional): status (optional):
Completion status of the control branch node. Completion status of the control branch node.
Required when `scope` refers to a control node. Required when `scope` refers to a control node.
Behavior: Behavior:
1. Control branch nodes: 1. Force activation:
- If `scope` matches a key in `control_nodes` and `status` matches the expected - If `self.force` is True, the method returns immediately.
branch label, the End node output becomes globally active (`activate = True`). - If `scope == self.id`, the node marks itself as completed:
- `activate = True`
- `force = True`
This is typically used for final flushing when the node finishes execution.
2. Variable output segments: 2. Control dependency resolution:
- For each segment that is a variable (`is_variable=True`): - If `scope` matches a key in `control_nodes`:
- If the segment literal references `scope`, mark the segment as active. - `status` must be provided.
- This applies both to regular node variables (e.g., "node_id.field") - If `status` matches expected branch labels, mark control as resolved
and special system variables (e.g., "sys.xxx"). (`control_resolved = True`).
3. Upstream output dependency resolution:
- If `scope` is in `upstream_output_nodes`,
mark data dependency as resolved (`output_resolved = True`).
4. Global activation condition:
- The node becomes active when BOTH conditions are satisfied:
- control_resolved == True
- output_resolved == True
- Once activated, `activate` remains True.
5. Variable segment activation:
- For each output segment that is a variable (`is_variable=True`):
- If the segment depends on the given `scope`,
mark the segment as active.
- This applies to both node variables (e.g., "node_id.field")
and system variables (e.g., "sys.xxx").
Notes: Notes:
- This method does not emit output or advance the streaming cursor. - This method does NOT emit output or advance the streaming cursor.
- It only updates activation flags based on upstream events or special variables. - It only updates activation and dependency resolution states.
- Activation is driven by both control flow (branch nodes) and
data flow (upstream output nodes).
""" """
if self.force:
return
# Case 1: resolve control branch dependency if scope == self.id:
self.activate = True
self.force = True
return
# resolve control branch dependency
if scope in self.control_nodes: if scope in self.control_nodes:
if status is None: if status is None:
raise RuntimeError("[Stream Output] Control node activation status not provided") raise RuntimeError("[Stream Output] Control node activation status not provided")
if status in self.control_nodes[scope]: if status in self.control_nodes[scope]:
self.activate = True self.control_resolved = True
# Case 2: activate variable segments related to this node if scope in self.upstream_output_nodes:
self.upstream_output_nodes.remove(scope)
if not self.upstream_output_nodes:
self.output_resolved = True
self.activate = self.activate or (self.control_resolved and self.output_resolved)
# activate variable segments related to this node
for i in range(len(self.outputs)): for i in range(len(self.outputs)):
if ( if (
self.outputs[i].is_variable self.outputs[i].is_variable
@@ -174,12 +256,17 @@ class StreamOutputCoordinator:
def __init__(self): def __init__(self):
self.end_outputs: dict[str, StreamOutputConfig] = {} self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None self.activate_end: str | None = None
self.output_queue: Queue = Queue()
self.processed_outputs = []
def initialize_end_outputs( def initialize_end_outputs(
self, self,
end_node_map: dict[str, StreamOutputConfig] end_node_map: dict[str, StreamOutputConfig]
): ):
self.end_outputs = end_node_map self.end_outputs = end_node_map
self.processed_outputs = []
self.activate_end = None
self.output_queue = Queue()
@property @property
def current_activate_end_info(self): def current_activate_end_info(self):
@@ -211,8 +298,11 @@ class StreamOutputCoordinator:
""" """
for node in self.end_outputs.keys(): for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(scope, status) self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and self.activate_end is None: if self.end_outputs[node].activate and node not in self.processed_outputs:
self.activate_end = node self.output_queue.put(node)
self.processed_outputs.append(node)
if self.activate_end is None and not self.output_queue.empty():
self.activate_end = self.output_queue.get_nowait()
async def emit_activate_chunk( async def emit_activate_chunk(
self, self,
@@ -256,7 +346,7 @@ class StreamOutputCoordinator:
final_chunk = '' final_chunk = ''
current_segment = end_info.outputs[end_info.cursor] current_segment = end_info.outputs[end_info.cursor]
if not current_segment.activate and not force: if not current_segment.activate and not force and not end_info.force:
# Stop processing until this segment becomes active # Stop processing until this segment becomes active
break break
@@ -273,7 +363,7 @@ class StreamOutputCoordinator:
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}") logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
if final_chunk: if final_chunk:
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}") logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
yield { yield {
"event": "message", "event": "message",
"data": { "data": {
@@ -285,8 +375,7 @@ class StreamOutputCoordinator:
end_info.cursor += 1 end_info.cursor += 1
if end_info.cursor >= len(end_info.outputs): if end_info.cursor >= len(end_info.outputs):
self.end_outputs.pop(self.activate_end) self.pop_current_activate_end()
self.activate_end = None
async def flush_remaining_chunk( async def flush_remaining_chunk(
self, self,
@@ -325,6 +414,8 @@ class StreamOutputCoordinator:
async for msg_event in self.emit_activate_chunk(variable_pool, force=True): async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
yield msg_event yield msg_event
if not self.output_queue.empty():
self.activate_end = self.output_queue.get_nowait()
# Move to next active End node if current one is done # Move to next active End node if current one is done
if not self.activate_end and self.end_outputs: if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0] self.activate_end = list(self.end_outputs.keys())[0]

View File

@@ -128,89 +128,100 @@ class WorkflowExecutor:
- token_usage: aggregated token usage if available - token_usage: aggregated token usage if available
- error: error message if any - error: error message if any
""" """
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}") start = datetime.datetime.now()
async for event in self.execute_stream(input_data):
start_time = datetime.datetime.now() if event.get("event") == "workflow_end":
return event.get("data")
# Execute the workflow return self.result_builder.build_final_output(
try: {"error": "Workflow execution did not end as expected"},
# Build the workflow graph self.variable_pool,
graph = self.build_graph() (datetime.datetime.now() - start).total_seconds(),
"",
# Initialize the variable pool with input data success=False
await self.variable_initializer.initialize( )
variable_pool=self.variable_pool, # logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
input_data=input_data, #
execution_context=self.execution_context # start_time = datetime.datetime.now()
) #
initial_state = self.state_manager.create_initial_state( # # Execute the workflow
workflow_config=self.workflow_config, # try:
input_data=input_data, # # Build the workflow graph
execution_context=self.execution_context, # graph = self.build_graph()
start_node_id=self.start_node_id #
) # # Initialize the variable pool with input data
# await self.variable_initializer.initialize(
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) # variable_pool=self.variable_pool,
# input_data=input_data,
# Aggregate output from all End nodes # execution_context=self.execution_context
full_content = '' # )
for end_id in self.stream_coordinator.end_outputs.keys(): # initial_state = self.state_manager.create_initial_state(
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) # workflow_config=self.workflow_config,
# input_data=input_data,
# Append messages for user and assistant # execution_context=self.execution_context,
if input_data.get("files"): # start_node_id=self.start_node_id
result["messages"].extend( # )
[ #
{ # result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
"role": "user", #
"content": input_data.get("message", '') # # Aggregate output from all End nodes
}, # full_content = ''
{ # for end_id in self.stream_coordinator.end_outputs.keys():
"role": "user", # full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
"content": input_data.get("files") #
}, # # Append messages for user and assistant
{ # if input_data.get("files"):
"role": "assistant", # result["messages"].extend(
"content": full_content # [
} # {
] # "role": "user",
) # "content": input_data.get("message", '')
else: # },
result["messages"].extend( # {
[ # "role": "user",
{ # "content": input_data.get("files")
"role": "user", # },
"content": input_data.get("message", '') # {
}, # "role": "assistant",
{ # "content": full_content
"role": "assistant", # }
"content": full_content # ]
} # )
] # else:
) # result["messages"].extend(
# Calculate elapsed time # [
end_time = datetime.datetime.now() # {
elapsed_time = (end_time - start_time).total_seconds() # "role": "user",
# "content": input_data.get("message", '')
logger.info( # },
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") # {
# "role": "assistant",
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) # "content": full_content
# }
except Exception as e: # ]
end_time = datetime.datetime.now() # )
elapsed_time = (end_time - start_time).total_seconds() # # Calculate elapsed time
# end_time = datetime.datetime.now()
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", # elapsed_time = (end_time - start_time).total_seconds()
exc_info=True) #
return { # logger.info(
"status": "failed", # f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
"error": str(e), #
"output": None, # return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
"node_outputs": {}, #
"elapsed_time": elapsed_time, # except Exception as e:
"token_usage": None # end_time = datetime.datetime.now()
} # elapsed_time = (end_time - start_time).total_seconds()
#
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
# exc_info=True)
# return {
# "status": "failed",
# "error": str(e),
# "output": None,
# "node_outputs": {},
# "elapsed_time": elapsed_time,
# "token_usage": None
# }
async def execute_stream( async def execute_stream(
self, self,
@@ -248,7 +259,8 @@ class WorkflowExecutor:
"timestamp": int(start_time.timestamp() * 1000) "timestamp": int(start_time.timestamp() * 1000)
} }
} }
result = None
full_content = ''
try: try:
# Build the workflow graph in streaming mode # Build the workflow graph in streaming mode
graph = self.build_graph(stream=True) graph = self.build_graph(stream=True)
@@ -266,7 +278,6 @@ class WorkflowExecutor:
start_node_id=self.start_node_id start_node_id=self.start_node_id
) )
full_content = ''
self.stream_coordinator.update_scope_activation("sys") self.stream_coordinator.update_scope_activation("sys")
# Execute the workflow with streaming # Execute the workflow with streaming
@@ -363,7 +374,12 @@ class WorkflowExecutor:
yield { yield {
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) "data": self.result_builder.build_final_output(
result,
self.variable_pool,
elapsed_time,
full_content,
success=True)
} }
except Exception as e: except Exception as e:
@@ -372,16 +388,19 @@ class WorkflowExecutor:
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
exc_info=True) exc_info=True)
if result is None:
result = {"error": str(e)}
else:
result["error"] = str(e)
yield { yield {
"event": "workflow_end", "event": "workflow_end",
"data": { "data": self.result_builder.build_final_output(
"execution_id": self.execution_context.execution_id, result,
"status": "failed", self.variable_pool,
"error": str(e), elapsed_time,
"elapsed_time": elapsed_time, full_content,
"timestamp": end_time.isoformat() success=False
} )
} }

View File

@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
else: else:
raise ValueError(f"Unsupported language: {self.typed_config.language}") raise ValueError(f"Unsupported language: {self.typed_config.language}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=60) as client:
response = await client.post( response = await client.post(
"http://sandbox:8194/v1/sandbox/run", "http://sandbox:8194/v1/sandbox/run",
headers={ headers={

View File

@@ -33,11 +33,11 @@ class IfElseNode(BaseNode):
"right": expression.right "right": expression.right
if expression.input_type == ValueInputType.CONSTANT if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False), else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator, "operator": str(expression.operator),
}) })
result.append({ result.append({
"expressions": expressions, "expressions": expressions,
"logical_operator": case.logical_operator, "logical_operator": str(case.logical_operator),
}) })
return { return {
"cases": result "cases": result

View File

@@ -170,7 +170,7 @@ class WorkflowValidator:
# 仅在发布时验证所有节点可达 # 仅在发布时验证所有节点可达
# 6. 验证所有节点可达(从 start 节点出发) # 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes( reachable = WorkflowValidator.get_reachable_nodes(
start_nodes[0]["id"], start_nodes[0]["id"],
edges edges
) )
@@ -194,7 +194,7 @@ class WorkflowValidator:
return len(errors) == 0, errors return len(errors) == 0, errors
@staticmethod @staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点 """获取从 start 节点可达的所有节点
Args: Args:

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, PrivateAttr
from app.schemas import FileType from app.schemas import FileType
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
""" """
if isinstance(var, str): if isinstance(var, str):
return cls.STRING return cls.STRING
elif isinstance(var, (int, float)):
return cls.NUMBER
elif isinstance(var, bool): elif isinstance(var, bool):
return cls.BOOLEAN return cls.BOOLEAN
elif isinstance(var, (int, float)):
return cls.NUMBER
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')): elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE return cls.FILE
elif isinstance(var, dict): elif isinstance(var, dict):
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
content_cache: dict = Field(default_factory=dict) content_cache: dict = Field(default_factory=dict)
is_file: bool is_file: bool
_byte_content: bytes | None = None _byte_content: bytes | None = PrivateAttr(default=None)
def get_content(self): def get_content(self):
return self._byte_content return self._byte_content

View File

@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
class StringVariable(BaseVariable): class StringVariable(BaseVariable):
value: str
type = 'str' type = 'str'
def valid_value(self, value) -> str: def valid_value(self, value) -> str:
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
class NumberVariable(BaseVariable): class NumberVariable(BaseVariable):
value: int | float
type = 'number' type = 'number'
def valid_value(self, value) -> int | float: def valid_value(self, value) -> int | float:
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
class BooleanVariable(BaseVariable): class BooleanVariable(BaseVariable):
value: bool
type = 'boolean' type = 'boolean'
def valid_value(self, value) -> bool: def valid_value(self, value) -> bool:
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
class DictVariable(BaseVariable): class DictVariable(BaseVariable):
value: dict
type = 'object' type = 'object'
def valid_value(self, value) -> dict: def valid_value(self, value) -> dict:
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
class FileVariable(BaseVariable): class FileVariable(BaseVariable):
value: FileObject
type = 'file' type = 'file'
def valid_value(self, value) -> FileObject: def valid_value(self, value) -> FileObject:
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
class ArrayVariable(BaseVariable, Generic[T]): class ArrayVariable(BaseVariable, Generic[T]):
value: list[T]
type = 'array' type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]): def __init__(self, child_type: Type[T], value: list[Any]):
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
class NestedArrayVariable(BaseVariable): class NestedArrayVariable(BaseVariable):
value: list[ArrayVariable]
type = 'array_nest' type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]: def valid_value(self, value: list[T]) -> list[T]:
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
category=RuntimeWarning category=RuntimeWarning
) )
class AnyVariable(BaseVariable): class AnyVariable(BaseVariable):
value: Any
type = 'any' type = 'any'
def valid_value(self, value: Any) -> Any: def valid_value(self, value: Any) -> Any:

View File

@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
WorkflowExecutionRepository, WorkflowExecutionRepository,
WorkflowNodeExecutionRepository WorkflowNodeExecutionRepository
) )
from app.schemas import DraftRunRequest, FileInput, FileType from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService