From 373b91143d7d49324e17b97e7d56cd2b8f82bf37 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:49:30 +0800 Subject: [PATCH 01/10] feat(workflow): support variable types(TODO) --- api/app/core/workflow/nodes/base_config.py | 80 ++++++++++++++++++---- 1 file changed, 67 insertions(+), 13 deletions(-) diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index 1550584a..a6b33928 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -4,13 +4,16 @@ """ from enum import StrEnum +from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict + +VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}" class VariableType(StrEnum): """变量类型枚举""" - + STRING = "string" NUMBER = "number" BOOLEAN = "boolean" @@ -22,43 +25,94 @@ class VariableType(StrEnum): ARRAY_OBJECT = "array[object]" +class TypedVariable(BaseModel): + """ + TODO: 强类型限制 + Strongly typed variable that validates value on assignment. + """ + + value: Any = Field(..., description="Variable value") + type: VariableType = Field(..., description="Declared type of the variable") + + model_config = ConfigDict( + validate_assignment=True + ) + + def __setattr__(self, name, value): + if name == "value": + self._validate_value(value) + if name == "type": + raise RuntimeError("Cannot modify variable type at runtime") + super().__setattr__(name, value) + + def _validate_value(self, v: Any): + t = self.type + match t: + case VariableType.STRING: + if not isinstance(v, str): + raise TypeError("Variable value does not match type STRING") + case VariableType.BOOLEAN: + if not isinstance(v, bool): + raise TypeError("Variable value does not match type BOOLEAN") + case VariableType.NUMBER: + if not isinstance(v, (int, float)): + raise TypeError("Variable value does not match type NUMBER") + case VariableType.OBJECT: + if not isinstance(v, dict): + raise TypeError("Variable value does not match type OBJECT") + case VariableType.ARRAY_STRING: + if not isinstance(v, list) or not all(isinstance(i, str) for i in v): + raise TypeError("Variable value does not match type ARRAY_STRING") + case VariableType.ARRAY_NUMBER: + if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v): + raise TypeError("Variable value does not match type ARRAY_NUMBER") + case VariableType.ARRAY_BOOLEAN: + if not isinstance(v, list) or not all(isinstance(i, bool) for i in v): + raise TypeError("Variable value does not match type ARRAY_BOOLEAN") + case VariableType.ARRAY_OBJECT: + if not isinstance(v, list) or not all(isinstance(i, dict) for i in v): + raise TypeError("Variable value does not match type ARRAY_OBJECT") + case _: + raise TypeError(f"Unknown variable type: {t}") + + class VariableDefinition(BaseModel): """变量定义 定义工作流或节点的输入/输出变量。 这是一个通用的数据结构,可以在多个地方使用。 """ - + name: str = Field( ..., description="变量名称" ) - + type: VariableType = Field( default=VariableType.STRING, description="变量类型" ) - + required: bool = Field( default=False, description="是否必需" ) - + default: str | int | float | bool | list | dict | None = Field( default=None, description="默认值" ) - + description: str | None = Field( default=None, description="变量描述" ) - + max_length: int = Field( default=200, description="只对字符串类型生效" ) - + class Config: json_schema_extra = { "examples": [ @@ -96,22 +150,22 @@ class BaseNodeConfig(BaseModel): - description: 节点描述 - tags: 节点标签(用于分类和搜索) """ - + name: str | None = Field( default=None, description="节点名称(显示名称),如果不设置则使用节点 ID" ) - + description: str | None = Field( default=None, description="节点描述,说明节点的作用" ) - + tags: list[str] = Field( default_factory=list, description="节点标签,用于分类和搜索" ) - + class Config: """Pydantic 配置""" # 允许额外字段(向后兼容) From 55dac533d962f1d19d0c00772ebc4239f91c5c17 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:50:32 +0800 Subject: [PATCH 02/10] fix(workflow): fix passing of loop variable termination condition --- api/app/core/workflow/nodes/base_node.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index e9fa4f25..8eb31fb4 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -356,7 +356,8 @@ class BaseNode(ABC): **final_output, "runtime_vars": { self.node_id: runtime_var - } + }, + "looping": state["looping"] } # Add streaming buffer for non-End nodes From bf6ede64bd47374c0ae3d10e701e4d223dece819 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:51:57 +0800 Subject: [PATCH 03/10] feat(workflow): add support for passing workspace ID --- api/app/controllers/app_controller.py | 5 +++-- api/app/services/workflow_service.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 8374680b..a0df7d67 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -597,7 +597,8 @@ async def draft_run( async for event in workflow_service.run_stream( app_id=app_id, payload=payload, - config=config + config=config, + workspace_id=current_user.current_workspace_id ): # 提取事件类型和数据 event_type = event.get("event", "message") @@ -627,7 +628,7 @@ async def draft_run( } ) - result = await workflow_service.run(app_id, payload,config) + result = await workflow_service.run(app_id, payload, config, current_user.current_workspace_id) logger.debug( "工作流试运行返回结果", diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 058767d9..917a40f9 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -410,7 +410,8 @@ class WorkflowService: self, app_id: uuid.UUID, payload: DraftRunRequest, - config: WorkflowConfig + config: WorkflowConfig, + workspace_id: uuid.UUID, ): """运行工作流 @@ -484,7 +485,7 @@ class WorkflowService: workflow_config=workflow_config_dict, input_data=input_data, execution_id=execution.execution_id, - workspace_id="", + workspace_id=str(workspace_id), user_id=payload.user_id ) @@ -530,7 +531,8 @@ class WorkflowService: self, app_id: uuid.UUID, payload: DraftRunRequest, - config: WorkflowConfig + config: WorkflowConfig, + workspace_id: uuid.UUID, ): """运行工作流(流式) @@ -603,7 +605,7 @@ class WorkflowService: workflow_config=workflow_config_dict, input_data=input_data, execution_id=execution.execution_id, - workspace_id="", + workspace_id=str(workspace_id), user_id=payload.user_id ): # 直接转发 executor 的事件(已经是正确的格式) From fc831e04c1428846e4396674dff9ee848a284506 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:52:46 +0800 Subject: [PATCH 04/10] feat(workflow): support retrieving variables wrapped in {{}} from variable pool --- api/app/core/workflow/variable_pool.py | 71 ++++++++++++++------------ 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index b7814f28..7d4b0609 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -10,6 +10,7 @@ """ import logging +import re from typing import Any, TYPE_CHECKING if TYPE_CHECKING: @@ -28,7 +29,7 @@ class VariableSelector: >>> selector = VariableSelector(["node_A", "output"]) >>> selector = VariableSelector.from_string("sys.message") """ - + def __init__(self, path: list[str]): """初始化变量选择器 @@ -37,11 +38,11 @@ class VariableSelector: """ if not path or len(path) < 1: raise ValueError("变量路径不能为空") - + self.path = path self.namespace = path[0] # sys, var, 或 node_id self.key = path[1] if len(path) > 1 else None - + @classmethod def from_string(cls, selector_str: str) -> "VariableSelector": """从字符串创建选择器 @@ -58,10 +59,10 @@ class VariableSelector: """ path = selector_str.split(".") return cls(path) - + def __str__(self) -> str: return ".".join(self.path) - + def __repr__(self) -> str: return f"VariableSelector({self.path})" @@ -84,7 +85,7 @@ class VariablePool: "AI 的回答" >>> pool.set(["conv", "user_name"], "张三") """ - + def __init__(self, state: "WorkflowState"): """初始化变量池 @@ -92,7 +93,7 @@ class VariablePool: state: 工作流状态(LangGraph State) """ self.state = state - + def get(self, selector: list[str] | str, default: Any = None) -> Any: """获取变量值 @@ -114,13 +115,15 @@ class VariablePool: """ # 转换为 VariableSelector if isinstance(selector, str): - selector = VariableSelector.from_string(selector).path - + pattern = r"\{\{\s*(.*?)\s*\}\}" + variable_literal = re.sub(pattern, r"\1", selector).strip() + selector = VariableSelector.from_string(variable_literal).path + if not selector or len(selector) < 1: raise ValueError("变量选择器不能为空") - + namespace = selector[0] - + try: # 系统变量 if namespace == "sys": @@ -128,30 +131,30 @@ class VariablePool: if not key: return self.state.get("variables", {}).get("sys", {}) return self.state.get("variables", {}).get("sys", {}).get(key, default) - + # 会话变量 elif namespace == "conv": key = selector[1] if len(selector) > 1 else None if not key: return self.state.get("variables", {}).get("conv", {}) return self.state.get("variables", {}).get("conv", {}).get(key, default) - + # 节点输出(从 runtime_vars 读取) else: node_id = namespace runtime_vars = self.state.get("runtime_vars", {}) - + if node_id not in runtime_vars: if default is not None: return default raise KeyError(f"节点 '{node_id}' 的输出不存在") - + node_var = runtime_vars[node_id] - + # 如果只有节点 ID,返回整个变量 if len(selector) == 1: return node_var - + # 获取特定字段 # 支持嵌套访问,如 node_id.field.subfield result = node_var @@ -166,14 +169,14 @@ class VariablePool: if default is not None: return default raise KeyError(f"无法访问 '{'.'.join(selector)}'") - + return result - + except KeyError: if default is not None: return default raise - + def set(self, selector: list[str] | str, value: Any): """设置变量值 @@ -192,17 +195,17 @@ class VariablePool: # 转换为 VariableSelector if isinstance(selector, str): selector = VariableSelector.from_string(selector).path - + if not selector or len(selector) < 2: raise ValueError("变量选择器必须包含命名空间和键名") - + namespace = selector[0] - + if namespace != "conv" and namespace not in self.state["cycle_nodes"]: raise ValueError("Only conversation or cycle variables can be assigned.") - + key = selector[1] - + # 确保 variables 结构存在 if "variables" not in self.state: self.state["variables"] = {"sys": {}, "conv": {}} @@ -214,9 +217,9 @@ class VariablePool: self.state["variables"]["conv"][key] = value elif namespace in self.state["cycle_nodes"]: self.state["runtime_vars"][namespace][key] = value - + logger.debug(f"设置变量: {'.'.join(selector)} = {value}") - + def has(self, selector: list[str] | str) -> bool: """检查变量是否存在 @@ -237,7 +240,7 @@ class VariablePool: return True except KeyError: return False - + def get_all_system_vars(self) -> dict[str, Any]: """获取所有系统变量 @@ -245,7 +248,7 @@ class VariablePool: 系统变量字典 """ return self.state.get("variables", {}).get("sys", {}) - + def get_all_conversation_vars(self) -> dict[str, Any]: """获取所有会话变量 @@ -253,7 +256,7 @@ class VariablePool: 会话变量字典 """ return self.state.get("variables", {}).get("conv", {}) - + def get_all_node_outputs(self) -> dict[str, Any]: """获取所有节点输出(运行时变量) @@ -261,7 +264,7 @@ class VariablePool: 节点输出字典,键为节点 ID """ return self.state.get("runtime_vars", {}) - + def get_node_output(self, node_id: str) -> dict[str, Any] | None: """获取指定节点的输出(运行时变量) @@ -272,7 +275,7 @@ class VariablePool: 节点输出或 None """ return self.state.get("runtime_vars", {}).get(node_id) - + def to_dict(self) -> dict[str, Any]: """导出为字典 @@ -284,12 +287,12 @@ class VariablePool: "conversation": self.get_all_conversation_vars(), "nodes": self.get_all_node_outputs() # 从 runtime_vars 读取 } - + def __repr__(self) -> str: sys_vars = self.get_all_system_vars() conv_vars = self.get_all_conversation_vars() runtime_vars = self.get_all_node_outputs() - + return ( f"VariablePool(\n" f" system_vars={len(sys_vars)},\n" From eaf24376338fc68cd237ef4a82c3013621e7cd9e Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:53:53 +0800 Subject: [PATCH 05/10] feat(prompt_opt): support streaming output for prompt optimization API --- .../prompt_optimizer_controller.py | 56 ++++++++---------- api/app/core/workflow/template_loader.py | 4 +- api/app/services/prompt_optimizer_service.py | 57 ++++++++++++++++--- .../prompt/prompt_optimizer_system.jinja2 | 2 +- 4 files changed, 75 insertions(+), 44 deletions(-) diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index d73ea0df..2069dd66 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -1,7 +1,9 @@ import uuid +import json from fastapi import APIRouter, Depends, Path from sqlalchemy.orm import Session +from starlette.responses import StreamingResponse from app.core.logging_config import get_api_logger from app.core.response_utils import success @@ -70,12 +72,12 @@ def get_prompt_session( SessionMessage(role=role, content=content) for role, content in history ] - + result = SessionHistoryResponse( session_id=session_id, messages=messages ) - + return success(data=result) @@ -104,35 +106,25 @@ async def get_prompt_opt( ApiResponse: Contains the optimized prompt, description, and a list of variables. """ service = PromptOptimizerService(db) - service.create_message( - tenant_id=current_user.tenant_id, - session_id=session_id, - user_id=current_user.id, - role=RoleType.USER, - content=data.message - ) - opt_result = await service.optimize_prompt( - tenant_id=current_user.tenant_id, - model_id=data.model_id, - session_id=session_id, - user_id=current_user.id, - current_prompt=data.current_prompt, - user_require=data.message - ) - service.create_message( - tenant_id=current_user.tenant_id, - session_id=session_id, - user_id=current_user.id, - role=RoleType.ASSISTANT, - content=opt_result.desc - ) - variables = service.parser_prompt_variables(opt_result.prompt) - result = { - "prompt": opt_result.prompt, - "desc": opt_result.desc, - "variables": variables - } - result_schema = OptimizePromptResponse.model_validate(result) - return success(data=result_schema) + async def event_generator(): + async for chunk in service.optimize_prompt( + tenant_id=current_user.tenant_id, + model_id=data.model_id, + session_id=session_id, + user_id=current_user.id, + current_prompt=data.current_prompt, + user_require=data.message + ): + # chunk 是 prompt 的增量内容 + yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n" + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) diff --git a/api/app/core/workflow/template_loader.py b/api/app/core/workflow/template_loader.py index ab5bd9fa..4ef49ba5 100644 --- a/api/app/core/workflow/template_loader.py +++ b/api/app/core/workflow/template_loader.py @@ -4,11 +4,11 @@ 从文件系统加载预定义的工作流模板 """ -import os -import yaml from pathlib import Path from typing import Optional +import yaml + class TemplateLoader: """工作流模板加载器""" diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 5f325a1b..482e8213 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -1,5 +1,6 @@ import re import uuid +from typing import Any, AsyncGenerator import json_repair from langchain_core.prompts import ChatPromptTemplate @@ -123,7 +124,7 @@ class PromptOptimizerService: user_id: uuid.UUID, current_prompt: str, user_require: str - ) -> OptimizePromptResult: + ) -> AsyncGenerator[dict[str, str | Any], Any]: """ Optimize a user-provided prompt using a configured prompt optimizer LLM. @@ -161,6 +162,7 @@ class PromptOptimizerService: BusinessException: If the LLM response cannot be parsed as valid JSON or does not conform to the expected output format. """ + self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require) model_config = self.get_model_config(tenant_id, model_id) session_history = self.get_session_message_history(session_id=session_id, user_id=user_id) @@ -202,17 +204,54 @@ class PromptOptimizerService: messages.extend(session_history[:-1]) # last message is current message messages.extend([(RoleType.USER.value, rendered_user_message)]) logger.info(f"Prompt optimization message: {messages}") - optim_resp = await llm.ainvoke(messages) - logger.info(optim_resp.content) - optim_result = json_repair.repair_json(optim_resp.content, return_objects=True) - prompt = optim_result.get("prompt") - desc = optim_result.get("desc") + buffer = "" + prompt_started = False + prompt_finished = False + idx = 0 - return OptimizePromptResult( - prompt=prompt, - desc=desc + async for chunk in llm.astream(messages): + content = getattr(chunk, "content", chunk) + if not content: + continue + buffer += content + cache = buffer[:-20] + + # 尝试找到 "prompt": " 开始位置 + if prompt_finished: + continue + + if not prompt_started: + m = re.search(r'"prompt"\s*:\s*"', cache) + if m: + prompt_started = True + prompt_index = m.end() + idx = prompt_index + else: + m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer) + if m: + prompt_index = m.start() + prompt_finished = True + yield {"type": "delta", "content": buffer[idx:prompt_index]} + else: + yield {"type": "delta", "content": cache[idx:]} + if len(cache) != 0: + idx = len(cache) + + # optim_resp = await llm.astream(messages) + logger.info(buffer) + optim_result = json_repair.repair_json(buffer, return_objects=True) + # prompt = optim_result.get("prompt") + desc = optim_result.get("desc") + self.create_message( + tenant_id=tenant_id, + session_id=session_id, + user_id=user_id, + role=RoleType.ASSISTANT, + content=desc ) + yield {"type": "done", "desc": optim_result.get("desc")} + @staticmethod def parser_prompt_variables(prompt: str): try: diff --git a/api/app/templates/prompt/prompt_optimizer_system.jinja2 b/api/app/templates/prompt/prompt_optimizer_system.jinja2 index ae19a6ab..b9060f68 100644 --- a/api/app/templates/prompt/prompt_optimizer_system.jinja2 +++ b/api/app/templates/prompt/prompt_optimizer_system.jinja2 @@ -25,7 +25,7 @@ Rules Basic Principles Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements. Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements. -Structure Rule: Use a clear block structure including [Role], [Task], [Requirements], [Input], [Output], [Constraints] labels. +Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints Language Rule: All label languages must fully match the user input language. Behavior Guidelines From b56994b9994dab22c3fe2d4e2b9cf6d873d11dbc Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:57:44 +0800 Subject: [PATCH 06/10] feat(workflow): update workflow conditional logic --- .../core/workflow/nodes/cycle_graph/loop.py | 142 +++++--- api/app/core/workflow/nodes/end/node.py | 2 +- api/app/core/workflow/nodes/enums.py | 8 +- api/app/core/workflow/nodes/if_else/config.py | 14 +- api/app/core/workflow/nodes/if_else/node.py | 92 +++-- .../core/workflow/nodes/jinja_render/node.py | 1 + api/app/core/workflow/nodes/operators.py | 329 +++++++++++++----- .../nodes/question_classifier/node.py | 32 +- .../nodes/variable_aggregator/__init__.py | 2 +- 9 files changed, 436 insertions(+), 186 deletions(-) diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index af75d372..2e2ab4fb 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -3,10 +3,11 @@ from typing import Any from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression +from app.core.workflow.expression_evaluator import evaluate_expression from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.cycle_graph import LoopNodeConfig -from app.core.workflow.nodes.operators import ConditionExpressionBuilder +from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator +from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,11 +15,13 @@ logger = logging.getLogger(__name__) class LoopRuntime: """ - Runtime executor for loop nodes in a workflow. + Runtime executor for a loop node in a workflow graph. - Handles iterative execution of a loop node according to defined loop variables - and conditional expressions. Supports maximum loop count and loop control - through the workflow state. + This class is responsible for executing a loop node at runtime: + - Initializing loop-scoped variables + - Evaluating loop continuation conditions + - Repeatedly invoking a compiled sub-graph + - Enforcing maximum loop count and external stop signals """ def __init__( @@ -29,13 +32,13 @@ class LoopRuntime: state: WorkflowState, ): """ - Initialize the loop runtime. + Initialize the loop runtime executor. Args: - graph: Compiled workflow graph capable of async invocation. - node_id: Unique identifier of the loop node. - config: Dictionary containing loop node configuration. - state: Current workflow state at the point of loop execution. + graph: A compiled LangGraph state graph representing the loop body. + node_id: The unique identifier of the loop node in the workflow. + config: Raw configuration dictionary for the loop node. + state: The current workflow state before entering the loop. """ self.graph = graph self.state = state @@ -46,12 +49,15 @@ class LoopRuntime: """ Initialize workflow state for loop execution. - - Evaluates initial values of loop variables. - - Stores loop variables in runtime_vars and node_outputs. - - Marks the loop as active by setting 'looping' to True. + This method: + - Evaluates initial values of loop variables + - Stores loop variables into both `runtime_vars` and `node_outputs` + under the current loop node's scope + - Creates a shallow copy of the workflow state + - Marks the loop as active by setting `looping = True` Returns: - A copy of the workflow state prepared for the loop execution. + WorkflowState: A prepared workflow state used for loop execution. """ pool = VariablePool(self.state) # 循环变量 @@ -61,7 +67,7 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) + ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } self.state["node_outputs"][self.node_id] = { @@ -70,7 +76,7 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) + ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } loopstate = WorkflowState( @@ -79,49 +85,93 @@ class LoopRuntime: loopstate["looping"] = True return loopstate - def _get_loop_expression(self): + @staticmethod + def _evaluate(operator, instance: CompareOperatorInstance) -> Any: """ - Build the Python boolean expression for evaluating the loop condition. + Dispatch and execute a comparison operator against a resolved + CompareOperatorInstance. - - Converts each condition in the loop configuration into a Python expression string. - - Combines multiple conditions with the configured logical operator (AND/OR). + Args: + operator: A ComparisonOperator enum value. + instance: A CompareOperatorInstance bound to concrete operands. Returns: - A string representing the combined loop condition expression. + Any: The evaluation result, typically a boolean. """ - branch_conditions = [ - ConditionExpressionBuilder( - left=condition.left, - operator=condition.comparison_operator, - right=condition.right - ).build() - for condition in self.typed_config.condition.expressions - ] - if len(branch_conditions) > 1: - combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions) - else: - combined_condition = branch_conditions[0] + match operator: + case ComparisonOperator.EMPTY: + return instance.empty() + case ComparisonOperator.NOT_EMPTY: + return instance.not_empty() + case ComparisonOperator.CONTAINS: + return instance.contains() + case ComparisonOperator.NOT_CONTAINS: + return instance.not_contains() + case ComparisonOperator.START_WITH: + return instance.startswith() + case ComparisonOperator.END_WITH: + return instance.endswith() + case ComparisonOperator.EQ: + return instance.eq() + case ComparisonOperator.NE: + return instance.ne() + case ComparisonOperator.LT: + return instance.lt() + case ComparisonOperator.LE: + return instance.le() + case ComparisonOperator.GT: + return instance.gt() + case ComparisonOperator.GE: + return instance.ge() + case _: + raise ValueError(f"Invalid condition: {operator}") - return combined_condition + def evaluate_conditional(self, state) -> bool: + """ + Evaluate the loop continuation condition at runtime. + + This method: + - Resolves all condition expressions against the current workflow state + - Evaluates each comparison expression immediately + - Combines results using the configured logical operator (AND / OR) + + Args: + state: The current workflow state during loop execution. + + Returns: + bool: True if the loop should continue, False otherwise. + """ + conditions = [] + + for expression in self.typed_config.condition.expressions: + left_value = VariablePool(state).get(expression.left) + evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( + VariablePool(state), + expression.left, + expression.right, + expression.input_type + ) + conditions.append(self._evaluate(expression.operator, evaluator)) + if self.typed_config.condition.logical_operator == LogicOperator.AND: + return all(conditions) + else: + return any(conditions) async def run(self): """ - Execute the loop node until the condition is no longer met, the loop is - manually stopped, or the maximum loop count is reached. + Execute the loop node until termination conditions are met. + + The loop terminates when any of the following occurs: + - The loop condition evaluates to False + - The `looping` flag in the workflow state is set to False + - The maximum loop count is reached Returns: - The final runtime variables of this loop node after completion. + dict[str, Any]: The final runtime variables of this loop node. """ loopstate = self._init_loop_state() - expression = self._get_loop_expression() - loop_variable_pool = VariablePool(loopstate) loop_time = self.typed_config.max_loop - while evaluate_condition( - expression=expression, - variables=loop_variable_pool.get_all_conversation_vars(), - node_outputs=loop_variable_pool.get_all_node_outputs(), - system_vars=loop_variable_pool.get_all_system_vars(), - ) and loopstate["looping"] and loop_time > 0: + while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0: logger.info(f"loop node {self.node_id}: running") await self.graph.ainvoke(loopstate) loop_time -= 1 diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 65bb6cb5..efc62dc5 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -61,7 +61,7 @@ class EndNode(BaseNode): 引用的节点 ID 列表 """ # 匹配 {{node_id.xxx}} 格式 - pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}' + pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' matches = re.findall(pattern, template) return list(set(matches)) # 去重 diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 0492a7bf..b1c9d687 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -45,7 +45,8 @@ class LogicOperator(StrEnum): class AssignmentOperator(StrEnum): - ASSIGN = "assign" + COVER = "cover" # 覆盖 + ASSIGN = "assign" # 设置 CLEAR = "clear" ADD = "add" # += @@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum): NONE = "none" DEFAULT = "default" BRANCH = "branch" + + +class ValueInputType(StrEnum): + VARIABLE = "Variable" + CONSTANT = "Constant" diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 9eddb473..4dcb00d1 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -1,12 +1,13 @@ """Condition Configuration""" +from typing import Any from pydantic import Field, BaseModel, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig -from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType class ConditionDetail(BaseModel): - comparison_operator: ComparisonOperator = Field( + operator: ComparisonOperator = Field( ..., description="Comparison operator used to evaluate the condition" ) @@ -16,17 +17,22 @@ class ConditionDetail(BaseModel): description="Value to compare against" ) - right: str = Field( + right: Any = Field( ..., description="Value to compare with" ) + input_type: ValueInputType = Field( + ..., + description="Value input type for comparison" + ) + class ConditionBranchConfig(BaseModel): """Configuration for a conditional branch""" logical_operator: LogicOperator = Field( - default=LogicOperator.AND.value, + default=LogicOperator.AND, description="Logical operator used to combine multiple condition expressions" ) diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 1450a28f..fd5864a8 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -1,10 +1,11 @@ import logging +import re from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig -from app.core.workflow.nodes.if_else.config import ConditionDetail -from app.core.workflow.nodes.operators import ConditionExpressionBuilder +from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance logger = logging.getLogger(__name__) @@ -15,30 +16,36 @@ class IfElseNode(BaseNode): self.typed_config = IfElseNodeConfig(**self.config) @staticmethod - def _build_condition_expression( - condition: ConditionDetail, - ) -> str: - """ - Build a single boolean condition expression string. + def _evaluate(operator, instance: CompareOperatorInstance) -> Any: + match operator: + case ComparisonOperator.EMPTY: + return instance.empty() + case ComparisonOperator.NOT_EMPTY: + return instance.not_empty() + case ComparisonOperator.CONTAINS: + return instance.contains() + case ComparisonOperator.NOT_CONTAINS: + return instance.not_contains() + case ComparisonOperator.START_WITH: + return instance.startswith() + case ComparisonOperator.END_WITH: + return instance.endswith() + case ComparisonOperator.EQ: + return instance.eq() + case ComparisonOperator.NE: + return instance.ne() + case ComparisonOperator.LT: + return instance.lt() + case ComparisonOperator.LE: + return instance.le() + case ComparisonOperator.GT: + return instance.gt() + case ComparisonOperator.GE: + return instance.ge() + case _: + raise ValueError(f"Invalid condition: {operator}") - This method does NOT evaluate the condition. - It only generates a valid Python boolean expression string - (e.g. "x > 10", "'a' in name") that can later be used - in a conditional edge or evaluated by the workflow engine. - - Args: - condition (ConditionDetail): Definition of a single comparison condition. - - Returns: - str: A Python boolean expression string. - """ - return ConditionExpressionBuilder( - left=condition.left, - operator=condition.comparison_operator, - right=condition.right - ).build() - - def build_conditional_edge_expressions(self) -> list[str]: + def evaluate_conditional_edge_expressions(self, state) -> list[bool]: """ Build conditional edge expressions for the If-Else node. @@ -60,19 +67,28 @@ class IfElseNode(BaseNode): for case_branch in self.typed_config.cases: branch_index += 1 - - branch_conditions = [ - self._build_condition_expression(condition) - for condition in case_branch.expressions - ] - if len(branch_conditions) > 1: - combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) + branch_result = [] + for expression in case_branch.expressions: + pattern = r"\{\{\s*(.*?)\s*\}\}" + left_string = re.sub(pattern, r"\1", expression.left).strip() + left_value = self.get_variable(left_string, state) + evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( + self.get_variable_pool(state), + expression.left, + expression.right, + expression.input_type + ) + branch_result.append(self._evaluate(expression.operator, evaluator)) + if case_branch.logical_operator == LogicOperator.AND: + conditions.append(all(branch_result)) else: - combined_condition = branch_conditions[0] - conditions.append(combined_condition) + condition_res = any(branch_result) + conditions.append(condition_res) + if condition_res: + return conditions # Default fallback branch - conditions.append("True") + conditions.append(True) return conditions @@ -90,10 +106,10 @@ class IfElseNode(BaseNode): Returns: str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. """ - expressions = self.build_conditional_edge_expressions() + expressions = self.evaluate_conditional_edge_expressions(state) + # TODO: 变量类型及文本类型解析 for i in range(len(expressions)): - logger.info(expressions[i]) - if self._evaluate_condition(expressions[i], state): + if expressions[i]: logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}") return f'CASE{i + 1}' return f'CASE{len(expressions)}' diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index 6130c30a..e18a2001 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer logger = logging.getLogger(__name__) + class JinjaRenderNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index 70668b6a..fc856aee 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -1,10 +1,73 @@ +import json +import re from abc import ABC -from typing import Union, Type +from typing import Union, Type, NoReturn -from app.core.workflow.nodes.enums import ComparisonOperator +from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.nodes.enums import ValueInputType from app.core.workflow.variable_pool import VariablePool +class TypeTransformer: + @classmethod + def _fail(cls, value, target) -> NoReturn: + raise TypeError(f"Cannot convert {value!r} to {target} type") + + @classmethod + def _json_load(cls, value, target): + try: + return json.loads(value) + except Exception: + cls._fail(value, target) + + @classmethod + def transform(cls, variable_literal: str | bool, target_type: VariableType): + match target_type: + case VariableType.STRING: + return str(variable_literal) + + case VariableType.NUMBER: + for caster in (int, float): + try: + return caster(variable_literal) + except Exception: + pass + cls._fail(variable_literal, target_type) + + case VariableType.BOOLEAN: + if isinstance(variable_literal, bool): + return variable_literal + cls._fail(variable_literal, target_type) + + case VariableType.OBJECT: + obj = cls._json_load(variable_literal, target_type) + if isinstance(obj, dict): + return obj + cls._fail(variable_literal, target_type) + + case VariableType.ARRAY_BOOLEAN: + return cls._parse_list(variable_literal, bool, target_type) + + case VariableType.ARRAY_NUMBER: + return cls._parse_list(variable_literal, (int, float), target_type) + + case VariableType.ARRAY_STRING: + return cls._parse_list(variable_literal, str, target_type) + + case VariableType.ARRAY_OBJECT: + return cls._parse_list(variable_literal, dict, target_type) + + case _: + raise TypeError("Invalid type") + + @classmethod + def _parse_list(cls, value, item_type, target): + arr = cls._json_load(value, target) + if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr): + return arr + cls._fail(value, target) + + class OperatorBase(ABC): def __init__(self, pool: VariablePool, left_selector, right): self.pool = pool @@ -19,7 +82,9 @@ class OperatorBase(ABC): raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") if not no_right and not isinstance(self.right, self.type_limit): - raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type") + raise TypeError( + f"The value assigned must be of {self.type_limit} type" + ) class StringOperator(OperatorBase): @@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase): class ObjectOperator(OperatorBase): def __init__(self, pool: VariablePool, left_selector, right): super().__init__(pool, left_selector, right) - self.type_limit = object + self.type_limit = dict def assign(self) -> None: self.check() @@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase): class AssignmentOperatorResolver: + OPERATOR_MAP = { + str: StringOperator, + bool: BooleanOperator, + int: NumberOperator, + float: NumberOperator, + list: ArrayOperator, + dict: ObjectOperator, + } + @classmethod def resolve_by_value(cls, value): - if isinstance(value, str): - return StringOperator - elif isinstance(value, bool): - return BooleanOperator - elif isinstance(value, (int, float)): - return NumberOperator - elif isinstance(value, list): - return ArrayOperator - elif isinstance(value, dict): - return ObjectOperator - else: - raise TypeError(f"Unsupported variable type: {type(value)}") + for t, op in cls.OPERATOR_MAP.items(): + if isinstance(value, t): + return op + raise TypeError(f"Unsupported variable type: {type(value)}") AssignmentOperatorInstance = Union[ @@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[ AssignmentOperatorType = Type[AssignmentOperatorInstance] -class ConditionExpressionBuilder: - """ - Build a Python boolean expression string based on a comparison operator. +class ConditionBase(ABC): + type_limit: type[str, int, dict, list] = None - This class does not evaluate the expression. - It only generates a valid Python expression string - that can be evaluated later in a workflow context. - """ + def __init__( + self, + pool: VariablePool, + left_selector, + right_selector: str, + input_type: ValueInputType + ): + self.pool = pool + self.left_selector = left_selector + self.right_selector = right_selector + self.input_type = input_type - def __init__(self, left: str, operator: ComparisonOperator, right: str): - self.left = left - self.operator = operator - self.right = right + self.left_value = self.pool.get(self.left_selector) + self.right_value = self.resolve_right_literal_value() - def _empty(self): - return f"{self.left} == ''" + self.type_limit = getattr(self, "type_limit", None) - def _not_empty(self): - return f"{self.left} != ''" + def resolve_right_literal_value(self): + if self.input_type == ValueInputType.VARIABLE: + pattern = r"\{\{\s*(.*?)\s*\}\}" + right_expression = re.sub(pattern, r"\1", self.right_selector).strip() + return self.pool.get(right_expression) + elif self.input_type == ValueInputType.CONSTANT: + return self.right_selector + raise RuntimeError("Unsupported variable type") - def _contains(self): - return f"{self.right} in {self.left}" + def check(self, no_right=False): + left = self.pool.get(self.left_selector.variable_selector) + if not isinstance(left, self.type_limit): + raise TypeError(f"The variable to be compared on must be of {self.type_limit} type") + if not no_right: + right = self.resolve_right_literal_value() + if not isinstance(right, self.type_limit): + raise TypeError( + f"The compared variable must be of {self.type_limit} type" + ) - def _not_contains(self): - return f"{self.right} not in {self.left}" - def _startswith(self): - return f'{self.left}.startswith({self.right})' +class StringComparisonOperator(ConditionBase): + type_limit = str - def _endswith(self): - return f'{self.left}.endswith({self.right})' + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) - def _eq(self): - return f"{self.left} == {self.right}" + def empty(self): + self.check(no_right=True) + return self.left_value == "" - def _ne(self): - return f"{self.left} != {self.right}" + def not_empty(self): + return not self.empty() - def _lt(self): - return f"{self.left} < {self.right}" + def contains(self): + self.check() + return self.right_value in self.left_value - def _le(self): - return f"{self.left} <= {self.right}" + def not_contains(self): + return self.right_value not in self.left_value - def _gt(self): - return f"{self.left} > {self.right}" + def startswith(self): + self.check() + return self.left_value.startswith(self.right_value) - def _ge(self): - return f"{self.left} >= {self.right}" + def endswith(self): + return self.left_value.endswith(self.right_value) - def build(self): - match self.operator: - case ComparisonOperator.EMPTY: - return self._empty() - case ComparisonOperator.NOT_EMPTY: - return self._not_empty() - case ComparisonOperator.CONTAINS: - return self._contains() - case ComparisonOperator.NOT_CONTAINS: - return self._not_contains() - case ComparisonOperator.START_WITH: - return self._startswith() - case ComparisonOperator.END_WITH: - return self._endswith() - case ComparisonOperator.EQ: - return self._eq() - case ComparisonOperator.NE: - return self._ne() - case ComparisonOperator.LT: - return self._lt() - case ComparisonOperator.LE: - return self._le() - case ComparisonOperator.GT: - return self._gt() - case ComparisonOperator.GE: - return self._ge() - case _: - raise ValueError(f"Invalid condition: {self.operator}") + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + +class NumberComparisonOperator(ConditionBase): + type_limit = (int, float) + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def empty(self): + return self.left_value == 0 + + def not_empty(self): + return self.left_value != 0 + + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + def lt(self): + return self.left_value < self.right_value + + def le(self): + return self.left_value <= self.right_value + + def gt(self): + return self.left_value > self.right_value + + def ge(self): + return self.left_value >= self.right_value + + +class BooleanComparisonOperator(ConditionBase): + type_limit = bool + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + +class ObjectComparisonOperator(ConditionBase): + type_limit = dict + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + def empty(self): + return not self.left_value + + def not_empty(self): + return bool(self.left_value) + + +class ArrayComparisonOperator(ConditionBase): + type_limit = list + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def empty(self): + return not self.left_value + + def not_empty(self): + return bool(self.left_value) + + def contains(self): + return self.right_value in self.left_value + + def not_contains(self): + return self.right_value not in self.left_value + + +CompareOperatorInstance = Union[ + StringComparisonOperator, + NumberComparisonOperator, + BooleanComparisonOperator, + ArrayComparisonOperator, + ObjectComparisonOperator +] +CompareOperatorType = Type[CompareOperatorInstance] + + +class ConditionExpressionResolver: + CONDITION_OPERATOR_MAP = { + str: StringComparisonOperator, + bool: BooleanComparisonOperator, + int: NumberComparisonOperator, + float: NumberComparisonOperator, + list: ArrayComparisonOperator, + dict: ObjectComparisonOperator, + } + + @classmethod + def resolve_by_value(cls, value) -> CompareOperatorType: + for t, op in cls.CONDITION_OPERATOR_MAP.items(): + if isinstance(value, t): + return op + raise TypeError(f"Unsupported variable type: {type(value)}") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index bd3c8752..7e6a40b2 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -15,29 +15,29 @@ logger = logging.getLogger(__name__) class QuestionClassifierNode(BaseNode): """问题分类器节点""" - + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config = QuestionClassifierNodeConfig(**self.config) - + def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" with get_db_read() as db: config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id) - + if not config: raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) - + if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - + api_config = config.api_keys[0] model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key base_url = api_config.api_base model_type = config.type - + return RedBearLLM( RedBearModelConfig( model_name=model_name, @@ -47,7 +47,7 @@ class QuestionClassifierNode(BaseNode): ), type=ModelType(model_type) ) - + async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行问题分类""" question = self.typed_config.input_variable @@ -55,15 +55,15 @@ class QuestionClassifierNode(BaseNode): supplement_prompt = "" if self.typed_config.user_supplement_prompt is not None: supplement_prompt = self.typed_config.user_supplement_prompt - + category_names = [class_item.class_name for class_item in self.typed_config.categories] - + if not question: logger.warning(f"节点 {self.node_id} 未获取到输入问题") return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"} - + llm = self._get_llm_instance() - + # 渲染用户提示词模板,支持工作流变量 user_prompt = self._render_template( self.typed_config.user_prompt.format( @@ -73,15 +73,15 @@ class QuestionClassifierNode(BaseNode): ), state ) - + messages = [ ("system", self.typed_config.system_prompt), ("user", user_prompt), ] - + response = await llm.ainvoke(messages) result = response.content.strip() - + if result in category_names: category = result else: @@ -90,5 +90,5 @@ class QuestionClassifierNode(BaseNode): log_supplement = supplement_prompt if supplement_prompt else "无" logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - - return {self.typed_config.output_variable: category} \ No newline at end of file + + return {self.typed_config.output_variable: category} diff --git a/api/app/core/workflow/nodes/variable_aggregator/__init__.py b/api/app/core/workflow/nodes/variable_aggregator/__init__.py index 7bc9afa7..d7eda8f5 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/__init__.py +++ b/api/app/core/workflow/nodes/variable_aggregator/__init__.py @@ -1,4 +1,4 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode -__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"] \ No newline at end of file +__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"] From 1f6835a8e0edaeab1df753ad58b896168857cb01 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 11:00:50 +0800 Subject: [PATCH 07/10] feat(workflow): enable front-end to cover pre-rendered non-variable values --- .../core/workflow/nodes/assigner/config.py | 4 +++- api/app/core/workflow/nodes/assigner/node.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/api/app/core/workflow/nodes/assigner/config.py b/api/app/core/workflow/nodes/assigner/config.py index d9721e99..092f0b51 100644 --- a/api/app/core/workflow/nodes/assigner/config.py +++ b/api/app/core/workflow/nodes/assigner/config.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import Field, BaseModel from app.core.workflow.nodes.base_config import BaseNodeConfig @@ -19,7 +21,7 @@ class AssignmentItem(BaseModel): description="Assignment operator", ) - value: str | list[str] = Field( + value: Any = Field( ..., description="Value(s) to assign to the variable(s)", ) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index a637a8c1..008002ed 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -2,7 +2,6 @@ import logging import re from typing import Any -from app.core.workflow.expression_evaluator import ExpressionEvaluator from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import AssignmentOperator @@ -29,6 +28,7 @@ class AssignerNode(BaseNode): None or the result of the assignment operation. """ # Initialize a variable pool for accessing conversation, node, and system variables + logger.info(f"节点 {self.node_id} 开始执行") pool = VariablePool(state) for assignment in self.typed_config.assignments: # Get the target variable selector (e.g., "conv.test") @@ -45,14 +45,13 @@ class AssignerNode(BaseNode): # Get the value or expression to assign value = assignment.value - if isinstance(value, list): - value = '.'.join(value) - value = ExpressionEvaluator.evaluate( - expression=value, - variables=pool.get_all_conversation_vars(), - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), - ) + pattern = r"\{\{\s*(.*?)\s*\}\}" + if isinstance(value, str): + expression = re.match(pattern, value) + if expression: + expression = expression.group(1) + expression = re.sub(pattern, r"\1", expression).strip() + value = self.get_variable(expression, state) # Select the appropriate assignment operator instance based on the target variable type operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value( @@ -63,6 +62,8 @@ class AssignerNode(BaseNode): # Execute the configured assignment operation match assignment.operation: + case AssignmentOperator.COVER: + operator.assign() case AssignmentOperator.ASSIGN: operator.assign() case AssignmentOperator.CLEAR: From 5957eb9c1afda3b6f4dcdfb0e728cf7ea9102f9b Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 11:02:17 +0800 Subject: [PATCH 08/10] fix(workflow): ensure default values are properly retrieved in HTTP nodes --- .../core/workflow/nodes/http_request/config.py | 15 ++++++++++++--- .../core/workflow/nodes/http_request/node.py | 18 ------------------ 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index 406d4d0e..6bb7baaf 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -63,7 +63,7 @@ class HttpContentTypeConfig(BaseModel): ) data: list[HttpFormData] | dict | str = Field( - ..., + default="", description="Data of the HTTP request body; type depends on content_type", ) @@ -98,6 +98,10 @@ class HttpTimeOutConfig(BaseModel): class HttpRetryConfig(BaseModel): + enable: bool = Field( + ..., + description="Enable/disable retry logic", + ) max_attempts: int = Field( default=1, description="Maximum number of retry attempts for failed requests", @@ -124,6 +128,11 @@ class HttpErrorDefaultTamplete(BaseModel): description="Default HTTP headers returned on error", ) + output: str = Field( + default="SUCCESS", + description="HTTP response body", + ) + class HttpErrorHandleConfig(BaseModel): method: HttpErrorHandle = Field( @@ -131,8 +140,8 @@ class HttpErrorHandleConfig(BaseModel): description="Error handling strategy: 'none', 'default', or 'branch'", ) - default: HttpErrorDefaultTamplete = Field( - ..., + default: HttpErrorDefaultTamplete | None = Field( + default=None, description="Default response template for error handling", ) diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index d23ccd03..55919998 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -165,24 +165,6 @@ class HttpRequestNode(BaseNode): case _: raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}") - def build_conditional_edge_expressions(self): - """ - Build conditional edge expressions for workflow branching. - - When the HTTP error handling strategy is set to `BRANCH`, - this node exposes a single conditional output labeled "ERROR". - The workflow engine uses this output to create an explicit - error-handling branch for downstream nodes. - - Returns: - list[str]: - - ["ERROR"] if error handling strategy is BRANCH - - An empty list if no conditional branching is required - """ - if self.typed_config.error_handle.method == HttpErrorHandle.BRANCH: - return ["ERROR"] - return [] - async def execute(self, state: WorkflowState) -> dict | str: """ Execute the HTTP request node. From 4685fd14adcac42d8cf990cfc779feb6353b66b0 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 11:06:21 +0800 Subject: [PATCH 09/10] refactor(workflow): refactor graph construction to support subgraph building --- api/app/core/workflow/executor.py | 162 +---------- api/app/core/workflow/graph_builder.py | 253 ++++++++++++++++++ .../core/workflow/nodes/cycle_graph/config.py | 22 +- .../core/workflow/nodes/cycle_graph/node.py | 164 +++--------- api/app/core/workflow/nodes/node_factory.py | 1 + 5 files changed, 321 insertions(+), 281 deletions(-) create mode 100644 api/app/core/workflow/graph_builder.py diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 0d0879d7..7274764a 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -10,11 +10,10 @@ import logging from typing import Any from langchain_core.messages import HumanMessage -from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_condition -from app.core.workflow.nodes import WorkflowState, NodeFactory +from app.core.workflow.graph_builder import GraphBuilder +from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.enums import NodeType # from app.core.tools.registry import ToolRegistry @@ -191,159 +190,10 @@ class WorkflowExecutor: 编译后的状态图 """ logger.info(f"开始构建工作流图: execution_id={self.execution_id}") - - # 分析 End 节点的前缀配置和相邻且被引用的节点 - end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set()) - - # 1. 创建状态图 - workflow = StateGraph(WorkflowState) - - # 2. 添加所有节点(包括 start 和 end) - start_node_id = None - end_node_ids = [] - - for node in self.nodes: - node_type = node.get("type") - node_id = node.get("id") - cycle_node = node.get("cycle") - if cycle_node: - # 处于循环子图中的节点由 CycleGraphNode 进行构建处理 - continue - - # 记录 start 和 end 节点 ID - if node_type == NodeType.START: - start_node_id = node_id - elif node_type == NodeType.END: - end_node_ids.append(node_id) - - # 创建节点实例(现在 start 和 end 也会被创建) - node_instance = NodeFactory.create_node(node, self.workflow_config) - - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: - expressions = node_instance.build_conditional_edge_expressions() - - # Number of branches, usually matches the number of conditional expressions - branch_number = len(expressions) - - # Find all edges whose source is the current node - related_edge = [edge for edge in self.edges if edge.get("source") == node_id] - - # Iterate over each branch - for idx in range(branch_number): - # Generate a condition expression for each edge - # Used later to determine which branch to take based on the node's output - # Assumes node output `node..output` matches the edge's label - # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' - related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" - - if node_instance: - # 如果是流式模式,且节点有 End 前缀配置,注入配置 - if stream and node_id in end_prefixes: - # 将 End 前缀配置注入到节点实例 - node_instance._end_node_prefix = end_prefixes[node_id] - logger.info(f"为节点 {node_id} 注入 End 前缀配置") - - # 如果是流式模式,标记节点是否与 End 相邻且被引用 - if stream: - node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced - if node_id in adjacent_and_referenced: - logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用") - - # 包装节点的 run 方法 - # 使用函数工厂避免闭包问题 - if stream: - # 流式模式:创建 async generator 函数 - # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state - def make_stream_func(inst): - async def node_func(state: WorkflowState): - # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") - async for item in inst.run_stream(state): - yield item - - return node_func - - workflow.add_node(node_id, make_stream_func(node_instance)) - else: - # 非流式模式:创建 async function - def make_func(inst): - async def node_func(state: WorkflowState): - return await inst.run(state) - - return node_func - - workflow.add_node(node_id, make_func(node_instance)) - - logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})") - - # 3. 添加边 - # 从 START 连接到 start 节点 - if start_node_id: - workflow.add_edge(START, start_node_id) - logger.debug(f"添加边: START -> {start_node_id}") - - for edge in self.workflow_config.get("edges", []): - source = edge.get("source") - target = edge.get("target") - edge_type = edge.get("type") - condition = edge.get("condition") - - # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) - if source == start_node_id: - # 但要连接 start 到下一个节点 - workflow.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - continue - - # # 处理到 end 节点的边 - # if target in end_node_ids: - # # 连接到 end 节点 - # workflow.add_edge(source, target) - # logger.debug(f"添加边: {source} -> {target}") - # continue - - # 跳过错误边(在节点内部处理) - if edge_type == "error": - continue - - if condition: - # 条件边 - def make_router(cond, tgt): - """Dynamically generate a conditional router function to ensure each branch has a unique name.""" - - - def router_fn(state: WorkflowState): - if evaluate_condition( - cond, - state.get("variables", {}), - state.get("node_outputs", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } - ): - return tgt - return END - - # 动态修改函数名,避免重复 - router_fn.__name__ = f"router_{tgt}" - return router_fn - - router_fn = make_router(condition, target) - workflow.add_conditional_edges(source, router_fn) - logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") - else: - # 普通边 - workflow.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - - # 从 end 节点连接到 END - for end_node_id in end_node_ids: - workflow.add_edge(end_node_id, END) - logger.debug(f"添加边: {end_node_id} -> END") - - # 4. 编译图 - graph = workflow.compile() + graph = GraphBuilder( + self.workflow_config, + stream=stream, + ).build() logger.info(f"工作流图构建完成: execution_id={self.execution_id}") return graph diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py new file mode 100644 index 00000000..9e80db33 --- /dev/null +++ b/api/app/core/workflow/graph_builder.py @@ -0,0 +1,253 @@ +import logging +import uuid +from typing import Any + +from langgraph.graph.state import CompiledStateGraph, StateGraph +from langgraph.graph import START, END + +from app.core.workflow.expression_evaluator import evaluate_condition +from app.core.workflow.nodes import WorkflowState, NodeFactory +from app.core.workflow.nodes.enums import NodeType + +logger = logging.getLogger(__name__) + + +# TODO: 子图拆解支持 +class GraphBuilder: + def __init__( + self, + workflow_config: dict[str, Any], + stream: bool = False, + subgraph: bool = False, + ): + self.workflow_config = workflow_config + + self.stream = stream + self.subgraph = subgraph + + self.start_node_id = None + self.end_node_ids = [] + + self.graph: StateGraph | CompiledStateGraph | None = None + + @property + def nodes(self) -> list[dict[str, Any]]: + return self.workflow_config.get("nodes", []) + + @property + def edges(self) -> list[dict[str, Any]]: + return self.workflow_config.get("edges", []) + + def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: + """分析 End 节点的前缀配置 + + 检查每个 End 节点的模板,找到直接上游节点的引用, + 提取该引用之前的前缀部分。 + + Returns: + 元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合}) + """ + import re + + prefixes = {} + adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点 + + # 找到所有 End 节点 + end_nodes = [node for node in self.nodes if node.get("type") == "end"] + logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点") + + for end_node in end_nodes: + end_node_id = end_node.get("id") + output_template = end_node.get("config", {}).get("output") + + logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}") + + if not output_template: + continue + + # 查找模板中引用了哪些节点 + # 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格) + pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}' + matches = list(re.finditer(pattern, output_template)) + + logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用") + + # 找到所有直接连接到 End 节点的上游节点 + direct_upstream_nodes = [] + for edge in self.edges: + if edge.get("target") == end_node_id: + source_node_id = edge.get("source") + direct_upstream_nodes.append(source_node_id) + + logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}") + + # 找到第一个直接上游节点的引用 + for match in matches: + referenced_node_id = match.group(1) + logger.info(f"[前缀分析] 检查引用: {referenced_node_id}") + + if referenced_node_id in direct_upstream_nodes: + # 这是直接上游节点的引用,提取前缀 + prefix = output_template[:match.start()] + + logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'") + + # 标记这个节点为"相邻且被引用" + adjacent_and_referenced.add(referenced_node_id) + + if prefix: + prefixes[referenced_node_id] = prefix + logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'") + + # 只处理第一个直接上游节点的引用 + break + + logger.info(f"[前缀分析] 最终配置: {prefixes}") + logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") + return prefixes, adjacent_and_referenced + + def add_nodes(self): + end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set()) + + for node in self.nodes: + node_type = node.get("type") + node_id = node.get("id") + cycle_node = node.get("cycle") + if cycle_node: + # 处于循环子图中的节点由 CycleGraphNode 进行构建处理 + if not self.subgraph: + continue + + # 记录 start 和 end 节点 ID + if node_type in [NodeType.START, NodeType.CYCLE_START]: + self.start_node_id = node_id + elif node_type == NodeType.END: + self.end_node_ids.append(node_id) + + # 创建节点实例(现在 start 和 end 也会被创建) + # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph + node_instance = NodeFactory.create_node(node, self.workflow_config) + + if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: + + # Find all edges whose source is the current node + related_edge = [edge for edge in self.edges if edge.get("source") == node_id] + + # Iterate over each branch + for idx in range(len(related_edge)): + # Generate a condition expression for each edge + # Used later to determine which branch to take based on the node's output + # Assumes node output `node..output` matches the edge's label + # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' + related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + + if node_instance: + # 如果是流式模式,且节点有 End 前缀配置,注入配置 + if self.stream and node_id in end_prefixes: + # 将 End 前缀配置注入到节点实例 + node_instance._end_node_prefix = end_prefixes[node_id] + logger.info(f"为节点 {node_id} 注入 End 前缀配置") + + # 如果是流式模式,标记节点是否与 End 相邻且被引用 + if self.stream: + node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced + if node_id in adjacent_and_referenced: + logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用") + + # 包装节点的 run 方法 + # 使用函数工厂避免闭包问题 + if self.stream: + # 流式模式:创建 async generator 函数 + # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state + def make_stream_func(inst): + async def node_func(state: WorkflowState): + # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") + async for item in inst.run_stream(state): + yield item + + return node_func + + self.graph.add_node(node_id, make_stream_func(node_instance)) + else: + # 非流式模式:创建 async function + def make_func(inst): + async def node_func(state: WorkflowState): + return await inst.run(state) + + return node_func + + self.graph.add_node(node_id, make_func(node_instance)) + + logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})") + + def add_edges(self): + if self.start_node_id: + self.graph.add_edge(START, self.start_node_id) + logger.debug(f"添加边: START -> {self.start_node_id}") + + for edge in self.edges: + source = edge.get("source") + target = edge.get("target") + edge_type = edge.get("type") + condition = edge.get("condition") + + # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) + if source == self.start_node_id: + # 但要连接 start 到下一个节点 + self.graph.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + continue + + # # 处理到 end 节点的边 + # if target in end_node_ids: + # # 连接到 end 节点 + # workflow.add_edge(source, target) + # logger.debug(f"添加边: {source} -> {target}") + # continue + + # 跳过错误边(在节点内部处理) + if edge_type == "error": + continue + + if condition: + # 条件边 + def make_router(cond, tgt): + """Dynamically generate a conditional router function to ensure each branch has a unique name.""" + + def router_fn(state: WorkflowState): + if evaluate_condition( + cond, + state.get("variables", {}), + state.get("runtime_vars", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } + ): + return tgt + return END + + # 动态修改函数名,避免重复 + router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}" + return router_fn + + router_fn = make_router(condition, target) + self.graph.add_conditional_edges(source, router_fn) + logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") + else: + # 普通边 + self.graph.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + + # 从 end 节点连接到 END + for end_node_id in self.end_node_ids: + self.graph.add_edge(end_node_id, END) + logger.debug(f"添加边: {end_node_id} -> END") + return + + def build(self) -> CompiledStateGraph: + self.graph = StateGraph(WorkflowState) + self.add_nodes() + self.add_edges() # 添加边必须在添加节点之后 + return self.graph.compile() diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py index b1b613a4..fcf65717 100644 --- a/api/app/core/workflow/nodes/cycle_graph/config.py +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -1,7 +1,9 @@ +from typing import Any + from pydantic import Field, BaseModel from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType -from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType class CycleVariable(BaseNodeConfig): @@ -9,18 +11,25 @@ class CycleVariable(BaseNodeConfig): ..., description="Name of the loop variable" ) + type: VariableType = Field( ..., description="Data type of the loop variable" ) - value: str = Field( + + input_type: ValueInputType = Field( + ..., + description="Input type of the loop variable" + ) + + value: Any = Field( ..., description="Initial or current value of the loop variable" ) class ConditionDetail(BaseModel): - comparison_operator: ComparisonOperator = Field( + operator: ComparisonOperator = Field( ..., description="Operator used to compare the left and right operands" ) @@ -30,11 +39,16 @@ class ConditionDetail(BaseModel): description="Left-hand operand of the comparison expression" ) - right: str = Field( + right: Any = Field( ..., description="Right-hand operand of the comparison expression" ) + input_type: ValueInputType = Field( + ..., + description="Input type of the loop variable" + ) + class ConditionsConfig(BaseModel): """Configuration for loop condition evaluation""" diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 2428ef46..fb062f39 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -1,10 +1,9 @@ import logging from typing import Any -from langgraph.graph import StateGraph, START, END +from langgraph.graph import StateGraph from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig @@ -17,12 +16,18 @@ logger = logging.getLogger(__name__) class CycleGraphNode(BaseNode): """ - Node representing a cycle (loop) subgraph within the workflow. + Node representing a cyclic (loop or iteration) subgraph within the workflow. - This node manages internal loop/iteration nodes, builds a subgraph - for execution, handles conditional routing, and executes loop - or iteration logic based on node type. + A CycleGraphNode is a structural node that: + - Extracts a group of nodes marked as belonging to the same cycle + - Builds an isolated internal StateGraph (subgraph) + - Delegates runtime execution to LoopRuntime or IterationRuntime + depending on the node type + + This node itself does NOT execute business logic directly. + It acts as a container and execution controller for a subgraph. """ + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None @@ -38,16 +43,23 @@ class CycleGraphNode(BaseNode): def pure_cycle_graph(self) -> tuple[list, list]: """ - Extract cycle nodes and internal edges from the workflow configuration, - removing them from the global workflow. + Extract cycle-scoped nodes and internal edges from the workflow configuration. - Raises: - ValueError: If cycle nodes are connected to external nodes improperly. + This method: + - Identifies all nodes marked with `cycle == self.node_id` + - Collects edges that fully connect cycle nodes + - Removes extracted nodes and edges from the global workflow configuration + + Safety check: + - Raises an error if a cycle node is connected to an external node Returns: - Tuple containing: - - cycle_nodes: List of removed nodes - - cycle_edges: List of removed edges + tuple[list, list]: + - cycle_nodes: Nodes belonging to this cycle + - cycle_edges: Edges connecting nodes within the cycle + + Raises: + ValueError: If a cycle node is improperly connected to an external node. """ nodes = self.workflow_config.get("nodes", []) edges = self.workflow_config.get("edges", []) @@ -83,131 +95,41 @@ class CycleGraphNode(BaseNode): return cycle_nodes, cycle_edges - def create_node(self): - """ - Instantiate node objects for each node in the cycle subgraph and add them to the graph. - - Special handling is applied for conditional nodes to generate - edge conditions based on node outputs. - """ - from app.core.workflow.nodes import NodeFactory - for node in self.cycle_nodes: - node_type = node.get("type") - node_id = node.get("id") - - if node_type == NodeType.CYCLE_START: - self.start_node_id = node_id - continue - elif node_type == NodeType.END: - self.end_node_ids.append(node_id) - - node_instance = NodeFactory.create_node(node, self.workflow_config) - - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: - expressions = node_instance.build_conditional_edge_expressions() - - # Number of branches, usually matches the number of conditional expressions - branch_number = len(expressions) - - # Find all edges whose source is the current node - related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id] - - # Iterate over each branch - for idx in range(branch_number): - # Generate a condition expression for each edge - # Used later to determine which branch to take based on the node's output - # Assumes node output `node..output` matches the edge's label - # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' - related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" - - def make_func(inst): - async def node_func(state: WorkflowState): - return await inst.run(state) - - return node_func - - self.graph.add_node(node_id, make_func(node_instance)) - - def create_edge(self): - """ - Connect nodes within the cycle subgraph by adding edges to the internal graph. - - Conditional edges are routed based on evaluated expressions. - Start and end nodes are connected to global START and END nodes. - """ - for edge in self.cycle_edges: - source = edge.get("source") - target = edge.get("target") - edge_type = edge.get("type") - condition = edge.get("condition") - - # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) - if source == self.start_node_id: - # 但要连接 start 到下一个节点 - self.graph.add_edge(START, target) - logger.debug(f"添加边: {source} -> {target}") - continue - - if condition: - # 条件边 - def router(state: WorkflowState, cond=condition, tgt=target): - """条件路由函数""" - if evaluate_condition( - cond, - state.get("variables", {}), - state.get("node_outputs", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } - ): - return tgt - return END # 条件不满足,结束 - - self.graph.add_conditional_edges(source, router) - logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") - else: - # 普通边 - self.graph.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - - # 从 end 节点连接到 END - for end_node_id in self.end_node_ids: - self.graph.add_edge(end_node_id, END) - logger.debug(f"添加边: {end_node_id} -> END") - def build_graph(self): """ - Build the internal subgraph for the cycle node. + Build and compile the internal subgraph for this cycle node. Steps: - 1. Extract cycle nodes and edges. - 2. Create node instances and add them to the graph. - 3. Connect edges and conditional routes. - 4. Compile the graph for execution. + 1. Extract cycle nodes and internal edges from the workflow + 2. Construct a StateGraph using GraphBuilder in subgraph mode + 3. Compile the graph for runtime execution """ - self.graph = StateGraph(WorkflowState) + from app.core.workflow.graph_builder import GraphBuilder self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() - self.create_node() - self.create_edge() - self.graph = self.graph.compile() + self.graph = GraphBuilder( + { + "nodes": self.cycle_nodes, + "edges": self.cycle_edges, + }, + subgraph=True + ).build() async def execute(self, state: WorkflowState) -> Any: """ Execute the cycle node at runtime. - Depending on the node type, runs either a loop (LoopRuntime) - or an iteration (IterationRuntime) over the internal subgraph. + Based on the node type: + - LOOP: Executes LoopRuntime, repeatedly invoking the subgraph + - ITERATION: Executes IterationRuntime, iterating over a collection Args: - state: Current workflow state. + state: The current workflow state when entering the cycle node. Returns: - Runtime result of the cycle, typically the final loop/iteration variables. + Any: The runtime result produced by the loop or iteration executor. Raises: - RuntimeError: If node type is unrecognized. + RuntimeError: If the node type is unsupported. """ if self.node_type == NodeType.LOOP: return await LoopRuntime( diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index ed26533d..f86a2b9b 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -72,6 +72,7 @@ class NodeFactory: NodeType.LOOP: CycleGraphNode, NodeType.ITERATION: CycleGraphNode, NodeType.BREAK: BreakNode, + NodeType.CYCLE_START: StartNode, } @classmethod From fc4cf418e095e9a6c0b1b8ccdbe149c8ae847ec7 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 11:29:27 +0800 Subject: [PATCH 10/10] feat(workflow): add support for question classifier in graph construction --- api/app/core/workflow/graph_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index 9e80db33..b24d5202 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -128,7 +128,7 @@ class GraphBuilder: # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph node_instance = NodeFactory.create_node(node, self.workflow_config) - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: + if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]: # Find all edges whose source is the current node related_edge = [edge for edge in self.edges if edge.get("source") == node_id]