From bd8a4518790e97b00afbbaf5bc45991d59ad785b Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 4 Feb 2026 11:01:16 +0800 Subject: [PATCH] feat(workflow): enforce strong typing for runtime variables - Reduce exposed information in release workflows --- .../controllers/public_share_controller.py | 3 +- api/app/core/workflow/executor.py | 800 +++++++++++------- api/app/core/workflow/expression_evaluator.py | 171 ++-- api/app/core/workflow/graph_builder.py | 52 +- api/app/core/workflow/nodes/__init__.py | 2 - api/app/core/workflow/nodes/agent/config.py | 3 +- api/app/core/workflow/nodes/agent/node.py | 32 +- api/app/core/workflow/nodes/assigner/node.py | 48 +- api/app/core/workflow/nodes/base_config.py | 70 +- api/app/core/workflow/nodes/base_node.py | 614 +++++++------- api/app/core/workflow/nodes/breaker/node.py | 10 +- api/app/core/workflow/nodes/code/config.py | 5 +- api/app/core/workflow/nodes/code/node.py | 15 +- api/app/core/workflow/nodes/configs.py | 4 - .../core/workflow/nodes/cycle_graph/config.py | 8 +- .../workflow/nodes/cycle_graph/iteration.py | 31 +- .../core/workflow/nodes/cycle_graph/loop.py | 78 +- .../core/workflow/nodes/cycle_graph/node.py | 44 +- api/app/core/workflow/nodes/end/config.py | 3 +- api/app/core/workflow/nodes/end/node.py | 12 +- api/app/core/workflow/nodes/enums.py | 1 - .../core/workflow/nodes/http_request/node.py | 45 +- api/app/core/workflow/nodes/if_else/node.py | 19 +- .../core/workflow/nodes/jinja_render/node.py | 12 +- api/app/core/workflow/nodes/knowledge/node.py | 12 +- api/app/core/workflow/nodes/llm/config.py | 3 +- api/app/core/workflow/nodes/llm/node.py | 38 +- api/app/core/workflow/nodes/memory/node.py | 23 +- api/app/core/workflow/nodes/node_factory.py | 3 - api/app/core/workflow/nodes/operators.py | 101 ++- .../nodes/parameter_extractor/config.py | 2 +- .../nodes/parameter_extractor/node.py | 15 +- .../nodes/question_classifier/node.py | 12 +- api/app/core/workflow/nodes/start/config.py | 3 +- api/app/core/workflow/nodes/start/node.py | 61 +- api/app/core/workflow/nodes/tool/node.py | 19 +- .../core/workflow/nodes/transform/__init__.py | 6 - .../core/workflow/nodes/transform/config.py | 80 -- api/app/core/workflow/nodes/transform/node.py | 60 -- .../nodes/variable_aggregator/config.py | 6 + .../nodes/variable_aggregator/node.py | 19 +- api/app/core/workflow/template_renderer.py | 31 +- api/app/core/workflow/validator.py | 9 +- api/app/core/workflow/variable/__init__.py | 0 .../core/workflow/variable/base_variable.py | 162 ++++ .../workflow/variable/variable_objects.py | 137 +++ api/app/core/workflow/variable_pool.py | 363 ++++---- api/app/services/app_chat_service.py | 4 +- api/app/services/workflow_service.py | 44 +- sandbox/app/controllers/sandbox_controller.py | 2 +- 50 files changed, 1925 insertions(+), 1372 deletions(-) delete mode 100644 api/app/core/workflow/nodes/transform/__init__.py delete mode 100644 api/app/core/workflow/nodes/transform/config.py delete mode 100644 api/app/core/workflow/nodes/transform/node.py create mode 100644 api/app/core/workflow/variable/__init__.py create mode 100644 api/app/core/workflow/variable/base_variable.py create mode 100644 api/app/core/workflow/variable/variable_objects.py diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 536dffd9..9435fc9b 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -587,7 +587,8 @@ async def chat( user_rag_memory_id=user_rag_memory_id, app_id=release.app_id, workspace_id=workspace_id, - release_id=release.id + release_id=release.id, + public=True ): event_type = event.get("event", "message") event_data = event.get("data", {}) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index b7abf659..f3763955 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -11,19 +11,20 @@ from typing import Any from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_expression from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig from app.core.workflow.nodes import WorkflowState -from app.core.workflow.nodes.base_config import VariableType from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) class WorkflowExecutor: - """工作流执行器 + """Workflow Executor. - 负责将工作流配置转换为 LangGraph 并执行。 + Converts workflow configuration into a LangGraph and executes it, + supporting both synchronous and streaming execution modes. """ def __init__( @@ -31,15 +32,29 @@ class WorkflowExecutor: workflow_config: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, ): - """初始化执行器 + """Initialize Workflow Executor. + + Converts a workflow configuration into an executor instance that can + run the workflow in both streaming and non-streaming modes. Args: - workflow_config: 工作流配置 - execution_id: 执行 ID - workspace_id: 工作空间 ID - user_id: 用户 ID + workflow_config (dict): The workflow configuration dictionary. + execution_id (str): Unique identifier for this workflow execution. + workspace_id (str): Workspace or project ID. + user_id (str): User ID executing the workflow. + + Attributes: + self.nodes (list): List of node definitions from workflow_config. + self.edges (list): List of edge definitions from workflow_config. + self.execution_config (dict): Optional execution parameters from workflow_config. + self.start_node_id (str | None): ID of the Start node, set after graph build. + self.end_outputs (dict[str, StreamOutputConfig]): End node output configs. + self.activate_end (str | None): Currently active End node ID for streaming outputs. + self.variable_pool (VariablePool | None): Variable pool instance. + self.graph (CompiledStateGraph | None): Compiled workflow graph. + self.checkpoint_config (RunnableConfig): Config for LangGraph checkpointing. """ self.workflow_config = workflow_config self.execution_id = execution_id @@ -52,73 +67,108 @@ class WorkflowExecutor: self.start_node_id = None self.end_outputs: dict[str, StreamOutputConfig] = {} self.activate_end: str | None = None + self.variable_pool: VariablePool | None = None + self.graph: CompiledStateGraph | None = None self.checkpoint_config = RunnableConfig( configurable={ "thread_id": uuid.uuid4(), } ) - def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: - """准备初始状态(注入系统变量和会话变量) + async def __init_variable_pool(self, input_data: dict[str, Any]): + """Initialize the variable pool with system, conversation, and input variables. - 变量命名空间: - - sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等) - - conv.xxx - 会话变量(跨多轮对话保持) - - node_id.xxx - 节点输出(执行时动态生成) + This method populates the VariablePool instance with: + - Conversation-level variables (`conv` namespace) from workflow config or provided values. + - System variables (`sys` namespace) such as message, files, conversation_id, execution_id, workspace_id, user_id, and input_variables. Args: - input_data: 输入数据 - - Returns: - 初始化的工作流状态 + input_data (dict): Input data for workflow execution, may contain: + - "message": user message (str) + - "file": list of user-uploaded files + - "conv": existing conversation variables (dict) + - "variables": custom variables for the Start node (dict) + - "conversation_id": conversation identifier """ user_message = input_data.get("message") or "" - conversation_messages = input_data.get("conv_messages") or [] + user_file = input_data.get("file") or [] - # 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value) config_variables_list = self.workflow_config.get("variables") or [] - conversation_vars = {} + conv_vars = input_data.get("conv", {}) + + # Initialize conversation variables (conv namespace) for var_def in config_variables_list: - if isinstance(var_def, dict): - var_name = var_def.get("name") - var_default = var_def.get("default") - if var_name: - if var_default: - conversation_vars[var_name] = var_default - else: - var_type = var_def.get("type") - match var_type: - case VariableType.STRING: - conversation_vars[var_name] = "" - case VariableType.NUMBER: - conversation_vars[var_name] = 0 - case VariableType.OBJECT: - conversation_vars[var_name] = {} - case VariableType.BOOLEAN: - conversation_vars[var_name] = False - case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING: - conversation_vars[var_name] = [] - input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 - conversation_vars = conversation_vars | input_data.get("conv", {}) - # 构建分层的变量结构 - variables = { - "sys": { - "message": user_message, # 用户消息 - "conversation_id": input_data.get("conversation_id"), # 会话 ID - "execution_id": self.execution_id, # 执行 ID - "workspace_id": self.workspace_id, # 工作空间 ID - "user_id": self.user_id, # 用户 ID - "input_variables": input_variables, # 自定义输入变量(给 Start 节点使用) - }, - "conv": conversation_vars # 会话级变量(跨多轮对话保持) + var_name = var_def.get("name") + var_default = conv_vars.get(var_name, var_def.get("default")) + var_type = var_def.get("type") + if var_name: + if var_default: + var_value = var_default + else: + var_value = DEFAULT_VALUE(var_type) + await self.variable_pool.new( + namespace="conv", + key=var_name, + value=var_value, + var_type=var_type, + mut=True + ) + + # Initialize system variables (sys namespace) + input_variables = input_data.get("variables") or {} + sys_vars = { + "message": (user_message, VariableType.STRING), + "file": (user_file, VariableType.ARRAY_FILE), + "conversation_id": (input_data.get("conversation_id"), VariableType.STRING), + "execution_id": (self.execution_id, VariableType.STRING), + "workspace_id": (self.workspace_id, VariableType.STRING), + "user_id": (self.user_id, VariableType.STRING), + "input_variables": (input_variables, VariableType.OBJECT), } + for key, var_def in sys_vars.items(): + value = var_def[0] + var_type = var_def[1] + await self.variable_pool.new( + namespace='sys', + key=key, + value=value, + var_type=var_type, + mut=False + ) + + def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: + """Generate the initial workflow state for execution. + + This method prepares the runtime state dictionary with system variables, + conversation variables, node outputs, loop tracking, and activation flags. + + Args: + input_data (dict): The input payload for workflow execution. + Expected keys: + - "conv_messages" (list, optional): Historical conversation messages + to include in the workflow state. + + Returns: + WorkflowState: A dictionary representing the initialized workflow state + with the following keys: + - "messages": List of conversation messages + - "node_outputs": Empty dict to store outputs of executed nodes + - "execution_id": Current workflow execution ID + - "workspace_id": Current workspace ID + - "user_id": ID of the user triggering execution + - "error": None initially, will store error message if a node fails + - "error_node": None initially, will store ID of node that caused error + - "cycle_nodes": List of node IDs that are of type LOOP or ITERATION + - "looping": Integer flag indicating loop execution state (0 = not looping) + - "activate": Dict mapping node IDs to activation status; initially + only the start node is active + """ + conversation_messages = input_data.get("conv_messages") or [] return { "messages": conversation_messages, - "variables": variables, "node_outputs": {}, - "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) "execution_id": self.execution_id, "workspace_id": self.workspace_id, "user_id": self.user_id, @@ -136,18 +186,47 @@ class WorkflowExecutor: } def _build_final_output(self, result, elapsed_time, final_output): + """Construct the final standardized output of the workflow execution. + + This method aggregates node outputs, token usage, conversation and system + variables, messages, and other metadata into a consistent dictionary + structure suitable for returning from workflow execution. + + Args: + result (dict): The runtime state returned by the workflow graph execution. + Expected keys include: + - "node_outputs" (dict): Outputs of executed nodes. + - "messages" (list): Conversation messages exchanged during execution. + - "error" (str, optional): Error message if any node failed. + elapsed_time (float): Total execution time in seconds. + final_output (Any): The aggregated or final output content of the workflow + (e.g., combined messages from all End nodes). + + Returns: + dict: A dictionary containing the final workflow execution result with keys: + - "status": Execution status ("completed") + - "output": Aggregated final output content + - "variables": Namespace dictionary with: + - "conv": Conversation variables + - "sys": System variables + - "node_outputs": Outputs from all executed nodes + - "messages": Conversation messages exchanged + - "conversation_id": ID of the current conversation + - "elapsed_time": Total execution time in seconds + - "token_usage": Aggregated token usage across nodes (if available) + - "error": Error message if any occurred during execution + """ node_outputs = result.get("node_outputs", {}) token_usage = self._aggregate_token_usage(node_outputs) - conversation_id = None - for node_id, node_output in node_outputs.items(): - if node_output.get("node_type") == "start": - conversation_id = node_output.get("output", {}).get("conversation_id") - break + conversation_id = self.variable_pool.get_value("sys.conversation_id") return { "status": "completed", "output": final_output, - "variables": result.get("variables", {}), + "variables": { + "conv": self.variable_pool.get_all_conversation_vars(), + "sys": self.variable_pool.get_all_system_vars() + }, "node_outputs": node_outputs, "messages": result.get("messages", []), "conversation_id": conversation_id, @@ -163,7 +242,7 @@ class WorkflowExecutor: Iterates over all End nodes in `self.end_outputs` and calls `update_activate` on each, which may: - Activate variable segments that depend on the completed node/scope. - - Activate the entire End node output if all control conditions are met. + - Activate the entire End node output if any control conditions are met. If any End node becomes active and `self.activate_end` is not yet set, this node will be marked as the currently active End node. @@ -197,18 +276,11 @@ class WorkflowExecutor: """ for node_id in data.keys(): if activate.get(node_id): - node_output_status = ( - data[node_id] - .get('runtime_vars', {}) - .get(node_id) - .get("output") - ) + node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False) self._update_scope_activate(node_id, status=node_output_status) async def _emit_active_chunks( self, - node_outputs: dict, - variables: dict, force=False ): """ @@ -231,8 +303,6 @@ class WorkflowExecutor: and reset `activate_end` to None. Args: - node_outputs (dict): Current runtime node outputs, used for variable evaluation. - variables (dict): Current runtime variables, used for variable evaluation. force (bool, default=False): If True, process segments even if `activate=False`. Yields: @@ -260,14 +330,9 @@ class WorkflowExecutor: else: # Variable segment: evaluate and transform try: - chunk = evaluate_expression( - current_segment.literal, - variables=variables, - node_outputs=node_outputs - ) - chunk = self._trans_output_string(chunk) + chunk = self.variable_pool.get_literal(current_segment.literal) final_chunk += chunk - except ValueError: + except KeyError: # Log failed evaluation but continue streaming logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}") @@ -287,63 +352,338 @@ class WorkflowExecutor: self.end_outputs.pop(self.activate_end) self.activate_end = None - @staticmethod - def _trans_output_string(content): - if isinstance(content, str): - return content - elif isinstance(content, list): - return "\n".join(content) - else: - return str(content) + async def _handle_updates_event(self, data): + """ + Handle workflow state update events ("updates") and stream active End node outputs. + + Steps: + 1. Retrieve the current graph state. + 2. Extract node activation information from the state. + 3. Update the activation status of all End nodes. + 4. While there is an active End node: + - Call _emit_active_chunks() to yield all currently active output segments. + - After all segments are processed, update activate_end if there are remaining End nodes. + 5. Log a debug message indicating state update received. + + Args: + data (dict): The latest node state updates. + + Yields: + dict: Streamed output event, each chunk in the format: + {"event": "message", "data": {"chunk": ...}} + """ + # Get the latest workflow state + state = self.graph.get_state(config=self.checkpoint_config).values + activate = state.get("activate", {}) + + # Update End node activation based on the new state + self._update_stream_output_status(activate, data) + wait = False + while self.activate_end and not wait: + async for msg_event in self._emit_active_chunks(): + yield msg_event + + if self.activate_end: + wait = True + else: + self._update_stream_output_status(activate, data) + + logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} " + f"- execution_id: {self.execution_id}") + + async def _handle_node_chunk_event(self, data): + """ + Handle streaming chunk events from individual nodes ("node_chunk"). + + This method processes output segments for the currently active End node. + If the segment depends on the provided node_id: + - If the node has finished execution (`done=True`), advance the cursor. + - If all segments are processed, deactivate the End node. + - Otherwise, yield the current chunk as a streaming message. + + Args: + data (dict): Node chunk event data, expected keys: + - "node_id": ID of the node producing this chunk + - "chunk": Chunk of output text + - "done": Boolean indicating whether the node finished producing output + + Yields: + dict: Streaming message event in the format: + {"event": "message", "data": {"chunk": ...}} + """ + node_id = data.get("node_id") + if self.activate_end: + end_info = self.end_outputs.get(self.activate_end) + if not end_info or end_info.cursor >= len(end_info.outputs): + return + current_output = end_info.outputs[end_info.cursor] + if current_output.is_variable and current_output.depends_on_scope(node_id): + if data.get("done"): + end_info.cursor += 1 + if end_info.cursor >= len(end_info.outputs): + self.end_outputs.pop(self.activate_end) + self.activate_end = None + else: + yield { + "event": "message", + "data": { + "chunk": data.get("chunk") + } + } + + async def _handle_node_error_event(self, data): + """ + Handle node error events ("node_error") during workflow execution. + + This method streams an error event for a node that has failed. The event + contains the node ID, status, input data, elapsed time, and error message. + + Args: + data (dict): Node error event data, expected keys: + - "node_id": ID of the node that failed + - "input_data": The input data that caused the error + - "elapsed_time": Execution time before the error occurred + - "error": Error message or exception string + + Yields: + dict: Node error event in the format: + { + "event": "node_error", + "data": { + "node_id": str, + "status": "failed", + "input": ..., + "elapsed_time": float, + "output": None, + "error": str + } + } + """ + node_id = data.get("node_id") + yield { + "event": "node_error", + "data": { + "node_id": node_id, + "status": "failed", + "input": data.get("input_data"), + "elapsed_time": data.get("elapsed_time"), + "output": None, + "error": data.get("error") + } + } + + async def _handle_debug_event(self, data, input_data): + """ + Handle debug events ("debug") related to node execution status. + + This method streams debug events for nodes, including when a node starts + execution ("node_start") and when it completes execution ("node_end"). + It filters out nodes with names starting with "nop" as no-operation nodes. + + Args: + data (dict): Debug event data, expected keys: + - "type": Event type ("task" for start, "task_result" for completion) + - "payload": Node-related information, including: + - "name": Node name / ID + - "input": Node input data (for "task" type) + - "result": Node execution result (for "task_result" type) + - "timestamp": ISO timestamp string of the event + input_data (dict): Original workflow input data (used to get conversation_id) + + Yields: + dict: Node debug event in one of the following formats: + 1. Node start: + { + "event": "node_start", + "data": { + "node_id": str, + "conversation_id": str, + "execution_id": str, + "timestamp": int (ms) + } + } + 2. Node end: + { + "event": "node_end", + "data": { + "node_id": str, + "conversation_id": str, + "execution_id": str, + "timestamp": int (ms), + "input": dict, + "output": Any, + "elapsed_time": float + } + } + """ + event_type = data.get("type") + payload = data.get("payload", {}) + node_name = payload.get("name") + + # Skip no-operation nodes + if node_name and node_name.startswith("nop"): + return + + if event_type == "task": + # Node starts execution + inputv = payload.get("input", {}) + if not inputv.get("activate", {}).get(node_name): + return + conversation_id = input_data.get("conversation_id") + logger.info(f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}") + + yield { + "event": "node_start", + "data": { + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": self.execution_id, + "timestamp": int(datetime.datetime.fromisoformat( + data.get("timestamp") + ).timestamp() * 1000), + } + } + elif event_type == "task_result": + # Node execution completed + result = payload.get("result", {}) + if not result.get("activate", {}).get(node_name): + return + + conversation_id = input_data.get("conversation_id") + logger.info(f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}") + + yield { + "event": "node_end", + "data": { + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": self.execution_id, + "timestamp": int(datetime.datetime.fromisoformat( + data.get("timestamp") + ).timestamp() * 1000), + "input": result.get("node_outputs", {}).get(node_name, {}).get("input"), + "output": result.get("node_outputs", {}).get(node_name, {}).get("output"), + "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), + } + } + + async def _flush_remaining_chunk(self): + """ + Flush and yield all remaining output segments from active End nodes. + + This method ensures that any remaining chunks of output, which may not have + been emitted during normal streaming due to activation conditions, are fully + processed. It is typically called at the end of a workflow to guarantee + that all output is delivered. + + Behavior: + 1. Filter `end_outputs` to only keep End nodes that are still active. + 2. While there is an active End node (`self.activate_end`): + - Call `_emit_active_chunks(force=True)` to emit all segments regardless + of their activation state. + - If the current End node finishes, move to the next active End node + if any remain. + + Yields: + dict: Streamed output events in the format: + {"event": "message", "data": {"chunk": ...}} + """ + # Keep only active End nodes + self.end_outputs = { + node_id: node_info + for node_id, node_info in self.end_outputs.items() + if node_info.activate + } + + if self.end_outputs or self.activate_end: + while self.activate_end: + # Force emit all remaining chunks of the active End node + async for msg_event in self._emit_active_chunks(force=True): + yield msg_event + + # Move to next active End node if current one is done + if not self.activate_end and self.end_outputs: + self.activate_end = list(self.end_outputs.keys())[0] def build_graph(self, stream=False) -> CompiledStateGraph: - """构建 LangGraph + """ + Build the workflow graph using LangGraph. + + This method initializes a GraphBuilder with the workflow configuration, + builds the compiled state graph, and sets up the executor's key attributes: + - `start_node_id`: the ID of the start node in the workflow + - `end_outputs`: mapping of End nodes and their output configurations + - `variable_pool`: pool containing workflow variables + - `graph`: the compiled state graph ready for execution + + Args: + stream (bool, optional): Whether to enable streaming mode. Defaults to False. Returns: - 编译后的状态图 + CompiledStateGraph: The compiled and ready-to-run state graph. """ - logger.info(f"开始构建工作流图: execution_id={self.execution_id}") + logger.info(f"Starting workflow graph build: execution_id={self.execution_id}") builder = GraphBuilder( self.workflow_config, stream=stream, ) self.start_node_id = builder.start_node_id self.end_outputs = builder.end_node_map - graph = builder.build() - logger.info(f"工作流图构建完成: execution_id={self.execution_id}") + self.variable_pool = builder.variable_pool + self.graph = builder.build() + logger.info(f"Workflow graph build completed: execution_id={self.execution_id}") - return graph + return self.graph async def execute( self, input_data: dict[str, Any] ) -> dict[str, Any]: - """执行工作流(非流式) + """ + Execute the workflow in non-streaming (batch) mode. + + Steps: + 1. Build the workflow graph. + 2. Initialize the variable pool and inject system variables. + 3. Prepare the initial workflow state. + 4. Invoke the compiled graph and collect outputs. + 5. Aggregate outputs, messages, and token usage. Args: - input_data: 输入数据,包含 message 和 variables + input_data (dict): Input data including 'message' and 'variables'. Returns: - 执行结果,包含 status, output, node_outputs, elapsed_time, token_usage + dict: Execution result containing: + - status: "completed" or "failed" + - output: aggregated output string from all End nodes + - variables: current conversation and system variables + - node_outputs: all node outputs + - messages: list of messages including user and assistant content + - elapsed_time: workflow execution time in seconds + - token_usage: aggregated token usage if available + - error: error message if any """ - logger.info(f"开始执行工作流: execution_id={self.execution_id}") + logger.info(f"Starting workflow execution: execution_id={self.execution_id}") - # 记录开始时间 start_time = datetime.datetime.now() - # 1. 构建图 + # Build the workflow graph graph = self.build_graph() - # 2. 初始化状态(自动注入系统变量) + # Initialize the variable pool with input data + await self.__init_variable_pool(input_data) initial_state = self._prepare_initial_state(input_data) - # 3. 执行工作流 + # Execute the workflow try: - result = await graph.ainvoke(initial_state, config=self.checkpoint_config) + + # Aggregate output from all End nodes full_content = '' for end_id in self.end_outputs.keys(): - full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '') + full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) + + # Append messages for user and assistant result["messages"].extend( [ { @@ -356,20 +696,19 @@ class WorkflowExecutor: } ] ) - # 计算耗时 + # Calculate elapsed time end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") + logger.info(f"Workflow execution completed: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") return self._build_final_output(result, elapsed_time, full_content) except Exception as e: - # 计算耗时(即使失败也记录) end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) + logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True) return { "status": "failed", "error": str(e), @@ -383,48 +722,52 @@ class WorkflowExecutor: self, input_data: dict[str, Any] ): - """执行工作流(流式) + """ + Execute the workflow in streaming mode. - 使用多个 stream_mode 来获取: - 1. "updates" - 节点的 state 更新和流式 chunk - 2. "debug" - 节点执行的详细信息(开始/完成时间) - 3. "custom" - 自定义流式数据(chunks) + Supports multiple streaming modes: + 1. "updates" - Node state updates and streaming chunks. + 2. "debug" - Detailed node execution info (start/end). + 3. "custom" - Custom streaming chunks from nodes. Args: - input_data: 输入数据 + input_data (dict): Input data including 'message', 'variables', etc. Yields: - 流式事件,格式: - { - "event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message", - "data": {...} - } + dict: Streaming events in the format: + { + "event": "workflow_start" | "workflow_end" | "node_start" | + "node_end" | "node_chunk" | "message", + "data": {...} + } """ - logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") + logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_id}") - # 记录开始时间 start_time = datetime.datetime.now() - # 发送 workflow_start 事件 yield { "event": "workflow_start", "data": { "execution_id": self.execution_id, "workspace_id": self.workspace_id, + "conversation_id": input_data.get("conversation_id"), "timestamp": int(start_time.timestamp() * 1000) } } - # 1. 构建图 + # Build the workflow graph in streaming mode graph = self.build_graph(stream=True) - # 2. 初始化状态(自动注入系统变量) + # Initialize the variable pool and system variables + await self.__init_variable_pool(input_data) initial_state = self._prepare_initial_state(input_data) - # 3. Execute workflow + + try: - chunk_count = 0 full_content = '' self._update_scope_activate("sys") + + # Execute the workflow with streaming async for event in graph.astream( initial_state, stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode @@ -442,153 +785,37 @@ class WorkflowExecutor: if mode == "custom": # Handle custom streaming events (chunks from nodes via stream writer) - chunk_count += 1 event_type = data.get("type", "node_chunk") # "message" or "node_chunk" if event_type == "node_chunk": - node_id = data.get("node_id") - if self.activate_end: - end_info = self.end_outputs.get(self.activate_end) - if not end_info or end_info.cursor >= len(end_info.outputs): - continue - current_output = end_info.outputs[end_info.cursor] - if current_output.is_variable and current_output.depends_on_scope(node_id): - if data.get("done"): - end_info.cursor += 1 - if end_info.cursor >= len(end_info.outputs): - self.end_outputs.pop(self.activate_end) - self.activate_end = None - else: - full_content += data.get("chunk") - yield { - "event": "message", - "data": { - "chunk": data.get("chunk") - } - } - logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}" - f"- execution_id: {self.execution_id}") - - elif event_type == "node_error": - yield { - "event": event_type, # "message" or "node_chunk" - "data": { - "node_id": data.get("node_id"), - "status": "failed", - "input": data.get("input_data"), - "elapsed_time": data.get("elapsed_time"), - "output": None, - "error": data.get("error") - } - } - - elif mode == "debug": - # Handle debug information (node execution status) - event_type = data.get("type") - payload = data.get("payload", {}) - node_name = payload.get("name") - - if node_name and node_name.startswith("nop"): - continue - - if event_type == "task": - # Node starts execution - inputv = payload.get("input", {}) - if not inputv.get("activate", {}).get(node_name): - continue - conversation_id = input_data.get("conversation_id") - logger.info(f"[NODE-START] Node starts execution: {node_name} " - f"- execution_id: {self.execution_id}") - yield { - "event": "node_start", - "data": { - "node_id": node_name, - "conversation_id": conversation_id, - "execution_id": self.execution_id, - "timestamp": int(datetime.datetime.fromisoformat( - data.get("timestamp") - ).timestamp() * 1000), - } - } - elif event_type == "task_result": - # Node execution completed - result = payload.get("result", {}) - if not result.get("activate", {}).get(node_name): - continue - - conversation_id = input_data.get("conversation_id") - logger.info(f"[NODE-END] Node execution completed: {node_name} " - f"- execution_id: {self.execution_id}") - - yield { - "event": "node_end", - "data": { - "node_id": node_name, - "conversation_id": conversation_id, - "execution_id": self.execution_id, - "timestamp": int(datetime.datetime.fromisoformat( - data.get("timestamp") - ).timestamp() * 1000), - "input": result.get("node_outputs", {}).get(node_name, {}).get("input"), - "output": result.get("node_outputs", {}).get(node_name, {}).get("output"), - "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), - } - } - - elif mode == "updates": - # Handle state updates - store final state - state = graph.get_state(config=self.checkpoint_config).values - node_outputs = state.get("runtime_vars", {}) - variables = state.get("variables", {}) - activate = state.get("activate", {}) - for _, node_data in data.items(): - node_outputs |= node_data.get("runtime_vars", {}) - variables |= node_data.get("variables", {}) - - self._update_stream_output_status(activate, data) - wait = False - while self.activate_end and not wait: - async for msg_event in self._emit_active_chunks( - node_outputs=node_outputs, - variables=variables - ): - full_content += msg_event["data"]['chunk'] + async for msg_event in self._handle_node_chunk_event(data): + full_content += data.get("chunk") yield msg_event - if self.activate_end: - wait = True - else: - self._update_stream_output_status(activate, data) + elif event_type == "node_error": + async for error_event in self._handle_node_error_event(data): + yield error_event + elif mode == "debug": + async for debug_event in self._handle_debug_event(data, input_data): + yield debug_event + + elif mode == "updates": logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} " f"- execution_id: {self.execution_id}") - - result = graph.get_state(self.checkpoint_config).values - node_outputs = result.get("runtime_vars", {}) - variables = result.get("variables", {}) - self.end_outputs = { - node_id: node_info - for node_id, node_info in self.end_outputs.items() - if node_info.activate - } - - if self.end_outputs or self.activate_end: - while self.activate_end: - async for msg_event in self._emit_active_chunks( - node_outputs=node_outputs, - variables=variables, - force=True - ): + async for msg_event in self._handle_updates_event(data): full_content += msg_event["data"]['chunk'] yield msg_event - if not self.activate_end and self.end_outputs: - self.activate_end = list(self.end_outputs.keys())[0] + # Flush any remaining chunks + async for msg_event in self._flush_remaining_chunk(): + full_content += msg_event["data"]['chunk'] + yield msg_event - # 计算耗时 + result = graph.get_state(self.checkpoint_config).values end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - result = graph.get_state(self.checkpoint_config).values - logger.info(result) + + # Append messages for user and assistant result["messages"].extend( [ { @@ -603,23 +830,20 @@ class WorkflowExecutor: ) logger.info( f"Workflow execution completed (streaming), " - f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}" + f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}" ) - # 发送 workflow_end 事件 yield { "event": "workflow_end", "data": self._build_final_output(result, elapsed_time, full_content) } except Exception as e: - # 计算耗时(即使失败也记录) end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) + logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True) - # 发送 workflow_end 事件(失败) yield { "event": "workflow_end", "data": { @@ -633,14 +857,20 @@ class WorkflowExecutor: @staticmethod def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None: - """聚合所有节点的 token 使用情况 + """ + Aggregate token usage statistics across all nodes. Args: - node_outputs: 所有节点的输出 + node_outputs (dict): A dictionary of all node outputs. Returns: - 聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z} - 如果没有 token 使用信息,返回 None + dict | None: Aggregated token usage in the format: + { + "prompt_tokens": int, + "completion_tokens": int, + "total_tokens": int + } + Returns None if no token usage information is available. """ total_prompt_tokens = 0 total_completion_tokens = 0 @@ -673,17 +903,18 @@ async def execute_workflow( workspace_id: str, user_id: str ) -> dict[str, Any]: - """执行工作流(便捷函数) + """ + Execute a workflow (convenience function, non-streaming). Args: - workflow_config: 工作流配置 - input_data: 输入数据 - execution_id: 执行 ID - workspace_id: 工作空间 ID - user_id: 用户 ID + workflow_config (dict): The workflow configuration. + input_data (dict): Input data for the workflow. + execution_id (str): Execution ID. + workspace_id (str): Workspace ID. + user_id (str): User ID. Returns: - 执行结果 + dict: Workflow execution result. """ executor = WorkflowExecutor( workflow_config=workflow_config, @@ -701,17 +932,18 @@ async def execute_workflow_stream( workspace_id: str, user_id: str ): - """执行工作流(流式,便捷函数) + """ + Execute a workflow in streaming mode (convenience function). Args: - workflow_config: 工作流配置 - input_data: 输入数据 - execution_id: 执行 ID - workspace_id: 工作空间 ID - user_id: 用户 ID + workflow_config (dict): The workflow configuration. + input_data (dict): Input data for the workflow. + execution_id (str): Execution ID. + workspace_id (str): Workspace ID. + user_id (str): User ID. Yields: - 流式事件 + dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. """ executor = WorkflowExecutor( workflow_config=workflow_config, diff --git a/api/app/core/workflow/expression_evaluator.py b/api/app/core/workflow/expression_evaluator.py index 1a8b101e..26f0c41c 100644 --- a/api/app/core/workflow/expression_evaluator.py +++ b/api/app/core/workflow/expression_evaluator.py @@ -1,9 +1,3 @@ -""" -安全的表达式求值器 - -使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。 -""" - import logging import re from typing import Any @@ -14,160 +8,119 @@ logger = logging.getLogger(__name__) class ExpressionEvaluator: - """安全的表达式求值器""" + """Safe expression evaluator for workflow variables and node outputs.""" - # 保留的命名空间 + # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} @staticmethod def evaluate( expression: str, - variables: dict[str, Any], + conv_vars: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> Any: - """安全地评估表达式 - - Args: - expression: 表达式字符串,如 "{{var.score}} > 0.8" - variables: 用户定义的变量 - node_outputs: 节点输出结果 - system_vars: 系统变量 - - Returns: - 表达式求值结果 - - Raises: - ValueError: 表达式无效或求值失败 - - Examples: - >>> evaluator = ExpressionEvaluator() - >>> evaluator.evaluate( - ... "var.score > 0.8", - ... {"score": 0.9}, - ... {}, - ... {} - ... ) - True - - >>> evaluator.evaluate( - ... "node.intent.output == '售前咨询'", - ... {}, - ... {"intent": {"output": "售前咨询"}}, - ... {} - ... ) - True """ - # 移除 Jinja2 模板语法的花括号(如果存在) + Safely evaluate an expression using workflow variables. + + Args: + expression (str): The expression string, e.g., "var.score > 0.8" + conv_vars (dict): Conversation-level variables + node_outputs (dict): Outputs from workflow nodes + system_vars (dict, optional): System variables + + Returns: + Any: Result of the evaluated expression + + Raises: + ValueError: If the expression is invalid or evaluation fails + """ + # Remove Jinja2-style brackets if present expression = expression.strip() - # "{{system.message}} == {{ user.messge }}" -> "system.message == user.message" pattern = r"\{\{\s*(.*?)\s*\}\}" expression = re.sub(pattern, r"\1", expression).strip() - # 构建命名空间上下文 + # Build context for evaluation context = { - "var": variables, # 用户变量 - "node": node_outputs, # 节点输出 - "sys": system_vars or {}, # 系统变量 + "conv": conv_vars, # conversation variables + "node": node_outputs, # node outputs + "sys": system_vars or {}, # system variables } - - # 为了向后兼容,也支持直接访问(但会在日志中警告) - context.update(variables) + + context.update(conv_vars) context["nodes"] = node_outputs context.update(node_outputs) try: - # simpleeval 只支持安全的操作: - # - 算术运算: +, -, *, /, //, %, ** - # - 比较运算: ==, !=, <, <=, >, >= - # - 逻辑运算: and, or, not - # - 成员运算: in, not in - # - 属性访问: obj.attr - # - 字典/列表访问: obj["key"], obj[0] - # 不支持:函数调用、导入、赋值等危险操作 + # simpleeval supports safe operations: + # arithmetic, comparisons, logical ops, attribute/dict/list access result = simple_eval(expression, names=context) return result except NameNotDefined as e: - logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}") - raise ValueError(f"未定义的变量: {e}") + logger.error(f"Undefined variable in expression: {expression}, error: {e}") + raise ValueError(f"Undefined variable: {e}") except InvalidExpression as e: - logger.error(f"表达式语法无效: {expression}, 错误: {e}") - raise ValueError(f"表达式语法无效: {e}") + logger.error(f"Invalid expression syntax: {expression}, error: {e}") + raise ValueError(f"Invalid expression syntax: {e}") except SyntaxError as e: - logger.error(f"表达式语法错误: {expression}, 错误: {e}") - raise ValueError(f"表达式语法错误: {e}") + logger.error(f"Syntax error in expression: {expression}, error: {e}") + raise ValueError(f"Syntax error: {e}") except Exception as e: - logger.error(f"表达式求值异常: {expression}, 错误: {e}") - raise ValueError(f"表达式求值失败: {e}") + logger.error(f"Expression evaluation failed: {expression}, error: {e}") + raise ValueError(f"Expression evaluation failed: {e}") @staticmethod def evaluate_bool( expression: str, - variables: dict[str, Any], + conv_var: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> bool: - """评估布尔表达式(用于条件判断) - + """ + Evaluate a boolean expression (for conditions). + Args: - expression: 布尔表达式 - variables: 用户变量 - node_outputs: 节点输出 - system_vars: 系统变量 - + expression (str): Boolean expression + conv_var (dict): Conversation variables + node_outputs (dict): Node outputs + system_vars (dict, optional): System variables + Returns: - 布尔值结果 - - Examples: - >>> ExpressionEvaluator.evaluate_bool( - ... "var.count >= 10 and var.status == 'active'", - ... {"count": 15, "status": "active"}, - ... {}, - ... {} - ... ) - True + bool: Boolean result """ result = ExpressionEvaluator.evaluate( - expression, variables, node_outputs, system_vars + expression, conv_var, node_outputs, system_vars ) return bool(result) @staticmethod def validate_variable_names(variables: list[dict]) -> list[str]: - """验证变量名是否合法 - + """ + Validate variable names for legality. + Args: - variables: 变量定义列表 - + variables (list[dict]): List of variable definitions + Returns: - 错误列表,如果为空则验证通过 - - Examples: - >>> ExpressionEvaluator.validate_variable_names([ - ... {"name": "user_input"}, - ... {"name": "var"} # 保留字 - ... ]) - ["变量名 'var' 是保留的命名空间,请使用其他名称"] + list[str]: List of error messages. Empty if all names are valid. """ errors = [] for var in variables: var_name = var.get("name", "") - - # 检查是否为保留命名空间 + if var_name in ExpressionEvaluator.RESERVED_NAMESPACES: errors.append( - f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称" + f"Variable name '{var_name}' is a reserved namespace, please use another name" ) - - # 检查是否为有效的 Python 标识符 + if not var_name.isidentifier(): errors.append( - f"变量名 '{var_name}' 不是有效的标识符" + f"Variable name '{var_name}' is not a valid Python identifier" ) return errors @@ -176,23 +129,23 @@ class ExpressionEvaluator: # 便捷函数 def evaluate_expression( expression: str, - variables: dict[str, Any], + conv_var: dict[str, Any], node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + system_vars: dict[str, Any] ) -> Any: - """评估表达式(便捷函数)""" + """Evaluate an expression (convenience function).""" return ExpressionEvaluator.evaluate( - expression, variables, node_outputs, system_vars + expression, conv_var, node_outputs, system_vars ) def evaluate_condition( expression: str, - variables: dict[str, Any], + conv_var: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> bool: - """评估条件表达式(便捷函数)""" + """Evaluate a boolean condition expression (convenience function).""" return ExpressionEvaluator.evaluate_bool( - expression, variables, node_outputs, system_vars + expression, conv_var, node_outputs, system_vars ) diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index b1d43e08..46a594d7 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -14,9 +14,14 @@ from pydantic import BaseModel, Field from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) +SCOPE_PATTERN = re.compile( + r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}" +) + class OutputContent(BaseModel): """ @@ -53,6 +58,12 @@ class OutputContent(BaseModel): ) ) + _SCOPE: str | None = None + + def get_scope(self) -> str: + self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0] + return self._SCOPE + def depends_on_scope(self, scope: str) -> bool: """ Check if this segment depends on a given scope. @@ -63,8 +74,9 @@ class OutputContent(BaseModel): Returns: bool: True if this segment references the given scope. """ - pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}" - return bool(re.search(pattern, self.literal)) + if self._SCOPE: + return self._SCOPE == scope + return self.get_scope() == scope class StreamOutputConfig(BaseModel): @@ -167,6 +179,7 @@ class GraphBuilder: workflow_config: dict[str, Any], stream: bool = False, subgraph: bool = False, + variable_pool: VariablePool | None = None ): self.workflow_config = workflow_config @@ -180,6 +193,10 @@ class GraphBuilder: self._find_upstream_branch_node = lru_cache( maxsize=len(self.nodes) * 2 )(self._find_upstream_branch_node) + if variable_pool: + self.variable_pool = variable_pool + else: + self.variable_pool = VariablePool() self.graph = StateGraph(WorkflowState) self.add_nodes() @@ -452,9 +469,9 @@ class GraphBuilder: if self.stream: # Stream mode: create an async generator function # LangGraph collects all yielded values; the last yielded dictionary is merged into the state - def make_stream_func(inst): + def make_stream_func(inst, variable_pool=self.variable_pool): async def node_func(state: WorkflowState): - async for item in inst.run_stream(state): + async for item in inst.run_stream(state, variable_pool): yield item return node_func @@ -462,9 +479,9 @@ class GraphBuilder: self.graph.add_node(node_id, make_stream_func(node_instance)) else: # Non-stream mode: create an async function - def make_func(inst): + def make_func(inst, variable_pool=self.variable_pool): async def node_func(state: WorkflowState): - return await inst.run(state) + return await inst.run(state, variable_pool) return node_func @@ -567,27 +584,28 @@ class GraphBuilder: for target in branch_info["target"]: waiting_edges[target].append(branch_info["node"]["name"]) - def router_fn(state: WorkflowState) -> list[Send]: + def router_fn(state: WorkflowState, variable_pool: VariablePool = self.variable_pool) -> list[Send]: branch_activate = [] new_state = state.copy() new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate - + node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False) for label, branch in unique_branch.items(): - if evaluate_condition( + if node_output and evaluate_condition( branch["condition"], - state.get("variables", {}), - state.get("runtime_vars", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } + {}, + {src: node_output}, + {} ): logger.debug(f"Conditional routing {src}: selected branch {label}") new_state["activate"][branch["node"]["name"]] = True + branch_activate.append( + Send( + branch['node']['name'], + new_state + ) + ) continue new_state["activate"][branch["node"]["name"]] = False - for label, branch in unique_branch.items(): branch_activate.append( Send( branch['node']['name'], diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 926f86e4..1f2eb15b 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -15,7 +15,6 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode -from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.tool import ToolNode @@ -25,7 +24,6 @@ __all__ = [ "WorkflowState", "LLMNode", "AgentNode", - "TransformNode", "IfElseNode", "StartNode", "EndNode", diff --git a/api/app/core/workflow/nodes/agent/config.py b/api/app/core/workflow/nodes/agent/config.py index 413ce606..4d428a4b 100644 --- a/api/app/core/workflow/nodes/agent/config.py +++ b/api/app/core/workflow/nodes/agent/config.py @@ -2,7 +2,8 @@ from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class AgentNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index e4525d88..0818749c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -2,6 +2,7 @@ Agent 节点实现 调用已发布的 Agent 应用。 +# TODO """ import logging @@ -9,6 +10,8 @@ from typing import Any from langchain_core.messages import AIMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.services.draft_run_service import DraftRunService from app.models import AppRelease from app.db import get_db @@ -30,19 +33,22 @@ class AgentNode(BaseNode): } } """ - - def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]: + + def _output_types(self) -> dict[str, VariableType]: + return {"output": VariableType.STRING} + + def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]: """准备 Agent(公共逻辑) Args: - state: 工作流状态 + variable_pool: 变量池 Returns: (draft_service, release, message): 服务实例、发布配置、消息 """ # 1. 渲染消息 message_template = self.config.get("message", "") - message = self._render_template(message_template, state) + message = self._render_template(message_template, variable_pool) # 2. 获取 Agent 配置 agent_id = self.config.get("agent_id") @@ -61,16 +67,17 @@ class AgentNode(BaseNode): return draft_service, release, message - async def execute(self, state: WorkflowState) -> dict[str, Any]: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """非流式执行 Args: state: 工作流状态 + variable_pool: 变量池 Returns: 状态更新字典 """ - draft_service, release, message = self._prepare_agent(state) + draft_service, release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") @@ -79,9 +86,9 @@ class AgentNode(BaseNode): agent_config=release.config, model_config=None, message=message, - workspace_id=state.get("workspace_id"), + workspace_id=variable_pool.get_value("sys.workspace_id"), user_id=state.get("user_id"), - variables=state.get("variables", {}) + variables=variable_pool.get_all_conversation_vars() ) response = result.get("response", "") @@ -99,16 +106,17 @@ class AgentNode(BaseNode): } } - async def execute_stream(self, state: WorkflowState): + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): """流式执行 Args: state: 工作流状态 + variable_pool: 变量池 Yields: 流式事件字典 """ - draft_service, release, message = self._prepare_agent(state) + draft_service, release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") @@ -120,9 +128,9 @@ class AgentNode(BaseNode): agent_config=release.config, model_config=None, message=message, - workspace_id=state.get("workspace_id"), + workspace_id=variable_pool.get_value("sys.workspace_id"), user_id=state.get("user_id"), - variables=state.get("variables", {}) + variables=variable_pool.get_all_conversation_vars() ): # 提取内容 content = chunk.get("content", "") diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 6f2583b4..e1bb6e9d 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -6,6 +6,7 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import AssignmentOperator from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -17,13 +18,17 @@ class AssignerNode(BaseNode): self.variable_updater = True self.typed_config: AssignerNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return {} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the assignment operation defined by this node. Args: state: The current workflow state, including conversation variables, node outputs, and system variables. + variable_pool: variable pool Returns: None or the result of the assignment operation. @@ -31,60 +36,57 @@ class AssignerNode(BaseNode): # Initialize a variable pool for accessing conversation, node, and system variables self.typed_config = AssignerNodeConfig(**self.config) logger.info(f"节点 {self.node_id} 开始执行") - pool = VariablePool(state) + pattern = r"\{\{\s*(.*?)\s*\}\}" + for assignment in self.typed_config.assignments: # Get the target variable selector (e.g., "conv.test") variable_selector = assignment.variable_selector - if isinstance(variable_selector, str): - # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] - pattern = r"\{\{\s*(.*?)\s*\}\}" - expression = re.sub(pattern, r"\1", variable_selector).strip() - variable_selector = expression.split('.') + namespace = re.sub(pattern, r"\1", variable_selector).split('.')[0] # Only conversation variables ('conv') are allowed - if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]: - raise ValueError("Only conversation or cycle variables can be assigned.") + if namespace != 'conv' and namespace not in state["cycle_nodes"]: + raise ValueError(f"Only conversation or cycle variables can be assigned. - {variable_selector}") # Get the value or expression to assign value = assignment.value logger.debug(f"left:{variable_selector}, right: {value}") - pattern = r"\{\{\s*(.*?)\s*\}\}" + if isinstance(value, str): expression = re.match(pattern, value) if expression: expression = expression.group(1) expression = re.sub(pattern, r"\1", expression).strip() - value = self.get_variable(expression, state) + value = self.get_variable(expression, variable_pool, default=value, strict=False) # Select the appropriate assignment operator instance based on the target variable type operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value( - pool.get(variable_selector) + variable_pool.get_value(variable_selector) )( - pool, variable_selector, value + variable_pool, variable_selector, value ) # Execute the configured assignment operation match assignment.operation: case AssignmentOperator.COVER: - operator.assign() + await operator.assign() case AssignmentOperator.ASSIGN: - operator.assign() + await operator.assign() case AssignmentOperator.CLEAR: - operator.clear() + await operator.clear() case AssignmentOperator.ADD: - operator.add() + await operator.add() case AssignmentOperator.SUBTRACT: - operator.subtract() + await operator.subtract() case AssignmentOperator.MULTIPLY: - operator.multiply() + await operator.multiply() case AssignmentOperator.DIVIDE: - operator.divide() + await operator.divide() case AssignmentOperator.APPEND: - operator.append() + await operator.append() case AssignmentOperator.REMOVE_FIRST: - operator.remove_first() + await operator.remove_first() case AssignmentOperator.REMOVE_LAST: - operator.remove_last() + await operator.remove_last() case _: raise ValueError(f"Invalid Operator: {assignment.operation}") logger.info(f"Node {self.node_id}: execution completed") diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index a6b33928..973e120d 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -3,79 +3,13 @@ 定义所有节点配置的通用字段和数据结构。 """ -from enum import StrEnum -from typing import Any +from pydantic import BaseModel, Field -from pydantic import BaseModel, Field, ConfigDict +from app.core.workflow.variable.base_variable import VariableType VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}" -class VariableType(StrEnum): - """变量类型枚举""" - - STRING = "string" - NUMBER = "number" - BOOLEAN = "boolean" - OBJECT = "object" - - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_OBJECT = "array[object]" - - -class TypedVariable(BaseModel): - """ - TODO: 强类型限制 - Strongly typed variable that validates value on assignment. - """ - - value: Any = Field(..., description="Variable value") - type: VariableType = Field(..., description="Declared type of the variable") - - model_config = ConfigDict( - validate_assignment=True - ) - - def __setattr__(self, name, value): - if name == "value": - self._validate_value(value) - if name == "type": - raise RuntimeError("Cannot modify variable type at runtime") - super().__setattr__(name, value) - - def _validate_value(self, v: Any): - t = self.type - match t: - case VariableType.STRING: - if not isinstance(v, str): - raise TypeError("Variable value does not match type STRING") - case VariableType.BOOLEAN: - if not isinstance(v, bool): - raise TypeError("Variable value does not match type BOOLEAN") - case VariableType.NUMBER: - if not isinstance(v, (int, float)): - raise TypeError("Variable value does not match type NUMBER") - case VariableType.OBJECT: - if not isinstance(v, dict): - raise TypeError("Variable value does not match type OBJECT") - case VariableType.ARRAY_STRING: - if not isinstance(v, list) or not all(isinstance(i, str) for i in v): - raise TypeError("Variable value does not match type ARRAY_STRING") - case VariableType.ARRAY_NUMBER: - if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v): - raise TypeError("Variable value does not match type ARRAY_NUMBER") - case VariableType.ARRAY_BOOLEAN: - if not isinstance(v, list) or not all(isinstance(i, bool) for i in v): - raise TypeError("Variable value does not match type ARRAY_BOOLEAN") - case VariableType.ARRAY_OBJECT: - if not isinstance(v, list) or not all(isinstance(i, dict) for i in v): - raise TypeError("Variable value does not match type ARRAY_OBJECT") - case _: - raise TypeError(f"Unknown variable type: {t}") - - class VariableDefinition(BaseModel): """变量定义 diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 4dcdf2bb..2bf748f2 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,12 +1,7 @@ -""" -工作流节点基类 - -定义节点的基本接口和通用功能。 -""" - import asyncio import logging from abc import ABC, abstractmethod +from functools import cached_property from typing import Any, AsyncGenerator from langgraph.config import get_stream_writer @@ -14,6 +9,7 @@ from typing_extensions import TypedDict, Annotated from app.core.config import settings from app.core.workflow.nodes.enums import BRANCH_NODES +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -42,22 +38,10 @@ class WorkflowState(TypedDict): cycle_nodes: list looping: Annotated[int, merge_looping_state] - # Input variables (passed from configured variables) - # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) - variables: Annotated[dict[str, Any], lambda x, y: { - **x, - **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v - for k, v in y.items()} - }] - # Node outputs (stores execution results of each node for variable references) # Uses a custom merge function to combine new node outputs into the existing dictionary node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - # Runtime node variables (simplified version, stores business data for fast access between nodes) - # Format: {node_id: business_result} - runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - # Execution context execution_id: str workspace_id: str @@ -72,17 +56,17 @@ class WorkflowState(TypedDict): class BaseNode(ABC): - """节点基类 - - 所有节点类型都应该继承此基类,实现 execute 方法。 + """Base class for workflow nodes. + + All node types should inherit from this class and implement the `execute` method. """ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - """初始化节点 - + """Initialize the node. + Args: - node_config: 节点配置 - workflow_config: 工作流配置 + node_config: Configuration of the node. + workflow_config: Configuration of the workflow. """ self.node_config = node_config self.workflow_config = workflow_config @@ -94,7 +78,27 @@ class BaseNode(ABC): self.config = node_config.get("config") or {} self.error_handling = node_config.get("error_handling") or {} - self.variable_updater = False + self.variable_change_able = False + + @cached_property + def output_types(self) -> dict[str, VariableType]: + """Returns the output variable types of the node. + + This property is cached to avoid recomputation. + """ + return self._output_types() + + @abstractmethod + def _output_types(self) -> dict[str, VariableType]: + """Defines output variable types for the node. + + Subclasses must override this method to declare the variables + produced by the node and their corresponding types. + + Returns: + A mapping from output variable names to ``VariableType``. + """ + return {} def check_activate(self, state: WorkflowState): """Check if the current node is activated in the workflow state. @@ -136,92 +140,84 @@ class BaseNode(ABC): } @abstractmethod - async def execute(self, state: WorkflowState) -> Any: - """执行节点业务逻辑(非流式) - - 节点只需要返回业务结果,不需要关心输出格式、时间统计等。 - BaseNode 会自动包装成标准格式。 - + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: + """Executes the node business logic (non-streaming). + + The node implementation should only return the business result. + It does not need to handle output formatting, timing, or statistics. + The ``BaseNode`` will automatically wrap the result into a standard + response format. + Args: - state: 工作流状态 - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 业务结果(任意类型) - - Examples: - >>> # LLM 节点 - >>> "这是 AI 的回复" - - >>> # Transform 节点 - >>> {"processed_data": [...]} - - >>> # Start/End 节点 - >>> {"message": "开始", "conversation_id": "xxx"} + The business result produced by the node. The return value can be + of any type. """ pass - async def execute_stream(self, state: WorkflowState): - """执行节点业务逻辑(流式) - - 子类可以重写此方法以支持流式输出。 - 默认实现:执行非流式方法并一次性返回。 - - 节点需要: - 1. yield 中间结果(如文本片段) - 2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result} - - Args: - state: 工作流状态 - - Yields: - 业务数据(chunk)或完成标记 - - Examples: - # 流式 LLM 节点 - full_response = "" - async for chunk in llm.astream(prompt): - full_response += chunk - yield chunk # yield 文本片段 + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): + """Executes the node business logic in streaming mode. - # 最后 yield 完成标记 - yield {"__final__": True, "result": AIMessage(content=full_response)} + Subclasses may override this method to support streaming output. + The default implementation executes the non-streaming method and + yields a single final result. + + For streaming execution, a node implementation should: + 1. Yield intermediate results (e.g. text chunks). + 2. Yield a final completion marker in the following format: + ``{"__final__": True, "result": final_result}``. + + Args: + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + + Yields: + Business data chunks or a final completion marker. """ - result = await self.execute(state) - # 默认实现:直接 yield 完成标记 + result = await self.execute(state, variable_pool) + # Default implementation: yield a single final completion marker. yield {"__final__": True, "result": result} def supports_streaming(self) -> bool: - """节点是否支持流式输出 - + """Returns whether the node supports streaming output. + + A node is considered to support streaming if its class overrides + the ``execute_stream`` method. If the default implementation from + ``BaseNode`` is used, streaming is not supported. + Returns: - 是否支持流式输出 + True if the node supports streaming output, False otherwise. """ - # 检查子类是否重写了 execute_stream 方法 + # Check whether the subclass overrides the execute_stream method. return self.__class__.execute_stream is not BaseNode.execute_stream - def get_timeout(self) -> int: - """获取超时时间(秒) - + @staticmethod + def get_timeout() -> int: + """Returns the execution timeout in seconds. + Returns: - 超时时间 + The timeout duration, in seconds. """ return settings.WORKFLOW_NODE_TIMEOUT - # return self.error_handling.get("timeout", 60) - async def run(self, state: WorkflowState) -> dict[str, Any]: - """执行节点(带错误处理和输出包装,非流式) - - 这个方法由 Executor 调用,负责: - 1. 时间统计 - 2. 调用节点的 execute() 方法 - 3. 将业务结果包装成标准输出格式 - 4. 错误处理 - + async def run(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + """Runs the node with error handling and output wrapping (non-streaming). + + This method is invoked by the Executor and is responsible for: + 1. Execution time measurement. + 2. Invoking the node's ``execute()`` method. + 3. Wrapping the business result into a standardized output format. + 4. Handling execution errors. + Args: - state: 工作流状态 - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 标准化的状态更新字典 + A standardized state update dictionary. """ if not self.check_activate(state): return self.trans_activate(state) @@ -233,70 +229,78 @@ class BaseNode(ABC): timeout = self.get_timeout() try: - # 调用节点的业务逻辑 + # Invoke the node business logic. business_result = await asyncio.wait_for( - self.execute(state), + self.execute(state, variable_pool), timeout=timeout ) elapsed_time = time.time() - start_time - # 提取处理后的输出(调用子类的 _extract_output) + # Extract processed outputs using subclass-defined logic. extracted_output = self._extract_output(business_result) - # 包装成标准输出格式 - wrapped_output = self._wrap_output(business_result, elapsed_time, state) + # Wrap the business result into the standard output format. + wrapped_output = self._wrap_output(business_result, elapsed_time, state, variable_pool) - # 将提取后的输出存储到运行时变量中(供后续节点快速访问) - # 如果提取后的输出是字典,拆包存储;否则存储为 output 字段 - if isinstance(extracted_output, dict): - runtime_var = extracted_output - else: - runtime_var = {"output": extracted_output} + # Store extracted outputs as runtime variables for downstream nodes. + if extracted_output is not None: + runtime_vars = extracted_output + if not isinstance(extracted_output, dict): + runtime_vars = {"output": extracted_output} + for k, v in runtime_vars.items(): + await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able) - # 返回包装后的输出和运行时变量 + # Return the wrapped output along with activation state updates. return { **wrapped_output, - "messages": state["messages"], - "runtime_vars": { - self.node_id: runtime_var - }, "looping": state["looping"] } | self.trans_activate(state) except TimeoutError: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") - return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + logger.error( + f"Node {self.node_id} execution timed out ({timeout} seconds)." + ) + return self._wrap_error( + f"Node execution timed out ({timeout} seconds).", + elapsed_time, + state, + variable_pool, + ) except Exception as e: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) - return self._wrap_error(str(e), elapsed_time, state) + logger.error( + f"Node {self.node_id} execution failed: {e}", + exc_info=True, + ) + return self._wrap_error(str(e), elapsed_time, state, variable_pool) + + async def run_stream( + self, state: WorkflowState, + variable_pool: VariablePool + ) -> AsyncGenerator[dict[str, Any], Any]: + """Executes the node with error handling and output wrapping (streaming). - async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]: - """Execute node with error handling and output wrapping (streaming) - This method is called by the Executor and is responsible for: - 1. Time tracking - 2. Calling the node's execute_stream() method - 3. Using LangGraph's stream writer to send chunks - 4. Updating streaming buffer in state for downstream nodes - 5. Wrapping business data into standard output format - 6. Error handling - - Special handling for End nodes: - - End nodes don't send chunks via writer (prefix and LLM content already sent) - - End nodes only yield suffix for final result assembly - + 1. Tracking execution time. + 2. Calling the node's ``execute_stream()`` method. + 3. Sending streaming chunks via LangGraph's stream writer. + 4. Updating activation-related state for downstream nodes. + 5. Wrapping business data into a standardized output format. + 6. Handling execution errors. + Args: - state: Workflow state - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Yields: - State updates with streaming buffer and final result + Incremental state updates, including activation state changes and + the final wrapped result. """ if not self.check_activate(state): yield self.trans_activate(state) - logger.info(f"jump node: {self.node_id}") + logger.debug(f"jump node: {self.node_id}") return import time @@ -317,7 +321,7 @@ class BaseNode(ABC): # Stream chunks in real-time loop_start = asyncio.get_event_loop().time() - async for item in self.execute_stream(state): + async for item in self.execute_stream(state, variable_pool): # Check timeout if asyncio.get_event_loop().time() - loop_start > timeout: raise TimeoutError() @@ -332,7 +336,7 @@ class BaseNode(ABC): chunks.append(content) # Send chunks for all nodes (including End nodes for suffix) - logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...") + logger.debug(f"Node {self.node_id} sent chunk #{chunk_count}: {content[:50]}...") # 1. Send via stream writer (for real-time client updates) writer({ @@ -344,27 +348,26 @@ class BaseNode(ABC): elapsed_time = time.time() - start_time - logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}") + logger.info(f"Node {self.node_id} streaming execution finished, " + f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}") # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) # Wrap final result - final_output = self._wrap_output(final_result, elapsed_time, state) + final_output = self._wrap_output(final_result, elapsed_time, state, variable_pool) # Store extracted output in runtime variables (for quick access by subsequent nodes) - if isinstance(extracted_output, dict): - runtime_var = extracted_output - else: - runtime_var = {"output": extracted_output} + if extracted_output is not None: + runtime_vars = extracted_output + if not isinstance(extracted_output, dict): + runtime_vars = {"output": extracted_output} + for k, v in runtime_vars.items(): + await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able) # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, - "messages": state["messages"], - "runtime_vars": { - self.node_id: runtime_var - }, "looping": state["looping"] } @@ -374,41 +377,49 @@ class BaseNode(ABC): except TimeoutError: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)") - error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state) + logger.error(f"Node {self.node_id} execution timed out ({timeout}s)") + error_output = self._wrap_error( + f"Node execution timed out ({timeout}s)", + elapsed_time, + state, + variable_pool + ) yield error_output except Exception as e: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) - error_output = self._wrap_error(str(e), elapsed_time, state) + logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True) + error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool) yield error_output def _wrap_output( self, business_result: Any, elapsed_time: float, - state: WorkflowState + state: WorkflowState, + variable_pool: VariablePool ) -> dict[str, Any]: - """将业务结果包装成标准输出格式 - - Args: - business_result: 节点返回的业务结果 - elapsed_time: 执行耗时 - state: 工作流状态 - - Returns: - 标准化的状态更新字典 - """ - # 提取输入数据(用于记录) - input_data = self._extract_input(state) + """Wraps the business result into a standardized node output format. - # 提取 token 使用情况(如果有) + Args: + business_result: The result returned by the node's business logic. + elapsed_time: Time elapsed during node execution (in seconds). + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + + Returns: + A dictionary representing the standardized state update for this node, + including node outputs, input, output, elapsed time, token usage, and status. + """ + # Extract input data (for logging or audit purposes) + input_data = self._extract_input(state, variable_pool) + + # Extract token usage information (if applicable) token_usage = self._extract_token_usage(business_result) - # 提取实际输出(去除元数据) + # Extract actual output (strip any metadata) output = self._extract_output(business_result) - # 构建标准节点输出 + # Construct standardized node output node_output = { "node_id": self.node_id, "node_type": self.node_type, @@ -423,8 +434,6 @@ class BaseNode(ABC): final_output = { "node_outputs": {self.node_id: node_output}, } - if self.variable_updater: - final_output = final_output | {"variables": state["variables"]} return final_output @@ -432,25 +441,33 @@ class BaseNode(ABC): self, error_message: str, elapsed_time: float, - state: WorkflowState + state: WorkflowState, + variable_pool: VariablePool ) -> dict[str, Any]: - """将错误包装成标准输出格式 - + """Wraps an error into a standardized node output format. + + This method handles both cases: + - If an error edge is defined, the workflow can continue to the error handling node. + - If no error edge exists, the workflow is stopped by raising an exception. + Args: - error_message: 错误信息 - elapsed_time: 执行耗时 - state: 工作流状态 - + error_message: The error message describing the failure. + elapsed_time: Time elapsed during node execution (in seconds). + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 标准化的状态更新字典 + A dictionary representing the standardized state update for this node + when an error edge exists. If no error edge exists, this method + raises an exception to stop the workflow. """ - # 查找错误边 + # Check if the node has an error edge defined error_edge = self._find_error_edge() - # 提取输入数据 - input_data = self._extract_input(state) + # Extract input data (for logging or audit purposes) + input_data = self._extract_input(state, variable_pool) - # 构建错误输出 + # Construct the standardized node output for the error node_output = { "node_id": self.node_id, "node_type": self.node_type, @@ -464,9 +481,9 @@ class BaseNode(ABC): } if error_edge: - # 有错误边:记录错误并继续 + # If an error edge exists, log a warning and continue to error node logger.warning( - f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}" + f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" ) return { "node_outputs": { @@ -476,198 +493,161 @@ class BaseNode(ABC): "error_node": self.node_id } else: + # If no error edge, send the error via stream writer and stop the workflow writer = get_stream_writer() writer({ "type": "node_error", **node_output }) - # 无错误边:抛出异常停止工作流 - logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}") - raise Exception(f"节点 {self.node_id} 执行失败: {error_message}") + logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") + raise Exception(f"Node {self.node_id} execution failed: {error_message}") + + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + """Extracts the input data for this node (used for logging or audit). + + Subclasses may override this method to customize what input data + should be recorded. - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: - """提取节点输入数据(用于记录) - - 子类可以重写此方法来自定义输入记录。 - Args: - state: 工作流状态 - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 输入数据字典 + A dictionary containing the node's input data. """ - # 默认返回配置 + # Default implementation returns the node configuration return {"config": self.config} def _extract_output(self, business_result: Any) -> Any: - """从业务结果中提取实际输出 - - 子类可以重写此方法来自定义输出提取。 - + """Extracts the actual output from the business result. + + Subclasses may override this method to customize how the node's + output is extracted. + Args: - business_result: 业务结果 - + business_result: The result returned by the node's business logic. + Returns: - 实际输出 + The actual output extracted from the business result. """ - # 默认直接返回业务结果 + # Default implementation returns the business result directly return business_result def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: - """从业务结果中提取 token 使用情况 - - 子类可以重写此方法来提取 token 信息。 - + """Extracts token usage information from the business result. + + Subclasses may override this method to extract token usage statistics + (e.g., for LLM nodes). + Args: - business_result: 业务结果 - + business_result: The result returned by the node's business logic. + Returns: - token 使用情况或 None + A dictionary mapping token types to counts, or None if not applicable. """ - # 默认返回 None + # Default implementation returns None return None def _find_error_edge(self) -> dict[str, Any] | None: - """查找错误边 - + """Finds the error edge for this node, if any. + + An error edge is used to redirect workflow execution when this node + fails. + Returns: - 错误边配置或 None + A dictionary representing the error edge configuration if it exists, + or None if no error edge is defined. """ for edge in self.workflow_config.get("edges", []): if edge.get("source") == self.node_id and edge.get("type") == "error": return edge return None - def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str: - """渲染模板 - - 支持的变量命名空间: - - sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id) - - conv.xxx: 会话变量(跨多轮对话保持) - - node_id.xxx: 节点输出 - + @staticmethod + def _render_template(template: str, variable_pool: VariablePool, strict: bool = True) -> str: + """Renders a template string using the provided variable pool. + + Supported variable namespaces: + - sys.xxx: System variables (e.g., message, execution_id, workspace_id, + user_id, conversation_id) + - conv.xxx: Conversation variables (persist across multiple turns) + - node_id.xxx: Node outputs + Args: - template: 模板字符串 - state: 工作流状态 - + template: The template string to render. + variable_pool: The variable pool containing system, conversation, and + node variables. + strict: If True, missing variables will raise an error; if False, + missing variables are ignored. + Returns: - 渲染后的字符串 + The rendered string with all variables substituted. """ from app.core.workflow.template_renderer import render_template - # 处理 state 为 None 的情况 - if state is None: - state = {} - - # 使用变量池获取变量 - pool = VariablePool(state) - - # 构建完整的 variables 结构 - variables = { - "sys": pool.get_all_system_vars(), - "conv": pool.get_all_conversation_vars() - } - return render_template( template=template, - variables=variables, - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), + conv_vars=variable_pool.get_all_conversation_vars(), + node_outputs=variable_pool.get_all_node_outputs(), + system_vars=variable_pool.get_all_system_vars(), strict=strict ) - def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: - """评估条件表达式 - - 支持的变量命名空间: - - sys.xxx: 系统变量 - - conv.xxx: 会话变量 - - node_id.xxx: 节点输出 - + @staticmethod + def _evaluate_condition(expression: str, variable_pool: VariablePool) -> bool: + """Evaluates a conditional expression using the provided variable pool. + + Supported variable namespaces: + - sys.xxx: System variables + - conv.xxx: Conversation variables + - node_id.xxx: Node outputs + Args: - expression: 条件表达式 - state: 工作流状态 - + expression: The conditional expression to evaluate. + variable_pool: The variable pool containing system, conversation, and + node variables. + Returns: - 布尔值结果 + The boolean result of evaluating the expression. """ from app.core.workflow.expression_evaluator import evaluate_condition - # 处理 state 为 None 的情况 - if state is None: - state = {} - - # 使用变量池获取变量 - pool = VariablePool(state) - - # 构建完整的 variables 结构(包含 sys 和 conv) - variables = { - "sys": pool.get_all_system_vars(), - "conv": pool.get_all_conversation_vars() - } - return evaluate_condition( expression=expression, - variables=variables, - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars() + conv_var=variable_pool.get_all_conversation_vars(), + node_outputs=variable_pool.get_all_node_outputs(), + system_vars=variable_pool.get_all_system_vars() ) - def get_variable_pool(self, state: WorkflowState) -> VariablePool: - """获取变量池实例 - - VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。 - - Args: - state: 工作流状态 - - Returns: - VariablePool 实例 - - Examples: - >>> pool = self.get_variable_pool(state) - >>> message = pool.get("sys.message") - >>> llm_output = pool.get("llm_qa.output") - """ - return VariablePool(state) - + @staticmethod def get_variable( - self, - selector: list[str] | str, - state: WorkflowState, - default: Any = None + selector: str, + variable_pool: VariablePool, + default: Any = None, + strict: bool = True ) -> Any: - """获取变量值(便捷方法) - - Args: - selector: 变量选择器 - state: 工作流状态 - default: 默认值 - - Returns: - 变量值 - - Examples: - >>> message = self.get_variable("sys.message", state) - >>> output = self.get_variable(["llm_qa", "output"], state) - >>> custom = self.get_variable("var.custom", state, default="默认值") - """ - pool = VariablePool(state) - return pool.get(selector, default=default) + """Retrieves a variable value from the variable pool (convenience method). - def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool: - """检查变量是否存在(便捷方法) - Args: - selector: 变量选择器 - state: 工作流状态 - + selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx). + variable_pool: The variable pool from which to fetch the value. + default: The default value to return if the variable does not exist. + strict: If True, raise an error when the variable is missing; if False, return the default. + Returns: - 变量是否存在 - - Examples: - >>> if self.has_variable("llm_qa.output", state): - ... output = self.get_variable("llm_qa.output", state) + The value of the selected variable, or the default if not found and strict is False. """ - pool = VariablePool(state) - return pool.has(selector) + return variable_pool.get_value(selector, default, strict=strict) + + @staticmethod + def has_variable(selector: str, variable_pool: VariablePool) -> bool: + """Checks whether a variable exists in the variable pool (convenience method). + + Args: + selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx). + variable_pool: The variable pool to check. + + Returns: + True if the variable exists in the pool, False otherwise. + """ + return variable_pool.has(selector) diff --git a/api/app/core/workflow/nodes/breaker/node.py b/api/app/core/workflow/nodes/breaker/node.py index f00015d1..8b772d6a 100644 --- a/api/app/core/workflow/nodes/breaker/node.py +++ b/api/app/core/workflow/nodes/breaker/node.py @@ -2,6 +2,8 @@ import logging from typing import Any from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,15 +16,19 @@ class BreakNode(BaseNode): to False, signaling the outer loop runtime to terminate further iterations. """ - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return {} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the break node. Args: state: Current workflow state, including loop control flags. + variable_pool: Pool of variables for the workflow. Effects: - - Sets 'looping' in the state to False to stop the loop. + - Sets 'looping' in the state too False to stop the loop. - Logs the action for debugging purposes. Returns: diff --git a/api/app/core/workflow/nodes/code/config.py b/api/app/core/workflow/nodes/code/config.py index 8af13f12..e17e841f 100644 --- a/api/app/core/workflow/nodes/code/config.py +++ b/api/app/core/workflow/nodes/code/config.py @@ -1,7 +1,8 @@ from typing import Literal from pydantic import Field, BaseModel -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.variable.base_variable import VariableType class InputVariable(BaseModel): @@ -44,7 +45,7 @@ class CodeNodeConfig(BaseNodeConfig): description="code content" ) - language: Literal['python3', 'nodejs'] = Field( + language: Literal['python3', 'javascript'] = Field( ..., description="language" ) diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 892708f2..bfd176e8 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -9,8 +9,9 @@ from typing import Any import httpx from app.core.workflow.nodes import BaseNode, WorkflowState -from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.nodes.code.config import CodeNodeConfig +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -52,6 +53,12 @@ class CodeNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: CodeNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + output_dict = {} + for output in self.typed_config.output_variables: + output_dict[output.name] = output.type + return output_dict + def extract_result(self, content: str): match = re.search(r'<>(.*?)<>', content, re.DOTALL) if match: @@ -92,11 +99,11 @@ class CodeNode(BaseNode): else: raise RuntimeError("The output of main must be a dictionary") - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = CodeNodeConfig(**self.config) input_variable_dict = {} for input_variable in self.typed_config.input_variables: - input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state) + input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, variable_pool) code = base64.b64decode( self.typed_config.code @@ -110,7 +117,7 @@ class CodeNode(BaseNode): code=code, inputs_variable=input_variable_dict, ) - elif self.typed_config.language == 'nodejs': + elif self.typed_config.language == 'javascript': final_script = NODEJS_SCRIPT_TEMPLATE.substitute( code=code, inputs_variable=input_variable_dict, diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index d73754f6..e4e418fe 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -8,7 +8,6 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.base_config import ( BaseNodeConfig, VariableDefinition, - VariableType, ) from app.core.workflow.nodes.code.config import CodeNodeConfig from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig @@ -23,21 +22,18 @@ from app.core.workflow.nodes.parameter_extractor.config import ParameterExtracto from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig -from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig __all__ = [ # 基础类 "BaseNodeConfig", "VariableDefinition", - "VariableType", # 节点配置 "StartNodeConfig", "EndNodeConfig", "LLMNodeConfig", "MessageConfig", "AgentNodeConfig", - "TransformNodeConfig", "IfElseNodeConfig", "KnowledgeRetrievalNodeConfig", "AssignerNodeConfig", diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py index 445ddd9a..72971286 100644 --- a/api/app/core/workflow/nodes/cycle_graph/config.py +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -2,7 +2,8 @@ from typing import Any from pydantic import Field, BaseModel, field_validator -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType @@ -127,4 +128,9 @@ class IterationNodeConfig(BaseNodeConfig): description="Output of the loop iteration" ) + output_type: VariableType = Field( + ..., + description="Data type of the loop iteration output" + ) + diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index cd63d233..762da847 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -7,6 +7,7 @@ from langgraph.graph.state import CompiledStateGraph from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.cycle_graph import IterationNodeConfig +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -28,6 +29,8 @@ class IterationRuntime: node_id: str, config: dict[str, Any], state: WorkflowState, + variable_pool: VariablePool, + child_variable_pool: VariablePool, ): """ Initialize the iteration runtime. @@ -44,11 +47,13 @@ class IterationRuntime: self.node_id = node_id self.typed_config = IterationNodeConfig(**config) self.looping = True + self.variable_pool = variable_pool + self.child_variable_pool = child_variable_pool self.output_value = None self.result: list = [] - def _init_iteration_state(self, item, idx): + async def _init_iteration_state(self, item, idx): """ Initialize a per-iteration copy of the workflow state. @@ -62,10 +67,9 @@ class IterationRuntime: loopstate = WorkflowState( **self.state ) - loopstate["runtime_vars"][self.node_id] = { - "item": item, - "index": idx, - } + self.child_variable_pool.copy(self.variable_pool) + await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True) + await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True) loopstate["node_outputs"][self.node_id] = { "item": item, "index": idx, @@ -74,6 +78,11 @@ class IterationRuntime: loopstate["activate"][self.start_id] = True return loopstate + def merge_conv_vars(self): + self.variable_pool.get_all_conversation_vars().update( + self.child_variable_pool.get_all_conversation_vars() + ) + async def run_task(self, item, idx): """ Execute a single iteration asynchronously. @@ -82,8 +91,8 @@ class IterationRuntime: item: The input element for this iteration. idx: The index of this iteration. """ - result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) - output = VariablePool(result).get(self.output_value) + result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) + output = self.child_variable_pool.get_value(self.output_value) if isinstance(output, list) and self.typed_config.flatten: self.result.extend(output) else: @@ -125,7 +134,7 @@ class IterationRuntime: input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip() self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip() - array_obj = VariablePool(self.state).get(input_expression) + array_obj = self.variable_pool.get_value(input_expression) if not isinstance(array_obj, list): raise RuntimeError("Cannot iterate over a non-list variable") child_state = [] @@ -137,14 +146,16 @@ class IterationRuntime: logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}") idx += self.typed_config.parallel_count child_state.extend(await asyncio.gather(*tasks)) + self.merge_conv_vars() else: # Execute iterations sequentially while idx < len(array_obj) and self.looping: logger.info(f"Iteration node {self.node_id}: running") item = array_obj[idx] - result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) child_state.append(result) - output = VariablePool(result).get(self.output_value) + output = self.child_variable_pool.get_value(self.output_value) + self.merge_conv_vars() if isinstance(output, list) and self.typed_config.flatten: self.result.extend(output) else: diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 6a15891f..7204a642 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -31,6 +31,8 @@ class LoopRuntime: node_id: str, config: dict[str, Any], state: WorkflowState, + variable_pool: VariablePool, + child_variable_pool: VariablePool ): """ Initialize the loop runtime executor. @@ -40,6 +42,8 @@ class LoopRuntime: node_id: The unique identifier of the loop node in the workflow. config: Raw configuration dictionary for the loop node. state: The current workflow state before entering the loop. + variable_pool: A VariablePool instance for accessing and modifying workflow variables. + child_variable_pool: A VariablePool instance for managing child node outputs. """ self.start_id = start_id self.graph = graph @@ -47,8 +51,10 @@ class LoopRuntime: self.node_id = node_id self.typed_config = LoopNodeConfig(**config) self.looping = True + self.variable_pool = variable_pool + self.child_variable_pool = child_variable_pool - def _init_loop_state(self): + async def _init_loop_state(self): """ Initialize workflow state for loop execution. @@ -62,33 +68,35 @@ class LoopRuntime: Returns: WorkflowState: A prepared workflow state used for loop execution. """ - pool = VariablePool(self.state) # 循环变量 - self.state["runtime_vars"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - variables=pool.get_all_conversation_vars(), - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), - ) - if variable.input_type == ValueInputType.VARIABLE - else TypeTransformer.transform(variable.value, variable.type) - for variable in self.typed_config.cycle_vars - } - self.state["node_outputs"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - variables=pool.get_all_conversation_vars(), - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), - ) - if variable.input_type == ValueInputType.VARIABLE - else TypeTransformer.transform(variable.value, variable.type) - for variable in self.typed_config.cycle_vars - } + self.child_variable_pool.copy(self.variable_pool) + + for variable in self.typed_config.cycle_vars: + if variable.input_type == ValueInputType.VARIABLE: + value = evaluate_expression( + expression=variable.value, + conv_var=self.variable_pool.get_all_conversation_vars(), + node_outputs=self.variable_pool.get_all_node_outputs(), + system_vars=self.variable_pool.get_all_system_vars(), + ) + else: + value = TypeTransformer.transform(variable.value, variable.type) + await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) loopstate = WorkflowState( **self.state ) + loopstate["node_outputs"][self.node_id] = { + variable.name: evaluate_expression( + expression=variable.value, + conv_var=self.variable_pool.get_all_conversation_vars(), + node_outputs=self.variable_pool.get_all_node_outputs(), + system_vars=self.variable_pool.get_all_system_vars(), + ) + if variable.input_type == ValueInputType.VARIABLE + else TypeTransformer.transform(variable.value, variable.type) + for variable in self.typed_config.cycle_vars + } + loopstate["looping"] = 1 loopstate["activate"][self.start_id] = True return loopstate @@ -134,7 +142,12 @@ class LoopRuntime: case _: raise ValueError(f"Invalid condition: {operator}") - def evaluate_conditional(self, state) -> bool: + def merge_conv_vars(self): + self.variable_pool.variables["conv"].update( + self.child_variable_pool.variables.get("conv", {}) + ) + + def evaluate_conditional(self) -> bool: """ Evaluate the loop continuation condition at runtime. @@ -143,18 +156,15 @@ class LoopRuntime: - Evaluates each comparison expression immediately - Combines results using the configured logical operator (AND / OR) - Args: - state: The current workflow state during loop execution. - Returns: bool: True if the loop should continue, False otherwise. """ conditions = [] for expression in self.typed_config.condition.expressions: - left_value = VariablePool(state).get(expression.left) + left_value = self.child_variable_pool.get_value(expression.left) evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( - VariablePool(state), + self.child_variable_pool, expression.left, expression.right, expression.input_type @@ -177,16 +187,18 @@ class LoopRuntime: Returns: dict[str, Any]: The final runtime variables of this loop node. """ - loopstate = self._init_loop_state() + loopstate = await self._init_loop_state() loop_time = self.typed_config.max_loop child_state = [] - while self.evaluate_conditional(loopstate) and self.looping and loop_time > 0: + while not self.evaluate_conditional() and self.looping and loop_time > 0: logger.info(f"loop node {self.node_id}: running") result = await self.graph.ainvoke(loopstate) child_state.append(result) + + self.merge_conv_vars() if result["looping"] == 2: self.looping = False loop_time -= 1 logger.info(f"loop node {self.node_id}: execution completed") - return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state} + return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state} diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 82782658..9a3cf6b2 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -6,9 +6,12 @@ from langgraph.graph.state import CompiledStateGraph from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -35,9 +38,38 @@ class CycleGraphNode(BaseNode): self.start_node_id = None # ID of the start node within the cycle self.graph: StateGraph | CompiledStateGraph | None = None + self.child_variable_pool: VariablePool | None = None self.build_graph() self.iteration_flag = True + def _output_types(self) -> dict[str, VariableType]: + outputs = {"__child_state": VariableType.ARRAY_OBJECT} + if self.node_type == NodeType.LOOP: + # Loop node outputs the final state of the loop + config = LoopNodeConfig(**self.config) + for var_def in config.cycle_vars: + outputs[var_def.name] = var_def.type + return outputs + elif self.node_type == NodeType.ITERATION: + # Iteration node outputs the processed collection + config = IterationNodeConfig(**self.config) + if config.output_type in [ + VariableType.ARRAY_FILE, + VariableType.ARRAY_STRING, + VariableType.NUMBER, + VariableType.ARRAY_OBJECT, + VariableType.BOOLEAN + ]: + if config.flatten: + outputs['output'] = config.output_type + else: + outputs['output'] = VariableType.ARRAY_STRING + else: + outputs['output'] = VariableType(f"array[{config.output_type}]") + return outputs + else: + raise KeyError(f"Valid Cycle Node Type - {self.node_type}") + def pure_cycle_graph(self) -> tuple[list, list]: """ Extract cycle-scoped nodes and internal edges from the workflow configuration. @@ -103,17 +135,20 @@ class CycleGraphNode(BaseNode): """ from app.core.workflow.graph_builder import GraphBuilder self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.child_variable_pool = VariablePool() builder = GraphBuilder( { "nodes": self.cycle_nodes, "edges": self.cycle_edges, }, - subgraph=True + subgraph=True, + variable_pool=self.child_variable_pool ) self.start_node_id = builder.start_node_id self.graph = builder.build() + self.child_variable_pool = builder.variable_pool - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the cycle node at runtime. @@ -123,6 +158,7 @@ class CycleGraphNode(BaseNode): Args: state: The current workflow state when entering the cycle node. + variable_pool: Variable Pool Returns: Any: The runtime result produced by the loop or iteration executor. @@ -137,6 +173,8 @@ class CycleGraphNode(BaseNode): node_id=self.node_id, config=self.config, state=state, + variable_pool=variable_pool, + child_variable_pool=self.child_variable_pool, ).run() if self.node_type == NodeType.ITERATION: return await IterationRuntime( @@ -145,5 +183,7 @@ class CycleGraphNode(BaseNode): node_id=self.node_id, config=self.config, state=state, + variable_pool=variable_pool, + child_variable_pool=self.child_variable_pool ).run() raise RuntimeError("Unknown cycle node type") diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index 50e84a36..f534dfb5 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -2,7 +2,8 @@ from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class EndNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 3a5153a9..a13a8153 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -7,6 +7,8 @@ End 节点实现 import logging from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -17,12 +19,18 @@ class EndNode(BaseNode): 工作流的结束节点,根据配置的模板输出最终结果。 支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。 """ + def _output_types(self) -> dict[str, VariableType]: + """声明此节点的输出类型""" + return { + "output": VariableType.STRING + } - async def execute(self, state: WorkflowState) -> str: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> str: """执行 end 节点业务逻辑 Args: state: 工作流状态 + variable_pool: 变量池 Returns: 最终输出字符串 @@ -34,7 +42,7 @@ class EndNode(BaseNode): # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: - output = self._render_template(output_template, state, strict=False) + output = self._render_template(output_template, variable_pool, strict=False) else: output = "" diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index aaf49a11..6ad1c6a8 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -9,7 +9,6 @@ class NodeType(StrEnum): KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" IF_ELSE = "if-else" CODE = "code" - TRANSFORM = "transform" QUESTION_CLASSIFIER = "question-classifier" HTTP_REQUEST = "http-request" TOOL = "tool" diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 141cba79..64fdfcb9 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -10,6 +10,8 @@ from httpx import AsyncClient, Response, Timeout from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__file__) @@ -34,6 +36,14 @@ class HttpRequestNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: HttpRequestNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + return { + "body": VariableType.STRING, + "status_code": VariableType.NUMBER, + "headers": VariableType.OBJECT, + "output": VariableType.STRING + } + def _build_timeout(self) -> Timeout: """ Build httpx Timeout configuration. @@ -50,7 +60,7 @@ class HttpRequestNode(BaseNode): ) return timeout - def _build_auth(self, state: WorkflowState) -> dict[str, str]: + def _build_auth(self, variable_pool: VariablePool) -> dict[str, str]: """ Build authentication-related HTTP headers. @@ -58,12 +68,12 @@ class HttpRequestNode(BaseNode): the current workflow runtime state. Args: - state: Current workflow runtime state. + variable_pool: Variable Pool Returns: A dictionary of HTTP headers used for authentication. """ - api_key = self._render_template(self.typed_config.auth.api_key, state) + api_key = self._render_template(self.typed_config.auth.api_key, variable_pool) match self.typed_config.auth.auth_type: case HttpAuthType.NONE: return {} @@ -82,7 +92,7 @@ class HttpRequestNode(BaseNode): case _: raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}") - def _build_header(self, state: WorkflowState) -> dict[str, str]: + def _build_header(self, variable_pool: VariablePool) -> dict[str, str]: """ Build HTTP request headers. @@ -90,10 +100,10 @@ class HttpRequestNode(BaseNode): """ headers = {} for key, value in self.typed_config.headers.items(): - headers[self._render_template(key, state)] = self._render_template(value, state) + headers[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) return headers - def _build_params(self, state: WorkflowState) -> dict[str, str]: + def _build_params(self, variable_pool: VariablePool) -> dict[str, str]: """ Build URL query parameters. @@ -101,10 +111,10 @@ class HttpRequestNode(BaseNode): """ params = {} for key, value in self.typed_config.params.items(): - params[self._render_template(key, state)] = self._render_template(value, state) + params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) return params - def _build_content(self, state) -> dict[str, Any]: + def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: """ Build HTTP request body arguments for httpx request methods. @@ -120,13 +130,13 @@ class HttpRequestNode(BaseNode): return {} case HttpContentType.JSON: content["json"] = json.loads(self._render_template( - self.typed_config.body.data, state + self.typed_config.body.data, variable_pool )) case HttpContentType.FROM_DATA: data = {} for item in self.typed_config.body.data: if item.type == "text": - data[self._render_template(item.key, state)] = self._render_template(item.value, state) + data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool) elif item.type == "file": # TODO: File support (Feature) pass @@ -136,11 +146,11 @@ class HttpRequestNode(BaseNode): pass case HttpContentType.WWW_FORM: content["data"] = json.loads(self._render_template( - json.dumps(self.typed_config.body.data), state + json.dumps(self.typed_config.body.data), variable_pool )) case HttpContentType.RAW: - content["content"] = self._render_template(self.typed_config.body.data, state) + content["content"] = self._render_template(self.typed_config.body.data, variable_pool) case _: raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}") return content @@ -165,7 +175,7 @@ class HttpRequestNode(BaseNode): case _: raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}") - async def execute(self, state: WorkflowState) -> dict | str: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str: """ Execute the HTTP request node. @@ -176,6 +186,7 @@ class HttpRequestNode(BaseNode): Args: state: Current workflow runtime state. + variable_pool: Variable Pool Returns: - dict: Serialized HttpRequestNodeOutput on success @@ -185,8 +196,8 @@ class HttpRequestNode(BaseNode): async with httpx.AsyncClient( verify=self.typed_config.verify_ssl, timeout=self._build_timeout(), - headers=self._build_header(state) | self._build_auth(state), - params=self._build_params(state), + headers=self._build_header(variable_pool) | self._build_auth(variable_pool), + params=self._build_params(variable_pool), follow_redirects=True ) as client: retries = self.typed_config.retry.max_attempts @@ -194,8 +205,8 @@ class HttpRequestNode(BaseNode): try: request_func = self._get_client_method(client) resp = await request_func( - url=self._render_template(self.typed_config.url, state), - **self._build_content(state) + url=self._render_template(self.typed_config.url, variable_pool), + **self._build_content(variable_pool) ) resp.raise_for_status() logger.info(f"Node {self.node_id}: HTTP request succeeded") diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index cf5a1499..3c6d0e36 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -6,6 +6,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -15,6 +17,11 @@ class IfElseNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: IfElseNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + return { + "output": VariableType.STRING + } + @staticmethod def _evaluate(operator, instance: CompareOperatorInstance) -> Any: match operator: @@ -45,7 +52,7 @@ class IfElseNode(BaseNode): case _: raise ValueError(f"Invalid condition: {operator}") - def evaluate_conditional_edge_expressions(self, state) -> list[bool]: + def evaluate_conditional_edge_expressions(self, variable_pool: VariablePool) -> list[bool]: """ Build conditional edge expressions for the If-Else node. @@ -72,11 +79,11 @@ class IfElseNode(BaseNode): pattern = r"\{\{\s*(.*?)\s*\}\}" left_string = re.sub(pattern, r"\1", expression.left).strip() try: - left_value = self.get_variable(left_string, state) + left_value = self.get_variable(left_string, variable_pool) except KeyError: left_value = None evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( - self.get_variable_pool(state), + variable_pool, expression.left, expression.right, expression.input_type @@ -95,7 +102,7 @@ class IfElseNode(BaseNode): return conditions - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the conditional branching logic of the node. @@ -105,13 +112,13 @@ class IfElseNode(BaseNode): Args: state (WorkflowState): The current workflow state, containing variables, messages, node outputs, etc. + variable_pool: Variable Pool Returns: str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. """ self.typed_config = IfElseNodeConfig(**self.config) - expressions = self.evaluate_conditional_edge_expressions(state) - # TODO: 变量类型及文本类型解析 + expressions = self.evaluate_conditional_edge_expressions(variable_pool) for i in range(len(expressions)): if expressions[i]: logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}") diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index 822f1918..240b003b 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -5,6 +5,8 @@ from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig from app.core.workflow.template_renderer import TemplateRenderer +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,7 +16,12 @@ class JinjaRenderNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: JinjaRenderNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return { + "output": VariableType.STRING + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the node: render the Jinja2 template with mapped variables. @@ -24,6 +31,7 @@ class JinjaRenderNode(BaseNode): Args: state (WorkflowState): Current workflow state containing variables, node outputs, and runtime variables. + variable_pool: Variable Pool Returns: dict[str, Any]: Node output dictionary containing the rendered result @@ -40,7 +48,7 @@ class JinjaRenderNode(BaseNode): context = {} for variable in self.typed_config.mapping: try: - context[variable.name] = self.get_variable(variable.value, state) + context[variable.name] = self.get_variable(variable.value, variable_pool) except Exception: logger.info(f"variable not found, var: {variable.value}") continue diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 997135f3..1e146721 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -8,6 +8,8 @@ from app.core.models import RedBearRerank, RedBearModelConfig from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.models import knowledge_model, knowledgeshare_model, ModelType from app.repositories import knowledge_repository, knowledgeshare_repository @@ -22,6 +24,11 @@ class KnowledgeRetrievalNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: KnowledgeRetrievalNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + return { + "output": VariableType.ARRAY_STRING + } + @staticmethod def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): """ @@ -149,7 +156,7 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the knowledge retrieval workflow node. @@ -163,6 +170,7 @@ class KnowledgeRetrievalNode(BaseNode): Args: state (WorkflowState): Current workflow execution state. + variable_pool: Variable Pool Returns: Any: List of retrieved knowledge chunks (dict format). @@ -171,7 +179,7 @@ class KnowledgeRetrievalNode(BaseNode): RuntimeError: If no valid knowledge base is found or access is denied. """ self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) - query = self._render_template(self.typed_config.query, state) + query = self._render_template(self.typed_config.query, variable_pool) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases]) diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 265724f3..48c51aa1 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -4,7 +4,8 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class MessageConfig(BaseModel): diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index f315b238..1246324d 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -15,6 +15,8 @@ from app.core.exceptions import BusinessException from app.core.models import RedBearLLM, RedBearModelConfig from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.llm.config import LLMNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_context from app.models import ModelType from app.services.model_service import ModelConfigService @@ -66,19 +68,27 @@ class LLMNode(BaseNode): - ai/assistant: AI 消息(AIMessage) """ + def _output_types(self) -> dict[str, VariableType]: + return {"output": VariableType.STRING} + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: LLMNodeConfig | None = None - def _render_context(self, message, state): - context = f"{self._render_template(self.typed_config.context, state)}" + def _render_context(self, message: str, variable_pool: VariablePool): + context = f"{self._render_template(self.typed_config.context, variable_pool)}" return re.sub(r"{{context}}", context, message) - def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]: + def _prepare_llm( + self, + state: WorkflowState, + variable_pool: VariablePool, + stream: bool = False + ) -> tuple[RedBearLLM, list | str]: """准备 LLM 实例(公共逻辑) Args: - state: 工作流状态 + variable_pool: 变量池 Returns: (llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串 @@ -94,8 +104,8 @@ class LLMNode(BaseNode): for msg_config in messages_config: role = msg_config.role.lower() content_template = msg_config.content - content_template = self._render_context(content_template, state) - content = self._render_template(content_template, state) + content_template = self._render_context(content_template, variable_pool) + content = self._render_template(content_template, variable_pool) # 根据角色创建对应的消息对象 if role == "system": @@ -115,7 +125,7 @@ class LLMNode(BaseNode): else: # 使用简单的 prompt 格式(向后兼容) prompt_template = self.config.get("prompt", "") - prompt_or_messages = self._render_template(prompt_template, state) + prompt_or_messages = self._render_template(prompt_template, variable_pool) # 2. 获取模型配置 model_id = self.config.get("model_id") @@ -159,17 +169,18 @@ class LLMNode(BaseNode): return llm, prompt_or_messages - async def execute(self, state: WorkflowState) -> AIMessage: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: """非流式执行 LLM 调用 Args: state: 工作流状态 + variable_pool: 变量池 Returns: LLM 响应消息 """ # self.typed_config = LLMNodeConfig(**self.config) - llm, prompt_or_messages = self._prepare_llm(state, True) + llm, prompt_or_messages = self._prepare_llm(state, variable_pool, False) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") @@ -186,9 +197,9 @@ class LLMNode(BaseNode): # 返回 AIMessage(包含响应元数据) return response if isinstance(response, AIMessage) else AIMessage(content=content) - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录)""" - _, prompt_or_messages = self._prepare_llm(state) + _, prompt_or_messages = self._prepare_llm(state, variable_pool) return { "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, @@ -221,18 +232,19 @@ class LLMNode(BaseNode): } return None - async def execute_stream(self, state: WorkflowState): + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): """流式执行 LLM 调用 Args: state: 工作流状态 + variable_pool: 变量池 Yields: 文本片段(chunk)或完成标记 """ self.typed_config = LLMNodeConfig(**self.config) - llm, prompt_or_messages = self._prepare_llm(state, True) + llm, prompt_or_messages = self._prepare_llm(state, variable_pool, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 13860bec..ddbe4b99 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -3,6 +3,8 @@ from typing import Any from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.services.memory_agent_service import MemoryAgentService from app.tasks import write_message_task @@ -13,17 +15,23 @@ class MemoryReadNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: MemoryReadNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return { + "answer": VariableType.STRING, + "intermediate_outputs": VariableType.ARRAY_OBJECT + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryReadNodeConfig(**self.config) with get_db_read() as db: - end_user_id = self.get_variable("sys.user_id", state) + end_user_id = self.get_variable("sys.user_id", variable_pool) if not end_user_id: raise RuntimeError("End user id is required") return await MemoryAgentService().read_memory( end_user_id=end_user_id, - message=self._render_template(self.typed_config.message, state), + message=self._render_template(self.typed_config.message, variable_pool), config_id=self.typed_config.config_id, search_switch=self.typed_config.search_switch, history=[], @@ -38,16 +46,19 @@ class MemoryWriteNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: MemoryWriteNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return {"output": VariableType.STRING} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryWriteNodeConfig(**self.config) - end_user_id = self.get_variable("sys.user_id", state) + end_user_id = self.get_variable("sys.user_id", variable_pool) if not end_user_id: raise RuntimeError("End user id is required") write_message_task.delay( end_user_id, - self._render_template(self.typed_config.message, state), + self._render_template(self.typed_config.message, variable_pool), str(self.typed_config.config_id), "neo4j", "" diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index fb2fe00f..00120ca0 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -22,7 +22,6 @@ from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.start import StartNode -from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.breaker import BreakNode @@ -37,7 +36,6 @@ WorkflowNode = Union[ LLMNode, IfElseNode, AgentNode, - TransformNode, AssignerNode, HttpRequestNode, KnowledgeRetrievalNode, @@ -67,7 +65,6 @@ class NodeFactory: NodeType.END: EndNode, NodeType.LLM: LLMNode, NodeType.AGENT: AgentNode, - NodeType.TRANSFORM: TransformNode, NodeType.IF_ELSE: IfElseNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.ASSIGNER: AssignerNode, diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index ad38284a..251d6a79 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -1,9 +1,9 @@ import json import re from abc import ABC -from typing import Union, Type, NoReturn +from typing import Union, Type, NoReturn, Any -from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.nodes.enums import ValueInputType from app.core.workflow.variable_pool import VariablePool @@ -69,7 +69,7 @@ class TypeTransformer: class OperatorBase(ABC): - def __init__(self, pool: VariablePool, left_selector, right): + def __init__(self, pool: VariablePool, left_selector: str, right: Any): self.pool = pool self.left_selector = left_selector self.right = right @@ -77,7 +77,7 @@ class OperatorBase(ABC): self.type_limit: type[str, int, dict, list] = None def check(self, no_right=False): - left = self.pool.get(self.left_selector) + left = self.pool.get_value(self.left_selector) if not isinstance(left, self.type_limit): raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") @@ -92,13 +92,13 @@ class StringOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = str - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, '') + await self.pool.set(self.left_selector, '') class NumberOperator(OperatorBase): @@ -106,33 +106,33 @@ class NumberOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = (float, int) - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, 0) + await self.pool.set(self.left_selector, 0) - def add(self) -> None: + async def add(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin + self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin + self.right) - def subtract(self) -> None: + async def subtract(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin - self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin - self.right) - def multiply(self) -> None: + async def multiply(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin * self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin * self.right) - def divide(self) -> None: + async def divide(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin / self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin / self.right) class BooleanOperator(OperatorBase): @@ -140,13 +140,13 @@ class BooleanOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = bool - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, False) + await self.pool.set(self.left_selector, False) class ArrayOperator(OperatorBase): @@ -154,38 +154,37 @@ class ArrayOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = list - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, list()) + await self.pool.set(self.left_selector, list()) - def append(self) -> None: + async def append(self) -> None: self.check(no_right=True) - # TODO:require type limit in list - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.append(self.right) - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) - def extend(self) -> None: + async def extend(self) -> None: self.check(no_right=True) - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.extend(self.right) - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) - def remove_last(self) -> None: + async def remove_last(self) -> None: self.check(no_right=True) - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.pop() - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) - def remove_first(self) -> None: + async def remove_first(self) -> None: self.check(no_right=True) - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.pop(0) - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) class ObjectOperator(OperatorBase): @@ -193,13 +192,13 @@ class ObjectOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = dict - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, dict()) + await self.pool.set(self.left_selector, dict()) class AssignmentOperatorResolver: @@ -245,7 +244,7 @@ class ConditionBase(ABC): self.right_selector = right_selector self.input_type = input_type - self.left_value = self.pool.get(self.left_selector) + self.left_value = self.pool.get_value(self.left_selector) self.right_value = self.resolve_right_literal_value() self.type_limit = getattr(self, "type_limit", None) @@ -254,7 +253,7 @@ class ConditionBase(ABC): if self.input_type == ValueInputType.VARIABLE: pattern = r"\{\{\s*(.*?)\s*\}\}" right_expression = re.sub(pattern, r"\1", self.right_selector).strip() - return self.pool.get(right_expression) + return self.pool.get_value(right_expression) elif self.input_type == ValueInputType.CONSTANT: return self.right_selector raise RuntimeError("Unsupported variable type") diff --git a/api/app/core/workflow/nodes/parameter_extractor/config.py b/api/app/core/workflow/nodes/parameter_extractor/config.py index cfbd9c14..a0b9c032 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/config.py +++ b/api/app/core/workflow/nodes/parameter_extractor/config.py @@ -1,7 +1,7 @@ import uuid +from enum import StrEnum from pydantic import Field, BaseModel -from enum import StrEnum from app.core.workflow.nodes.base_config import BaseNodeConfig diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index ec58d96c..475c54fe 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -12,6 +12,8 @@ from app.core.models import RedBearLLM, RedBearModelConfig from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.models import ModelType from app.services.model_service import ModelConfigService @@ -24,6 +26,12 @@ class ParameterExtractorNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: ParameterExtractorNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + outputs = {} + for param in self.typed_config.params: + outputs[param.name] = param.type + return outputs + @staticmethod def _get_prompt(): """ @@ -120,7 +128,7 @@ class ParameterExtractorNode(BaseNode): field_type[param.name] = f'{param.type}, required:{str(param.required)}' return field_type - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Main execution function for this node. @@ -138,6 +146,7 @@ class ParameterExtractorNode(BaseNode): Args: state (WorkflowState): Current state of the workflow, used for template rendering. + variable_pool (VariablePool): Used for accessing and setting variables during execution. Returns: dict[str, Any]: Dictionary containing extracted parameters under the "output" key. @@ -153,7 +162,7 @@ class ParameterExtractorNode(BaseNode): rendered_user_prompt = user_prompt_teplate.render( field_descriptions=str(self._get_field_desc()), field_type=str(self._get_field_type()), - text_input=self._render_template(self.typed_config.text, state) + text_input=self._render_template(self.typed_config.text, variable_pool) ) messages = [ @@ -162,7 +171,7 @@ class ParameterExtractorNode(BaseNode): ] if self.typed_config.prompt: messages.extend([ - ("user", self._render_template(self.typed_config.prompt, state)), + ("user", self._render_template(self.typed_config.prompt, variable_pool)), ("user", rendered_user_prompt), ]) else: diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 6df410cb..d7496f12 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -6,6 +6,8 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie from app.core.models import RedBearLLM, RedBearModelConfig from app.core.exceptions import BusinessException from app.core.error_codes import BizCode +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.models import ModelType from app.services.model_service import ModelConfigService @@ -24,6 +26,12 @@ class QuestionClassifierNode(BaseNode): self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} + def _output_types(self) -> dict[str, VariableType]: + return { + "class_name": VariableType.STRING, + "output": VariableType.STRING + } + def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" with get_db_read() as db: @@ -65,7 +73,7 @@ class QuestionClassifierNode(BaseNode): category_map[category_name] = case_tag return category_map - async def execute(self, state: WorkflowState) -> dict: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict: """执行问题分类""" self.typed_config = QuestionClassifierNodeConfig(**self.config) self.category_to_case_map = self._build_category_case_map() @@ -102,7 +110,7 @@ class QuestionClassifierNode(BaseNode): categories=", ".join(category_names), supplement_prompt=supplement_prompt ), - state + variable_pool ) messages = [ diff --git a/api/app/core/workflow/nodes/start/config.py b/api/app/core/workflow/nodes/start/config.py index 1544f89f..98390bf7 100644 --- a/api/app/core/workflow/nodes/start/config.py +++ b/api/app/core/workflow/nodes/start/config.py @@ -2,7 +2,8 @@ from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class StartNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index 69560422..db66bc65 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -7,9 +7,10 @@ Start 节点实现 import logging from typing import Any -from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.start.config import StartNodeConfig +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -36,14 +37,25 @@ class StartNode(BaseNode): # 解析并验证配置 self.typed_config: StartNodeConfig | None = None + self.output_var_types = {} - async def execute(self, state: WorkflowState) -> dict[str, Any]: + def _output_types(self) -> dict[str, VariableType]: + return self.output_var_types | { + "message": VariableType.STRING, + "execution_id": VariableType.STRING, + "conversation_id": VariableType.STRING, + "workspace_id": VariableType.STRING, + "user_id": VariableType.STRING, + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """执行 start 节点业务逻辑 Start 节点输出系统变量、会话变量和自定义变量。 Args: state: 工作流状态 + variable_pool: 变量池 Returns: 包含系统参数、会话变量和自定义变量的字典 @@ -51,19 +63,16 @@ class StartNode(BaseNode): self.typed_config = StartNodeConfig(**self.config) logger.info(f"节点 {self.node_id} (Start) 开始执行") - # 创建变量池实例(在方法内复用) - pool = self.get_variable_pool(state) - # 处理自定义变量(传入 pool 避免重复创建) - custom_vars = self._process_custom_variables(pool) + custom_vars = self._process_custom_variables(variable_pool) # 返回业务数据(包含自定义变量) result = { - "message": pool.get("sys.message"), - "execution_id": pool.get("sys.execution_id"), - "conversation_id": pool.get("sys.conversation_id"), - "workspace_id": pool.get("sys.workspace_id"), - "user_id": pool.get("sys.user_id"), + "message": variable_pool.get_value("sys.message"), + "execution_id": variable_pool.get_value("sys.execution_id"), + "conversation_id": variable_pool.get_value("sys.conversation_id"), + "workspace_id": variable_pool.get_value("sys.workspace_id"), + "user_id": variable_pool.get_value("sys.user_id"), **custom_vars # 自定义变量作为节点输出的一部分 } @@ -74,7 +83,7 @@ class StartNode(BaseNode): return result - def _process_custom_variables(self, pool) -> dict[str, Any]: + def _process_custom_variables(self, pool: VariablePool) -> dict[str, Any]: """处理自定义变量 从输入数据中提取自定义变量,应用默认值和验证。 @@ -89,13 +98,14 @@ class StartNode(BaseNode): ValueError: 缺少必需变量 """ # 获取输入数据中的自定义变量 - input_variables = pool.get("sys.input_variables", default={}) + input_variables = pool.get_value("sys.input_variables", default={}, strict=False) processed = {} # 遍历配置的变量定义 for var_def in self.typed_config.variables: var_name = var_def.name + var_type = var_def.type # 检查变量是否存在 if var_name in input_variables: @@ -116,21 +126,12 @@ class StartNode(BaseNode): f"变量 '{var_name}' 使用默认值: {var_def.default}" ) else: - match var_def.type: - case VariableType.STRING: - processed[var_name] = "" - case VariableType.NUMBER: - processed[var_name] = 0 - case VariableType.OBJECT: - processed[var_name] = {} - case VariableType.BOOLEAN: - processed[var_name] = False - case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING: - processed[var_name] = [] + processed[var_name] = DEFAULT_VALUE(var_type) + self.output_var_types[var_name] = var_type return processed - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录) Args: @@ -139,11 +140,9 @@ class StartNode(BaseNode): Returns: 输入数据字典 """ - pool = self.get_variable_pool(state) - return { - "execution_id": pool.get("sys.execution_id"), - "conversation_id": pool.get("sys.conversation_id"), - "message": pool.get("sys.message"), - "conversation_vars": pool.get_all_conversation_vars() + "execution_id": variable_pool.get_value("sys.execution_id"), + "conversation_id": variable_pool.get_value("sys.conversation_id"), + "message": variable_pool.get_value("sys.message"), + "conversation_vars": variable_pool.get_all_conversation_vars() } diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index aba96303..adc55d87 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -6,6 +6,8 @@ from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.tool.config import ToolNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.services.tool_service import ToolService from app.db import get_db_read @@ -21,13 +23,20 @@ class ToolNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: ToolNodeConfig | None = None - async def execute(self, state: WorkflowState) -> dict[str, Any]: + def _output_types(self) -> dict[str, VariableType]: + return { + "data": VariableType.STRING, + "error_code": VariableType.STRING, + "execution_time": VariableType.NUMBER + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """执行工具""" self.typed_config = ToolNodeConfig(**self.config) # 获取租户ID和用户ID - tenant_id = self.get_variable("sys.tenant_id", state) - user_id = self.get_variable("sys.user_id", state) - workspace_id = self.get_variable("sys.workspace_id", state) + tenant_id = self.get_variable("sys.tenant_id", variable_pool, strict=False) + user_id = self.get_variable("sys.user_id", variable_pool) + workspace_id = self.get_variable("sys.workspace_id", variable_pool) # 如果没有租户ID,尝试从工作流ID获取 if not tenant_id: @@ -48,7 +57,7 @@ class ToolNode(BaseNode): for param_name, param_template in self.typed_config.tool_parameters.items(): if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): try: - rendered_value = self._render_template(param_template, state) + rendered_value = self._render_template(param_template, variable_pool) except Exception as e: raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e else: diff --git a/api/app/core/workflow/nodes/transform/__init__.py b/api/app/core/workflow/nodes/transform/__init__.py deleted file mode 100644 index 384b818c..00000000 --- a/api/app/core/workflow/nodes/transform/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Transform 节点""" - -from app.core.workflow.nodes.transform.node import TransformNode -from app.core.workflow.nodes.transform.config import TransformNodeConfig - -__all__ = ["TransformNode", "TransformNodeConfig"] diff --git a/api/app/core/workflow/nodes/transform/config.py b/api/app/core/workflow/nodes/transform/config.py deleted file mode 100644 index 47d2a6ac..00000000 --- a/api/app/core/workflow/nodes/transform/config.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Transform 节点配置""" - -from typing import Literal - -from pydantic import Field - -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType - - -class TransformNodeConfig(BaseNodeConfig): - """Transform 节点配置 - - 用于数据转换和处理。 - """ - - transform_type: Literal["template", "code", "json"] = Field( - default="template", - description="转换类型:template(模板), code(代码), json(JSON处理)" - ) - - # 模板模式 - template: str | None = Field( - default=None, - description="转换模板,支持变量引用" - ) - - # 代码模式 - code: str | None = Field( - default=None, - description="Python 代码,用于数据转换" - ) - - # JSON 模式 - json_path: str | None = Field( - default=None, - description="JSON 路径表达式" - ) - - # 输入变量 - inputs: dict[str, str] | None = Field( - default=None, - description="输入变量映射,key 为变量名,value 为变量选择器" - ) - - # 输出变量 - output_key: str = Field( - default="result", - description="输出变量的键名" - ) - - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="result", - type=VariableType.STRING, - description="转换后的结果" - ) - ], - description="输出变量定义(根据 output_key 动态生成)" - ) - - class Config: - json_schema_extra = { - "examples": [ - { - "transform_type": "template", - "template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}", - "output_key": "formatted_result" - }, - { - "transform_type": "code", - "code": "result = input_text.upper()", - "inputs": { - "input_text": "{{ sys.message }}" - }, - "output_key": "uppercase_text" - } - ] - } diff --git a/api/app/core/workflow/nodes/transform/node.py b/api/app/core/workflow/nodes/transform/node.py deleted file mode 100644 index 4211c510..00000000 --- a/api/app/core/workflow/nodes/transform/node.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Transform 节点实现 - -数据转换节点,用于处理和转换数据。 -""" - -import logging -from typing import Any - -from app.core.workflow.nodes.base_node import BaseNode, WorkflowState - -logger = logging.getLogger(__name__) - - -class TransformNode(BaseNode): - """数据转换节点 - - 配置示例: - { - "type": "transform", - "config": { - "mapping": { - "output_field": "{{node.previous.output}}", - "processed": "{{var.input | upper}}" - } - } - } - """ - - async def execute(self, state: WorkflowState) -> dict[str, Any]: - """执行数据转换 - - Args: - state: 工作流状态 - - Returns: - 状态更新字典 - """ - logger.info(f"节点 {self.node_id} 开始执行数据转换") - - # 获取映射配置 - mapping = self.config.get("mapping", {}) - - # 执行数据转换 - transformed_data = {} - for target_key, source_template in mapping.items(): - # 渲染模板获取值 - value = self._render_template(str(source_template), state) - transformed_data[target_key] = value - - logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}") - - return { - "node_outputs": { - self.node_id: { - "output": transformed_data, - "status": "completed" - } - } - } diff --git a/api/app/core/workflow/nodes/variable_aggregator/config.py b/api/app/core/workflow/nodes/variable_aggregator/config.py index ac1419a4..dbcc08e2 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/config.py +++ b/api/app/core/workflow/nodes/variable_aggregator/config.py @@ -1,6 +1,7 @@ from pydantic import Field, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.variable.base_variable import VariableType class VariableAggregatorNodeConfig(BaseNodeConfig): @@ -14,6 +15,11 @@ class VariableAggregatorNodeConfig(BaseNodeConfig): description="需要被聚合的变量" ) + group_type: dict[str, VariableType] = Field( + ..., + description="每个分组的变量类型" + ) + @field_validator("group_variables") @classmethod def group_variables_validator(cls, v, info): diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index 5bff8e33..48ee9f85 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -5,6 +5,8 @@ from typing import Any from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,6 +16,13 @@ class VariableAggregatorNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: VariableAggregatorNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + config = VariableAggregatorNodeConfig(**self.config) + output = {} + for var_type in config.group_type: + output[var_type] = config.group_type[var_type] + return output + @staticmethod def _get_express(variable_string: str) -> Any: """ @@ -29,7 +38,7 @@ class VariableAggregatorNode(BaseNode): expression = re.sub(pattern, r"\1", variable_string).strip() return expression - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the variable aggregation logic. @@ -45,7 +54,7 @@ class VariableAggregatorNode(BaseNode): for variable in self.typed_config.group_variables: var_express = self._get_express(variable) try: - value = self.get_variable(var_express, state) + value = self.get_variable(var_express, variable_pool) except Exception as e: logger.warning(f"Failed to get variable '{var_express}': {e}") continue @@ -55,7 +64,7 @@ class VariableAggregatorNode(BaseNode): return value logger.info("No variable found in non-group mode; returning empty string.") - return "" + return DEFAULT_VALUE(self.typed_config.group_type["output"]) # -------------------------- # Group mode @@ -65,7 +74,7 @@ class VariableAggregatorNode(BaseNode): for variable in variables: var_express = self._get_express(variable) try: - value = self.get_variable(var_express, state) + value = self.get_variable(var_express, variable_pool) except Exception as e: logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}") continue @@ -74,7 +83,7 @@ class VariableAggregatorNode(BaseNode): result[group_name] = value break else: - result[group_name] = "" + result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name]) logger.info(f"No variable found for group '{group_name}'; set empty string.") logger.info(f"Node: {self.node_id} variable aggregation result: {result}") return result diff --git a/api/app/core/workflow/template_renderer.py b/api/app/core/workflow/template_renderer.py index c2d7f255..9e2a28e8 100644 --- a/api/app/core/workflow/template_renderer.py +++ b/api/app/core/workflow/template_renderer.py @@ -43,7 +43,7 @@ class TemplateRenderer: def render( self, template: str, - variables: dict[str, Any], + conv_vars: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> str: @@ -51,7 +51,7 @@ class TemplateRenderer: Args: template: 模板字符串 - variables: 用户定义的变量 + conv_vars: 会话变量 node_outputs: 节点输出结果 system_vars: 系统变量 @@ -80,20 +80,11 @@ class TemplateRenderer: '分析结果: 正面情绪' """ # 构建命名空间上下文 - # variables 的结构:{"sys": {...}, "conv": {...}} - sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} - conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} - if self.strict: - context = defaultdict(dict) - context["conv"] = conv_vars - context["node"] = node_outputs - context["sys"] = {**(system_vars or {}), **sys_vars} - else: - context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) - } + context = { + "conv": conv_vars, # 会话变量:{{conv.user_name}} + "node": node_outputs, # 节点输出:{{node.node_1.output}} + "sys": system_vars, # 系统变量:{{sys.execution_id}} + } # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} # 将所有节点输出添加到顶层上下文 @@ -157,9 +148,9 @@ _default_renderer = TemplateRenderer(strict=True) def render_template( template: str, - variables: dict[str, Any], + conv_vars: dict[str, Any], node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None, + system_vars: dict[str, Any], strict: bool = True ) -> str: """渲染模板(便捷函数) @@ -167,7 +158,7 @@ def render_template( Args: strict: 严格模式 template: 模板字符串 - variables: 用户变量 + conv_vars: 会话变量 node_outputs: 节点输出 system_vars: 系统变量 @@ -184,7 +175,7 @@ def render_template( '请分析: 这是一段文本' """ renderer = TemplateRenderer(strict=strict) - return renderer.render(template, variables, node_outputs, system_vars) + return renderer.render(template, conv_vars, node_outputs, system_vars) def validate_template(template: str) -> list[str]: diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 6daf415d..96fc35ad 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -5,10 +5,13 @@ """ import logging -from typing import Any, Union +from typing import Any, Union, TYPE_CHECKING from app.core.workflow.nodes.enums import NodeType +if TYPE_CHECKING: + from app.schemas.workflow_schema import WorkflowConfig + logger = logging.getLogger(__name__) @@ -64,7 +67,7 @@ class WorkflowValidator: return cycle_nodes, cycle_edges @classmethod - def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list: + def get_subgraph(cls, workflow_config: Union[dict[str, Any], "WorkflowConfig"]) -> list: if not isinstance(workflow_config, dict): workflow_config = { "nodes": workflow_config.nodes, @@ -331,7 +334,7 @@ class WorkflowValidator: def validate_workflow_config( - workflow_config: dict[str, Any], + workflow_config: Union[dict[str, Any], 'WorkflowConfig'], for_publish: bool = False ) -> tuple[bool, list[str]]: """验证工作流配置(便捷函数) diff --git a/api/app/core/workflow/variable/__init__.py b/api/app/core/workflow/variable/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py new file mode 100644 index 00000000..221c7917 --- /dev/null +++ b/api/app/core/workflow/variable/base_variable.py @@ -0,0 +1,162 @@ +from enum import StrEnum +from abc import abstractmethod, ABC +from typing import Any + + +class VariableType(StrEnum): + """Enumeration of supported variable types in the workflow.""" + + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + OBJECT = "object" + FILE = "file" + + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILE = "array[file]" + + NESTED_ARRAY = "array_nest" + + @classmethod + def type_map(cls, var: Any) -> "VariableType": + """Maps a Python value to a corresponding VariableType. + + Args: + var: The Python value to map. + + Returns: + The VariableType corresponding to the input value. + + Raises: + TypeError: If the type of the input value is not supported. + """ + var_type = type(var) + if isinstance(var_type, str): + return cls.STRING + elif isinstance(var_type, (int, float)): + return cls.NUMBER + elif isinstance(var_type, bool): + return cls.BOOLEAN + elif isinstance(var_type, FileObj): + return cls.FILE + elif isinstance(var_type, dict): + return cls.OBJECT + elif isinstance(var_type, list): + if len(var) == 0: + return cls.ARRAY_STRING + else: + child_type = type(var[0]) + if child_type == str: + return cls.ARRAY_STRING + elif child_type == int or child_type == float: + return cls.ARRAY_NUMBER + elif child_type == bool: + return cls.ARRAY_BOOLEAN + elif child_type == dict: + return cls.ARRAY_OBJECT + elif child_type == list: + return cls.NESTED_ARRAY + else: + raise TypeError(f"Unsupported array child type - {child_type}") + raise TypeError(f"Unsupported type - {var_type}") + + +def DEFAULT_VALUE(var_type: VariableType) -> Any: + """Returns the default value for a given VariableType. + + Args: + var_type: The variable type for which to get the default value. + + Returns: + The default Python value corresponding to the VariableType. + + Raises: + TypeError: If the VariableType is invalid. + """ + match var_type: + case VariableType.STRING: + return "" + case VariableType.NUMBER: + return 0 + case VariableType.BOOLEAN: + return False + case VariableType.OBJECT: + return {} + case VariableType.FILE: + return None + case VariableType.ARRAY_STRING: + return [] + case VariableType.ARRAY_NUMBER: + return [] + case VariableType.ARRAY_BOOLEAN: + return [] + case VariableType.ARRAY_OBJECT: + return [] + case VariableType.ARRAY_FILE: + return [] + case _: + raise TypeError(f"Invalid type - {type}") + + +class FileObj: + pass + + +class BaseVariable(ABC): + """Abstract base class for all workflow variables. + + Subclasses must implement validation and serialization methods. + """ + type = None + + def __init__(self, value: Any): + """Initializes a variable instance. + + Args: + value: The initial value for the variable. + + Attributes: + self.value: The validated value stored in the variable. + self.literal: A string representation of the variable. + """ + self.value = self.valid_value(value) + self.literal = self.to_literal() + + @abstractmethod + def valid_value(self, value) -> Any: + """Validates or converts a value to the correct type for the variable. + + Args: + value: The value to validate. + + Returns: + The validated or converted value. + + Raises: + TypeError: If the value is invalid. + """ + pass + + @abstractmethod + def to_literal(self) -> str: + """Converts the variable value to a string literal representation. + + Returns: + A string representing the variable's value. + """ + pass + + def get_value(self) -> Any: + """Returns the current value of the variable.""" + return self.value + + def set(self, value): + """Sets the variable to a new value after validation. + + Args: + value: The new value to assign to the variable. + """ + self.value = self.valid_value(value) diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py new file mode 100644 index 00000000..3af7d1bb --- /dev/null +++ b/api/app/core/workflow/variable/variable_objects.py @@ -0,0 +1,137 @@ +from typing import Any, TypeVar, Type, Generic + +from app.core.workflow.variable.base_variable import BaseVariable, VariableType + +T = TypeVar("T", bound=BaseVariable) + + +class StringObject(BaseVariable): + type = 'str' + + def valid_value(self, value) -> str: + if not isinstance(value, str): + raise TypeError("Value must be a string") + return value + + def to_literal(self) -> str: + return self.value + + +class NumberObject(BaseVariable): + type = 'number' + + def valid_value(self, value) -> int | float: + if not isinstance(value, (int, float)): + raise TypeError("Value must be a number") + return value + + def to_literal(self) -> str: + return str(self.value) + + +class BooleanObject(BaseVariable): + type = 'boolean' + + def valid_value(self, value) -> bool: + if not isinstance(value, bool): + raise TypeError("Value must be a boolean") + return value + + def to_literal(self) -> str: + return str(self.value).lower() + + +class DictObject(BaseVariable): + type = 'object' + + def valid_value(self, value) -> dict: + if not isinstance(value, dict): + raise TypeError("Value must be a dict") + return value + + def to_literal(self) -> str: + return str(self.value) + + +class FileObject(BaseVariable): + type = 'file' + + def valid_value(self, value) -> Any: + pass + + def to_literal(self) -> str: + pass + + +class ArrayObject(BaseVariable, Generic[T]): + type = 'array' + + def __init__(self, child_type: Type[T], value: list[Any]): + if not issubclass(child_type, BaseVariable): + raise TypeError("child_type must be a subclass of BaseVariable") + self.child_type = child_type + super().__init__(value) + + def valid_value(self, value: list[Any]) -> list[T]: + if not isinstance(value, list): + raise TypeError("Value must be a list") + final_value = [] + for v in value: + try: + final_value.append(self.child_type(v)) + except: + raise TypeError(f"All elements must be of type {self.child_type.type}") + return final_value + + def to_literal(self) -> str: + return "\n".join([v.to_literal() for v in self.value]) + + +class NestedArrayObject(BaseVariable): + type = 'array_nest' + + def valid_value(self, value: list[T]) -> list[T]: + if not isinstance(value, list): + raise TypeError("Value must be a list") + final_value = [] + for v in value: + if not isinstance(v, ArrayObject): + raise TypeError("All elements must be of type list") + final_value.append(v) + return final_value + + def to_literal(self) -> str: + return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value]) + + def get_value(self) -> Any: + return [[item.get_value() for item in row] for row in self.value] + + +def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]: + """简化 ArrayObject 创建,不需要重复写类型""" + + return ArrayObject(child_type, value) + + +def create_variable_instance(var_type: VariableType, value: Any) -> T: + match var_type: + case VariableType.STRING: + return StringObject(value) + case VariableType.NUMBER: + return NumberObject(value) + case VariableType.BOOLEAN: + return BooleanObject(value) + case VariableType.OBJECT: + return DictObject(value) + case VariableType.ARRAY_STRING: + return make_array(StringObject, value) + case VariableType.ARRAY_NUMBER: + return make_array(NumberObject, value) + case VariableType.ARRAY_BOOLEAN: + return make_array(BooleanObject, value) + case VariableType.ARRAY_OBJECT: + return make_array(DictObject, value) + case VariableType.ARRAY_FILE: + return make_array(FileObject, value) + case _: + raise TypeError(f"Invalid type - {var_type}") diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 7d4b0609..32bfc5e1 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -11,10 +11,15 @@ import logging import re -from typing import Any, TYPE_CHECKING +from asyncio import Lock +from collections import defaultdict +from copy import deepcopy +from typing import Any, Generic -if TYPE_CHECKING: - from app.core.workflow.nodes import WorkflowState +from pydantic import BaseModel + +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import T, create_variable_instance logger = logging.getLogger(__name__) @@ -23,11 +28,6 @@ class VariableSelector: """变量选择器 用于引用变量的路径表示。 - - Examples: - >>> selector = VariableSelector(["sys", "message"]) - >>> selector = VariableSelector(["node_A", "output"]) - >>> selector = VariableSelector.from_string("sys.message") """ def __init__(self, path: list[str]): @@ -52,10 +52,6 @@ class VariableSelector: Returns: VariableSelector 实例 - - Examples: - >>> selector = VariableSelector.from_string("sys.message") - >>> selector = VariableSelector.from_string("llm_qa.output") """ path = selector_str.split(".") return cls(path) @@ -67,160 +63,212 @@ class VariableSelector: return f"VariableSelector({self.path})" +class VariableStruct(BaseModel, Generic[T]): + """A typed variable struct. + + Represents a runtime variable with an associated logical type and + a concrete value object. + + This class bridges the static type system (via generics) and the + runtime type system (via ``VariableType``). + + Attributes: + type: + Logical variable type descriptor used for runtime validation, + serialization, and workflow type checking. + instance: + The concrete variable object. The actual Python type is + represented by the generic parameter ``T`` (e.g. StringObject, + NumberObject, ArrayObject[StringObject]). + mut: + Whether the variable is mutable. + """ + type: VariableType + instance: T + mut: bool + + model_config = { + "arbitrary_types_allowed": True + } + + class VariablePool: - """变量池 - - 管理工作流执行过程中的所有变量。 - - 变量命名空间: - - sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id) - - conv.*: 会话变量(跨多轮对话保持的变量) - - .*: 节点输出 - - Examples: - >>> pool = VariablePool(state) - >>> pool.get(["sys", "message"]) - "用户的问题" - >>> pool.get(["llm_qa", "output"]) - "AI 的回答" - >>> pool.set(["conv", "user_name"], "张三") + """Variable pool. + + Manages all variables during workflow execution, including storage, + namespacing, and concurrency control. + + Variable namespace conventions: + - ``sys.*``: + System variables (e.g. message, execution_id, workspace_id, + user_id, conversation_id). + - ``conv.*``: + Conversation-level variables that persist across multiple turns. + - ``.*``: + Variables produced by workflow nodes. """ - def __init__(self, state: "WorkflowState"): - """初始化变量池 - - Args: - state: 工作流状态(LangGraph State) - """ - self.state = state + def __init__(self): + """Initialize the variable pool. + + Attributes: + self.locks: + A per-key lock table used for fine-grained concurrency control. + + self.variables: + Storage for all variables managed by the pool. + """ + self.locks = defaultdict(Lock) + self.variables: dict[str, dict[str, VariableStruct[Any]]] = {} + + @staticmethod + def transform_selector(selector): + pattern = r"\{\{\s*(.*?)\s*\}\}" + variable_literal = re.sub(pattern, r"\1", selector).strip() + selector = VariableSelector.from_string(variable_literal).path + if len(selector) != 2: + raise ValueError(f"Selector not valid - {selector}") + return selector + + def _get_variable_struct( + self, + selector: str + ) -> VariableStruct[T] | None: + """Retrieve a variable struct from the variable pool. - def get(self, selector: list[str] | str, default: Any = None) -> Any: - """获取变量值 - Args: - selector: 变量选择器,可以是列表或字符串 - default: 默认值(变量不存在时返回) - + selector: + Variable selector, either: + - A string variable literal (e.g. "{{ sys.message }}") + Returns: - 变量值 - - Examples: - >>> pool.get(["sys", "message"]) - >>> pool.get("sys.message") - >>> pool.get(["llm_qa", "output"]) - >>> pool.get("llm_qa.output") - - Raises: - KeyError: 变量不存在且未提供默认值 + The variable's struct if it exists; otherwise returns None. """ - # 转换为 VariableSelector - if isinstance(selector, str): - pattern = r"\{\{\s*(.*?)\s*\}\}" - variable_literal = re.sub(pattern, r"\1", selector).strip() - selector = VariableSelector.from_string(variable_literal).path - - if not selector or len(selector) < 1: - raise ValueError("变量选择器不能为空") + selector = self.transform_selector(selector) namespace = selector[0] + variable_name = selector[1] - try: - # 系统变量 - if namespace == "sys": - key = selector[1] if len(selector) > 1 else None - if not key: - return self.state.get("variables", {}).get("sys", {}) - return self.state.get("variables", {}).get("sys", {}).get(key, default) + namespace_variables = self.variables.get(namespace) + if namespace_variables is None: + return None - # 会话变量 - elif namespace == "conv": - key = selector[1] if len(selector) > 1 else None - if not key: - return self.state.get("variables", {}).get("conv", {}) - return self.state.get("variables", {}).get("conv", {}).get(key, default) + var_instance = namespace_variables.get(variable_name) + if var_instance is None: + return None + return var_instance - # 节点输出(从 runtime_vars 读取) - else: - node_id = namespace - runtime_vars = self.state.get("runtime_vars", {}) + def get_value( + self, + selector: str, + default: Any = None, + strict: bool = True, + ) -> Any: + """Retrieve a variable value from the variable pool. - if node_id not in runtime_vars: - if default is not None: - return default - raise KeyError(f"节点 '{node_id}' 的输出不存在") + Args: + selector: + Variable selector, either: + - A list of path components (e.g. ["sys", "message"]) + - A string variable literal (e.g. "{{ sys.message }}") + default: + The value to return if the variable does not exist. + strict: + If True, raises KeyError when the variable does not exist. - node_var = runtime_vars[node_id] + Returns: + The variable's value if it exists; otherwise returns `default`. - # 如果只有节点 ID,返回整个变量 - if len(selector) == 1: - return node_var + Raises: + KeyError: If strict is True and the variable does not exist. + """ + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + if strict: + raise KeyError(f"{selector} not exist") + return default - # 获取特定字段 - # 支持嵌套访问,如 node_id.field.subfield - result = node_var - for k in selector[1:]: - if isinstance(result, dict): - result = result.get(k) - if result is None: - if default is not None: - return default - raise KeyError(f"字段 '{'.'.join(selector)}' 不存在") - else: - if default is not None: - return default - raise KeyError(f"无法访问 '{'.'.join(selector)}'") + return variable_struct.instance.get_value() - return result + def get_literal( + self, + selector: str, + default: Any = None, + strict: bool = True, + ) -> Any: + """Retrieve a variable value from the variable pool. - except KeyError: - if default is not None: - return default - raise + Args: + selector: + Variable selector, either: + - A list of path components (e.g. ["sys", "message"]) + - A string variable literal (e.g. "{{ sys.message }}") + default: + The value to return if the variable does not exist. + strict: + If True, raises KeyError when the variable does not exist. - def set(self, selector: list[str] | str, value: Any): + Returns: + The variable's value if it exists; otherwise returns `default`. + + Raises: + KeyError: If strict is True and the variable does not exist. + """ + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + if strict: + raise KeyError(f"{selector} not exist") + return default + + return variable_struct.instance.to_literal() + + async def set( + self, + selector: str, + value: Any + ): """设置变量值 Args: selector: 变量选择器 value: 变量值 - - Examples: - >>> pool.set(["conv", "user_name"], "张三") - >>> pool.set("conv.user_name", "张三") - + Note: - 只能设置会话变量 (conv.*) - 系统变量和节点输出是只读的 """ - # 转换为 VariableSelector - if isinstance(selector, str): - selector = VariableSelector.from_string(selector).path + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + raise KeyError(f"Variable {selector} is not defined") + if not variable_struct.mut: + raise KeyError(f"{selector} cannot be modified") + async with self.locks[selector]: + variable_struct.instance.set(value) - if not selector or len(selector) < 2: - raise ValueError("变量选择器必须包含命名空间和键名") + async def new( + self, + namespace: str, + key: str, + value: Any, + var_type: VariableType, + mut: bool + ): + if self.has(f"{namespace}.{key}"): + try: + await self.set(f"{namespace}.{key}", value) + except KeyError: + pass + instance = create_variable_instance(var_type, value) + variable_struct = VariableStruct(type=var_type, instance=instance, mut=mut) + namespace_variable = self.variables.get(namespace) + if namespace_variable is None: + self.variables[namespace] = { + key: variable_struct + } + else: + self.variables[namespace][key] = variable_struct - namespace = selector[0] - - if namespace != "conv" and namespace not in self.state["cycle_nodes"]: - raise ValueError("Only conversation or cycle variables can be assigned.") - - key = selector[1] - - # 确保 variables 结构存在 - if "variables" not in self.state: - self.state["variables"] = {"sys": {}, "conv": {}} - if namespace == "conv": - if "conv" not in self.state["variables"]: - self.state["variables"]["conv"] = {} - - # 设置值 - self.state["variables"]["conv"][key] = value - elif namespace in self.state["cycle_nodes"]: - self.state["runtime_vars"][namespace][key] = value - - logger.debug(f"设置变量: {'.'.join(selector)} = {value}") - - def has(self, selector: list[str] | str) -> bool: + def has(self, selector: str) -> bool: """检查变量是否存在 Args: @@ -228,18 +276,8 @@ class VariablePool: Returns: 变量是否存在 - - Examples: - >>> pool.has(["sys", "message"]) - True - >>> pool.has("llm_qa.output") - False """ - try: - self.get(selector) - return True - except KeyError: - return False + return self._get_variable_struct(selector) is not None def get_all_system_vars(self) -> dict[str, Any]: """获取所有系统变量 @@ -247,7 +285,8 @@ class VariablePool: Returns: 系统变量字典 """ - return self.state.get("variables", {}).get("sys", {}) + sys_namespace = self.variables.get("sys", {}) + return {k: v.instance.value for k, v in sys_namespace.items()} def get_all_conversation_vars(self) -> dict[str, Any]: """获取所有会话变量 @@ -255,7 +294,8 @@ class VariablePool: Returns: 会话变量字典 """ - return self.state.get("variables", {}).get("conv", {}) + conv_namespace = self.variables.get("conv", {}) + return {k: v.instance.value for k, v in conv_namespace.items()} def get_all_node_outputs(self) -> dict[str, Any]: """获取所有节点输出(运行时变量) @@ -263,18 +303,37 @@ class VariablePool: Returns: 节点输出字典,键为节点 ID """ - return self.state.get("runtime_vars", {}) + runtime_vars = { + namespace: { + k: v.instance.value + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") + } + return runtime_vars - def get_node_output(self, node_id: str) -> dict[str, Any] | None: + def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: """获取指定节点的输出(运行时变量) Args: node_id: 节点 ID + defalut: 默认值 + strict: 是否严格模式 Returns: 节点输出或 None """ - return self.state.get("runtime_vars", {}).get(node_id) + node_namespace = self.variables.get(node_id) + if node_namespace: + return {k: v.instance.value for k, v in node_namespace.items()} + if strict: + raise KeyError(f"node {node_id} output not exist") + else: + return defalut + + def copy(self, pool: 'VariablePool'): + self.variables = deepcopy(pool.variables) def to_dict(self) -> dict[str, Any]: """导出为字典 diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 1d9ab4a8..fcd4bc79 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -618,6 +618,7 @@ class AppChatService: memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + public=False ) -> AsyncGenerator[dict, None]: """聊天(流式)""" @@ -634,7 +635,8 @@ class AppChatService: payload=payload, config=config, workspace_id=workspace_id, - release_id=release_id + release_id=release_id, + public=public ): yield event diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 8ce5fa37..f19e2d41 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -4,9 +4,8 @@ import datetime import logging import uuid -from typing import Any, Annotated, AsyncGenerator, Optional +from typing import Any, Annotated, Optional -from deprecated import deprecated from fastapi import Depends from sqlalchemy.orm import Session @@ -566,6 +565,41 @@ class WorkflowService: message=f"工作流执行失败: {str(e)}" ) + @staticmethod + def _map_public_event(event: dict) -> dict | None: + event_type = event.get("event") + payload = event.get("data") + match event_type: + case "workflow_start": + return { + "event": "start", + "data": { + "conversation_id": payload.get("conversation_id"), + } + } + case "workflow_end": + return { + "event": "end", + "data": { + "elapsed_time": payload.get("elapsed_time"), + "message_length": len(payload.get("output", "")) + } + } + case "node_start" | "node_end" | "node_error": + return None + case _: + return event + + def _emit(self, public: bool, internal_event: dict): + """ + decide + """ + if public: + mapped = self._map_public_event(internal_event) + else: + mapped = internal_event + return mapped + async def run_stream( self, app_id: uuid.UUID, @@ -663,7 +697,7 @@ class WorkflowService: input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=payload.user_id, ): if event.get("event") == "workflow_end": @@ -694,7 +728,9 @@ class WorkflowService: ) else: logger.error(f"unexpect workflow run status, status: {status}") - yield event + event = self._emit(public, event) + if event: + yield event except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) diff --git a/sandbox/app/controllers/sandbox_controller.py b/sandbox/app/controllers/sandbox_controller.py index c5cce40c..f9bc3fc0 100644 --- a/sandbox/app/controllers/sandbox_controller.py +++ b/sandbox/app/controllers/sandbox_controller.py @@ -33,7 +33,7 @@ async def run_code(request: RunCodeRequest): """Execute code in sandbox""" if request.language == "python3": return await run_python_code(request.code, request.preload, request.options) - elif request.language == "nodejs": + elif request.language == "javascript": return await run_nodejs_code(request.code, request.preload, request.options) else: return error_response(-400, "unsupported language")