Merge pull request #638 from SuanmoSuanyangTechnology/feature/multi-end-stream
fix(workflow): unify streaming and non-stream execution outputs
This commit is contained in:
@@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.validator import WorkflowValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Regex to split output into:
|
||||
# - variable placeholders: {{ ... }}
|
||||
# - normal literal text
|
||||
#
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
@@ -37,13 +49,13 @@ class GraphBuilder:
|
||||
self.stream = stream
|
||||
self.subgraph = subgraph
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_node_ids = []
|
||||
self.start_node_id: str | None = None
|
||||
|
||||
self.node_map = {node["id"]: node for node in self.nodes}
|
||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
self._find_upstream_activation_dep = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
)(self._find_upstream_activation_dep)
|
||||
if variable_pool:
|
||||
self.variable_pool = variable_pool
|
||||
else:
|
||||
@@ -51,10 +63,19 @@ class GraphBuilder:
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self.add_edges()
|
||||
self._analyze_end_node_output()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||
self._build_reverse_adj()
|
||||
self._analyze_end_node_output()
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("nodes", [])
|
||||
@@ -87,60 +108,50 @@ class GraphBuilder:
|
||||
result[node[0]].append(node[1])
|
||||
return result
|
||||
|
||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
||||
"""
|
||||
Recursively find all upstream branch (control) nodes that influence the execution
|
||||
of the given target node.
|
||||
def _build_reverse_adj(self):
|
||||
for edge in self.edges:
|
||||
if edge["source"] not in self.reachable_nodes:
|
||||
continue
|
||||
self._reverse_adj[edge.get("target")].append({
|
||||
"id": edge["source"], "branch": edge.get("label")
|
||||
})
|
||||
|
||||
This method walks upstream along the workflow graph starting from `target_node`.
|
||||
It distinguishes between:
|
||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
||||
- non-branch nodes (ordinary processing nodes)
|
||||
def _find_upstream_activation_dep(
|
||||
self,
|
||||
target_node: str
|
||||
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
|
||||
"""Find upstream dependencies that affect the activation of a target node.
|
||||
|
||||
Traversal rules:
|
||||
1. For each immediate upstream node:
|
||||
- If it is a branch node, it is recorded as an affecting control node.
|
||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
||||
a branch node, the traversal is considered invalid:
|
||||
- `has_branch` will be False
|
||||
- no branch nodes are returned.
|
||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
||||
branch node will `has_branch` be True.
|
||||
Walks upstream along the workflow graph from the target node, collecting
|
||||
two types of dependencies:
|
||||
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
|
||||
routing outcome determines whether the target node executes.
|
||||
- Output nodes: upstream END nodes that must complete their output
|
||||
before the target node can activate.
|
||||
|
||||
Special case:
|
||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
||||
it is considered directly reachable from the workflow entry, and therefore
|
||||
has no controlling branch nodes.
|
||||
The traversal terminates early and returns empty tuples if any upstream
|
||||
path reaches START/CYCLE_START without encountering a branch or output
|
||||
node, indicating the target node is directly reachable and should be
|
||||
activated immediately.
|
||||
|
||||
Args:
|
||||
target_node (str):
|
||||
The identifier of the node whose upstream control branches
|
||||
are to be resolved.
|
||||
target_node: The ID of the node whose upstream activation
|
||||
dependencies are to be resolved.
|
||||
|
||||
Returns:
|
||||
tuple[bool, tuple[tuple[str, str]]]:
|
||||
- has_branch (bool):
|
||||
True if every upstream path from `target_node` encounters
|
||||
at least one branch node.
|
||||
False if any path reaches a start node without a branch.
|
||||
- branch_nodes (tuple[tuple[str, str]]):
|
||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
||||
representing all branch nodes that can influence `target_node`.
|
||||
Returns an empty tuple if `has_branch` is False.
|
||||
A tuple of two elements:
|
||||
- A deduplicated tuple of (branch_node_id, branch_label) pairs
|
||||
representing upstream branch control dependencies. Empty if
|
||||
any clean path to START exists.
|
||||
- A deduplicated tuple of upstream output node IDs that must
|
||||
complete before this node activates.
|
||||
"""
|
||||
source_nodes = [
|
||||
{
|
||||
"id": edge.get("source"),
|
||||
"branch": edge.get("label")
|
||||
}
|
||||
for edge in self.edges
|
||||
if edge.get("target") == target_node
|
||||
]
|
||||
source_nodes = self._reverse_adj[target_node]
|
||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||
return False, tuple()
|
||||
return tuple(), tuple()
|
||||
|
||||
branch_nodes = []
|
||||
output_nodes = []
|
||||
non_branch_nodes = []
|
||||
|
||||
for node_info in source_nodes:
|
||||
@@ -149,19 +160,23 @@ class GraphBuilder:
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
else:
|
||||
if self.get_node_type(node_info["id"]) == NodeType.END:
|
||||
output_nodes.append(node_info["id"])
|
||||
non_branch_nodes.append(node_info["id"])
|
||||
|
||||
has_branch = True
|
||||
for node_id in non_branch_nodes:
|
||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
||||
has_branch = has_branch and node_has_branch
|
||||
if not has_branch:
|
||||
break
|
||||
branch_nodes.extend(nodes)
|
||||
if not has_branch:
|
||||
branch_nodes = []
|
||||
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
|
||||
if not upstream_control_nodes:
|
||||
if not upstream_output_nodes and node_id not in output_nodes:
|
||||
return tuple(), tuple()
|
||||
branch_nodes = []
|
||||
has_branch = False
|
||||
if has_branch:
|
||||
branch_nodes.extend(upstream_control_nodes)
|
||||
output_nodes.extend(upstream_output_nodes)
|
||||
|
||||
return has_branch, tuple(set(branch_nodes))
|
||||
return tuple(set(branch_nodes)), tuple(set(output_nodes))
|
||||
|
||||
def _analyze_end_node_output(self):
|
||||
"""
|
||||
@@ -182,11 +197,10 @@ class GraphBuilder:
|
||||
"""
|
||||
|
||||
# Collect all End nodes in the workflow
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
|
||||
|
||||
# Iterate through each End node to analyze its output
|
||||
for end_node in end_nodes:
|
||||
for end_node in self.end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
config = end_node.get("config", {})
|
||||
output = config.get("output")
|
||||
@@ -195,42 +209,33 @@ class GraphBuilder:
|
||||
if not output:
|
||||
continue
|
||||
|
||||
# Regex to split output into:
|
||||
# - variable placeholders: {{ ... }}
|
||||
# - normal literal text
|
||||
#
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
||||
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
|
||||
# Split output into ordered segments
|
||||
output_template = list(re.findall(pattern, output))
|
||||
output_template = list(_OUTPUT_PATTERN.findall(output))
|
||||
|
||||
# Determine whether each segment is literal text
|
||||
# True -> literal (can be directly output)
|
||||
# False -> variable placeholder (needs runtime value)
|
||||
output_flag = [
|
||||
not bool(variable_pattern.match(item))
|
||||
not bool(_VARIABLE_PATTERN.match(item))
|
||||
for item in output_template
|
||||
]
|
||||
|
||||
# Stream mode: output activation depends on upstream branch nodes
|
||||
if self.stream:
|
||||
# Find upstream branch nodes that can control this End node
|
||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
||||
|
||||
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
|
||||
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
|
||||
# Build StreamOutputConfig for this End node
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
id=end_node_id,
|
||||
# If there is no upstream branch, output is active immediately
|
||||
activate=not has_branch,
|
||||
activate=activate,
|
||||
|
||||
# Branch nodes that control activation of this End node
|
||||
control_nodes=self._merge_control_nodes(control_nodes),
|
||||
control_nodes=self._merge_control_nodes(upstream_control_nodes),
|
||||
upstream_output_nodes=list(upstream_output_nodes),
|
||||
control_resolved=not bool(upstream_control_nodes),
|
||||
output_resolved=not bool(upstream_output_nodes),
|
||||
|
||||
# Convert output segments into OutputContent objects
|
||||
outputs=list(
|
||||
@@ -249,14 +254,16 @@ class GraphBuilder:
|
||||
cursor=0
|
||||
)
|
||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||
f"activate: {not has_branch}, "
|
||||
f"control_nodes: {control_nodes},"
|
||||
f"activate: {activate}, "
|
||||
f"control_nodes: {upstream_control_nodes},"
|
||||
f"ref_outputs: {upstream_output_nodes},"
|
||||
f"output: {output_template},"
|
||||
f"output_activate: {output_flag}")
|
||||
|
||||
# Non-stream mode: all outputs are activated by default
|
||||
else:
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
id=end_node_id,
|
||||
activate=True,
|
||||
control_nodes={},
|
||||
outputs=list(
|
||||
@@ -269,7 +276,10 @@ class GraphBuilder:
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
cursor=0
|
||||
cursor=0,
|
||||
upstream_output_nodes=[],
|
||||
control_resolved=True,
|
||||
output_resolved=True,
|
||||
)
|
||||
|
||||
def add_nodes(self):
|
||||
@@ -304,8 +314,6 @@ class GraphBuilder:
|
||||
# Record start and end node IDs
|
||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||
self.start_node_id = node_id
|
||||
elif node_type == NodeType.END:
|
||||
self.end_node_ids.append(node_id)
|
||||
|
||||
# Create node instance (start and end nodes are also created)
|
||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||
@@ -494,9 +502,11 @@ class GraphBuilder:
|
||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||
|
||||
# Connect End nodes to the global END node
|
||||
for end_node_id in self.end_node_ids:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||
for end_node in self.end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
if end_node_id:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||
return
|
||||
|
||||
def build(self) -> CompiledStateGraph:
|
||||
|
||||
@@ -12,6 +12,7 @@ class WorkflowResultBuilder:
|
||||
variable_pool: VariablePool,
|
||||
elapsed_time: float,
|
||||
final_output: str,
|
||||
success: bool
|
||||
):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
@@ -29,6 +30,7 @@ class WorkflowResultBuilder:
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
success (bool): Whether the execution was successful.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
@@ -49,7 +51,7 @@ class WorkflowResultBuilder:
|
||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": variable_pool.get_all_conversation_vars(),
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 15:11
|
||||
import re
|
||||
from queue import Queue
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"Whether this output segment is currently active."
|
||||
"- True: allowed to be emitted/output"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
@@ -46,8 +47,8 @@ class OutputContent(BaseModel):
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"Whether this segment represents a variable placeholder."
|
||||
"True -> variable (e.g. {{ node.field }})"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
@@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel):
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
id: str = Field(
|
||||
...,
|
||||
description="ID of the End node this configuration belongs to."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"Global activation flag for the End node output."
|
||||
"When False, output segments should not be emitted even if available."
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
@@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel):
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"Control branch conditions for this End node output."
|
||||
"Mapping of `branch_node_id -> expected_branch_label`."
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
upstream_output_nodes: list[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Upstream output node dependencies (data flow)."
|
||||
"Represents END/output nodes that this output depends on."
|
||||
"These nodes provide data sources required before this output can be activated "
|
||||
"or streamed."
|
||||
"Used to ensure correct ordering and dependency resolution in streaming mode."
|
||||
)
|
||||
)
|
||||
|
||||
control_resolved: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether all upstream branch control dependencies have been satisfied."
|
||||
"True if no upstream branch nodes exist or the required branch "
|
||||
"conditions have been met."
|
||||
)
|
||||
)
|
||||
|
||||
output_resolved: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether all upstream output node dependencies have been completed."
|
||||
"True if no upstream output nodes exist or all upstream output "
|
||||
"nodes have finished their output."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Ordered list of output segments parsed from the output template."
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
@@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel):
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Streaming cursor index."
|
||||
"Indicates the next output segment index to be emitted."
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
force: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Force flag for output emission."
|
||||
"When True, all output segments are emitted regardless of activation state."
|
||||
"Triggered when this output node has finished execution."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
Update streaming activation state based on upstream events.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
|
||||
it may appear in output segments.
|
||||
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
1. Force activation:
|
||||
- If `self.force` is True, the method returns immediately.
|
||||
- If `scope == self.id`, the node marks itself as completed:
|
||||
- `activate = True`
|
||||
- `force = True`
|
||||
This is typically used for final flushing when the node finishes execution.
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
2. Control dependency resolution:
|
||||
- If `scope` matches a key in `control_nodes`:
|
||||
- `status` must be provided.
|
||||
- If `status` matches expected branch labels, mark control as resolved
|
||||
(`control_resolved = True`).
|
||||
|
||||
3. Upstream output dependency resolution:
|
||||
- If `scope` is in `upstream_output_nodes`,
|
||||
mark data dependency as resolved (`output_resolved = True`).
|
||||
|
||||
4. Global activation condition:
|
||||
- The node becomes active when BOTH conditions are satisfied:
|
||||
- control_resolved == True
|
||||
- output_resolved == True
|
||||
- Once activated, `activate` remains True.
|
||||
|
||||
5. Variable segment activation:
|
||||
- For each output segment that is a variable (`is_variable=True`):
|
||||
- If the segment depends on the given `scope`,
|
||||
mark the segment as active.
|
||||
- This applies to both node variables (e.g., "node_id.field")
|
||||
and system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
- This method does NOT emit output or advance the streaming cursor.
|
||||
- It only updates activation and dependency resolution states.
|
||||
- Activation is driven by both control flow (branch nodes) and
|
||||
data flow (upstream output nodes).
|
||||
"""
|
||||
if self.force:
|
||||
return
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope == self.id:
|
||||
self.activate = True
|
||||
self.force = True
|
||||
return
|
||||
|
||||
# resolve control branch dependency
|
||||
if scope in self.control_nodes:
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
self.control_resolved = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
if scope in self.upstream_output_nodes:
|
||||
self.upstream_output_nodes.remove(scope)
|
||||
if not self.upstream_output_nodes:
|
||||
self.output_resolved = True
|
||||
|
||||
self.activate = self.activate or (self.control_resolved and self.output_resolved)
|
||||
|
||||
# activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
@@ -174,12 +256,17 @@ class StreamOutputCoordinator:
|
||||
def __init__(self):
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
self.output_queue: Queue = Queue()
|
||||
self.processed_outputs = []
|
||||
|
||||
def initialize_end_outputs(
|
||||
self,
|
||||
end_node_map: dict[str, StreamOutputConfig]
|
||||
):
|
||||
self.end_outputs = end_node_map
|
||||
self.processed_outputs = []
|
||||
self.activate_end = None
|
||||
self.output_queue = Queue()
|
||||
|
||||
@property
|
||||
def current_activate_end_info(self):
|
||||
@@ -211,8 +298,11 @@ class StreamOutputCoordinator:
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||
self.output_queue.put(node)
|
||||
self.processed_outputs.append(node)
|
||||
if self.activate_end is None and not self.output_queue.empty():
|
||||
self.activate_end = self.output_queue.get_nowait()
|
||||
|
||||
async def emit_activate_chunk(
|
||||
self,
|
||||
@@ -256,7 +346,7 @@ class StreamOutputCoordinator:
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
if not current_segment.activate and not force and not end_info.force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
@@ -273,7 +363,7 @@ class StreamOutputCoordinator:
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||
|
||||
if final_chunk:
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
@@ -285,8 +375,7 @@ class StreamOutputCoordinator:
|
||||
end_info.cursor += 1
|
||||
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
self.pop_current_activate_end()
|
||||
|
||||
async def flush_remaining_chunk(
|
||||
self,
|
||||
@@ -325,6 +414,8 @@ class StreamOutputCoordinator:
|
||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||
yield msg_event
|
||||
|
||||
if not self.output_queue.empty():
|
||||
self.activate_end = self.output_queue.get_nowait()
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
|
||||
@@ -128,89 +128,100 @@ class WorkflowExecutor:
|
||||
- token_usage: aggregated token usage if available
|
||||
- error: error message if any
|
||||
"""
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# Execute the workflow
|
||||
try:
|
||||
# Build the workflow graph
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
|
||||
# Aggregate output from all End nodes
|
||||
full_content = ''
|
||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
|
||||
# Append messages for user and assistant
|
||||
if input_data.get("files"):
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("files")
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
# Calculate elapsed time
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
|
||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"output": None,
|
||||
"node_outputs": {},
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
}
|
||||
start = datetime.datetime.now()
|
||||
async for event in self.execute_stream(input_data):
|
||||
if event.get("event") == "workflow_end":
|
||||
return event.get("data")
|
||||
return self.result_builder.build_final_output(
|
||||
{"error": "Workflow execution did not end as expected"},
|
||||
self.variable_pool,
|
||||
(datetime.datetime.now() - start).total_seconds(),
|
||||
"",
|
||||
success=False
|
||||
)
|
||||
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
#
|
||||
# start_time = datetime.datetime.now()
|
||||
#
|
||||
# # Execute the workflow
|
||||
# try:
|
||||
# # Build the workflow graph
|
||||
# graph = self.build_graph()
|
||||
#
|
||||
# # Initialize the variable pool with input data
|
||||
# await self.variable_initializer.initialize(
|
||||
# variable_pool=self.variable_pool,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context
|
||||
# )
|
||||
# initial_state = self.state_manager.create_initial_state(
|
||||
# workflow_config=self.workflow_config,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context,
|
||||
# start_node_id=self.start_node_id
|
||||
# )
|
||||
#
|
||||
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
#
|
||||
# # Aggregate output from all End nodes
|
||||
# full_content = ''
|
||||
# for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
#
|
||||
# # Append messages for user and assistant
|
||||
# if input_data.get("files"):
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("files")
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# else:
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# # Calculate elapsed time
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.info(
|
||||
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
#
|
||||
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
#
|
||||
# except Exception as e:
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
# exc_info=True)
|
||||
# return {
|
||||
# "status": "failed",
|
||||
# "error": str(e),
|
||||
# "output": None,
|
||||
# "node_outputs": {},
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "token_usage": None
|
||||
# }
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
@@ -248,7 +259,8 @@ class WorkflowExecutor:
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
}
|
||||
|
||||
result = None
|
||||
full_content = ''
|
||||
try:
|
||||
# Build the workflow graph in streaming mode
|
||||
graph = self.build_graph(stream=True)
|
||||
@@ -266,7 +278,6 @@ class WorkflowExecutor:
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
full_content = ''
|
||||
self.stream_coordinator.update_scope_activation("sys")
|
||||
|
||||
# Execute the workflow with streaming
|
||||
@@ -363,7 +374,12 @@ class WorkflowExecutor:
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
success=True)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -372,16 +388,19 @@ class WorkflowExecutor:
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
|
||||
if result is None:
|
||||
result = {"error": str(e)}
|
||||
else:
|
||||
result["error"] = str(e)
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
success=False
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
response = await client.post(
|
||||
"http://sandbox:8194/v1/sandbox/run",
|
||||
headers={
|
||||
|
||||
@@ -33,11 +33,11 @@ class IfElseNode(BaseNode):
|
||||
"right": expression.right
|
||||
if expression.input_type == ValueInputType.CONSTANT
|
||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||
"operator": expression.operator,
|
||||
"operator": str(expression.operator),
|
||||
})
|
||||
result.append({
|
||||
"expressions": expressions,
|
||||
"logical_operator": case.logical_operator,
|
||||
"logical_operator": str(case.logical_operator),
|
||||
})
|
||||
return {
|
||||
"cases": result
|
||||
|
||||
@@ -170,7 +170,7 @@ class WorkflowValidator:
|
||||
# 仅在发布时验证所有节点可达
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
reachable = WorkflowValidator.get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
@@ -194,7 +194,7 @@ class WorkflowValidator:
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@staticmethod
|
||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
"""获取从 start 节点可达的所有节点
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from app.schemas import FileType
|
||||
|
||||
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
|
||||
"""
|
||||
if isinstance(var, str):
|
||||
return cls.STRING
|
||||
elif isinstance(var, (int, float)):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var, (int, float)):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||
return cls.FILE
|
||||
elif isinstance(var, dict):
|
||||
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
is_file: bool
|
||||
|
||||
_byte_content: bytes | None = None
|
||||
_byte_content: bytes | None = PrivateAttr(default=None)
|
||||
|
||||
def get_content(self):
|
||||
return self._byte_content
|
||||
|
||||
@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
|
||||
|
||||
|
||||
class StringVariable(BaseVariable):
|
||||
value: str
|
||||
type = 'str'
|
||||
|
||||
def valid_value(self, value) -> str:
|
||||
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
|
||||
|
||||
|
||||
class NumberVariable(BaseVariable):
|
||||
value: int | float
|
||||
type = 'number'
|
||||
|
||||
def valid_value(self, value) -> int | float:
|
||||
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
|
||||
|
||||
|
||||
class BooleanVariable(BaseVariable):
|
||||
value: bool
|
||||
type = 'boolean'
|
||||
|
||||
def valid_value(self, value) -> bool:
|
||||
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
|
||||
|
||||
|
||||
class DictVariable(BaseVariable):
|
||||
value: dict
|
||||
type = 'object'
|
||||
|
||||
def valid_value(self, value) -> dict:
|
||||
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
|
||||
|
||||
|
||||
class FileVariable(BaseVariable):
|
||||
value: FileObject
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> FileObject:
|
||||
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
|
||||
|
||||
|
||||
class ArrayVariable(BaseVariable, Generic[T]):
|
||||
value: list[T]
|
||||
type = 'array'
|
||||
|
||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
|
||||
|
||||
|
||||
class NestedArrayVariable(BaseVariable):
|
||||
value: list[ArrayVariable]
|
||||
type = 'array_nest'
|
||||
|
||||
def valid_value(self, value: list[T]) -> list[T]:
|
||||
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
|
||||
category=RuntimeWarning
|
||||
)
|
||||
class AnyVariable(BaseVariable):
|
||||
value: Any
|
||||
type = 'any'
|
||||
|
||||
def valid_value(self, value: Any) -> Any:
|
||||
|
||||
@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.schemas import DraftRunRequest, FileInput, FileType
|
||||
from app.schemas import DraftRunRequest, FileInput
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
Reference in New Issue
Block a user