Merge #78 into develop from feature/20251219_myh

Merge branch 'develop' into feature/20251219_myh

* feature/20251219_myh: (8 commits squashed)

  - feat(workflow): update reranker model configuration for knowledge base retrieval

  - fix(workflow): fix output issue in parameter extraction node

  - fix(workflow): fix output issue in parameter extraction node

  - feat(workflow): add user prompt to parameter extraction node

  - perf(workflow): change grouped variable input to key-value format in variable aggregator

  - feat(workflow): Add new cycle node for iterative workflow execution
    
    - Introduce a new Loop/Iteration node in the workflow engine.
    - Supports both conditional loops and iteration over lists.
    - Allows parallel execution and flattening of iteration outputs.
    - Maintains runtime state, node outputs, and loop variables for downstream nodes.
    - Enhances workflow flexibility for complex, repeated operations.

  - Merge branch 'develop' into feature/20251219_myh
    
    # Conflicts:
    #	api/app/core/workflow/nodes/configs.py
    #	api/app/core/workflow/nodes/node_factory.py

  - feat(workflow): Add new cycle node for iterative workflow execution
    
    - Introduce a new Loop/Iteration node in the workflow engine.
    - Supports both conditional loops and iteration over lists.
    - Allows parallel execution and flattening of iteration outputs.
    - Maintains runtime state, node outputs, and loop variables for downstream nodes.
    - Enhances workflow flexibility for complex, repeated operations.

Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/78
This commit is contained in:
孟永豪
2025-12-29 11:57:13 +00:00
committed by 孙科
parent 6defcaf982
commit b376c3d648
27 changed files with 967 additions and 243 deletions

View File

@@ -1,4 +1,3 @@
from typing import Any, Dict, List, Optional, Sequence, Type, Union
from copy import deepcopy
from urllib.parse import urlparse
@@ -8,8 +7,10 @@ from langchain_core.callbacks import Callbacks
from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory
from app.models import ModelProvider
class RedBearRerank(BaseDocumentCompressor):
""" Rerank → 作为 Runnable 插入任意 LCEL 链"""
def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config
@@ -22,10 +23,10 @@ class RedBearRerank(BaseDocumentCompressor):
return model_class(**model_params)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Jina's Rerank API.
@@ -46,17 +47,17 @@ class RedBearRerank(BaseDocumentCompressor):
compressed.append(doc_copy)
return compressed
def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
*,
top_n: Optional[int] = -1,
) -> List[Dict[str, Any]]:
provider = self._config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
*,
top_n: Optional[int] = -1,
) -> List[Dict[str, Any]]:
provider = self._config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import langchain_community.document_compressors.jina_rerank as jina_mod
# 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank
def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]:
if not base_url:
@@ -73,8 +74,7 @@ class RedBearRerank(BaseDocumentCompressor):
# 设置完整的 rerank 端点,例如 http://host:port/v1/rerank
jina_mod.JINA_API_URL = jina_base
from langchain_community.document_compressors import JinaRerank
model_instance : JinaRerank = self._model
return model_instance.rerank(documents = documents, query = query, top_n=top_n)
model_instance: JinaRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
else:
raise ValueError(f"不支持的模型提供商: {provider}")

View File

@@ -4,9 +4,9 @@
基于 LangGraph 的工作流执行引擎。
"""
import logging
# import uuid
import datetime
import logging
from typing import Any
from langchain_core.messages import HumanMessage
@@ -107,7 +107,13 @@ class WorkflowExecutor:
"user_id": self.user_id,
"error": None,
"error_node": None,
"streaming_buffer": {} # 流式缓冲区
"streaming_buffer": {}, # 流式缓冲区
"cycle_nodes": [
node.get("id")
for node in self.workflow_config.get("nodes")
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
], # loop, iteration node id
"looping": False # loop runing flag, only use in loop node,not use in main loop
}
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
@@ -199,6 +205,10 @@ class WorkflowExecutor:
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
continue
# 记录 start 和 end 节点 ID
if node_type == NodeType.START:
@@ -271,7 +281,7 @@ class WorkflowExecutor:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
for edge in self.workflow_config.get("edges", []):
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
@@ -284,12 +294,12 @@ class WorkflowExecutor:
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":

View File

@@ -74,6 +74,7 @@ class ExpressionEvaluator:
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context["nodes"] = node_outputs
context.update(node_outputs)
try:
# simpleeval 只支持安全的操作:

View File

@@ -6,7 +6,7 @@ from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -40,8 +40,8 @@ class AssignerNode(BaseNode):
variable_selector = expression.split('.')
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
raise ValueError("Only conversation variables can be assigned.")
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
# Get the value or expression to assign
value = assignment.value
@@ -55,7 +55,9 @@ class AssignerNode(BaseNode):
)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
pool.get(variable_selector)
)(
pool, variable_selector, value
)

View File

@@ -14,9 +14,13 @@ class VariableType(StrEnum):
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
class VariableDefinition(BaseModel):
"""变量定义

View File

@@ -20,40 +20,44 @@ logger = logging.getLogger(__name__)
class WorkflowState(TypedDict):
"""工作流状态
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
"""Workflow state
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
"""
# 消息列表(追加模式)
# List of messages (append mode)
messages: Annotated[list[AnyMessage], add]
# 输入变量(从配置的 variables 传入)
# 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
looping: bool
# Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
variables: Annotated[dict[str, Any], lambda x, y: {
**x,
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
for k, v in y.items()}
}]
# 节点输出(存储每个节点的执行结果,用于变量引用)
# 使用自定义合并函数,将新的节点输出合并到现有字典中
# Node outputs (stores execution results of each node for variable references)
# Uses a custom merge function to combine new node outputs into the existing dictionary
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
# 格式:{node_id: business_result}
# Runtime node variables (simplified version, stores business data for fast access between nodes)
# Format: {node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 执行上下文
# Execution context
execution_id: str
workspace_id: str
user_id: str
# 错误信息(用于错误边)
# Error information (for error edges)
error: str | None
error_node: str | None
# 流式缓冲区(存储节点的实时流式输出)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
# Streaming buffer (stores real-time streaming output of nodes)
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
@@ -170,10 +174,10 @@ class BaseNode(ABC):
import time
start_time = time.time()
timeout = self.get_timeout()
try:
timeout = self.get_timeout()
# 调用节点的业务逻辑
business_result = await asyncio.wait_for(
self.execute(state),
@@ -200,7 +204,8 @@ class BaseNode(ABC):
**wrapped_output,
"runtime_vars": {
self.node_id: runtime_var
}
},
"looping": state["looping"]
}
except TimeoutError:
@@ -236,10 +241,10 @@ class BaseNode(ABC):
import time
start_time = time.time()
timeout = self.get_timeout()
try:
timeout = self.get_timeout()
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()

View File

@@ -0,0 +1,3 @@
from app.core.workflow.nodes.breaker.node import BreakNode
__all__ = ["BreakNode"]

View File

@@ -0,0 +1,32 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class BreakNode(BaseNode):
"""
Workflow node that immediately stops loop execution.
When executed, this node sets the 'looping' flag in the workflow state
to False, signaling the outer loop runtime to terminate further iterations.
"""
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the break node.
Args:
state: Current workflow state, including loop control flags.
Effects:
- Sets 'looping' in the state to False to stop the loop.
- Logs the action for debugging purposes.
Returns:
Optional dictionary indicating the loop has been stopped.
"""
state["looping"] = False
logger.info(f"run break node, looping={state['looping']}")

View File

@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
@@ -41,5 +42,7 @@ __all__ = [
"JinjaRenderNodeConfig",
"VariableAggregatorNodeConfig",
"ParameterExtractorNodeConfig",
"LoopNodeConfig",
"IterationNodeConfig",
"QuestionClassifierNodeConfig"
]

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
__all__ = ['CycleGraphNode', 'LoopNodeConfig', 'IterationNodeConfig']

View File

@@ -0,0 +1,96 @@
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
class CycleVariable(BaseNodeConfig):
name: str = Field(
...,
description="Name of the loop variable"
)
type: VariableType = Field(
...,
description="Data type of the loop variable"
)
value: str = Field(
...,
description="Initial or current value of the loop variable"
)
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
...,
description="Operator used to compare the left and right operands"
)
left: str = Field(
...,
description="Left-hand operand of the comparison expression"
)
right: str = Field(
...,
description="Right-hand operand of the comparison expression"
)
class ConditionsConfig(BaseModel):
"""Configuration for loop condition evaluation"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
description="Logical operator used to combine multiple condition expressions"
)
expressions: list[ConditionDetail] = Field(
...,
description="Collection of condition expressions to be evaluated"
)
class LoopNodeConfig(BaseNodeConfig):
condition: ConditionsConfig = Field(
default_factory=list,
description="Conditional configuration that controls loop execution"
)
cycle_vars: list[CycleVariable] = Field(
default_factory=list,
description="List of variables used and updated during the loop"
)
max_loop: int = Field(
default=10,
description="Maximum number of loop iterations"
)
class IterationNodeConfig(BaseNodeConfig):
input: str = Field(
...,
description="Input of the loop iteration"
)
parallel: bool = Field(
default=False,
description="Whether to execute loop iterations in parallel"
)
parallel_count: int = Field(
default=4,
description="Number of iterations to run in parallel"
)
flatten: bool = Field(
default=False,
description="Whether to flatten the output list from iterations"
)
output: str = Field(
...,
description="Output of the loop iteration"
)

View File

@@ -0,0 +1,146 @@
import asyncio
import copy
import re
from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.variable_pool import VariablePool
class IterationRuntime:
"""
Runtime executor for loop/iteration nodes in a workflow.
This class handles executing iterations over a list variable, supporting
optional parallel execution, flattening of output, and loop control via
the workflow state.
"""
def __init__(
self,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
):
"""
Initialize the iteration runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing iteration node configuration.
state: Current workflow state at the point of iteration.
"""
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.output_value = None
self.result: list = []
def _init_iteration_state(self, item, idx):
"""
Initialize a per-iteration copy of the workflow state.
Args:
item: Current element from the input array for this iteration.
idx: Index of the element in the input array.
Returns:
A deep copy of the workflow state with iteration-specific variables set.
"""
loopstate = WorkflowState(
**copy.deepcopy(self.state)
)
loopstate["runtime_vars"][self.node_id] = {
"item": item,
"index": idx,
}
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
}
loopstate["looping"] = True
return loopstate
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
Args:
item: The input element for this iteration.
idx: The index of this iteration.
"""
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if not result["looping"]:
self.looping = False
def _create_iteration_tasks(self, array_obj, idx):
"""
Create async tasks for a batch of iterations based on parallel count.
Args:
array_obj: The input array to iterate over.
idx: Starting index for this batch of iterations.
Returns:
List of coroutine tasks ready to be executed in parallel.
"""
tasks = []
for i in range(self.typed_config.parallel_count):
if idx + i >= len(array_obj):
break
item = array_obj[idx + i]
tasks.append(self.run_task(item, idx + i))
return tasks
async def run(self):
"""
Execute the loop over the input array according to configuration.
Returns:
A list of outputs from all iterations, optionally flattened.
Raises:
RuntimeError: If the input variable is not a list.
"""
pattern = r"\{\{\s*(.*?)\s*\}\}"
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
array_obj = VariablePool(self.state).get(input_expression)
if not isinstance(array_obj, list):
raise RuntimeError("Cannot iterate over a non-list variable")
idx = 0
if self.typed_config.parallel:
# Execute iterations in parallel batches
while idx < len(array_obj) and self.looping:
tasks = self._create_iteration_tasks(array_obj, idx)
idx += self.typed_config.parallel_count
await asyncio.gather(*tasks)
return self.result
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if not result["looping"]:
self.looping = False
idx += 1
return self.result

View File

@@ -0,0 +1,124 @@
from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.variable_pool import VariablePool
class LoopRuntime:
"""
Runtime executor for loop nodes in a workflow.
Handles iterative execution of a loop node according to defined loop variables
and conditional expressions. Supports maximum loop count and loop control
through the workflow state.
"""
def __init__(
self,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
):
"""
Initialize the loop runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing loop node configuration.
state: Current workflow state at the point of loop execution.
"""
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = LoopNodeConfig(**config)
def _init_loop_state(self):
"""
Initialize workflow state for loop execution.
- Evaluates initial values of loop variables.
- Stores loop variables in runtime_vars and node_outputs.
- Marks the loop as active by setting 'looping' to True.
Returns:
A copy of the workflow state prepared for the loop execution.
"""
pool = VariablePool(self.state)
# 循环变量
self.state["runtime_vars"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
for variable in self.typed_config.cycle_vars
}
loopstate = WorkflowState(
**self.state
)
loopstate["looping"] = True
return loopstate
def _get_loop_expression(self):
"""
Build the Python boolean expression for evaluating the loop condition.
- Converts each condition in the loop configuration into a Python expression string.
- Combines multiple conditions with the configured logical operator (AND/OR).
Returns:
A string representing the combined loop condition expression.
"""
branch_conditions = [
ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
for condition in self.typed_config.condition.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
return combined_condition
async def run(self):
"""
Execute the loop node until the condition is no longer met, the loop is
manually stopped, or the maximum loop count is reached.
Returns:
The final runtime variables of this loop node after completion.
"""
loopstate = self._init_loop_state()
expression = self._get_loop_expression()
loop_variable_pool = VariablePool(loopstate)
loop_time = self.typed_config.max_loop
while evaluate_condition(
expression=expression,
variables=loop_variable_pool.get_all_conversation_vars(),
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
await self.graph.ainvoke(loopstate)
loop_time -= 1
return loopstate["runtime_vars"][self.node_id]

View File

@@ -0,0 +1,223 @@
import logging
from typing import Any
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
class CycleGraphNode(BaseNode):
"""
Node representing a cycle (loop) subgraph within the workflow.
This node manages internal loop/iteration nodes, builds a subgraph
for execution, handles conditional routing, and executes loop
or iteration logic based on node type.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle
self.end_node_ids = [] # IDs of end nodes within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None
self.build_graph()
self.iteration_flag = True
def pure_cycle_graph(self) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
"""
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
# Select all nodes that belong to the current cycle
cycle_nodes = [node for node in nodes if node.get("cycle") == self.node_id]
cycle_node_ids = {node.get("id") for node in cycle_nodes}
cycle_edges = []
remain_edges = []
for edge in edges:
source_in = edge.get("source") in cycle_node_ids
target_in = edge.get("target") in cycle_node_ids
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}")
if source_in and target_in:
cycle_edges.append(edge)
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id
]
self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
def create_node(self):
"""
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
Special handling is applied for conditional nodes to generate
edge conditions based on node outputs.
"""
from app.core.workflow.nodes import NodeFactory
for node in self.cycle_nodes:
node_type = node.get("type")
node_id = node.get("id")
if node_type == NodeType.CYCLE_START:
self.start_node_id = node_id
continue
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
def create_edge(self):
"""
Connect nodes within the cycle subgraph by adding edges to the internal graph.
Conditional edges are routed based on evaluated expressions.
Start and end nodes are connected to global START and END nodes.
"""
for edge in self.cycle_edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(START, target)
logger.debug(f"添加边: {source} -> {target}")
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
self.graph.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
def build_graph(self):
"""
Build the internal subgraph for the cycle node.
Steps:
1. Extract cycle nodes and edges.
2. Create node instances and add them to the graph.
3. Connect edges and conditional routes.
4. Compile the graph for execution.
"""
self.graph = StateGraph(WorkflowState)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.create_node()
self.create_edge()
self.graph = self.graph.compile()
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the cycle node at runtime.
Depending on the node type, runs either a loop (LoopRuntime)
or an iteration (IterationRuntime) over the internal subgraph.
Args:
state: Current workflow state.
Returns:
Runtime result of the cycle, typically the final loop/iteration variables.
Raises:
RuntimeError: If node type is unrecognized.
"""
if self.node_type == NodeType.LOOP:
return await LoopRuntime(
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
).run()
raise RuntimeError("未知循环节点类型")

View File

@@ -1,14 +1,5 @@
from enum import StrEnum
from app.core.workflow.nodes.operators import (
StringOperator,
NumberOperator,
AssignmentOperatorType,
BooleanOperator,
ArrayOperator,
ObjectOperator
)
class NodeType(StrEnum):
START = "start"
@@ -27,6 +18,10 @@ class NodeType(StrEnum):
JINJARENDER = "jinja-render"
VAR_AGGREGATOR = "var-aggregator"
PARAMETER_EXTRACTOR = "parameter-extractor"
LOOP = "loop"
ITERATION = "iteration"
CYCLE_START = "cycle-start"
BREAK = "break"
class ComparisonOperator(StrEnum):
@@ -62,21 +57,6 @@ class AssignmentOperator(StrEnum):
REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first"
@classmethod
def get_operator(cls, obj) -> AssignmentOperatorType:
if isinstance(obj, str):
return StringOperator
elif isinstance(obj, bool):
return BooleanOperator
elif isinstance(obj, (int, float)):
return NumberOperator
elif isinstance(obj, list):
return ArrayOperator
elif isinstance(obj, dict):
return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})")
class HttpRequestMethod(StrEnum):
GET = "GET"

View File

@@ -30,7 +30,7 @@ class ConditionBranchConfig(BaseModel):
description="Logical operator used to combine multiple condition expressions"
)
conditions: list[ConditionDetail] = Field(
expressions: list[ConditionDetail] = Field(
...,
description="List of condition expressions within this branch"
)
@@ -57,7 +57,7 @@ class IfElseNodeConfig(BaseNodeConfig):
# CASE1 / IF Branch
{
"logical_operator": "and",
"conditions": [
"expressions": [
[
{
"left": "node.userinput.message",
@@ -75,7 +75,7 @@ class IfElseNodeConfig(BaseNodeConfig):
# CASE1 / ELIF Branch
{
"logical_operator": "or",
"conditions": [
"expressions": [
[
{
"left": "node.userinput.test",

View File

@@ -2,93 +2,13 @@ import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
logger = logging.getLogger(__name__)
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startwith(self):
return f'{self.left}.startswith({self.right})'
def _endwith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startwith()
case ComparisonOperator.END_WITH:
return self._endwith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
@@ -143,7 +63,7 @@ class IfElseNode(BaseNode):
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.conditions
for condition in case_branch.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)

View File

@@ -1,18 +1,13 @@
from uuid import UUID
from pydantic import Field
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.schemas.chunk_schema import RetrieveType
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
kb_ids: list[UUID] = Field(
class KnowledgeBaseConfig(BaseModel):
kb_id: UUID = Field(
...,
description="Knowledge base IDs"
)
@@ -37,18 +32,42 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
description="Retrieve type"
)
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
knowledge_bases: list[KnowledgeBaseConfig] = Field(
...,
description="Knowledge base config"
)
reranker_id: UUID = Field(
...,
description="Reranker top k"
)
reranker_top_k: int = Field(
default=4,
description="Knowledge base top k"
)
class Config:
json_schema_extra = {
"examples": [
{
"query": "{{sys.message}}",
"kb_ids": [
"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
],
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.3,
"top_k": 1,
"retrieve_type": "hybrid"
"knowledge_bases": [{
"kb_id": "xxxxxxxx-xxxx-xxxx-xxxxxxxxxxxxxxxxx",
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.3,
"top_k": 4,
"retrieve_type": "hybrid"
}],
"reranker_top_k": 1,
"reranker_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
}
]
}

View File

@@ -2,14 +2,18 @@ import logging
import uuid
from typing import Any
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services import knowledge_service, knowledgeshare_service
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
@@ -108,6 +112,44 @@ class KnowledgeRetrievalNode(BaseNode):
existing_ids.extend(items)
return existing_ids
def get_reranker_model(self) -> RedBearRerank:
"""
Retrieve and initialize a RedBear reranker model based on configuration.
Raises:
BusinessException: If configuration is missing or API keys are not set.
RuntimeError: If the configured model is not of type RERANK.
"""
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.reranker_id)
if not config:
raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
if model_type != ModelType.RERANK:
raise RuntimeError("Model is not a reranker")
reranker = RedBearRerank(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
)
)
return reranker
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the knowledge retrieval workflow node.
@@ -131,38 +173,41 @@ class KnowledgeRetrievalNode(BaseNode):
"""
query = self._render_template(self.typed_config.query, state)
with get_db_read() as db:
existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids)
knowledge_bases = self.typed_config.knowledge_bases
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
rs = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match self.typed_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
case RetrieveType.SEMANTIC:
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]
score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.HYBRID:
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _:
raise RuntimeError("Unknown retrieval type")
final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k)
return [chunk.model_dump() for chunk in final_rs]

View File

@@ -10,6 +10,7 @@ from typing import Any, Union
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.http_request import HttpRequestNode
@@ -22,6 +23,7 @@ from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
logger = logging.getLogger(__name__)
@@ -39,6 +41,9 @@ WorkflowNode = Union[
JinjaRenderNode,
VariableAggregatorNode,
ParameterExtractorNode,
CycleGraphNode,
BreakNode,
ParameterExtractorNode,
QuestionClassifierNode
]
@@ -64,6 +69,9 @@ class NodeFactory:
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
}
@classmethod

View File

@@ -1,6 +1,7 @@
from abc import ABC
from typing import Union, Type
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.variable_pool import VariablePool
@@ -136,6 +137,23 @@ class ObjectOperator(OperatorBase):
self.pool.set(self.left_selector, dict())
class AssignmentOperatorResolver:
@classmethod
def resolve_by_value(cls, value):
if isinstance(value, str):
return StringOperator
elif isinstance(value, bool):
return BooleanOperator
elif isinstance(value, (int, float)):
return NumberOperator
elif isinstance(value, list):
return ArrayOperator
elif isinstance(value, dict):
return ObjectOperator
else:
raise TypeError(f"Unsupported variable type: {type(value)}")
AssignmentOperatorInstance = Union[
StringOperator,
NumberOperator,
@@ -144,3 +162,83 @@ AssignmentOperatorInstance = Union[
ObjectOperator
]
AssignmentOperatorType = Type[AssignmentOperatorInstance]
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startswith(self):
return f'{self.left}.startswith({self.right})'
def _endswith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startswith()
case ComparisonOperator.END_WITH:
return self._endswith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")

View File

@@ -36,6 +36,11 @@ class ParamsConfig(BaseModel):
description="Description of the parameter"
)
required: bool = Field(
...,
description="Whether the parameter is required"
)
class ParameterExtractorNodeConfig(BaseNodeConfig):
model_id: uuid.UUID = Field(
@@ -52,3 +57,8 @@ class ParameterExtractorNodeConfig(BaseNodeConfig):
...,
description="List of parameters"
)
prompt: str = Field(
...,
description="User-provided supplemental prompt"
)

View File

@@ -1,4 +1,5 @@
import os
import logging
import json_repair
from typing import Any
@@ -15,6 +16,8 @@ from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -114,7 +117,7 @@ class ParameterExtractorNode(BaseNode):
"""
field_type = {}
for param in self.typed_config.params:
field_type[param.name] = param.type
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
return field_type
async def execute(self, state: WorkflowState) -> Any:
@@ -154,12 +157,12 @@ class ParameterExtractorNode(BaseNode):
messages = [
("system", system_prompt),
("user", self._render_template(self.typed_config.prompt, state)),
("user", rendered_user_prompt),
]
model_resp = await llm.ainvoke(messages)
result = json_repair.repair_json(model_resp.content)
result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"get prarms:{result}")
return {
"output": result,
}
return result

View File

@@ -9,43 +9,27 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
description="输出变量是否需要分组",
)
group_names: list[str] = Field(
default_factory=lambda: ["output"],
description="各个分组的名称"
)
group_variables: list[str] | list[list[str]] = Field(
group_variables: list[str] | dict[str, list[str]] = Field(
...,
description="需要被聚合的变量"
)
@field_validator("group_names", mode="before")
@classmethod
def group_names_validator(cls, v, info):
group_status = info.data.get("group")
if not group_status or not v:
return ["output"]
return v
@field_validator("group_variables")
@classmethod
def group_variables_validator(cls, v, info):
group_status = info.data.get("group")
group_names = info.data.get("group_names")
if not isinstance(v, list):
raise ValueError("group_variables must be a list")
if not group_status:
for variable in v:
if not isinstance(variable, str):
raise ValueError("When group=False, group_variables must be a list of strings")
else:
if len(group_names) != len(v):
raise ValueError("group_names and group_variables length mismatch")
for group in v:
if not isinstance(group, list):
if not isinstance(v, dict):
raise ValueError("When group=True, group_variables must be a dict")
for group_name, group_values in v.items():
if not isinstance(group_name, str):
raise ValueError("When group=True, each element of group_variables must be a list")
for variable in group:
for variable in group_values:
if not isinstance(variable, str):
raise ValueError("Each element inside group_variables lists must be a string")
return v

View File

@@ -59,7 +59,7 @@ class VariableAggregatorNode(BaseNode):
# Group mode
# --------------------------
result = {}
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables):
for group_name, variables in self.typed_config.group_variables.items():
for variable in variables:
var_express = self._get_express(variable)
try:

View File

@@ -198,19 +198,22 @@ class VariablePool:
namespace = selector[0]
if namespace != "conv":
raise ValueError("只能设置会话变量 (conv.*)")
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
if namespace == "conv":
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
elif namespace in self.state["cycle_nodes"]:
self.state["runtime_vars"][namespace][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")

View File

@@ -20,6 +20,7 @@ class NodeDefinition(BaseModel):
id: str = Field(..., description="节点唯一标识")
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
name: str | None = Field(None, description="节点名称")
cycle: str | None = Field(None, description="父循环节点id")
description: str | None = Field(None, description="节点描述")
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")