diff --git a/api/app/core/models/rerank.py b/api/app/core/models/rerank.py index 64b3b566..c4b91e25 100644 --- a/api/app/core/models/rerank.py +++ b/api/app/core/models/rerank.py @@ -1,4 +1,3 @@ - from typing import Any, Dict, List, Optional, Sequence, Type, Union from copy import deepcopy from urllib.parse import urlparse @@ -8,8 +7,10 @@ from langchain_core.callbacks import Callbacks from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory from app.models import ModelProvider + class RedBearRerank(BaseDocumentCompressor): """ Rerank → 作为 Runnable 插入任意 LCEL 链""" + def __init__(self, config: RedBearModelConfig): self._model = self._create_model(config) self._config = config @@ -22,10 +23,10 @@ class RedBearRerank(BaseDocumentCompressor): return model_class(**model_params) def compress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using Jina's Rerank API. @@ -46,17 +47,17 @@ class RedBearRerank(BaseDocumentCompressor): compressed.append(doc_copy) return compressed - def rerank( - self, - documents: Sequence[Union[str, Document, dict]], - query: str, - *, - top_n: Optional[int] = -1, - ) -> List[Dict[str, Any]]: - provider = self._config.provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + self, + documents: Sequence[Union[str, Document, dict]], + query: str, + *, + top_n: Optional[int] = -1, + ) -> List[Dict[str, Any]]: + provider = self._config.provider.lower() + if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: import langchain_community.document_compressors.jina_rerank as jina_mod + # 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]: if not base_url: @@ -73,8 +74,7 @@ class RedBearRerank(BaseDocumentCompressor): # 设置完整的 rerank 端点,例如 http://host:port/v1/rerank jina_mod.JINA_API_URL = jina_base from langchain_community.document_compressors import JinaRerank - model_instance : JinaRerank = self._model - return model_instance.rerank(documents = documents, query = query, top_n=top_n) + model_instance: JinaRerank = self._model + return model_instance.rerank(documents=documents, query=query, top_n=top_n) else: raise ValueError(f"不支持的模型提供商: {provider}") - \ No newline at end of file diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index de10f6f6..1cbfb66b 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -4,9 +4,9 @@ 基于 LangGraph 的工作流执行引擎。 """ -import logging # import uuid import datetime +import logging from typing import Any from langchain_core.messages import HumanMessage @@ -107,7 +107,13 @@ class WorkflowExecutor: "user_id": self.user_id, "error": None, "error_node": None, - "streaming_buffer": {} # 流式缓冲区 + "streaming_buffer": {}, # 流式缓冲区 + "cycle_nodes": [ + node.get("id") + for node in self.workflow_config.get("nodes") + if node.get("type") in [NodeType.LOOP, NodeType.ITERATION] + ], # loop, iteration node id + "looping": False # loop runing flag, only use in loop node,not use in main loop } def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: @@ -199,6 +205,10 @@ class WorkflowExecutor: for node in self.nodes: node_type = node.get("type") node_id = node.get("id") + cycle_node = node.get("cycle") + if cycle_node: + # 处于循环子图中的节点由 CycleGraphNode 进行构建处理 + continue # 记录 start 和 end 节点 ID if node_type == NodeType.START: @@ -271,7 +281,7 @@ class WorkflowExecutor: workflow.add_edge(START, start_node_id) logger.debug(f"添加边: START -> {start_node_id}") - for edge in self.edges: + for edge in self.workflow_config.get("edges", []): source = edge.get("source") target = edge.get("target") edge_type = edge.get("type") @@ -284,12 +294,12 @@ class WorkflowExecutor: logger.debug(f"添加边: {source} -> {target}") continue - # 处理到 end 节点的边 - if target in end_node_ids: - # 连接到 end 节点 - workflow.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - continue + # # 处理到 end 节点的边 + # if target in end_node_ids: + # # 连接到 end 节点 + # workflow.add_edge(source, target) + # logger.debug(f"添加边: {source} -> {target}") + # continue # 跳过错误边(在节点内部处理) if edge_type == "error": diff --git a/api/app/core/workflow/expression_evaluator.py b/api/app/core/workflow/expression_evaluator.py index 81ab25dc..1a8b101e 100644 --- a/api/app/core/workflow/expression_evaluator.py +++ b/api/app/core/workflow/expression_evaluator.py @@ -74,6 +74,7 @@ class ExpressionEvaluator: # 为了向后兼容,也支持直接访问(但会在日志中警告) context.update(variables) context["nodes"] = node_outputs + context.update(node_outputs) try: # simpleeval 只支持安全的操作: diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index c174f52a..01ffc992 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -6,7 +6,7 @@ from app.core.workflow.expression_evaluator import ExpressionEvaluator from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import AssignmentOperator -from app.core.workflow.nodes.operators import AssignmentOperatorInstance +from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -40,8 +40,8 @@ class AssignerNode(BaseNode): variable_selector = expression.split('.') # Only conversation variables ('conv') are allowed - if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) - raise ValueError("Only conversation variables can be assigned.") + if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]: + raise ValueError("Only conversation or cycle variables can be assigned.") # Get the value or expression to assign value = assignment.value @@ -55,7 +55,9 @@ class AssignerNode(BaseNode): ) # Select the appropriate assignment operator instance based on the target variable type - operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( + operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value( + pool.get(variable_selector) + )( pool, variable_selector, value ) diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index 90d02732..1550584a 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -14,9 +14,13 @@ class VariableType(StrEnum): STRING = "string" NUMBER = "number" BOOLEAN = "boolean" - ARRAY = "array" OBJECT = "object" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_OBJECT = "array[object]" + class VariableDefinition(BaseModel): """变量定义 diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 82f3d9b8..146541cc 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -20,40 +20,44 @@ logger = logging.getLogger(__name__) class WorkflowState(TypedDict): - """工作流状态 - - 在节点间传递的状态对象,包含消息、变量、节点输出等信息。 + """Workflow state + + The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ - # 消息列表(追加模式) + # List of messages (append mode) messages: Annotated[list[AnyMessage], add] - - # 输入变量(从配置的 variables 传入) - # 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx) + + # Set of loop node IDs, used for assigning values in loop nodes + cycle_nodes: list + looping: bool + + # Input variables (passed from configured variables) + # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) variables: Annotated[dict[str, Any], lambda x, y: { **x, **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v for k, v in y.items()} }] - - # 节点输出(存储每个节点的执行结果,用于变量引用) - # 使用自定义合并函数,将新的节点输出合并到现有字典中 + + # Node outputs (stores execution results of each node for variable references) + # Uses a custom merge function to combine new node outputs into the existing dictionary node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - - # 运行时节点变量(简化版,只存储业务数据,供节点间快速访问) - # 格式:{node_id: business_result} + + # Runtime node variables (simplified version, stores business data for fast access between nodes) + # Format: {node_id: business_result} runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - # 执行上下文 + # Execution context execution_id: str workspace_id: str user_id: str - # 错误信息(用于错误边) + # Error information (for error edges) error: str | None error_node: str | None - - # 流式缓冲区(存储节点的实时流式输出) - # 格式:{node_id: {"chunks": [...], "full_content": "..."}} + + # Streaming buffer (stores real-time streaming output of nodes) + # Format: {node_id: {"chunks": [...], "full_content": "..."}} streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}] @@ -170,10 +174,10 @@ class BaseNode(ABC): import time start_time = time.time() + + timeout = self.get_timeout() try: - timeout = self.get_timeout() - # 调用节点的业务逻辑 business_result = await asyncio.wait_for( self.execute(state), @@ -200,7 +204,8 @@ class BaseNode(ABC): **wrapped_output, "runtime_vars": { self.node_id: runtime_var - } + }, + "looping": state["looping"] } except TimeoutError: @@ -236,10 +241,10 @@ class BaseNode(ABC): import time start_time = time.time() + + timeout = self.get_timeout() try: - timeout = self.get_timeout() - # Get LangGraph's stream writer for sending custom data writer = get_stream_writer() diff --git a/api/app/core/workflow/nodes/breaker/__init__.py b/api/app/core/workflow/nodes/breaker/__init__.py new file mode 100644 index 00000000..d028cc25 --- /dev/null +++ b/api/app/core/workflow/nodes/breaker/__init__.py @@ -0,0 +1,3 @@ +from app.core.workflow.nodes.breaker.node import BreakNode + +__all__ = ["BreakNode"] diff --git a/api/app/core/workflow/nodes/breaker/node.py b/api/app/core/workflow/nodes/breaker/node.py new file mode 100644 index 00000000..45568f76 --- /dev/null +++ b/api/app/core/workflow/nodes/breaker/node.py @@ -0,0 +1,32 @@ +import logging +from typing import Any + +from app.core.workflow.nodes import BaseNode, WorkflowState + +logger = logging.getLogger(__name__) + + +class BreakNode(BaseNode): + """ + Workflow node that immediately stops loop execution. + + When executed, this node sets the 'looping' flag in the workflow state + to False, signaling the outer loop runtime to terminate further iterations. + """ + + async def execute(self, state: WorkflowState) -> Any: + """ + Execute the break node. + + Args: + state: Current workflow state, including loop control flags. + + Effects: + - Sets 'looping' in the state to False to stop the loop. + - Logs the action for debugging purposes. + + Returns: + Optional dictionary indicating the loop has been stopped. + """ + state["looping"] = False + logger.info(f"run break node, looping={state['looping']}") diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index b1c64227..2ba23d4c 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig +from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig __all__ = [ # 基础类 "BaseNodeConfig", @@ -41,5 +42,7 @@ __all__ = [ "JinjaRenderNodeConfig", "VariableAggregatorNodeConfig", "ParameterExtractorNodeConfig", + "LoopNodeConfig", + "IterationNodeConfig", "QuestionClassifierNodeConfig" ] diff --git a/api/app/core/workflow/nodes/cycle_graph/__init__.py b/api/app/core/workflow/nodes/cycle_graph/__init__.py new file mode 100644 index 00000000..dc2d72e0 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig +from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode + +__all__ = ['CycleGraphNode', 'LoopNodeConfig', 'IterationNodeConfig'] diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py new file mode 100644 index 00000000..b1b613a4 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -0,0 +1,96 @@ +from pydantic import Field, BaseModel + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator + + +class CycleVariable(BaseNodeConfig): + name: str = Field( + ..., + description="Name of the loop variable" + ) + type: VariableType = Field( + ..., + description="Data type of the loop variable" + ) + value: str = Field( + ..., + description="Initial or current value of the loop variable" + ) + + +class ConditionDetail(BaseModel): + comparison_operator: ComparisonOperator = Field( + ..., + description="Operator used to compare the left and right operands" + ) + + left: str = Field( + ..., + description="Left-hand operand of the comparison expression" + ) + + right: str = Field( + ..., + description="Right-hand operand of the comparison expression" + ) + + +class ConditionsConfig(BaseModel): + """Configuration for loop condition evaluation""" + + logical_operator: LogicOperator = Field( + default=LogicOperator.AND.value, + description="Logical operator used to combine multiple condition expressions" + ) + + expressions: list[ConditionDetail] = Field( + ..., + description="Collection of condition expressions to be evaluated" + ) + + +class LoopNodeConfig(BaseNodeConfig): + condition: ConditionsConfig = Field( + default_factory=list, + description="Conditional configuration that controls loop execution" + ) + + cycle_vars: list[CycleVariable] = Field( + default_factory=list, + description="List of variables used and updated during the loop" + ) + + max_loop: int = Field( + default=10, + description="Maximum number of loop iterations" + ) + + +class IterationNodeConfig(BaseNodeConfig): + input: str = Field( + ..., + description="Input of the loop iteration" + ) + + parallel: bool = Field( + default=False, + description="Whether to execute loop iterations in parallel" + ) + + parallel_count: int = Field( + default=4, + description="Number of iterations to run in parallel" + ) + + flatten: bool = Field( + default=False, + description="Whether to flatten the output list from iterations" + ) + + output: str = Field( + ..., + description="Output of the loop iteration" + ) + + diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py new file mode 100644 index 00000000..a1b93f24 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -0,0 +1,146 @@ +import asyncio +import copy +import re +from typing import Any + +from langgraph.graph.state import CompiledStateGraph + +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.cycle_graph import IterationNodeConfig +from app.core.workflow.variable_pool import VariablePool + + +class IterationRuntime: + """ + Runtime executor for loop/iteration nodes in a workflow. + + This class handles executing iterations over a list variable, supporting + optional parallel execution, flattening of output, and loop control via + the workflow state. + """ + def __init__( + self, + graph: CompiledStateGraph, + node_id: str, + config: dict[str, Any], + state: WorkflowState, + ): + """ + Initialize the iteration runtime. + + Args: + graph: Compiled workflow graph capable of async invocation. + node_id: Unique identifier of the loop node. + config: Dictionary containing iteration node configuration. + state: Current workflow state at the point of iteration. + """ + self.graph = graph + self.state = state + self.node_id = node_id + self.typed_config = IterationNodeConfig(**config) + self.looping = True + + self.output_value = None + self.result: list = [] + + def _init_iteration_state(self, item, idx): + """ + Initialize a per-iteration copy of the workflow state. + + Args: + item: Current element from the input array for this iteration. + idx: Index of the element in the input array. + + Returns: + A deep copy of the workflow state with iteration-specific variables set. + """ + loopstate = WorkflowState( + **copy.deepcopy(self.state) + ) + loopstate["runtime_vars"][self.node_id] = { + "item": item, + "index": idx, + } + loopstate["node_outputs"][self.node_id] = { + "item": item, + "index": idx, + } + loopstate["looping"] = True + return loopstate + + async def run_task(self, item, idx): + """ + Execute a single iteration asynchronously. + + Args: + item: The input element for this iteration. + idx: The index of this iteration. + """ + result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + output = VariablePool(result).get(self.output_value) + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + if not result["looping"]: + self.looping = False + + def _create_iteration_tasks(self, array_obj, idx): + """ + Create async tasks for a batch of iterations based on parallel count. + + Args: + array_obj: The input array to iterate over. + idx: Starting index for this batch of iterations. + + Returns: + List of coroutine tasks ready to be executed in parallel. + """ + tasks = [] + for i in range(self.typed_config.parallel_count): + if idx + i >= len(array_obj): + break + item = array_obj[idx + i] + tasks.append(self.run_task(item, idx + i)) + return tasks + + async def run(self): + """ + Execute the loop over the input array according to configuration. + + Returns: + A list of outputs from all iterations, optionally flattened. + + Raises: + RuntimeError: If the input variable is not a list. + """ + pattern = r"\{\{\s*(.*?)\s*\}\}" + input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip() + self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip() + + array_obj = VariablePool(self.state).get(input_expression) + if not isinstance(array_obj, list): + raise RuntimeError("Cannot iterate over a non-list variable") + + idx = 0 + if self.typed_config.parallel: + # Execute iterations in parallel batches + while idx < len(array_obj) and self.looping: + tasks = self._create_iteration_tasks(array_obj, idx) + idx += self.typed_config.parallel_count + await asyncio.gather(*tasks) + return self.result + else: + # Execute iterations sequentially + while idx < len(array_obj) and self.looping: + item = array_obj[idx] + result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + output = VariablePool(result).get(self.output_value) + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + if not result["looping"]: + self.looping = False + idx += 1 + return self.result diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py new file mode 100644 index 00000000..247425c3 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -0,0 +1,124 @@ +from typing import Any + +from langgraph.graph.state import CompiledStateGraph + +from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.cycle_graph import LoopNodeConfig +from app.core.workflow.nodes.operators import ConditionExpressionBuilder +from app.core.workflow.variable_pool import VariablePool + + +class LoopRuntime: + """ + Runtime executor for loop nodes in a workflow. + + Handles iterative execution of a loop node according to defined loop variables + and conditional expressions. Supports maximum loop count and loop control + through the workflow state. + """ + + def __init__( + self, + graph: CompiledStateGraph, + node_id: str, + config: dict[str, Any], + state: WorkflowState, + ): + """ + Initialize the loop runtime. + + Args: + graph: Compiled workflow graph capable of async invocation. + node_id: Unique identifier of the loop node. + config: Dictionary containing loop node configuration. + state: Current workflow state at the point of loop execution. + """ + self.graph = graph + self.state = state + self.node_id = node_id + self.typed_config = LoopNodeConfig(**config) + + def _init_loop_state(self): + """ + Initialize workflow state for loop execution. + + - Evaluates initial values of loop variables. + - Stores loop variables in runtime_vars and node_outputs. + - Marks the loop as active by setting 'looping' to True. + + Returns: + A copy of the workflow state prepared for the loop execution. + """ + pool = VariablePool(self.state) + # 循环变量 + self.state["runtime_vars"][self.node_id] = { + variable.name: evaluate_expression( + expression=variable.value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) + for variable in self.typed_config.cycle_vars + } + self.state["node_outputs"][self.node_id] = { + variable.name: evaluate_expression( + expression=variable.value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) + for variable in self.typed_config.cycle_vars + } + loopstate = WorkflowState( + **self.state + ) + loopstate["looping"] = True + return loopstate + + def _get_loop_expression(self): + """ + Build the Python boolean expression for evaluating the loop condition. + + - Converts each condition in the loop configuration into a Python expression string. + - Combines multiple conditions with the configured logical operator (AND/OR). + + Returns: + A string representing the combined loop condition expression. + """ + branch_conditions = [ + ConditionExpressionBuilder( + left=condition.left, + operator=condition.comparison_operator, + right=condition.right + ).build() + for condition in self.typed_config.condition.expressions + ] + if len(branch_conditions) > 1: + combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions) + else: + combined_condition = branch_conditions[0] + + return combined_condition + + async def run(self): + """ + Execute the loop node until the condition is no longer met, the loop is + manually stopped, or the maximum loop count is reached. + + Returns: + The final runtime variables of this loop node after completion. + """ + loopstate = self._init_loop_state() + expression = self._get_loop_expression() + loop_variable_pool = VariablePool(loopstate) + loop_time = self.typed_config.max_loop + while evaluate_condition( + expression=expression, + variables=loop_variable_pool.get_all_conversation_vars(), + node_outputs=loop_variable_pool.get_all_node_outputs(), + system_vars=loop_variable_pool.get_all_system_vars(), + ) and loopstate["looping"] and loop_time > 0: + await self.graph.ainvoke(loopstate) + loop_time -= 1 + return loopstate["runtime_vars"][self.node_id] diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py new file mode 100644 index 00000000..a8744035 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -0,0 +1,223 @@ +import logging +from typing import Any + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.state import CompiledStateGraph + +from app.core.workflow.expression_evaluator import evaluate_condition +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig +from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime +from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime +from app.core.workflow.nodes.enums import NodeType + +logger = logging.getLogger(__name__) + + +class CycleGraphNode(BaseNode): + """ + Node representing a cycle (loop) subgraph within the workflow. + + This node manages internal loop/iteration nodes, builds a subgraph + for execution, handles conditional routing, and executes loop + or iteration logic based on node type. + """ + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None + + self.cycle_nodes = list() # Nodes belonging to this cycle + self.cycle_edges = list() # Edges connecting nodes within the cycle + self.start_node_id = None # ID of the start node within the cycle + self.end_node_ids = [] # IDs of end nodes within the cycle + + self.graph: StateGraph | CompiledStateGraph | None = None + self.build_graph() + self.iteration_flag = True + + def pure_cycle_graph(self) -> tuple[list, list]: + """ + Extract cycle nodes and internal edges from the workflow configuration, + removing them from the global workflow. + + Raises: + ValueError: If cycle nodes are connected to external nodes improperly. + + Returns: + Tuple containing: + - cycle_nodes: List of removed nodes + - cycle_edges: List of removed edges + """ + nodes = self.workflow_config.get("nodes", []) + edges = self.workflow_config.get("edges", []) + + # Select all nodes that belong to the current cycle + cycle_nodes = [node for node in nodes if node.get("cycle") == self.node_id] + cycle_node_ids = {node.get("id") for node in cycle_nodes} + + cycle_edges = [] + remain_edges = [] + + for edge in edges: + source_in = edge.get("source") in cycle_node_ids + target_in = edge.get("target") in cycle_node_ids + + # Raise error if cycle nodes are connected with external nodes + if source_in ^ target_in: + raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}") + + if source_in and target_in: + cycle_edges.append(edge) + else: + remain_edges.append(edge) + + # Update workflow_config by removing cycle nodes and internal edges + self.workflow_config["nodes"] = [ + node for node in nodes if node.get("cycle") != self.node_id + ] + self.workflow_config["edges"] = remain_edges + + return cycle_nodes, cycle_edges + + def create_node(self): + """ + Instantiate node objects for each node in the cycle subgraph and add them to the graph. + + Special handling is applied for conditional nodes to generate + edge conditions based on node outputs. + """ + from app.core.workflow.nodes import NodeFactory + for node in self.cycle_nodes: + node_type = node.get("type") + node_id = node.get("id") + + if node_type == NodeType.CYCLE_START: + self.start_node_id = node_id + continue + elif node_type == NodeType.END: + self.end_node_ids.append(node_id) + + node_instance = NodeFactory.create_node(node, self.workflow_config) + + if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: + expressions = node_instance.build_conditional_edge_expressions() + + # Number of branches, usually matches the number of conditional expressions + branch_number = len(expressions) + + # Find all edges whose source is the current node + related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id] + + # Iterate over each branch + for idx in range(branch_number): + # Generate a condition expression for each edge + # Used later to determine which branch to take based on the node's output + # Assumes node output `node..output` matches the edge's label + # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' + related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + + def make_func(inst): + async def node_func(state: WorkflowState): + return await inst.run(state) + + return node_func + + self.graph.add_node(node_id, make_func(node_instance)) + + def create_edge(self): + """ + Connect nodes within the cycle subgraph by adding edges to the internal graph. + + Conditional edges are routed based on evaluated expressions. + Start and end nodes are connected to global START and END nodes. + """ + for edge in self.cycle_edges: + source = edge.get("source") + target = edge.get("target") + edge_type = edge.get("type") + condition = edge.get("condition") + + # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) + if source == self.start_node_id: + # 但要连接 start 到下一个节点 + self.graph.add_edge(START, target) + logger.debug(f"添加边: {source} -> {target}") + continue + + if condition: + # 条件边 + def router(state: WorkflowState, cond=condition, tgt=target): + """条件路由函数""" + if evaluate_condition( + cond, + state.get("variables", {}), + state.get("node_outputs", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } + ): + return tgt + return END # 条件不满足,结束 + + self.graph.add_conditional_edges(source, router) + logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") + else: + # 普通边 + self.graph.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + + # 从 end 节点连接到 END + for end_node_id in self.end_node_ids: + self.graph.add_edge(end_node_id, END) + logger.debug(f"添加边: {end_node_id} -> END") + + def build_graph(self): + """ + Build the internal subgraph for the cycle node. + + Steps: + 1. Extract cycle nodes and edges. + 2. Create node instances and add them to the graph. + 3. Connect edges and conditional routes. + 4. Compile the graph for execution. + """ + self.graph = StateGraph(WorkflowState) + self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.create_node() + self.create_edge() + self.graph = self.graph.compile() + + async def execute(self, state: WorkflowState) -> Any: + """ + Execute the cycle node at runtime. + + Depending on the node type, runs either a loop (LoopRuntime) + or an iteration (IterationRuntime) over the internal subgraph. + + Args: + state: Current workflow state. + + Returns: + Runtime result of the cycle, typically the final loop/iteration variables. + + Raises: + RuntimeError: If node type is unrecognized. + """ + if self.node_type == NodeType.LOOP: + return await LoopRuntime( + graph=self.graph, + node_id=self.node_id, + config=self.config, + state=state, + ).run() + if self.node_type == NodeType.ITERATION: + return await IterationRuntime( + graph=self.graph, + node_id=self.node_id, + config=self.config, + state=state, + ).run() + raise RuntimeError("未知循环节点类型") diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index b4cc0634..0492a7bf 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -1,14 +1,5 @@ from enum import StrEnum -from app.core.workflow.nodes.operators import ( - StringOperator, - NumberOperator, - AssignmentOperatorType, - BooleanOperator, - ArrayOperator, - ObjectOperator -) - class NodeType(StrEnum): START = "start" @@ -27,6 +18,10 @@ class NodeType(StrEnum): JINJARENDER = "jinja-render" VAR_AGGREGATOR = "var-aggregator" PARAMETER_EXTRACTOR = "parameter-extractor" + LOOP = "loop" + ITERATION = "iteration" + CYCLE_START = "cycle-start" + BREAK = "break" class ComparisonOperator(StrEnum): @@ -62,21 +57,6 @@ class AssignmentOperator(StrEnum): REMOVE_LAST = "remove_last" REMOVE_FIRST = "remove_first" - @classmethod - def get_operator(cls, obj) -> AssignmentOperatorType: - if isinstance(obj, str): - return StringOperator - elif isinstance(obj, bool): - return BooleanOperator - elif isinstance(obj, (int, float)): - return NumberOperator - elif isinstance(obj, list): - return ArrayOperator - elif isinstance(obj, dict): - return ObjectOperator - - raise TypeError(f"Unsupported variable type ({type(obj)})") - class HttpRequestMethod(StrEnum): GET = "GET" diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 4e424b54..9eddb473 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -30,7 +30,7 @@ class ConditionBranchConfig(BaseModel): description="Logical operator used to combine multiple condition expressions" ) - conditions: list[ConditionDetail] = Field( + expressions: list[ConditionDetail] = Field( ..., description="List of condition expressions within this branch" ) @@ -57,7 +57,7 @@ class IfElseNodeConfig(BaseNodeConfig): # CASE1 / IF Branch { "logical_operator": "and", - "conditions": [ + "expressions": [ [ { "left": "node.userinput.message", @@ -75,7 +75,7 @@ class IfElseNodeConfig(BaseNodeConfig): # CASE1 / ELIF Branch { "logical_operator": "or", - "conditions": [ + "expressions": [ [ { "left": "node.userinput.test", diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 579c2840..03a2b430 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -2,93 +2,13 @@ import logging from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState -from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import ConditionDetail +from app.core.workflow.nodes.operators import ConditionExpressionBuilder logger = logging.getLogger(__name__) -class ConditionExpressionBuilder: - """ - Build a Python boolean expression string based on a comparison operator. - - This class does not evaluate the expression. - It only generates a valid Python expression string - that can be evaluated later in a workflow context. - """ - - def __init__(self, left: str, operator: ComparisonOperator, right: str): - self.left = left - self.operator = operator - self.right = right - - def _empty(self): - return f"{self.left} == ''" - - def _not_empty(self): - return f"{self.left} != ''" - - def _contains(self): - return f"{self.right} in {self.left}" - - def _not_contains(self): - return f"{self.right} not in {self.left}" - - def _startwith(self): - return f'{self.left}.startswith({self.right})' - - def _endwith(self): - return f'{self.left}.endswith({self.right})' - - def _eq(self): - return f"{self.left} == {self.right}" - - def _ne(self): - return f"{self.left} != {self.right}" - - def _lt(self): - return f"{self.left} < {self.right}" - - def _le(self): - return f"{self.left} <= {self.right}" - - def _gt(self): - return f"{self.left} > {self.right}" - - def _ge(self): - return f"{self.left} >= {self.right}" - - def build(self): - match self.operator: - case ComparisonOperator.EMPTY: - return self._empty() - case ComparisonOperator.NOT_EMPTY: - return self._not_empty() - case ComparisonOperator.CONTAINS: - return self._contains() - case ComparisonOperator.NOT_CONTAINS: - return self._not_contains() - case ComparisonOperator.START_WITH: - return self._startwith() - case ComparisonOperator.END_WITH: - return self._endwith() - case ComparisonOperator.EQ: - return self._eq() - case ComparisonOperator.NE: - return self._ne() - case ComparisonOperator.LT: - return self._lt() - case ComparisonOperator.LE: - return self._le() - case ComparisonOperator.GT: - return self._gt() - case ComparisonOperator.GE: - return self._ge() - case _: - raise ValueError(f"Invalid condition: {self.operator}") - - class IfElseNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) @@ -143,7 +63,7 @@ class IfElseNode(BaseNode): branch_conditions = [ self._build_condition_expression(condition) - for condition in case_branch.conditions + for condition in case_branch.expressions ] if len(branch_conditions) > 1: combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py index 09c23855..cdb83131 100644 --- a/api/app/core/workflow/nodes/knowledge/config.py +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -1,18 +1,13 @@ from uuid import UUID -from pydantic import Field +from pydantic import Field, BaseModel from app.core.workflow.nodes.base_config import BaseNodeConfig from app.schemas.chunk_schema import RetrieveType -class KnowledgeRetrievalNodeConfig(BaseNodeConfig): - query: str = Field( - ..., - description="Search query string" - ) - - kb_ids: list[UUID] = Field( +class KnowledgeBaseConfig(BaseModel): + kb_id: UUID = Field( ..., description="Knowledge base IDs" ) @@ -37,18 +32,42 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig): description="Retrieve type" ) + +class KnowledgeRetrievalNodeConfig(BaseNodeConfig): + query: str = Field( + ..., + description="Search query string" + ) + + knowledge_bases: list[KnowledgeBaseConfig] = Field( + ..., + description="Knowledge base config" + ) + + reranker_id: UUID = Field( + ..., + description="Reranker top k" + ) + + reranker_top_k: int = Field( + default=4, + description="Knowledge base top k" + ) + class Config: json_schema_extra = { "examples": [ { "query": "{{sys.message}}", - "kb_ids": [ - "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - ], - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, - "top_k": 1, - "retrieve_type": "hybrid" + "knowledge_bases": [{ + "kb_id": "xxxxxxxx-xxxx-xxxx-xxxxxxxxxxxxxxxxx", + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + "top_k": 4, + "retrieve_type": "hybrid" + }], + "reranker_top_k": 1, + "reranker_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" } ] } diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 319a0b88..60a8a4de 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -2,14 +2,18 @@ import logging import uuid from typing import Any +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.models import RedBearRerank, RedBearModelConfig from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.db import get_db_read -from app.models import knowledge_model, knowledgeshare_model +from app.models import knowledge_model, knowledgeshare_model, ModelType from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType from app.services import knowledge_service, knowledgeshare_service +from app.services.model_service import ModelConfigService logger = logging.getLogger(__name__) @@ -108,6 +112,44 @@ class KnowledgeRetrievalNode(BaseNode): existing_ids.extend(items) return existing_ids + def get_reranker_model(self) -> RedBearRerank: + """ + Retrieve and initialize a RedBear reranker model based on configuration. + + Raises: + BusinessException: If configuration is missing or API keys are not set. + RuntimeError: If the configured model is not of type RERANK. + """ + with get_db_read() as db: + config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.reranker_id) + + if not config: + raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND) + + if not config.api_keys or len(config.api_keys) == 0: + raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) + + # 在 Session 关闭前提取所有需要的数据 + api_config = config.api_keys[0] + model_name = api_config.model_name + provider = api_config.provider + api_key = api_config.api_key + api_base = api_config.api_base + model_type = config.type + + if model_type != ModelType.RERANK: + raise RuntimeError("Model is not a reranker") + + reranker = RedBearRerank( + RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base, + ) + ) + return reranker + async def execute(self, state: WorkflowState) -> Any: """ Execute the knowledge retrieval workflow node. @@ -131,38 +173,41 @@ class KnowledgeRetrievalNode(BaseNode): """ query = self._render_template(self.typed_config.query, state) with get_db_read() as db: - existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids) + knowledge_bases = self.typed_config.knowledge_bases + existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases]) if not existing_ids: raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") - kb_id = existing_ids[0] - uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] - indices = ",".join(uuid_strs) + rs = [] + for kb_config in knowledge_bases: + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") - db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) - if not db_knowledge: - raise RuntimeError("The knowledge base does not exist or access is denied.") - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - - match self.typed_config.retrieve_type: - case RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - case RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - case _: - rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + indices = f"Vector_index_{kb_config.kb_id}_Node".lower() + match kb_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold)) + case RetrieveType.SEMANTIC: + rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k, indices=indices, - score_threshold=self.typed_config.similarity_threshold) - # Deduplicate hybrid retrieval results - unique_rs = self._deduplicate_docs(rs1, rs2) - rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) - return [chunk.model_dump() for chunk in rs] + score_threshold=kb_config.vector_similarity_weight)) + case RetrieveType.HYBRID: + rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold) + # Deduplicate hybrid retrieval results + unique_rs = self._deduplicate_docs(rs1, rs2) + vector_service.reranker = self.get_reranker_model() + rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + case _: + raise RuntimeError("Unknown retrieval type") + final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k) + return [chunk.model_dump() for chunk in final_rs] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 90c48ac0..ed26533d 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -10,6 +10,7 @@ from typing import Any, Union from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.http_request import HttpRequestNode @@ -22,6 +23,7 @@ from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode +from app.core.workflow.nodes.breaker import BreakNode logger = logging.getLogger(__name__) @@ -39,6 +41,9 @@ WorkflowNode = Union[ JinjaRenderNode, VariableAggregatorNode, ParameterExtractorNode, + CycleGraphNode, + BreakNode, + ParameterExtractorNode, QuestionClassifierNode ] @@ -64,6 +69,9 @@ class NodeFactory: NodeType.VAR_AGGREGATOR: VariableAggregatorNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.LOOP: CycleGraphNode, + NodeType.ITERATION: CycleGraphNode, + NodeType.BREAK: BreakNode, } @classmethod diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index a80cf326..70668b6a 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -1,6 +1,7 @@ from abc import ABC from typing import Union, Type +from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.variable_pool import VariablePool @@ -136,6 +137,23 @@ class ObjectOperator(OperatorBase): self.pool.set(self.left_selector, dict()) +class AssignmentOperatorResolver: + @classmethod + def resolve_by_value(cls, value): + if isinstance(value, str): + return StringOperator + elif isinstance(value, bool): + return BooleanOperator + elif isinstance(value, (int, float)): + return NumberOperator + elif isinstance(value, list): + return ArrayOperator + elif isinstance(value, dict): + return ObjectOperator + else: + raise TypeError(f"Unsupported variable type: {type(value)}") + + AssignmentOperatorInstance = Union[ StringOperator, NumberOperator, @@ -144,3 +162,83 @@ AssignmentOperatorInstance = Union[ ObjectOperator ] AssignmentOperatorType = Type[AssignmentOperatorInstance] + + +class ConditionExpressionBuilder: + """ + Build a Python boolean expression string based on a comparison operator. + + This class does not evaluate the expression. + It only generates a valid Python expression string + that can be evaluated later in a workflow context. + """ + + def __init__(self, left: str, operator: ComparisonOperator, right: str): + self.left = left + self.operator = operator + self.right = right + + def _empty(self): + return f"{self.left} == ''" + + def _not_empty(self): + return f"{self.left} != ''" + + def _contains(self): + return f"{self.right} in {self.left}" + + def _not_contains(self): + return f"{self.right} not in {self.left}" + + def _startswith(self): + return f'{self.left}.startswith({self.right})' + + def _endswith(self): + return f'{self.left}.endswith({self.right})' + + def _eq(self): + return f"{self.left} == {self.right}" + + def _ne(self): + return f"{self.left} != {self.right}" + + def _lt(self): + return f"{self.left} < {self.right}" + + def _le(self): + return f"{self.left} <= {self.right}" + + def _gt(self): + return f"{self.left} > {self.right}" + + def _ge(self): + return f"{self.left} >= {self.right}" + + def build(self): + match self.operator: + case ComparisonOperator.EMPTY: + return self._empty() + case ComparisonOperator.NOT_EMPTY: + return self._not_empty() + case ComparisonOperator.CONTAINS: + return self._contains() + case ComparisonOperator.NOT_CONTAINS: + return self._not_contains() + case ComparisonOperator.START_WITH: + return self._startswith() + case ComparisonOperator.END_WITH: + return self._endswith() + case ComparisonOperator.EQ: + return self._eq() + case ComparisonOperator.NE: + return self._ne() + case ComparisonOperator.LT: + return self._lt() + case ComparisonOperator.LE: + return self._le() + case ComparisonOperator.GT: + return self._gt() + case ComparisonOperator.GE: + return self._ge() + case _: + raise ValueError(f"Invalid condition: {self.operator}") diff --git a/api/app/core/workflow/nodes/parameter_extractor/config.py b/api/app/core/workflow/nodes/parameter_extractor/config.py index 30c0e1ef..3b5607c5 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/config.py +++ b/api/app/core/workflow/nodes/parameter_extractor/config.py @@ -36,6 +36,11 @@ class ParamsConfig(BaseModel): description="Description of the parameter" ) + required: bool = Field( + ..., + description="Whether the parameter is required" + ) + class ParameterExtractorNodeConfig(BaseNodeConfig): model_id: uuid.UUID = Field( @@ -52,3 +57,8 @@ class ParameterExtractorNodeConfig(BaseNodeConfig): ..., description="List of parameters" ) + + prompt: str = Field( + ..., + description="User-provided supplemental prompt" + ) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 0eb3bfd4..0e311215 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -1,4 +1,5 @@ import os +import logging import json_repair from typing import Any @@ -15,6 +16,8 @@ from app.db import get_db_read from app.models import ModelType from app.services.model_service import ModelConfigService +logger = logging.getLogger(__name__) + class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): @@ -114,7 +117,7 @@ class ParameterExtractorNode(BaseNode): """ field_type = {} for param in self.typed_config.params: - field_type[param.name] = param.type + field_type[param.name] = f'{param.type}, required:{str(param.required)}' return field_type async def execute(self, state: WorkflowState) -> Any: @@ -154,12 +157,12 @@ class ParameterExtractorNode(BaseNode): messages = [ ("system", system_prompt), + ("user", self._render_template(self.typed_config.prompt, state)), ("user", rendered_user_prompt), ] model_resp = await llm.ainvoke(messages) - result = json_repair.repair_json(model_resp.content) + result = json_repair.repair_json(model_resp.content, return_objects=True) + logger.info(f"get prarms:{result}") - return { - "output": result, - } + return result diff --git a/api/app/core/workflow/nodes/variable_aggregator/config.py b/api/app/core/workflow/nodes/variable_aggregator/config.py index 84f82487..ac1419a4 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/config.py +++ b/api/app/core/workflow/nodes/variable_aggregator/config.py @@ -9,43 +9,27 @@ class VariableAggregatorNodeConfig(BaseNodeConfig): description="输出变量是否需要分组", ) - group_names: list[str] = Field( - default_factory=lambda: ["output"], - description="各个分组的名称" - ) - - group_variables: list[str] | list[list[str]] = Field( + group_variables: list[str] | dict[str, list[str]] = Field( ..., description="需要被聚合的变量" ) - @field_validator("group_names", mode="before") - @classmethod - def group_names_validator(cls, v, info): - group_status = info.data.get("group") - if not group_status or not v: - return ["output"] - return v - @field_validator("group_variables") @classmethod def group_variables_validator(cls, v, info): group_status = info.data.get("group") - group_names = info.data.get("group_names") - if not isinstance(v, list): - raise ValueError("group_variables must be a list") if not group_status: for variable in v: if not isinstance(variable, str): raise ValueError("When group=False, group_variables must be a list of strings") else: - if len(group_names) != len(v): - raise ValueError("group_names and group_variables length mismatch") - for group in v: - if not isinstance(group, list): + if not isinstance(v, dict): + raise ValueError("When group=True, group_variables must be a dict") + for group_name, group_values in v.items(): + if not isinstance(group_name, str): raise ValueError("When group=True, each element of group_variables must be a list") - for variable in group: + for variable in group_values: if not isinstance(variable, str): raise ValueError("Each element inside group_variables lists must be a string") return v diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index f53f9269..d4cc8c55 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -59,7 +59,7 @@ class VariableAggregatorNode(BaseNode): # Group mode # -------------------------- result = {} - for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables): + for group_name, variables in self.typed_config.group_variables.items(): for variable in variables: var_express = self._get_express(variable) try: diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 0f97c349..b7814f28 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -198,19 +198,22 @@ class VariablePool: namespace = selector[0] - if namespace != "conv": - raise ValueError("只能设置会话变量 (conv.*)") + if namespace != "conv" and namespace not in self.state["cycle_nodes"]: + raise ValueError("Only conversation or cycle variables can be assigned.") key = selector[1] # 确保 variables 结构存在 if "variables" not in self.state: self.state["variables"] = {"sys": {}, "conv": {}} - if "conv" not in self.state["variables"]: - self.state["variables"]["conv"] = {} - - # 设置值 - self.state["variables"]["conv"][key] = value + if namespace == "conv": + if "conv" not in self.state["variables"]: + self.state["variables"]["conv"] = {} + + # 设置值 + self.state["variables"]["conv"][key] = value + elif namespace in self.state["cycle_nodes"]: + self.state["runtime_vars"][namespace][key] = value logger.debug(f"设置变量: {'.'.join(selector)} = {value}") diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index eb337298..bdef825e 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -20,6 +20,7 @@ class NodeDefinition(BaseModel): id: str = Field(..., description="节点唯一标识") type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code") name: str | None = Field(None, description="节点名称") + cycle: str | None = Field(None, description="父循环节点id") description: str | None = Field(None, description="节点描述") config: dict[str, Any] = Field(default_factory=dict, description="节点配置") position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")