refactor(workflow): add new engine and utils modules
- Add engine/ directory with core components: - graph_builder: workflow graph construction - variable_pool: variable management - state_manager: execution state tracking - event_stream_handler: event processing - stream_output_coordinator: streaming output control - result_builder: result aggregation - runtime_schema: runtime type definitions - Add utils/ directory with utilities: - expression_evaluator: safe expression evaluation - template_renderer: Jinja2 template rendering
This commit is contained in:
4
api/app/core/__init__.py
Normal file
4
api/app/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
4
api/app/core/workflow/engine/__init__.py
Normal file
4
api/app/core/workflow/engine/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:28
|
||||
273
api/app/core/workflow/engine/event_stream_handler.py
Normal file
273
api/app/core/workflow/engine/event_stream_handler.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EventStreamHandler:
|
||||
def __init__(
|
||||
self,
|
||||
output_coordinator: StreamOutputCoordinator,
|
||||
variable_pool: VariablePool,
|
||||
execution_id: str,
|
||||
):
|
||||
self.coordinator = output_coordinator
|
||||
self.variable_pool = variable_pool
|
||||
self.execution_id = execution_id
|
||||
|
||||
def update_stream_output_status(self, activate: dict, data: dict):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
|
||||
self.coordinator.update_scope_activation(node_id, status=node_output_status)
|
||||
|
||||
async def handle_updates_event(
|
||||
self,
|
||||
data: dict,
|
||||
graph: CompiledStateGraph,
|
||||
checkpoint_config: RunnableConfig
|
||||
):
|
||||
"""
|
||||
Handle workflow state update events ("updates") and stream active End node outputs.
|
||||
|
||||
Steps:
|
||||
1. Retrieve the current graph state.
|
||||
2. Extract node activation information from the state.
|
||||
3. Update the activation status of all End nodes.
|
||||
4. While there is an active End node:
|
||||
- Call _emit_active_chunks() to yield all currently active output segments.
|
||||
- After all segments are processed, update activate_end if there are remaining End nodes.
|
||||
5. Log a debug message indicating state update received.
|
||||
|
||||
Args:
|
||||
data (dict): The latest node state updates.
|
||||
graph (CompiledStateGraph): The compiled LangGraph state machine.
|
||||
checkpoint_config (RunnableConfig): Configuration for the current execution context.)
|
||||
|
||||
Yields:
|
||||
dict: Streamed output event, each chunk in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
state = graph.get_state(config=checkpoint_config).values
|
||||
activate = state.get("activate", {})
|
||||
|
||||
self.update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.coordinator.activate_end and not wait:
|
||||
async for msg_event in self.coordinator.emit_activate_chunk(self.variable_pool):
|
||||
yield msg_event
|
||||
|
||||
if self.coordinator.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self.update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
async def handle_node_chunk_event(self, data: dict):
|
||||
"""
|
||||
Handle streaming chunk events from individual nodes ("node_chunk").
|
||||
|
||||
This method processes output segments for the currently active End node.
|
||||
If the segment depends on the provided node_id:
|
||||
- If the node has finished execution (`done=True`), advance the cursor.
|
||||
- If all segments are processed, deactivate the End node.
|
||||
- Otherwise, yield the current chunk as a streaming message.
|
||||
|
||||
Args:
|
||||
data (dict): Node chunk event data, expected keys:
|
||||
- "node_id": ID of the node producing this chunk
|
||||
- "chunk": Chunk of output text
|
||||
- "done": Boolean indicating whether the node finished producing output
|
||||
|
||||
Yields:
|
||||
dict: Streaming message event in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
if self.coordinator.activate_end:
|
||||
end_info = self.coordinator.current_activate_end_info
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
return
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.coordinator.pop_current_activate_end()
|
||||
else:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_node_error_event(data: dict):
|
||||
"""
|
||||
Handle node error events ("node_error") during workflow execution.
|
||||
|
||||
This method streams an error event for a node that has failed. The event
|
||||
contains the node ID, status, input data, elapsed time, and error message.
|
||||
|
||||
Args:
|
||||
data (dict): Node error event data, expected keys:
|
||||
- "node_id": ID of the node that failed
|
||||
- "input_data": The input data that caused the error
|
||||
- "elapsed_time": Execution time before the error occurred
|
||||
- "error": Error message or exception string
|
||||
|
||||
Yields:
|
||||
dict: Node error event in the format:
|
||||
{
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"status": "failed",
|
||||
"input": ...,
|
||||
"elapsed_time": float,
|
||||
"output": None,
|
||||
"error": str
|
||||
}
|
||||
}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
yield {
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"status": "failed",
|
||||
"input": data.get("input_data"),
|
||||
"elapsed_time": data.get("elapsed_time"),
|
||||
"output": None,
|
||||
"error": data.get("error")
|
||||
}
|
||||
}
|
||||
|
||||
async def handle_debug_event(self, data: dict, input_data: dict):
|
||||
"""
|
||||
Handle debug events ("debug") related to node execution status.
|
||||
|
||||
This method streams debug events for nodes, including when a node starts
|
||||
execution ("node_start") and when it completes execution ("node_end").
|
||||
It filters out nodes with names starting with "nop" as no-operation nodes.
|
||||
|
||||
Args:
|
||||
data (dict): Debug event data, expected keys:
|
||||
- "type": Event type ("task" for start, "task_result" for completion)
|
||||
- "payload": Node-related information, including:
|
||||
- "name": Node name / ID
|
||||
- "input": Node input data (for "task" type)
|
||||
- "result": Node execution result (for "task_result" type)
|
||||
- "timestamp": ISO timestamp string of the event
|
||||
input_data (dict): Original workflow input data (used to get conversation_id)
|
||||
|
||||
Yields:
|
||||
dict: Node debug event in one of the following formats:
|
||||
1. Node start:
|
||||
{
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms)
|
||||
}
|
||||
}
|
||||
2. Node end:
|
||||
{
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms),
|
||||
"input": dict,
|
||||
"output": Any,
|
||||
"elapsed_time": float
|
||||
}
|
||||
}
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
|
||||
# Skip no-operation nodes
|
||||
if node_name and node_name.startswith("nop"):
|
||||
return
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
if not inputv.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
@@ -9,169 +13,16 @@ from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.types import Send
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.stream_output_coordinator import OutputContent, StreamOutputConfig
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
@@ -230,7 +81,7 @@ class GraphBuilder:
|
||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||
|
||||
@staticmethod
|
||||
def _merge_control_nodes(control_nodes: list[tuple[str, str]]) -> dict[str, list]:
|
||||
def _merge_control_nodes(control_nodes: tuple[tuple[str, str]]) -> dict[str, list]:
|
||||
result = defaultdict(list)
|
||||
for node in control_nodes:
|
||||
result[node[0]].append(node[1])
|
||||
104
api/app/core/workflow/engine/result_builder.py
Normal file
104
api/app/core/workflow/engine/result_builder.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
|
||||
class WorkflowResultBuiler:
|
||||
def build_final_output(
|
||||
self,
|
||||
result: dict,
|
||||
variable_pool: VariablePool,
|
||||
elapsed_time: float,
|
||||
final_output: str,
|
||||
):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
This method aggregates node outputs, token usage, conversation and system
|
||||
variables, messages, and other metadata into a consistent dictionary
|
||||
structure suitable for returning from workflow execution.
|
||||
|
||||
Args:
|
||||
result (dict): The runtime state returned by the workflow graph execution.
|
||||
Expected keys include:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
variable_pool (VariablePool): Variable Pool
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
- "status": Execution status ("completed")
|
||||
- "output": Aggregated final output content
|
||||
- "variables": Namespace dictionary with:
|
||||
- "conv": Conversation variables
|
||||
- "sys": System variables
|
||||
- "node_outputs": Outputs from all executed nodes
|
||||
- "messages": Conversation messages exchanged
|
||||
- "conversation_id": ID of the current conversation
|
||||
- "elapsed_time": Total execution time in seconds
|
||||
- "token_usage": Aggregated token usage across nodes (if available)
|
||||
- "error": Error message if any occurred during execution
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self.aggregate_token_usage(node_outputs)
|
||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": variable_pool.get_all_conversation_vars(),
|
||||
"sys": variable_pool.get_all_system_vars()
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||
"""
|
||||
Aggregate token usage statistics across all nodes.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): A dictionary of all node outputs.
|
||||
|
||||
Returns:
|
||||
dict | None: Aggregated token usage in the format:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
Returns None if no token usage information is available.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
29
api/app/core/workflow/engine/runtime_schema.py
Normal file
29
api/app/core/workflow/engine/runtime_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import uuid
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
|
||||
@classmethod
|
||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
||||
return cls(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
checkpoint_config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
99
api/app/core/workflow/engine/state_manager.py
Normal file
99
api/app/core/workflow/engine/state_manager.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
return y if y > x else x
|
||||
|
||||
|
||||
class WorkflowState(dict):
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
__required_keys__ = frozenset({
|
||||
"messages",
|
||||
"cycle_nodes",
|
||||
"looping",
|
||||
"node_outputs",
|
||||
"execution_id",
|
||||
"workspace_id",
|
||||
"user_id",
|
||||
"activate",
|
||||
})
|
||||
__optional_keys__ = frozenset({
|
||||
"error",
|
||||
"error_node",
|
||||
})
|
||||
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[dict[str, str]], lambda x, y: y]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
|
||||
class WorkflowStateManager:
|
||||
def create_initial_state(
|
||||
self,
|
||||
workflow_config: dict,
|
||||
input_data: dict,
|
||||
execution_context: ExecutionContext,
|
||||
start_node_id: str
|
||||
) -> WorkflowState:
|
||||
conversation_messages = input_data.get("conv_messages", [])
|
||||
|
||||
return WorkflowState(
|
||||
messages=conversation_messages,
|
||||
node_outputs={},
|
||||
execution_id=execution_context.execution_id,
|
||||
workspace_id=execution_context.workspace_id,
|
||||
user_id=execution_context.user_id,
|
||||
error=None,
|
||||
error_node=None,
|
||||
cycle_nodes=self._identify_cycle_nodes(workflow_config),
|
||||
looping=0,
|
||||
activate={
|
||||
start_node_id: True
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _identify_cycle_nodes(
|
||||
workflow_config: dict
|
||||
):
|
||||
return [
|
||||
node.get("id")
|
||||
for node in workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
]
|
||||
328
api/app/core/workflow/engine/stream_output_coordinator.py
Normal file
328
api/app/core/workflow/engine/stream_output_coordinator.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 15:11
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class StreamOutputCoordinator:
|
||||
def __init__(self):
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
|
||||
def initialize_end_outputs(
|
||||
self,
|
||||
end_node_map: dict[str, StreamOutputConfig]
|
||||
):
|
||||
self.end_outputs = end_node_map
|
||||
|
||||
@property
|
||||
def current_activate_end_info(self):
|
||||
return self.end_outputs.get(self.activate_end)
|
||||
|
||||
def pop_current_activate_end(self):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
def update_scope_activation(
|
||||
self,
|
||||
scope: str,
|
||||
status: str | None = None
|
||||
):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if any control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
async def emit_activate_chunk(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
force: bool = False
|
||||
) -> AsyncGenerator[dict[str, str | dict], None]:
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
variable_pool (VariablePool): Pool of variables for evaluating segment values.
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
# Simulate evaluation (replace with actual logic)
|
||||
chunk = variable_pool.get_literal(current_segment.literal)
|
||||
final_chunk += chunk
|
||||
except Exception as e:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||
|
||||
if final_chunk:
|
||||
logger.warning(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
async def flush_remaining_chunk(
|
||||
self,
|
||||
variable_pool: VariablePool
|
||||
) -> AsyncGenerator[dict[str, str | dict], None]:
|
||||
"""
|
||||
Flush and yield all remaining output segments from active End nodes.
|
||||
|
||||
This method ensures that any remaining chunks of output, which may not have
|
||||
been emitted during normal streaming due to activation conditions, are fully
|
||||
processed. It is typically called at the end of a workflow to guarantee
|
||||
that all output is delivered.
|
||||
|
||||
Behavior:
|
||||
1. Filter `end_outputs` to only keep End nodes that are still active.
|
||||
2. While there is an active End node (`self.activate_end`):
|
||||
- Call `_emit_active_chunks(force=True)` to emit all segments regardless
|
||||
of their activation state.
|
||||
- If the current End node finishes, move to the next active End node
|
||||
if any remain.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output events in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Keep only active End nodes
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.activate_end or self.activate_end:
|
||||
while self.activate_end:
|
||||
# Force emit all remaining chunks of the active End node
|
||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||
yield msg_event
|
||||
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
@@ -1,14 +1,7 @@
|
||||
"""
|
||||
变量池 (Variable Pool)
|
||||
|
||||
工作流执行的数据中心,管理所有变量的存储和访问。
|
||||
|
||||
变量类型:
|
||||
1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等)
|
||||
2. 节点输出 (node_id.*) - 节点执行结果
|
||||
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
|
||||
"""
|
||||
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2025/12/15 19:50
|
||||
import logging
|
||||
import re
|
||||
from asyncio import Lock
|
||||
@@ -18,7 +11,8 @@ from typing import Any, Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -359,3 +353,74 @@ class VariablePool:
|
||||
f" runtime_vars={len(runtime_vars)}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class VariablePoolInitializer:
|
||||
def __init__(self, workflow_config: dict):
|
||||
self.workflow_config = workflow_config
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict,
|
||||
execution_context: ExecutionContext
|
||||
) -> None:
|
||||
await self._init_conversation_vars(variable_pool, input_data)
|
||||
await self._init_system_vars(variable_pool, input_data, execution_context)
|
||||
|
||||
async def _init_conversation_vars(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict
|
||||
):
|
||||
init_conv_vars: list[dict] = self.workflow_config.get("variables") or []
|
||||
runtime_conv_vars: dict[str, Any] = input_data.get("conv", {})
|
||||
|
||||
for var_def in init_conv_vars:
|
||||
var_name = var_def.get("name")
|
||||
var_default = runtime_conv_vars.get(var_name, var_def.get("default"))
|
||||
var_type = var_def.get("type")
|
||||
if var_name:
|
||||
if var_default:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
value=var_value,
|
||||
var_type=var_type,
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict,
|
||||
context: ExecutionContext
|
||||
):
|
||||
user_message = input_data.get("message") or ""
|
||||
user_files = input_data.get("files") or []
|
||||
|
||||
input_variables = input_data.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (context.execution_id, VariableType.STRING),
|
||||
"workspace_id": (context.workspace_id, VariableType.STRING),
|
||||
"user_id": (context.user_id, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
"""
|
||||
工作流执行器
|
||||
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 13:51
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.engine.event_stream_handler import EventStreamHandler
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
from app.core.workflow.engine.result_builder import WorkflowResultBuiler
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.engine.state_manager import WorkflowStateManager
|
||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,9 +29,7 @@ class WorkflowExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
):
|
||||
"""Initialize Workflow Executor.
|
||||
|
||||
@@ -41,13 +38,10 @@ class WorkflowExecutor:
|
||||
|
||||
Args:
|
||||
workflow_config (dict): The workflow configuration dictionary.
|
||||
execution_id (str): Unique identifier for this workflow execution.
|
||||
workspace_id (str): Workspace or project ID.
|
||||
user_id (str): User ID executing the workflow.
|
||||
execution_context (ExecutionContext): The workflow execution context
|
||||
include execution_id, workspace_id, user_id, checkpoint_config
|
||||
|
||||
Attributes:
|
||||
self.nodes (list): List of node definitions from workflow_config.
|
||||
self.edges (list): List of edge definitions from workflow_config.
|
||||
self.execution_config (dict): Optional execution parameters from workflow_config.
|
||||
self.start_node_id (str | None): ID of the Start node, set after graph build.
|
||||
self.end_outputs (dict[str, StreamOutputConfig]): End node output configs.
|
||||
@@ -57,555 +51,18 @@ class WorkflowExecutor:
|
||||
self.checkpoint_config (RunnableConfig): Config for LangGraph checkpointing.
|
||||
"""
|
||||
self.workflow_config = workflow_config
|
||||
self.execution_id = execution_id
|
||||
self.workspace_id = workspace_id
|
||||
self.user_id = user_id
|
||||
self.nodes = workflow_config.get("nodes", [])
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_context = execution_context
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
self.start_node_id: str | None = None
|
||||
self.variable_pool: VariablePool | None = None
|
||||
|
||||
self.graph: CompiledStateGraph | None = None
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
)
|
||||
|
||||
async def __init_variable_pool(self, input_data: dict[str, Any]):
|
||||
"""Initialize the variable pool with system, conversation, and input variables.
|
||||
|
||||
This method populates the VariablePool instance with:
|
||||
- Conversation-level variables (`conv` namespace) from workflow config or provided values.
|
||||
- System variables (`sys` namespace) such as message, files, conversation_id, execution_id, workspace_id, user_id, and input_variables.
|
||||
|
||||
Args:
|
||||
input_data (dict): Input data for workflow execution, may contain:
|
||||
- "message": user message (str)
|
||||
- "file": list of user-uploaded files
|
||||
- "conv": existing conversation variables (dict)
|
||||
- "variables": custom variables for the Start node (dict)
|
||||
- "conversation_id": conversation identifier
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
user_files = input_data.get("files") or []
|
||||
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
conv_vars = input_data.get("conv", {})
|
||||
|
||||
# Initialize conversation variables (conv namespace)
|
||||
for var_def in config_variables_list:
|
||||
var_name = var_def.get("name")
|
||||
var_default = conv_vars.get(var_name, var_def.get("default"))
|
||||
var_type = var_def.get("type")
|
||||
if var_name:
|
||||
if var_default:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
await self.variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
value=var_value,
|
||||
var_type=var_type,
|
||||
mut=True
|
||||
)
|
||||
|
||||
# Initialize system variables (sys namespace)
|
||||
input_variables = input_data.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (self.execution_id, VariableType.STRING),
|
||||
"workspace_id": (self.workspace_id, VariableType.STRING),
|
||||
"user_id": (self.user_id, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await self.variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""Generate the initial workflow state for execution.
|
||||
|
||||
This method prepares the runtime state dictionary with system variables,
|
||||
conversation variables, node outputs, loop tracking, and activation flags.
|
||||
|
||||
Args:
|
||||
input_data (dict): The input payload for workflow execution.
|
||||
Expected keys:
|
||||
- "conv_messages" (list, optional): Historical conversation messages
|
||||
to include in the workflow state.
|
||||
|
||||
Returns:
|
||||
WorkflowState: A dictionary representing the initialized workflow state
|
||||
with the following keys:
|
||||
- "messages": List of conversation messages
|
||||
- "node_outputs": Empty dict to store outputs of executed nodes
|
||||
- "execution_id": Current workflow execution ID
|
||||
- "workspace_id": Current workspace ID
|
||||
- "user_id": ID of the user triggering execution
|
||||
- "error": None initially, will store error message if a node fails
|
||||
- "error_node": None initially, will store ID of node that caused error
|
||||
- "cycle_nodes": List of node IDs that are of type LOOP or ITERATION
|
||||
- "looping": Integer flag indicating loop execution state (0 = not looping)
|
||||
- "activate": Dict mapping node IDs to activation status; initially
|
||||
only the start node is active
|
||||
"""
|
||||
conversation_messages = input_data.get("conv_messages") or []
|
||||
|
||||
return {
|
||||
"messages": conversation_messages,
|
||||
"node_outputs": {},
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
], # loop, iteration node id
|
||||
"looping": 0, # loop runing flag, only use in loop node,not use in main loop
|
||||
"activate": {
|
||||
self.start_node_id: True
|
||||
}
|
||||
}
|
||||
|
||||
def _build_final_output(self, result, elapsed_time, final_output):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
This method aggregates node outputs, token usage, conversation and system
|
||||
variables, messages, and other metadata into a consistent dictionary
|
||||
structure suitable for returning from workflow execution.
|
||||
|
||||
Args:
|
||||
result (dict): The runtime state returned by the workflow graph execution.
|
||||
Expected keys include:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
- "status": Execution status ("completed")
|
||||
- "output": Aggregated final output content
|
||||
- "variables": Namespace dictionary with:
|
||||
- "conv": Conversation variables
|
||||
- "sys": System variables
|
||||
- "node_outputs": Outputs from all executed nodes
|
||||
- "messages": Conversation messages exchanged
|
||||
- "conversation_id": ID of the current conversation
|
||||
- "elapsed_time": Total execution time in seconds
|
||||
- "token_usage": Aggregated token usage across nodes (if available)
|
||||
- "error": Error message if any occurred during execution
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = self.variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": self.variable_pool.get_all_conversation_vars(),
|
||||
"sys": self.variable_pool.get_all_system_vars()
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
def _update_scope_activate(self, scope, status=None):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if any control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
def _update_stream_output_status(self, activate, data):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
|
||||
self._update_scope_activate(node_id, status=node_output_status)
|
||||
|
||||
async def _emit_active_chunks(
|
||||
self,
|
||||
force=False
|
||||
):
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
chunk = self.variable_pool.get_literal(current_segment.literal)
|
||||
final_chunk += chunk
|
||||
except KeyError:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||
|
||||
if final_chunk:
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
# Remove End node from active tracking if all segments have been processed
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
async def _handle_updates_event(self, data):
|
||||
"""
|
||||
Handle workflow state update events ("updates") and stream active End node outputs.
|
||||
|
||||
Steps:
|
||||
1. Retrieve the current graph state.
|
||||
2. Extract node activation information from the state.
|
||||
3. Update the activation status of all End nodes.
|
||||
4. While there is an active End node:
|
||||
- Call _emit_active_chunks() to yield all currently active output segments.
|
||||
- After all segments are processed, update activate_end if there are remaining End nodes.
|
||||
5. Log a debug message indicating state update received.
|
||||
|
||||
Args:
|
||||
data (dict): The latest node state updates.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output event, each chunk in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Get the latest workflow state
|
||||
state = self.graph.get_state(config=self.checkpoint_config).values
|
||||
activate = state.get("activate", {})
|
||||
|
||||
# Update End node activation based on the new state
|
||||
self._update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.activate_end and not wait:
|
||||
async for msg_event in self._emit_active_chunks():
|
||||
yield msg_event
|
||||
|
||||
if self.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self._update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
async def _handle_node_chunk_event(self, data):
|
||||
"""
|
||||
Handle streaming chunk events from individual nodes ("node_chunk").
|
||||
|
||||
This method processes output segments for the currently active End node.
|
||||
If the segment depends on the provided node_id:
|
||||
- If the node has finished execution (`done=True`), advance the cursor.
|
||||
- If all segments are processed, deactivate the End node.
|
||||
- Otherwise, yield the current chunk as a streaming message.
|
||||
|
||||
Args:
|
||||
data (dict): Node chunk event data, expected keys:
|
||||
- "node_id": ID of the node producing this chunk
|
||||
- "chunk": Chunk of output text
|
||||
- "done": Boolean indicating whether the node finished producing output
|
||||
|
||||
Yields:
|
||||
dict: Streaming message event in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
if self.activate_end:
|
||||
end_info = self.end_outputs.get(self.activate_end)
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
return
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
else:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
async def _handle_node_error_event(self, data):
|
||||
"""
|
||||
Handle node error events ("node_error") during workflow execution.
|
||||
|
||||
This method streams an error event for a node that has failed. The event
|
||||
contains the node ID, status, input data, elapsed time, and error message.
|
||||
|
||||
Args:
|
||||
data (dict): Node error event data, expected keys:
|
||||
- "node_id": ID of the node that failed
|
||||
- "input_data": The input data that caused the error
|
||||
- "elapsed_time": Execution time before the error occurred
|
||||
- "error": Error message or exception string
|
||||
|
||||
Yields:
|
||||
dict: Node error event in the format:
|
||||
{
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"status": "failed",
|
||||
"input": ...,
|
||||
"elapsed_time": float,
|
||||
"output": None,
|
||||
"error": str
|
||||
}
|
||||
}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
yield {
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"status": "failed",
|
||||
"input": data.get("input_data"),
|
||||
"elapsed_time": data.get("elapsed_time"),
|
||||
"output": None,
|
||||
"error": data.get("error")
|
||||
}
|
||||
}
|
||||
|
||||
async def _handle_debug_event(self, data, input_data):
|
||||
"""
|
||||
Handle debug events ("debug") related to node execution status.
|
||||
|
||||
This method streams debug events for nodes, including when a node starts
|
||||
execution ("node_start") and when it completes execution ("node_end").
|
||||
It filters out nodes with names starting with "nop" as no-operation nodes.
|
||||
|
||||
Args:
|
||||
data (dict): Debug event data, expected keys:
|
||||
- "type": Event type ("task" for start, "task_result" for completion)
|
||||
- "payload": Node-related information, including:
|
||||
- "name": Node name / ID
|
||||
- "input": Node input data (for "task" type)
|
||||
- "result": Node execution result (for "task_result" type)
|
||||
- "timestamp": ISO timestamp string of the event
|
||||
input_data (dict): Original workflow input data (used to get conversation_id)
|
||||
|
||||
Yields:
|
||||
dict: Node debug event in one of the following formats:
|
||||
1. Node start:
|
||||
{
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms)
|
||||
}
|
||||
}
|
||||
2. Node end:
|
||||
{
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms),
|
||||
"input": dict,
|
||||
"output": Any,
|
||||
"elapsed_time": float
|
||||
}
|
||||
}
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
# Skip no-operation nodes
|
||||
if node_name and node_name.startswith("nop"):
|
||||
return
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
if not inputv.get("activate", {}).get(node_name):
|
||||
return
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
logger.info(f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
logger.info(f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
}
|
||||
|
||||
async def _flush_remaining_chunk(self):
|
||||
"""
|
||||
Flush and yield all remaining output segments from active End nodes.
|
||||
|
||||
This method ensures that any remaining chunks of output, which may not have
|
||||
been emitted during normal streaming due to activation conditions, are fully
|
||||
processed. It is typically called at the end of a workflow to guarantee
|
||||
that all output is delivered.
|
||||
|
||||
Behavior:
|
||||
1. Filter `end_outputs` to only keep End nodes that are still active.
|
||||
2. While there is an active End node (`self.activate_end`):
|
||||
- Call `_emit_active_chunks(force=True)` to emit all segments regardless
|
||||
of their activation state.
|
||||
- If the current End node finishes, move to the next active End node
|
||||
if any remain.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output events in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Keep only active End nodes
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.end_outputs or self.activate_end:
|
||||
while self.activate_end:
|
||||
# Force emit all remaining chunks of the active End node
|
||||
async for msg_event in self._emit_active_chunks(force=True):
|
||||
yield msg_event
|
||||
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
self.variable_initializer = VariablePoolInitializer(workflow_config)
|
||||
self.state_manager = WorkflowStateManager()
|
||||
self.result_builder = WorkflowResultBuiler()
|
||||
self.stream_coordinator = StreamOutputCoordinator()
|
||||
self.event_handler: EventStreamHandler | None = None
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""
|
||||
@@ -624,16 +81,22 @@ class WorkflowExecutor:
|
||||
Returns:
|
||||
CompiledStateGraph: The compiled and ready-to-run state graph.
|
||||
"""
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
||||
builder = GraphBuilder(
|
||||
self.workflow_config,
|
||||
stream=stream,
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.end_outputs = builder.end_node_map
|
||||
self.variable_pool = builder.variable_pool
|
||||
self.graph = builder.build()
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_id}")
|
||||
|
||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||
self.event_handler = EventStreamHandler(
|
||||
output_coordinator=self.stream_coordinator,
|
||||
variable_pool=self.variable_pool,
|
||||
execution_id=self.execution_context.execution_id
|
||||
)
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
return self.graph
|
||||
|
||||
@@ -665,7 +128,7 @@ class WorkflowExecutor:
|
||||
- token_usage: aggregated token usage if available
|
||||
- error: error message if any
|
||||
"""
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
@@ -673,16 +136,25 @@ class WorkflowExecutor:
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.__init_variable_pool(input_data)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
# Execute the workflow
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
|
||||
# Aggregate output from all End nodes
|
||||
full_content = ''
|
||||
for end_id in self.end_outputs.keys():
|
||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
|
||||
# Append messages for user and assistant
|
||||
@@ -703,15 +175,16 @@ class WorkflowExecutor:
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Workflow execution completed: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return self._build_final_output(result, elapsed_time, full_content)
|
||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
@@ -744,15 +217,15 @@ class WorkflowExecutor:
|
||||
"data": {...}
|
||||
}
|
||||
"""
|
||||
logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
yield {
|
||||
"event": "workflow_start",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"workspace_id": self.execution_context.workspace_id,
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
@@ -762,18 +235,27 @@ class WorkflowExecutor:
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# Initialize the variable pool and system variables
|
||||
await self.__init_variable_pool(input_data)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
try:
|
||||
full_content = ''
|
||||
self._update_scope_activate("sys")
|
||||
self.stream_coordinator.update_scope_activation("sys")
|
||||
|
||||
# Execute the workflow with streaming
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
config=self.checkpoint_config
|
||||
config=self.execution_context.checkpoint_config
|
||||
):
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
@@ -782,38 +264,42 @@ class WorkflowExecutor:
|
||||
else:
|
||||
# Unexpected format, log and skip
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
f"- execution_id: {self.execution_context.execution_id}")
|
||||
continue
|
||||
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type == "node_chunk":
|
||||
async for msg_event in self._handle_node_chunk_event(data):
|
||||
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
||||
full_content += msg_event["data"]["chunk"]
|
||||
yield msg_event
|
||||
|
||||
elif event_type == "node_error":
|
||||
async for error_event in self._handle_node_error_event(data):
|
||||
async for error_event in self.event_handler.handle_node_error_event(data):
|
||||
yield error_event
|
||||
|
||||
elif mode == "debug":
|
||||
async for debug_event in self._handle_debug_event(data, input_data):
|
||||
async for debug_event in self.event_handler.handle_debug_event(data, input_data):
|
||||
yield debug_event
|
||||
|
||||
elif mode == "updates":
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
async for msg_event in self._handle_updates_event(data):
|
||||
f"- execution_id: {self.execution_context.execution_id}")
|
||||
async for msg_event in self.event_handler.handle_updates_event(
|
||||
data,
|
||||
self.graph,
|
||||
self.execution_context.checkpoint_config
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
# Flush any remaining chunks
|
||||
async for msg_event in self._flush_remaining_chunk():
|
||||
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
result = graph.get_state(self.execution_context.checkpoint_config).values
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
@@ -832,24 +318,25 @@ class WorkflowExecutor:
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
|
||||
)
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -857,46 +344,6 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""
|
||||
Aggregate token usage statistics across all nodes.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): A dictionary of all node outputs.
|
||||
|
||||
Returns:
|
||||
dict | None: Aggregated token usage in the format:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
Returns None if no token usage information is available.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
@@ -918,12 +365,15 @@ async def execute_workflow(
|
||||
Returns:
|
||||
dict: Workflow execution result.
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context=execution_context
|
||||
)
|
||||
return await executor.execute(input_data)
|
||||
|
||||
|
||||
@@ -947,11 +397,14 @@ async def execute_workflow_stream(
|
||||
Yields:
|
||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context=execution_context
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
@@ -14,16 +15,14 @@ from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
"WorkflowState",
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"IfElseNode",
|
||||
|
||||
@@ -7,14 +7,16 @@ Agent 节点实现
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import AppRelease
|
||||
from app.db import get_db
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,12 +2,13 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,57 +5,17 @@ from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
return y if y > x else x
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[dict[str, str]], lambda x, y: y]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""Base class for workflow nodes.
|
||||
|
||||
@@ -584,7 +544,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
The rendered string with all variables substituted.
|
||||
"""
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
from app.core.workflow.utils.template_renderer import render_template
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
@@ -611,7 +571,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
The boolean result of evaluating the expression.
|
||||
"""
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,13 +6,14 @@ import urllib.parse
|
||||
from string import Template
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,14 +4,14 @@ from typing import Any
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -136,7 +136,7 @@ class CycleGraphNode(BaseNode):
|
||||
2. Construct a StateGraph using GraphBuilder in subgraph mode
|
||||
3. Compile the graph for runtime execution
|
||||
"""
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
|
||||
@@ -6,9 +6,10 @@ End 节点实现
|
||||
|
||||
import logging
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,11 +7,12 @@ import httpx
|
||||
# import filetypes # TODO: File support (Feature)
|
||||
from httpx import AsyncClient, Response, Timeout
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class IfElseNodeConfig(BaseNodeConfig):
|
||||
|
||||
@field_validator("cases")
|
||||
@classmethod
|
||||
def validate_case_number(cls, v, info):
|
||||
def validate_case_number(cls, v):
|
||||
if len(v) < 1:
|
||||
raise ValueError("At least one cases are required")
|
||||
return v
|
||||
|
||||
@@ -2,12 +2,13 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.utils.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -56,7 +57,7 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
model_id: str = Field(
|
||||
model_id: uuid.UUID = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
@@ -148,7 +149,7 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
def validate_input_mode(cls, v):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
@@ -13,10 +13,11 @@ from langchain_core.messages import AIMessage
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -268,7 +269,7 @@ class LLMNode(BaseNode):
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -3,9 +3,9 @@ import re
|
||||
from abc import ABC
|
||||
from typing import Union, Type, NoReturn, Any
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import ValueInputType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class TypeTransformer:
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import json_repair
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from jinja2 import Template
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -7,10 +7,11 @@ Start 节点实现
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,16 +4,17 @@ import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.tool_service import ToolService
|
||||
from app.db import get_db_read
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
|
||||
@@ -2,11 +2,11 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
4
api/app/core/workflow/utils/__init__.py
Normal file
4
api/app/core/workflow/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
@@ -5,7 +5,6 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
@@ -187,7 +187,7 @@ class WorkflowValidator:
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.utils.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.engine.variable_pool import VariablePool, VariableSelector
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool, VariableSelector
|
||||
|
||||
|
||||
# ==================== VariableSelector 测试 ====================
|
||||
|
||||
@@ -6,8 +6,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
TEST_WORKSPACE_ID = "test_workspace_id"
|
||||
TEST_USER_ID = "test_user_id"
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import StartNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from tests.workflow.nodes.base import (
|
||||
simple_state,
|
||||
simple_state,
|
||||
simple_vairable_pool,
|
||||
TEST_EXECUTION_ID,
|
||||
TEST_WORKSPACE_ID,
|
||||
|
||||
Reference in New Issue
Block a user