Merge pull request #638 from SuanmoSuanyangTechnology/feature/multi-end-stream
fix(workflow): unify streaming and non-stream execution outputs
This commit is contained in:
@@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes import NodeFactory
|
from app.core.workflow.nodes import NodeFactory
|
||||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.validator import WorkflowValidator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Regex to split output into:
|
||||||
|
# - variable placeholders: {{ ... }}
|
||||||
|
# - normal literal text
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# "Hello {{user.name}}!" ->
|
||||||
|
# ["Hello ", "{{user.name}}", "!"]
|
||||||
|
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||||
|
# Strict variable format: {{ node_id.field_name }}
|
||||||
|
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
class GraphBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -37,13 +49,13 @@ class GraphBuilder:
|
|||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.subgraph = subgraph
|
self.subgraph = subgraph
|
||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id: str | None = None
|
||||||
self.end_node_ids = []
|
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map = {node["id"]: node for node in self.nodes}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_branch_node = lru_cache(
|
self._find_upstream_activation_dep = lru_cache(
|
||||||
maxsize=len(self.nodes) * 2
|
maxsize=len(self.nodes) * 2
|
||||||
)(self._find_upstream_branch_node)
|
)(self._find_upstream_activation_dep)
|
||||||
if variable_pool:
|
if variable_pool:
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
else:
|
else:
|
||||||
@@ -51,10 +63,19 @@ class GraphBuilder:
|
|||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
|
self.end_nodes = [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
|
]
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
self._analyze_end_node_output()
|
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
self._build_reverse_adj()
|
||||||
|
self._analyze_end_node_output()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
def nodes(self) -> list[dict[str, Any]]:
|
||||||
return self.workflow_config.get("nodes", [])
|
return self.workflow_config.get("nodes", [])
|
||||||
@@ -87,60 +108,50 @@ class GraphBuilder:
|
|||||||
result[node[0]].append(node[1])
|
result[node[0]].append(node[1])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
def _build_reverse_adj(self):
|
||||||
"""
|
for edge in self.edges:
|
||||||
Recursively find all upstream branch (control) nodes that influence the execution
|
if edge["source"] not in self.reachable_nodes:
|
||||||
of the given target node.
|
continue
|
||||||
|
self._reverse_adj[edge.get("target")].append({
|
||||||
|
"id": edge["source"], "branch": edge.get("label")
|
||||||
|
})
|
||||||
|
|
||||||
This method walks upstream along the workflow graph starting from `target_node`.
|
def _find_upstream_activation_dep(
|
||||||
It distinguishes between:
|
self,
|
||||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
target_node: str
|
||||||
- non-branch nodes (ordinary processing nodes)
|
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
|
||||||
|
"""Find upstream dependencies that affect the activation of a target node.
|
||||||
|
|
||||||
Traversal rules:
|
Walks upstream along the workflow graph from the target node, collecting
|
||||||
1. For each immediate upstream node:
|
two types of dependencies:
|
||||||
- If it is a branch node, it is recorded as an affecting control node.
|
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
|
||||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
routing outcome determines whether the target node executes.
|
||||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
- Output nodes: upstream END nodes that must complete their output
|
||||||
a branch node, the traversal is considered invalid:
|
before the target node can activate.
|
||||||
- `has_branch` will be False
|
|
||||||
- no branch nodes are returned.
|
|
||||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
|
||||||
branch node will `has_branch` be True.
|
|
||||||
|
|
||||||
Special case:
|
The traversal terminates early and returns empty tuples if any upstream
|
||||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
path reaches START/CYCLE_START without encountering a branch or output
|
||||||
it is considered directly reachable from the workflow entry, and therefore
|
node, indicating the target node is directly reachable and should be
|
||||||
has no controlling branch nodes.
|
activated immediately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_node (str):
|
target_node: The ID of the node whose upstream activation
|
||||||
The identifier of the node whose upstream control branches
|
dependencies are to be resolved.
|
||||||
are to be resolved.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[bool, tuple[tuple[str, str]]]:
|
A tuple of two elements:
|
||||||
- has_branch (bool):
|
- A deduplicated tuple of (branch_node_id, branch_label) pairs
|
||||||
True if every upstream path from `target_node` encounters
|
representing upstream branch control dependencies. Empty if
|
||||||
at least one branch node.
|
any clean path to START exists.
|
||||||
False if any path reaches a start node without a branch.
|
- A deduplicated tuple of upstream output node IDs that must
|
||||||
- branch_nodes (tuple[tuple[str, str]]):
|
complete before this node activates.
|
||||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
|
||||||
representing all branch nodes that can influence `target_node`.
|
|
||||||
Returns an empty tuple if `has_branch` is False.
|
|
||||||
"""
|
"""
|
||||||
source_nodes = [
|
source_nodes = self._reverse_adj[target_node]
|
||||||
{
|
|
||||||
"id": edge.get("source"),
|
|
||||||
"branch": edge.get("label")
|
|
||||||
}
|
|
||||||
for edge in self.edges
|
|
||||||
if edge.get("target") == target_node
|
|
||||||
]
|
|
||||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
return False, tuple()
|
return tuple(), tuple()
|
||||||
|
|
||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
|
output_nodes = []
|
||||||
non_branch_nodes = []
|
non_branch_nodes = []
|
||||||
|
|
||||||
for node_info in source_nodes:
|
for node_info in source_nodes:
|
||||||
@@ -149,19 +160,23 @@ class GraphBuilder:
|
|||||||
(node_info["id"], node_info["branch"])
|
(node_info["id"], node_info["branch"])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.get_node_type(node_info["id"]) == NodeType.END:
|
||||||
|
output_nodes.append(node_info["id"])
|
||||||
non_branch_nodes.append(node_info["id"])
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
has_branch = True
|
has_branch = True
|
||||||
for node_id in non_branch_nodes:
|
for node_id in non_branch_nodes:
|
||||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
|
||||||
has_branch = has_branch and node_has_branch
|
if not upstream_control_nodes:
|
||||||
if not has_branch:
|
if not upstream_output_nodes and node_id not in output_nodes:
|
||||||
break
|
return tuple(), tuple()
|
||||||
branch_nodes.extend(nodes)
|
branch_nodes = []
|
||||||
if not has_branch:
|
has_branch = False
|
||||||
branch_nodes = []
|
if has_branch:
|
||||||
|
branch_nodes.extend(upstream_control_nodes)
|
||||||
|
output_nodes.extend(upstream_output_nodes)
|
||||||
|
|
||||||
return has_branch, tuple(set(branch_nodes))
|
return tuple(set(branch_nodes)), tuple(set(output_nodes))
|
||||||
|
|
||||||
def _analyze_end_node_output(self):
|
def _analyze_end_node_output(self):
|
||||||
"""
|
"""
|
||||||
@@ -182,11 +197,10 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Collect all End nodes in the workflow
|
# Collect all End nodes in the workflow
|
||||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
|
||||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
|
||||||
|
|
||||||
# Iterate through each End node to analyze its output
|
# Iterate through each End node to analyze its output
|
||||||
for end_node in end_nodes:
|
for end_node in self.end_nodes:
|
||||||
end_node_id = end_node.get("id")
|
end_node_id = end_node.get("id")
|
||||||
config = end_node.get("config", {})
|
config = end_node.get("config", {})
|
||||||
output = config.get("output")
|
output = config.get("output")
|
||||||
@@ -195,42 +209,33 @@ class GraphBuilder:
|
|||||||
if not output:
|
if not output:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Regex to split output into:
|
|
||||||
# - variable placeholders: {{ ... }}
|
|
||||||
# - normal literal text
|
|
||||||
#
|
|
||||||
# Example:
|
|
||||||
# "Hello {{user.name}}!" ->
|
|
||||||
# ["Hello ", "{{user.name}}", "!"]
|
|
||||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
|
||||||
|
|
||||||
# Strict variable format: {{ node_id.field_name }}
|
|
||||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
|
||||||
variable_pattern = re.compile(variable_pattern_string)
|
|
||||||
|
|
||||||
# Split output into ordered segments
|
# Split output into ordered segments
|
||||||
output_template = list(re.findall(pattern, output))
|
output_template = list(_OUTPUT_PATTERN.findall(output))
|
||||||
|
|
||||||
# Determine whether each segment is literal text
|
# Determine whether each segment is literal text
|
||||||
# True -> literal (can be directly output)
|
# True -> literal (can be directly output)
|
||||||
# False -> variable placeholder (needs runtime value)
|
# False -> variable placeholder (needs runtime value)
|
||||||
output_flag = [
|
output_flag = [
|
||||||
not bool(variable_pattern.match(item))
|
not bool(_VARIABLE_PATTERN.match(item))
|
||||||
for item in output_template
|
for item in output_template
|
||||||
]
|
]
|
||||||
|
|
||||||
# Stream mode: output activation depends on upstream branch nodes
|
# Stream mode: output activation depends on upstream branch nodes
|
||||||
if self.stream:
|
if self.stream:
|
||||||
# Find upstream branch nodes that can control this End node
|
# Find upstream branch nodes that can control this End node
|
||||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
|
||||||
|
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
|
||||||
# Build StreamOutputConfig for this End node
|
# Build StreamOutputConfig for this End node
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
# If there is no upstream branch, output is active immediately
|
# If there is no upstream branch, output is active immediately
|
||||||
activate=not has_branch,
|
activate=activate,
|
||||||
|
|
||||||
# Branch nodes that control activation of this End node
|
# Branch nodes that control activation of this End node
|
||||||
control_nodes=self._merge_control_nodes(control_nodes),
|
control_nodes=self._merge_control_nodes(upstream_control_nodes),
|
||||||
|
upstream_output_nodes=list(upstream_output_nodes),
|
||||||
|
control_resolved=not bool(upstream_control_nodes),
|
||||||
|
output_resolved=not bool(upstream_output_nodes),
|
||||||
|
|
||||||
# Convert output segments into OutputContent objects
|
# Convert output segments into OutputContent objects
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -249,14 +254,16 @@ class GraphBuilder:
|
|||||||
cursor=0
|
cursor=0
|
||||||
)
|
)
|
||||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||||
f"activate: {not has_branch}, "
|
f"activate: {activate}, "
|
||||||
f"control_nodes: {control_nodes},"
|
f"control_nodes: {upstream_control_nodes},"
|
||||||
|
f"ref_outputs: {upstream_output_nodes},"
|
||||||
f"output: {output_template},"
|
f"output: {output_template},"
|
||||||
f"output_activate: {output_flag}")
|
f"output_activate: {output_flag}")
|
||||||
|
|
||||||
# Non-stream mode: all outputs are activated by default
|
# Non-stream mode: all outputs are activated by default
|
||||||
else:
|
else:
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
activate=True,
|
activate=True,
|
||||||
control_nodes={},
|
control_nodes={},
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -269,7 +276,10 @@ class GraphBuilder:
|
|||||||
for output_string, activate in zip(output_template, output_flag)
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
cursor=0
|
cursor=0,
|
||||||
|
upstream_output_nodes=[],
|
||||||
|
control_resolved=True,
|
||||||
|
output_resolved=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_nodes(self):
|
def add_nodes(self):
|
||||||
@@ -304,8 +314,6 @@ class GraphBuilder:
|
|||||||
# Record start and end node IDs
|
# Record start and end node IDs
|
||||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
self.start_node_id = node_id
|
self.start_node_id = node_id
|
||||||
elif node_type == NodeType.END:
|
|
||||||
self.end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
# Create node instance (start and end nodes are also created)
|
# Create node instance (start and end nodes are also created)
|
||||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
@@ -494,9 +502,11 @@ class GraphBuilder:
|
|||||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||||
|
|
||||||
# Connect End nodes to the global END node
|
# Connect End nodes to the global END node
|
||||||
for end_node_id in self.end_node_ids:
|
for end_node in self.end_nodes:
|
||||||
self.graph.add_edge(end_node_id, END)
|
end_node_id = end_node.get("id")
|
||||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
if end_node_id:
|
||||||
|
self.graph.add_edge(end_node_id, END)
|
||||||
|
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||||
return
|
return
|
||||||
|
|
||||||
def build(self) -> CompiledStateGraph:
|
def build(self) -> CompiledStateGraph:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class WorkflowResultBuilder:
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
elapsed_time: float,
|
elapsed_time: float,
|
||||||
final_output: str,
|
final_output: str,
|
||||||
|
success: bool
|
||||||
):
|
):
|
||||||
"""Construct the final standardized output of the workflow execution.
|
"""Construct the final standardized output of the workflow execution.
|
||||||
|
|
||||||
@@ -29,6 +30,7 @@ class WorkflowResultBuilder:
|
|||||||
elapsed_time (float): Total execution time in seconds.
|
elapsed_time (float): Total execution time in seconds.
|
||||||
final_output (Any): The aggregated or final output content of the workflow
|
final_output (Any): The aggregated or final output content of the workflow
|
||||||
(e.g., combined messages from all End nodes).
|
(e.g., combined messages from all End nodes).
|
||||||
|
success (bool): Whether the execution was successful.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the final workflow execution result with keys:
|
dict: A dictionary containing the final workflow execution result with keys:
|
||||||
@@ -49,7 +51,7 @@ class WorkflowResultBuilder:
|
|||||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "completed",
|
"status": "completed" if success else "failed",
|
||||||
"output": final_output,
|
"output": final_output,
|
||||||
"variables": {
|
"variables": {
|
||||||
"conv": variable_pool.get_all_conversation_vars(),
|
"conv": variable_pool.get_all_conversation_vars(),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/9 15:11
|
# @Time : 2026/2/9 15:11
|
||||||
import re
|
import re
|
||||||
|
from queue import Queue
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
|
|||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this output segment is currently active.\n"
|
"Whether this output segment is currently active."
|
||||||
"- True: allowed to be emitted/output\n"
|
"- True: allowed to be emitted/output"
|
||||||
"- False: blocked until activated by branch control"
|
"- False: blocked until activated by branch control"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -46,8 +47,8 @@ class OutputContent(BaseModel):
|
|||||||
is_variable: bool = Field(
|
is_variable: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this segment represents a variable placeholder.\n"
|
"Whether this segment represents a variable placeholder."
|
||||||
"True -> variable (e.g. {{ node.field }})\n"
|
"True -> variable (e.g. {{ node.field }})"
|
||||||
"False -> literal text"
|
"False -> literal text"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel):
|
|||||||
- which upstream branch/control nodes gate the activation
|
- which upstream branch/control nodes gate the activation
|
||||||
- how each parsed output segment is streamed and activated
|
- how each parsed output segment is streamed and activated
|
||||||
"""
|
"""
|
||||||
|
id: str = Field(
|
||||||
|
...,
|
||||||
|
description="ID of the End node this configuration belongs to."
|
||||||
|
)
|
||||||
|
|
||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Global activation flag for the End node output.\n"
|
"Global activation flag for the End node output."
|
||||||
"When False, output segments should not be emitted even if available.\n"
|
"When False, output segments should not be emitted even if available."
|
||||||
"This flag typically becomes True once required control branch conditions "
|
"This flag typically becomes True once required control branch conditions "
|
||||||
"are satisfied."
|
"are satisfied."
|
||||||
)
|
)
|
||||||
@@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel):
|
|||||||
control_nodes: dict[str, list[str]] = Field(
|
control_nodes: dict[str, list[str]] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Control branch conditions for this End node output.\n"
|
"Control branch conditions for this End node output."
|
||||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
"Mapping of `branch_node_id -> expected_branch_label`."
|
||||||
"The End node output becomes globally active when a controlling branch node "
|
"The End node output becomes globally active when a controlling branch node "
|
||||||
"reports a matching completion status."
|
"reports a matching completion status."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upstream_output_nodes: list[str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Upstream output node dependencies (data flow)."
|
||||||
|
"Represents END/output nodes that this output depends on."
|
||||||
|
"These nodes provide data sources required before this output can be activated "
|
||||||
|
"or streamed."
|
||||||
|
"Used to ensure correct ordering and dependency resolution in streaming mode."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
control_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream branch control dependencies have been satisfied."
|
||||||
|
"True if no upstream branch nodes exist or the required branch "
|
||||||
|
"conditions have been met."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream output node dependencies have been completed."
|
||||||
|
"True if no upstream output nodes exist or all upstream output "
|
||||||
|
"nodes have finished their output."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
outputs: list[OutputContent] = Field(
|
outputs: list[OutputContent] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Ordered list of output segments parsed from the output template.\n"
|
"Ordered list of output segments parsed from the output template."
|
||||||
"Each segment represents either a literal text block or a variable placeholder "
|
"Each segment represents either a literal text block or a variable placeholder "
|
||||||
"that may be activated independently."
|
"that may be activated independently."
|
||||||
)
|
)
|
||||||
@@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel):
|
|||||||
cursor: int = Field(
|
cursor: int = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Streaming cursor index.\n"
|
"Streaming cursor index."
|
||||||
"Indicates the next output segment index to be emitted.\n"
|
"Indicates the next output segment index to be emitted."
|
||||||
"Segments with index < cursor are considered already streamed."
|
"Segments with index < cursor are considered already streamed."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
force: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Force flag for output emission."
|
||||||
|
"When True, all output segments are emitted regardless of activation state."
|
||||||
|
"Triggered when this output node has finished execution."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def update_activate(self, scope: str, status=None):
|
def update_activate(self, scope: str, status=None):
|
||||||
"""
|
"""
|
||||||
Update streaming activation state based on an upstream node or special variable.
|
Update streaming activation state based on upstream events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scope (str):
|
scope (str):
|
||||||
Identifier of the completed upstream entity.
|
Identifier of the completed upstream entity.
|
||||||
- If a control branch node, it should match a key in `control_nodes`.
|
- If a control branch node, it should match a key in `control_nodes`.
|
||||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
|
||||||
|
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
|
||||||
|
it may appear in output segments.
|
||||||
|
|
||||||
status (optional):
|
status (optional):
|
||||||
Completion status of the control branch node.
|
Completion status of the control branch node.
|
||||||
Required when `scope` refers to a control node.
|
Required when `scope` refers to a control node.
|
||||||
|
|
||||||
Behavior:
|
Behavior:
|
||||||
1. Control branch nodes:
|
1. Force activation:
|
||||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
- If `self.force` is True, the method returns immediately.
|
||||||
branch label, the End node output becomes globally active (`activate = True`).
|
- If `scope == self.id`, the node marks itself as completed:
|
||||||
|
- `activate = True`
|
||||||
|
- `force = True`
|
||||||
|
This is typically used for final flushing when the node finishes execution.
|
||||||
|
|
||||||
2. Variable output segments:
|
2. Control dependency resolution:
|
||||||
- For each segment that is a variable (`is_variable=True`):
|
- If `scope` matches a key in `control_nodes`:
|
||||||
- If the segment literal references `scope`, mark the segment as active.
|
- `status` must be provided.
|
||||||
- This applies both to regular node variables (e.g., "node_id.field")
|
- If `status` matches expected branch labels, mark control as resolved
|
||||||
and special system variables (e.g., "sys.xxx").
|
(`control_resolved = True`).
|
||||||
|
|
||||||
|
3. Upstream output dependency resolution:
|
||||||
|
- If `scope` is in `upstream_output_nodes`,
|
||||||
|
mark data dependency as resolved (`output_resolved = True`).
|
||||||
|
|
||||||
|
4. Global activation condition:
|
||||||
|
- The node becomes active when BOTH conditions are satisfied:
|
||||||
|
- control_resolved == True
|
||||||
|
- output_resolved == True
|
||||||
|
- Once activated, `activate` remains True.
|
||||||
|
|
||||||
|
5. Variable segment activation:
|
||||||
|
- For each output segment that is a variable (`is_variable=True`):
|
||||||
|
- If the segment depends on the given `scope`,
|
||||||
|
mark the segment as active.
|
||||||
|
- This applies to both node variables (e.g., "node_id.field")
|
||||||
|
and system variables (e.g., "sys.xxx").
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- This method does not emit output or advance the streaming cursor.
|
- This method does NOT emit output or advance the streaming cursor.
|
||||||
- It only updates activation flags based on upstream events or special variables.
|
- It only updates activation and dependency resolution states.
|
||||||
|
- Activation is driven by both control flow (branch nodes) and
|
||||||
|
data flow (upstream output nodes).
|
||||||
"""
|
"""
|
||||||
|
if self.force:
|
||||||
|
return
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
if scope == self.id:
|
||||||
|
self.activate = True
|
||||||
|
self.force = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# resolve control branch dependency
|
||||||
if scope in self.control_nodes:
|
if scope in self.control_nodes:
|
||||||
if status is None:
|
if status is None:
|
||||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
if status in self.control_nodes[scope]:
|
if status in self.control_nodes[scope]:
|
||||||
self.activate = True
|
self.control_resolved = True
|
||||||
|
|
||||||
# Case 2: activate variable segments related to this node
|
if scope in self.upstream_output_nodes:
|
||||||
|
self.upstream_output_nodes.remove(scope)
|
||||||
|
if not self.upstream_output_nodes:
|
||||||
|
self.output_resolved = True
|
||||||
|
|
||||||
|
self.activate = self.activate or (self.control_resolved and self.output_resolved)
|
||||||
|
|
||||||
|
# activate variable segments related to this node
|
||||||
for i in range(len(self.outputs)):
|
for i in range(len(self.outputs)):
|
||||||
if (
|
if (
|
||||||
self.outputs[i].is_variable
|
self.outputs[i].is_variable
|
||||||
@@ -174,12 +256,17 @@ class StreamOutputCoordinator:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||||
self.activate_end: str | None = None
|
self.activate_end: str | None = None
|
||||||
|
self.output_queue: Queue = Queue()
|
||||||
|
self.processed_outputs = []
|
||||||
|
|
||||||
def initialize_end_outputs(
|
def initialize_end_outputs(
|
||||||
self,
|
self,
|
||||||
end_node_map: dict[str, StreamOutputConfig]
|
end_node_map: dict[str, StreamOutputConfig]
|
||||||
):
|
):
|
||||||
self.end_outputs = end_node_map
|
self.end_outputs = end_node_map
|
||||||
|
self.processed_outputs = []
|
||||||
|
self.activate_end = None
|
||||||
|
self.output_queue = Queue()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_activate_end_info(self):
|
def current_activate_end_info(self):
|
||||||
@@ -211,8 +298,11 @@ class StreamOutputCoordinator:
|
|||||||
"""
|
"""
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs.keys():
|
||||||
self.end_outputs[node].update_activate(scope, status)
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
if self.end_outputs[node].activate and self.activate_end is None:
|
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||||
self.activate_end = node
|
self.output_queue.put(node)
|
||||||
|
self.processed_outputs.append(node)
|
||||||
|
if self.activate_end is None and not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
|
|
||||||
async def emit_activate_chunk(
|
async def emit_activate_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -256,7 +346,7 @@ class StreamOutputCoordinator:
|
|||||||
final_chunk = ''
|
final_chunk = ''
|
||||||
current_segment = end_info.outputs[end_info.cursor]
|
current_segment = end_info.outputs[end_info.cursor]
|
||||||
|
|
||||||
if not current_segment.activate and not force:
|
if not current_segment.activate and not force and not end_info.force:
|
||||||
# Stop processing until this segment becomes active
|
# Stop processing until this segment becomes active
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -273,7 +363,7 @@ class StreamOutputCoordinator:
|
|||||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||||
|
|
||||||
if final_chunk:
|
if final_chunk:
|
||||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
|
||||||
yield {
|
yield {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": {
|
"data": {
|
||||||
@@ -285,8 +375,7 @@ class StreamOutputCoordinator:
|
|||||||
end_info.cursor += 1
|
end_info.cursor += 1
|
||||||
|
|
||||||
if end_info.cursor >= len(end_info.outputs):
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
self.end_outputs.pop(self.activate_end)
|
self.pop_current_activate_end()
|
||||||
self.activate_end = None
|
|
||||||
|
|
||||||
async def flush_remaining_chunk(
|
async def flush_remaining_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -325,6 +414,8 @@ class StreamOutputCoordinator:
|
|||||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
|
if not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
# Move to next active End node if current one is done
|
# Move to next active End node if current one is done
|
||||||
if not self.activate_end and self.end_outputs:
|
if not self.activate_end and self.end_outputs:
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
|||||||
@@ -128,89 +128,100 @@ class WorkflowExecutor:
|
|||||||
- token_usage: aggregated token usage if available
|
- token_usage: aggregated token usage if available
|
||||||
- error: error message if any
|
- error: error message if any
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
start = datetime.datetime.now()
|
||||||
|
async for event in self.execute_stream(input_data):
|
||||||
start_time = datetime.datetime.now()
|
if event.get("event") == "workflow_end":
|
||||||
|
return event.get("data")
|
||||||
# Execute the workflow
|
return self.result_builder.build_final_output(
|
||||||
try:
|
{"error": "Workflow execution did not end as expected"},
|
||||||
# Build the workflow graph
|
self.variable_pool,
|
||||||
graph = self.build_graph()
|
(datetime.datetime.now() - start).total_seconds(),
|
||||||
|
"",
|
||||||
# Initialize the variable pool with input data
|
success=False
|
||||||
await self.variable_initializer.initialize(
|
)
|
||||||
variable_pool=self.variable_pool,
|
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||||
input_data=input_data,
|
#
|
||||||
execution_context=self.execution_context
|
# start_time = datetime.datetime.now()
|
||||||
)
|
#
|
||||||
initial_state = self.state_manager.create_initial_state(
|
# # Execute the workflow
|
||||||
workflow_config=self.workflow_config,
|
# try:
|
||||||
input_data=input_data,
|
# # Build the workflow graph
|
||||||
execution_context=self.execution_context,
|
# graph = self.build_graph()
|
||||||
start_node_id=self.start_node_id
|
#
|
||||||
)
|
# # Initialize the variable pool with input data
|
||||||
|
# await self.variable_initializer.initialize(
|
||||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
# variable_pool=self.variable_pool,
|
||||||
|
# input_data=input_data,
|
||||||
# Aggregate output from all End nodes
|
# execution_context=self.execution_context
|
||||||
full_content = ''
|
# )
|
||||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
# initial_state = self.state_manager.create_initial_state(
|
||||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
# workflow_config=self.workflow_config,
|
||||||
|
# input_data=input_data,
|
||||||
# Append messages for user and assistant
|
# execution_context=self.execution_context,
|
||||||
if input_data.get("files"):
|
# start_node_id=self.start_node_id
|
||||||
result["messages"].extend(
|
# )
|
||||||
[
|
#
|
||||||
{
|
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||||
"role": "user",
|
#
|
||||||
"content": input_data.get("message", '')
|
# # Aggregate output from all End nodes
|
||||||
},
|
# full_content = ''
|
||||||
{
|
# for end_id in self.stream_coordinator.end_outputs.keys():
|
||||||
"role": "user",
|
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||||
"content": input_data.get("files")
|
#
|
||||||
},
|
# # Append messages for user and assistant
|
||||||
{
|
# if input_data.get("files"):
|
||||||
"role": "assistant",
|
# result["messages"].extend(
|
||||||
"content": full_content
|
# [
|
||||||
}
|
# {
|
||||||
]
|
# "role": "user",
|
||||||
)
|
# "content": input_data.get("message", '')
|
||||||
else:
|
# },
|
||||||
result["messages"].extend(
|
# {
|
||||||
[
|
# "role": "user",
|
||||||
{
|
# "content": input_data.get("files")
|
||||||
"role": "user",
|
# },
|
||||||
"content": input_data.get("message", '')
|
# {
|
||||||
},
|
# "role": "assistant",
|
||||||
{
|
# "content": full_content
|
||||||
"role": "assistant",
|
# }
|
||||||
"content": full_content
|
# ]
|
||||||
}
|
# )
|
||||||
]
|
# else:
|
||||||
)
|
# result["messages"].extend(
|
||||||
# Calculate elapsed time
|
# [
|
||||||
end_time = datetime.datetime.now()
|
# {
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
# "role": "user",
|
||||||
|
# "content": input_data.get("message", '')
|
||||||
logger.info(
|
# },
|
||||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
# {
|
||||||
|
# "role": "assistant",
|
||||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
# "content": full_content
|
||||||
|
# }
|
||||||
except Exception as e:
|
# ]
|
||||||
end_time = datetime.datetime.now()
|
# )
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
# # Calculate elapsed time
|
||||||
|
# end_time = datetime.datetime.now()
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
# elapsed_time = (end_time - start_time).total_seconds()
|
||||||
exc_info=True)
|
#
|
||||||
return {
|
# logger.info(
|
||||||
"status": "failed",
|
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||||
"error": str(e),
|
#
|
||||||
"output": None,
|
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||||
"node_outputs": {},
|
#
|
||||||
"elapsed_time": elapsed_time,
|
# except Exception as e:
|
||||||
"token_usage": None
|
# end_time = datetime.datetime.now()
|
||||||
}
|
# elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
#
|
||||||
|
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
|
# exc_info=True)
|
||||||
|
# return {
|
||||||
|
# "status": "failed",
|
||||||
|
# "error": str(e),
|
||||||
|
# "output": None,
|
||||||
|
# "node_outputs": {},
|
||||||
|
# "elapsed_time": elapsed_time,
|
||||||
|
# "token_usage": None
|
||||||
|
# }
|
||||||
|
|
||||||
async def execute_stream(
|
async def execute_stream(
|
||||||
self,
|
self,
|
||||||
@@ -248,7 +259,8 @@ class WorkflowExecutor:
|
|||||||
"timestamp": int(start_time.timestamp() * 1000)
|
"timestamp": int(start_time.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
result = None
|
||||||
|
full_content = ''
|
||||||
try:
|
try:
|
||||||
# Build the workflow graph in streaming mode
|
# Build the workflow graph in streaming mode
|
||||||
graph = self.build_graph(stream=True)
|
graph = self.build_graph(stream=True)
|
||||||
@@ -266,7 +278,6 @@ class WorkflowExecutor:
|
|||||||
start_node_id=self.start_node_id
|
start_node_id=self.start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
full_content = ''
|
|
||||||
self.stream_coordinator.update_scope_activation("sys")
|
self.stream_coordinator.update_scope_activation("sys")
|
||||||
|
|
||||||
# Execute the workflow with streaming
|
# Execute the workflow with streaming
|
||||||
@@ -363,7 +374,12 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
"data": self.result_builder.build_final_output(
|
||||||
|
result,
|
||||||
|
self.variable_pool,
|
||||||
|
elapsed_time,
|
||||||
|
full_content,
|
||||||
|
success=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -372,16 +388,19 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
exc_info=True)
|
exc_info=True)
|
||||||
|
if result is None:
|
||||||
|
result = {"error": str(e)}
|
||||||
|
else:
|
||||||
|
result["error"] = str(e)
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": {
|
"data": self.result_builder.build_final_output(
|
||||||
"execution_id": self.execution_context.execution_id,
|
result,
|
||||||
"status": "failed",
|
self.variable_pool,
|
||||||
"error": str(e),
|
elapsed_time,
|
||||||
"elapsed_time": elapsed_time,
|
full_content,
|
||||||
"timestamp": end_time.isoformat()
|
success=False
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"http://sandbox:8194/v1/sandbox/run",
|
"http://sandbox:8194/v1/sandbox/run",
|
||||||
headers={
|
headers={
|
||||||
|
|||||||
@@ -33,11 +33,11 @@ class IfElseNode(BaseNode):
|
|||||||
"right": expression.right
|
"right": expression.right
|
||||||
if expression.input_type == ValueInputType.CONSTANT
|
if expression.input_type == ValueInputType.CONSTANT
|
||||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||||
"operator": expression.operator,
|
"operator": str(expression.operator),
|
||||||
})
|
})
|
||||||
result.append({
|
result.append({
|
||||||
"expressions": expressions,
|
"expressions": expressions,
|
||||||
"logical_operator": case.logical_operator,
|
"logical_operator": str(case.logical_operator),
|
||||||
})
|
})
|
||||||
return {
|
return {
|
||||||
"cases": result
|
"cases": result
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class WorkflowValidator:
|
|||||||
# 仅在发布时验证所有节点可达
|
# 仅在发布时验证所有节点可达
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
reachable = WorkflowValidator.get_reachable_nodes(
|
||||||
start_nodes[0]["id"],
|
start_nodes[0]["id"],
|
||||||
edges
|
edges
|
||||||
)
|
)
|
||||||
@@ -194,7 +194,7 @@ class WorkflowValidator:
|
|||||||
return len(errors) == 0, errors
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||||
"""获取从 start 节点可达的所有节点
|
"""获取从 start 节点可达的所有节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
|||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType
|
||||||
|
|
||||||
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
|
|||||||
"""
|
"""
|
||||||
if isinstance(var, str):
|
if isinstance(var, str):
|
||||||
return cls.STRING
|
return cls.STRING
|
||||||
elif isinstance(var, (int, float)):
|
|
||||||
return cls.NUMBER
|
|
||||||
elif isinstance(var, bool):
|
elif isinstance(var, bool):
|
||||||
return cls.BOOLEAN
|
return cls.BOOLEAN
|
||||||
|
elif isinstance(var, (int, float)):
|
||||||
|
return cls.NUMBER
|
||||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||||
return cls.FILE
|
return cls.FILE
|
||||||
elif isinstance(var, dict):
|
elif isinstance(var, dict):
|
||||||
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
|
|||||||
content_cache: dict = Field(default_factory=dict)
|
content_cache: dict = Field(default_factory=dict)
|
||||||
is_file: bool
|
is_file: bool
|
||||||
|
|
||||||
_byte_content: bytes | None = None
|
_byte_content: bytes | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
def get_content(self):
|
def get_content(self):
|
||||||
return self._byte_content
|
return self._byte_content
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
|
|||||||
|
|
||||||
|
|
||||||
class StringVariable(BaseVariable):
|
class StringVariable(BaseVariable):
|
||||||
|
value: str
|
||||||
type = 'str'
|
type = 'str'
|
||||||
|
|
||||||
def valid_value(self, value) -> str:
|
def valid_value(self, value) -> str:
|
||||||
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class NumberVariable(BaseVariable):
|
class NumberVariable(BaseVariable):
|
||||||
|
value: int | float
|
||||||
type = 'number'
|
type = 'number'
|
||||||
|
|
||||||
def valid_value(self, value) -> int | float:
|
def valid_value(self, value) -> int | float:
|
||||||
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class BooleanVariable(BaseVariable):
|
class BooleanVariable(BaseVariable):
|
||||||
|
value: bool
|
||||||
type = 'boolean'
|
type = 'boolean'
|
||||||
|
|
||||||
def valid_value(self, value) -> bool:
|
def valid_value(self, value) -> bool:
|
||||||
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class DictVariable(BaseVariable):
|
class DictVariable(BaseVariable):
|
||||||
|
value: dict
|
||||||
type = 'object'
|
type = 'object'
|
||||||
|
|
||||||
def valid_value(self, value) -> dict:
|
def valid_value(self, value) -> dict:
|
||||||
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class FileVariable(BaseVariable):
|
class FileVariable(BaseVariable):
|
||||||
|
value: FileObject
|
||||||
type = 'file'
|
type = 'file'
|
||||||
|
|
||||||
def valid_value(self, value) -> FileObject:
|
def valid_value(self, value) -> FileObject:
|
||||||
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class ArrayVariable(BaseVariable, Generic[T]):
|
class ArrayVariable(BaseVariable, Generic[T]):
|
||||||
|
value: list[T]
|
||||||
type = 'array'
|
type = 'array'
|
||||||
|
|
||||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||||
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
class NestedArrayVariable(BaseVariable):
|
class NestedArrayVariable(BaseVariable):
|
||||||
|
value: list[ArrayVariable]
|
||||||
type = 'array_nest'
|
type = 'array_nest'
|
||||||
|
|
||||||
def valid_value(self, value: list[T]) -> list[T]:
|
def valid_value(self, value: list[T]) -> list[T]:
|
||||||
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
|
|||||||
category=RuntimeWarning
|
category=RuntimeWarning
|
||||||
)
|
)
|
||||||
class AnyVariable(BaseVariable):
|
class AnyVariable(BaseVariable):
|
||||||
|
value: Any
|
||||||
type = 'any'
|
type = 'any'
|
||||||
|
|
||||||
def valid_value(self, value: Any) -> Any:
|
def valid_value(self, value: Any) -> Any:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
|
|||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest, FileInput, FileType
|
from app.schemas import DraftRunRequest, FileInput
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|||||||
Reference in New Issue
Block a user