Merge #80 into develop from feature/20251219_myh

feat(workflow): support cycle nodes in workflow config validation and enhance node logging

* feature/20251219_myh: (11 commits squashed)

  - feat(workflow): update reranker model configuration for knowledge base retrieval

  - fix(workflow): fix output issue in parameter extraction node

  - fix(workflow): fix output issue in parameter extraction node

  - feat(workflow): add user prompt to parameter extraction node

  - perf(workflow): change grouped variable input to key-value format in variable aggregator

  - feat(workflow): Add new cycle node for iterative workflow execution
    
    - Introduce a new Loop/Iteration node in the workflow engine.
    - Supports both conditional loops and iteration over lists.
    - Allows parallel execution and flattening of iteration outputs.
    - Maintains runtime state, node outputs, and loop variables for downstream nodes.
    - Enhances workflow flexibility for complex, repeated operations.

  - Merge branch 'develop' into feature/20251219_myh
    
    # Conflicts:
    #	api/app/core/workflow/nodes/configs.py
    #	api/app/core/workflow/nodes/node_factory.py

  - feat(workflow): Add new cycle node for iterative workflow execution
    
    - Introduce a new Loop/Iteration node in the workflow engine.
    - Supports both conditional loops and iteration over lists.
    - Allows parallel execution and flattening of iteration outputs.
    - Maintains runtime state, node outputs, and loop variables for downstream nodes.
    - Enhances workflow flexibility for complex, repeated operations.

  - feat(workflow): support cycle nodes in workflow config validation and enhance node logging

  - feat(workflow): support cycle nodes in workflow config validation and enhance node logging

  - fix(workflow): fix compatibility with some legacy node configurations

Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/80
This commit is contained in:
孟永豪
2025-12-30 03:40:38 +00:00
committed by 孙科
parent b376c3d648
commit 9bedcadca4
13 changed files with 221 additions and 114 deletions

View File

@@ -83,3 +83,5 @@ class AssignerNode(BaseNode):
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -78,6 +78,7 @@ class BaseNode(ABC):
self.workflow_config = workflow_config
self.node_id = node_config["id"]
self.node_type = node_config["type"]
self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}

View File

@@ -29,4 +29,5 @@ class BreakNode(BaseNode):
Optional dictionary indicating the loop has been stopped.
"""
state["looping"] = False
logger.info(f"run break node, looping={state['looping']}")
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")

View File

@@ -1,5 +1,6 @@
import asyncio
import copy
import logging
import re
from typing import Any
@@ -9,6 +10,8 @@ from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class IterationRuntime:
"""
@@ -127,12 +130,15 @@ class IterationRuntime:
# Execute iterations in parallel batches
while idx < len(array_obj) and self.looping:
tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
await asyncio.gather(*tasks)
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
@@ -143,4 +149,6 @@ class IterationRuntime:
if not result["looping"]:
self.looping = False
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from langgraph.graph.state import CompiledStateGraph
@@ -8,6 +9,8 @@ from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class LoopRuntime:
"""
@@ -119,6 +122,9 @@ class LoopRuntime:
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate)
loop_time -= 1
logger.info(f"loop node {self.node_id}: execution completed")
return loopstate["runtime_vars"][self.node_id]

View File

@@ -65,7 +65,10 @@ class CycleGraphNode(BaseNode):
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}")
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
cycle_edges.append(edge)
@@ -220,4 +223,4 @@ class CycleGraphNode(BaseNode):
config=self.config,
state=state,
).run()
raise RuntimeError("未知循环节点类型")
raise RuntimeError("Unknown cycle node type")

View File

@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
**self._build_content(state)
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")
return HttpRequestNodeOutput(
body=resp.text,
status_code=resp.status_code,
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
else:
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result"
)
return self.typed_config.error_handle.default.model_dump()
case HttpErrorHandle.BRANCH:
logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR"

View File

@@ -94,5 +94,6 @@ class IfElseNode(BaseNode):
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}'
return f'CASE{len(expressions)}'

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from app.core.workflow.nodes import WorkflowState
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
res = render.env.from_string(self.typed_config.template).render(**context)
except Exception as e:
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
return res

View File

@@ -190,12 +190,12 @@ class KnowledgeRetrievalNode(BaseNode):
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
indices=indices,
score_threshold=kb_config.similarity_threshold))
case RetrieveType.SEMANTIC:
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.HYBRID:
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
@@ -209,5 +209,9 @@ class KnowledgeRetrievalNode(BaseNode):
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _:
raise RuntimeError("Unknown retrieval type")
final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k)
vector_service.reranker = self.get_reranker_model()
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
return [chunk.model_dump() for chunk in final_rs]

View File

@@ -163,6 +163,6 @@ class ParameterExtractorNode(BaseNode):
model_resp = await llm.ainvoke(messages)
result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"get prarms:{result}")
logger.info(f"node: {self.node_id} get params:{result}")
return result

View File

@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
continue
if value is not None:
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
return value
logger.info("No variable found in non-group mode; returning empty string.")
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
else:
result[group_name] = ""
logger.info(f"No variable found for group '{group_name}'; set empty string.")
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
return result

View File

@@ -7,14 +7,87 @@
import logging
from typing import Any, Union
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流配置验证器"""
@staticmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
@classmethod
def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
"""
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
# Select all nodes that belong to the current cycle
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
cycle_node_ids = {node.get("id") for node in cycle_nodes}
cycle_edges = []
remain_edges = []
for edge in edges:
source_in = edge.get("source") in cycle_node_ids
target_in = edge.get("target") in cycle_node_ids
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
cycle_edges.append(edge)
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != node_id
]
workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
@classmethod
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
if not isinstance(workflow_config, dict):
workflow_config = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
}
cycle_nodes = [
node.get("id")
for node in workflow_config.get("nodes", [])
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
]
graphs = []
for cycle_node in cycle_nodes:
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
graphs.append({
"nodes": nodes,
"edges": edges,
})
graphs.append(workflow_config)
return graphs
@classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
@@ -38,84 +111,79 @@ class WorkflowValidator:
True
"""
errors = []
# 支持字典和 Pydantic 模型
if isinstance(workflow_config, dict):
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
variables = workflow_config.get("variables", [])
else:
# Pydantic 模型
nodes = getattr(workflow_config, "nodes", [])
edges = getattr(workflow_config, "edges", [])
variables = getattr(workflow_config, "variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == "end"]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i}source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
graphs = cls.get_subgraph(workflow_config)
logger.info(graphs)
for graph in graphs:
nodes = graph.get("nodes", [])
edges = graph.get("edges", [])
variables = graph.get("variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i}target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
@@ -129,7 +197,7 @@ class WorkflowValidator:
"""
reachable = {start_id}
queue = [start_id]
while queue:
current = queue.pop(0)
for edge in edges:
@@ -138,9 +206,9 @@ class WorkflowValidator:
if target and target not in reachable:
reachable.add(target)
queue.append(target)
return reachable
@staticmethod
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
"""检测是否存在循环依赖DFS
@@ -154,39 +222,39 @@ class WorkflowValidator:
"""
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
# 跳过错误边
if edge_type == "error":
continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target:
if source not in graph:
graph[source] = []
graph[source].append(target)
# DFS 检测环
visited = set()
rec_stack = set()
path = []
cycle_path = []
def dfs(node: str) -> bool:
"""DFS 检测环,返回是否找到环"""
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
@@ -196,19 +264,19 @@ class WorkflowValidator:
cycle_start = path.index(neighbor)
cycle_path.extend([*path[cycle_start:], neighbor])
return True
rec_stack.remove(node)
path.pop()
return False
# 检查所有节点
for node_id in graph:
if node_id not in visited:
if dfs(node_id):
return True, cycle_path
return False, []
@staticmethod
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布(更严格的验证)
@@ -221,30 +289,30 @@ class WorkflowValidator:
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
if not is_valid:
return False, errors
# 额外的发布验证
nodes = workflow_config.get("nodes", [])
# 1. 验证所有节点都有名称
for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"):
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
for node in nodes:
node_type = node.get("type")
if node_type not in ["start", "end"]:
if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量
variables = workflow_config.get("variables", [])
required_vars = [v for v in variables if v.get("required")]
@@ -254,13 +322,13 @@ class WorkflowValidator:
f"工作流包含 {len(required_vars)} 个必填变量: "
f"{[v.get('name') for v in required_vars]}"
)
return len(errors) == 0, errors
def validate_workflow_config(
workflow_config: dict[str, Any],
for_publish: bool = False
workflow_config: dict[str, Any],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)