fix(workflow): fix streaming output issues with multi-output End nodes
End nodes with multiple output segments could cause cursor errors or leave some segments inactive, resulting in incorrect final outputs. Unified _emit_active_chunks and _update_scope_activate to ensure all segments are activated in order and streamed correctly.
This commit is contained in:
@@ -16,7 +16,6 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
|||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_config import VariableType
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.workflow.template_renderer import render_template
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -157,12 +156,137 @@ class WorkflowExecutor:
|
|||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _update_end_activate(self, node_id):
|
def _update_scope_activate(self, scope, status=None):
|
||||||
|
"""
|
||||||
|
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||||
|
|
||||||
|
Iterates over all End nodes in `self.end_outputs` and calls
|
||||||
|
`update_activate` on each, which may:
|
||||||
|
- Activate variable segments that depend on the completed node/scope.
|
||||||
|
- Activate the entire End node output if all control conditions are met.
|
||||||
|
|
||||||
|
If any End node becomes active and `self.activate_end` is not yet set,
|
||||||
|
this node will be marked as the currently active End node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope (str): The node ID or scope that has completed execution.
|
||||||
|
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||||
|
"""
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs.keys():
|
||||||
self.end_outputs[node].update_activate(node_id)
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
if self.end_outputs[node].activate and self.activate_end is None:
|
if self.end_outputs[node].activate and self.activate_end is None:
|
||||||
self.activate_end = node
|
self.activate_end = node
|
||||||
|
|
||||||
|
def _update_stream_output_status(self, activate, data):
|
||||||
|
"""
|
||||||
|
Update the stream output state of End nodes based on workflow state updates.
|
||||||
|
|
||||||
|
This method checks which nodes/scopes are activated and propagates
|
||||||
|
activation to End nodes accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||||
|
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
For each node in `data`:
|
||||||
|
1. If the node is activated (`activate[node_id]` is True),
|
||||||
|
retrieve its output status from `runtime_vars`.
|
||||||
|
2. Call `_update_scope_activate` to propagate the activation
|
||||||
|
to all relevant End nodes and update `self.activate_end`.
|
||||||
|
"""
|
||||||
|
for node_id in data.keys():
|
||||||
|
if activate.get(node_id):
|
||||||
|
node_output_status = (
|
||||||
|
data[node_id]
|
||||||
|
.get('runtime_vars', {})
|
||||||
|
.get(node_id)
|
||||||
|
.get("output")
|
||||||
|
)
|
||||||
|
self._update_scope_activate(node_id, status=node_output_status)
|
||||||
|
|
||||||
|
async def _emit_active_chunks(
|
||||||
|
self,
|
||||||
|
node_outputs: dict,
|
||||||
|
variables: dict,
|
||||||
|
force=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process and yield all currently active output segments for the currently active End node.
|
||||||
|
|
||||||
|
This method handles stream-mode output for an End node by iterating through its output segments
|
||||||
|
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||||
|
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||||
|
2. For each segment:
|
||||||
|
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||||
|
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||||
|
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||||
|
then transform the result with `_trans_output_string`.
|
||||||
|
3. Yield a stream event of type "message" containing the processed chunk.
|
||||||
|
4. Move the `cursor` forward after processing each segment.
|
||||||
|
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||||
|
and reset `activate_end` to None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
|
||||||
|
variables (dict): Current runtime variables, used for variable evaluation.
|
||||||
|
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
dict: A stream event of type "message" containing the processed chunk.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||||
|
- This method only processes the currently active End node (`self.activate_end`).
|
||||||
|
- Use `force=True` for final emission regardless of activation state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
end_info = self.end_outputs[self.activate_end]
|
||||||
|
|
||||||
|
while end_info.cursor < len(end_info.outputs):
|
||||||
|
final_chunk = ''
|
||||||
|
current_segment = end_info.outputs[end_info.cursor]
|
||||||
|
|
||||||
|
if not current_segment.activate and not force:
|
||||||
|
# Stop processing until this segment becomes active
|
||||||
|
break
|
||||||
|
|
||||||
|
# Literal segment
|
||||||
|
if not current_segment.is_variable:
|
||||||
|
final_chunk += current_segment.literal
|
||||||
|
else:
|
||||||
|
# Variable segment: evaluate and transform
|
||||||
|
try:
|
||||||
|
chunk = evaluate_expression(
|
||||||
|
current_segment.literal,
|
||||||
|
variables=variables,
|
||||||
|
node_outputs=node_outputs
|
||||||
|
)
|
||||||
|
chunk = self._trans_output_string(chunk)
|
||||||
|
final_chunk += chunk
|
||||||
|
except ValueError:
|
||||||
|
# Log failed evaluation but continue streaming
|
||||||
|
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||||
|
|
||||||
|
if final_chunk:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": final_chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Advance cursor after processing
|
||||||
|
end_info.cursor += 1
|
||||||
|
|
||||||
|
# Remove End node from active tracking if all segments have been processed
|
||||||
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _trans_output_string(content):
|
def _trans_output_string(content):
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -218,14 +342,8 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||||
full_content = ''
|
full_content = ''
|
||||||
for end_info in self.end_outputs.values():
|
for end_id in self.end_outputs.keys():
|
||||||
output_template = "".join([output.literal for output in end_info.outputs])
|
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
|
||||||
full_content += render_template(
|
|
||||||
output_template,
|
|
||||||
result.get("variables", {}),
|
|
||||||
result.get("runtime_vars", {}),
|
|
||||||
strict=False
|
|
||||||
)
|
|
||||||
result["messages"].extend(
|
result["messages"].extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@@ -306,7 +424,7 @@ class WorkflowExecutor:
|
|||||||
try:
|
try:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
full_content = ''
|
full_content = ''
|
||||||
|
self._update_scope_activate("sys")
|
||||||
async for event in graph.astream(
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||||
@@ -333,9 +451,12 @@ class WorkflowExecutor:
|
|||||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||||
continue
|
continue
|
||||||
current_output = end_info.outputs[end_info.cursor]
|
current_output = end_info.outputs[end_info.cursor]
|
||||||
if current_output.is_variable and current_output.depends_on_node(node_id):
|
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||||
if data.get("done"):
|
if data.get("done"):
|
||||||
end_info.cursor += 1
|
end_info.cursor += 1
|
||||||
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
else:
|
else:
|
||||||
full_content += data.get("chunk")
|
full_content += data.get("chunk")
|
||||||
yield {
|
yield {
|
||||||
@@ -415,91 +536,53 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# Handle state updates - store final state
|
# Handle state updates - store final state
|
||||||
for node_id in data.keys():
|
state = graph.get_state(config=self.checkpoint_config).values
|
||||||
self._update_end_activate(node_id)
|
node_outputs = state.get("runtime_vars", {})
|
||||||
wait = False
|
variables = state.get("variables", {})
|
||||||
state = graph.get_state(config=self.checkpoint_config)
|
activate = state.get("activate", {})
|
||||||
node_outputs = state.values.get("runtime_vars", {})
|
for _, node_data in data.items():
|
||||||
for _ in data.keys():
|
node_outputs |= node_data.get("runtime_vars", {})
|
||||||
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
|
variables |= node_data.get("variables", {})
|
||||||
|
|
||||||
|
self._update_stream_output_status(activate, data)
|
||||||
|
wait = False
|
||||||
while self.activate_end and not wait:
|
while self.activate_end and not wait:
|
||||||
message = ''
|
async for msg_event in self._emit_active_chunks(
|
||||||
logger.info(self.activate_end)
|
node_outputs=node_outputs,
|
||||||
end_info = self.end_outputs[self.activate_end]
|
variables=variables
|
||||||
content = end_info.outputs[end_info.cursor]
|
):
|
||||||
while content.activate:
|
full_content += msg_event["data"]['chunk']
|
||||||
if not content.is_variable:
|
yield msg_event
|
||||||
full_content += content.literal
|
|
||||||
message += content.literal
|
if self.activate_end:
|
||||||
else:
|
|
||||||
try:
|
|
||||||
chunk = evaluate_expression(
|
|
||||||
content.literal,
|
|
||||||
variables={},
|
|
||||||
node_outputs=node_outputs
|
|
||||||
)
|
|
||||||
chunk = self._trans_output_string(chunk)
|
|
||||||
message += chunk
|
|
||||||
full_content += chunk
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
end_info.cursor += 1
|
|
||||||
if end_info.cursor == len(end_info.outputs):
|
|
||||||
break
|
|
||||||
content = end_info.outputs[end_info.cursor]
|
|
||||||
if end_info.cursor != len(end_info.outputs):
|
|
||||||
wait = True
|
wait = True
|
||||||
else:
|
else:
|
||||||
self.end_outputs.pop(self.activate_end)
|
self._update_stream_output_status(activate, data)
|
||||||
self.activate_end = None
|
|
||||||
for node_id in data.keys():
|
|
||||||
self._update_end_activate(node_id)
|
|
||||||
if message:
|
|
||||||
yield {
|
|
||||||
"event": "message",
|
|
||||||
"data": {
|
|
||||||
"chunk": message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||||
f"- execution_id: {self.execution_id}")
|
f"- execution_id: {self.execution_id}")
|
||||||
|
|
||||||
result = graph.get_state(self.checkpoint_config).values
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
while self.activate_end:
|
node_outputs = result.get("runtime_vars", {})
|
||||||
message = ''
|
variables = result.get("variables", {})
|
||||||
end_info = self.end_outputs[self.activate_end]
|
self.end_outputs = {
|
||||||
content = end_info.outputs[end_info.cursor]
|
node_id: node_info
|
||||||
if not content.is_variable:
|
for node_id, node_info in self.end_outputs.items()
|
||||||
message += content.literal
|
if node_info.activate
|
||||||
else:
|
}
|
||||||
node_outputs = result.get("runtime_vars", {})
|
|
||||||
variables = result.get("variables", {})
|
if self.end_outputs or self.activate_end:
|
||||||
try:
|
while self.activate_end:
|
||||||
chunk = evaluate_expression(
|
async for msg_event in self._emit_active_chunks(
|
||||||
content.literal,
|
node_outputs=node_outputs,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
node_outputs=node_outputs
|
force=True
|
||||||
)
|
):
|
||||||
chunk = self._trans_output_string(chunk)
|
full_content += msg_event["data"]['chunk']
|
||||||
message += chunk
|
yield msg_event
|
||||||
full_content += chunk
|
|
||||||
except ValueError:
|
if not self.activate_end and self.end_outputs:
|
||||||
pass
|
|
||||||
end_info.cursor += 1
|
|
||||||
if end_info.cursor == len(end_info.outputs):
|
|
||||||
self.end_outputs.pop(self.activate_end)
|
|
||||||
self.activate_end = None
|
|
||||||
if self.end_outputs:
|
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
if message:
|
|
||||||
yield {
|
|
||||||
"event": "message",
|
|
||||||
"data": {
|
|
||||||
"chunk": message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
|||||||
@@ -53,114 +53,110 @@ class OutputContent(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def depends_on_node(self, node_id: str) -> bool:
|
def depends_on_scope(self, scope: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if this output segment depends on a specific node's variable.
|
Check if this segment depends on a given scope.
|
||||||
|
|
||||||
This method examines the `literal` of the output segment to see if it
|
|
||||||
contains a variable placeholder referencing the given node in the form:
|
|
||||||
|
|
||||||
{{ node_id.field_name }}
|
|
||||||
|
|
||||||
It uses a regular expression to match the exact node ID, avoiding
|
|
||||||
false positives from substring matches (e.g., 'node1' should not match 'node10').
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id (str): The ID of the node to check for in this segment's variable placeholders.
|
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool:
|
bool: True if this segment references the given scope.
|
||||||
- True if the segment contains a variable referencing the given node.
|
|
||||||
- False otherwise.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
literal = "{{node1.name}}"
|
|
||||||
|
|
||||||
depends_on_node("node1") -> True
|
|
||||||
depends_on_node("node2") -> False
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
This method is primarily used in stream mode to determine whether
|
|
||||||
a particular variable output segment should be activated when a
|
|
||||||
specific upstream node completes execution.
|
|
||||||
"""
|
"""
|
||||||
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||||
pattern = re.compile(variable_pattern)
|
return bool(re.search(pattern, self.literal))
|
||||||
match = pattern.search(self.literal)
|
|
||||||
if match:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class StreamOutputConfig(BaseModel):
|
class StreamOutputConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Streaming output configuration for an End node.
|
Streaming output configuration for an End node.
|
||||||
|
|
||||||
This structure controls:
|
This configuration describes how the End node output behaves in streaming mode,
|
||||||
- whether the End node output is globally active
|
including:
|
||||||
- which upstream branch nodes are responsible for activation
|
- whether output emission is globally activated
|
||||||
- how each output segment behaves in streaming mode
|
- which upstream branch/control nodes gate the activation
|
||||||
|
- how each parsed output segment is streamed and activated
|
||||||
"""
|
"""
|
||||||
|
|
||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Global activation state of the End node output.\n"
|
"Global activation flag for the End node output.\n"
|
||||||
"If False, no output should be emitted until all control nodes are resolved."
|
"When False, output segments should not be emitted even if available.\n"
|
||||||
|
"This flag typically becomes True once required control branch conditions "
|
||||||
|
"are satisfied."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
control_nodes: list[str] = Field(
|
control_nodes: dict[str, str] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"List of upstream branch node IDs that control this End node.\n"
|
"Control branch conditions for this End node output.\n"
|
||||||
"Each node must signal completion before output becomes active."
|
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||||
|
"The End node output becomes globally active when a controlling branch node "
|
||||||
|
"reports a matching completion status."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs: list[OutputContent] = Field(
|
outputs: list[OutputContent] = Field(
|
||||||
...,
|
...,
|
||||||
description="Ordered list of output segments parsed from the output template."
|
description=(
|
||||||
|
"Ordered list of output segments parsed from the output template.\n"
|
||||||
|
"Each segment represents either a literal text block or a variable placeholder "
|
||||||
|
"that may be activated independently."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
cursor: int = Field(
|
cursor: int = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Streaming cursor index.\n"
|
"Streaming cursor index.\n"
|
||||||
"Indicates how many output segments have already been emitted."
|
"Indicates the next output segment index to be emitted.\n"
|
||||||
|
"Segments with index < cursor are considered already streamed."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_activate(self, node_id):
|
def update_activate(self, scope: str, status=None):
|
||||||
"""
|
"""
|
||||||
Update activation state based on an upstream node completion.
|
Update streaming activation state based on an upstream node or special variable.
|
||||||
|
|
||||||
This method is typically called when a branch/control node finishes execution.
|
Args:
|
||||||
|
scope (str):
|
||||||
|
Identifier of the completed upstream entity.
|
||||||
|
- If a control branch node, it should match a key in `control_nodes`.
|
||||||
|
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||||
|
status (optional):
|
||||||
|
Completion status of the control branch node.
|
||||||
|
Required when `scope` refers to a control node.
|
||||||
|
|
||||||
Behavior:
|
Behavior:
|
||||||
1. If the node is a control node:
|
1. Control branch nodes:
|
||||||
- Remove it from `control_nodes`
|
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||||
- If all control nodes are resolved, activate the entire output
|
branch label, the End node output becomes globally active (`activate = True`).
|
||||||
|
|
||||||
2. Activate variable output segments that depend on this node:
|
2. Variable output segments:
|
||||||
- If an output segment is a variable
|
- For each segment that is a variable (`is_variable=True`):
|
||||||
- And its literal references the completed node_id
|
- If the segment literal references `scope`, mark the segment as active.
|
||||||
- Mark that segment as active
|
- This applies both to regular node variables (e.g., "node_id.field")
|
||||||
|
and special system variables (e.g., "sys.xxx").
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- This method does not emit output or advance the streaming cursor.
|
||||||
|
- It only updates activation flags based on upstream events or special variables.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
# Case 1: resolve control branch dependency
|
||||||
if node_id in self.control_nodes:
|
if scope in self.control_nodes.keys():
|
||||||
self.control_nodes.remove(node_id)
|
if status is None:
|
||||||
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
# All branch constraints resolved → enable output
|
if status == self.control_nodes[scope]:
|
||||||
if not self.control_nodes:
|
|
||||||
self.activate = True
|
self.activate = True
|
||||||
|
|
||||||
# Case 2: activate variable segments related to this node
|
# Case 2: activate variable segments related to this node
|
||||||
for i in range(len(self.outputs)):
|
for i in range(len(self.outputs)):
|
||||||
if (
|
if (
|
||||||
self.outputs[i].is_variable
|
self.outputs[i].is_variable
|
||||||
and self.outputs[i].depends_on_node(node_id)
|
and self.outputs[i].depends_on_scope(scope)
|
||||||
):
|
):
|
||||||
self.outputs[i].activate = True
|
self.outputs[i].activate = True
|
||||||
|
|
||||||
@@ -184,11 +180,11 @@ class GraphBuilder:
|
|||||||
self._find_upstream_branch_node = lru_cache(
|
self._find_upstream_branch_node = lru_cache(
|
||||||
maxsize=len(self.nodes) * 2
|
maxsize=len(self.nodes) * 2
|
||||||
)(self._find_upstream_branch_node)
|
)(self._find_upstream_branch_node)
|
||||||
self._analyze_end_node_output()
|
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
|
self._analyze_end_node_output()
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -216,30 +212,53 @@ class GraphBuilder:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||||
|
|
||||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
|
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
||||||
"""Find upstream branch nodes for a given target node in the workflow graph.
|
"""
|
||||||
|
Recursively find all upstream branch (control) nodes that influence the execution
|
||||||
|
of the given target node.
|
||||||
|
|
||||||
This method identifies all upstream control (branch) nodes that can affect
|
This method walks upstream along the workflow graph starting from `target_node`.
|
||||||
the execution of `target_node`. If `target_node` is reachable from a start
|
It distinguishes between:
|
||||||
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
|
- branch nodes (node types listed in `BRANCH_NODES`)
|
||||||
|
- non-branch nodes (ordinary processing nodes)
|
||||||
|
|
||||||
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
|
Traversal rules:
|
||||||
and non-branch nodes, recursively traversing upstream through non-branch
|
1. For each immediate upstream node:
|
||||||
nodes. If any non-branch upstream path does not lead to a branch node,
|
- If it is a branch node, it is recorded as an affecting control node.
|
||||||
the result will indicate that no valid upstream branch node exists.
|
- If it is a non-branch node, the traversal continues recursively upstream.
|
||||||
|
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
||||||
|
a branch node, the traversal is considered invalid:
|
||||||
|
- `has_branch` will be False
|
||||||
|
- no branch nodes are returned.
|
||||||
|
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
||||||
|
branch node will `has_branch` be True.
|
||||||
|
|
||||||
|
Special case:
|
||||||
|
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
||||||
|
it is considered directly reachable from the workflow entry, and therefore
|
||||||
|
has no controlling branch nodes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_node (str): The identifier of the target node.
|
target_node (str):
|
||||||
|
The identifier of the node whose upstream control branches
|
||||||
|
are to be resolved.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[bool, tuple[str]]:
|
tuple[bool, tuple[tuple[str, str]]]:
|
||||||
- has_branch (bool): True if all upstream non-branch paths lead to at least
|
- has_branch (bool):
|
||||||
one branch node; False if any path reaches a start node without a branch.
|
True if every upstream path from `target_node` encounters
|
||||||
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
|
at least one branch node.
|
||||||
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
|
False if any path reaches a start node without a branch.
|
||||||
|
- branch_nodes (tuple[tuple[str, str]]):
|
||||||
|
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
||||||
|
representing all branch nodes that can influence `target_node`.
|
||||||
|
Returns an empty tuple if `has_branch` is False.
|
||||||
"""
|
"""
|
||||||
source_nodes = [
|
source_nodes = [
|
||||||
edge.get("source")
|
{
|
||||||
|
"id": edge.get("source"),
|
||||||
|
"branch": edge.get("label")
|
||||||
|
}
|
||||||
for edge in self.edges
|
for edge in self.edges
|
||||||
if edge.get("target") == target_node
|
if edge.get("target") == target_node
|
||||||
]
|
]
|
||||||
@@ -249,11 +268,13 @@ class GraphBuilder:
|
|||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
non_branch_nodes = []
|
non_branch_nodes = []
|
||||||
|
|
||||||
for node_id in source_nodes:
|
for node_info in source_nodes:
|
||||||
if self.get_node_type(node_id) in BRANCH_NODES:
|
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||||
branch_nodes.append(node_id)
|
branch_nodes.append(
|
||||||
|
(node_info["id"], node_info["branch"])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
non_branch_nodes.append(node_id)
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
has_branch = True
|
has_branch = True
|
||||||
for node_id in non_branch_nodes:
|
for node_id in non_branch_nodes:
|
||||||
@@ -334,7 +355,7 @@ class GraphBuilder:
|
|||||||
activate=not has_branch,
|
activate=not has_branch,
|
||||||
|
|
||||||
# Branch nodes that control activation of this End node
|
# Branch nodes that control activation of this End node
|
||||||
control_nodes=list(control_nodes),
|
control_nodes=dict(control_nodes),
|
||||||
|
|
||||||
# Convert output segments into OutputContent objects
|
# Convert output segments into OutputContent objects
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -362,7 +383,7 @@ class GraphBuilder:
|
|||||||
else:
|
else:
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
activate=True,
|
activate=True,
|
||||||
control_nodes=[],
|
control_nodes={},
|
||||||
outputs=list(
|
outputs=list(
|
||||||
[
|
[
|
||||||
OutputContent(
|
OutputContent(
|
||||||
|
|||||||
Reference in New Issue
Block a user