Merge #80 into develop from feature/20251219_myh

feat(workflow): support cycle nodes in workflow config validation and enhance node logging

* feature/20251219_myh: (11 commits squashed)

  - feat(workflow): update reranker model configuration for knowledge base retrieval

  - fix(workflow): fix output issue in parameter extraction node

  - fix(workflow): fix output issue in parameter extraction node

  - feat(workflow): add user prompt to parameter extraction node

  - perf(workflow): change grouped variable input to key-value format in variable aggregator

  - feat(workflow): Add new cycle node for iterative workflow execution
    
    - Introduce a new Loop/Iteration node in the workflow engine.
    - Supports both conditional loops and iteration over lists.
    - Allows parallel execution and flattening of iteration outputs.
    - Maintains runtime state, node outputs, and loop variables for downstream nodes.
    - Enhances workflow flexibility for complex, repeated operations.

  - Merge branch 'develop' into feature/20251219_myh
    
    # Conflicts:
    #	api/app/core/workflow/nodes/configs.py
    #	api/app/core/workflow/nodes/node_factory.py

  - feat(workflow): Add new cycle node for iterative workflow execution
    
    - Introduce a new Loop/Iteration node in the workflow engine.
    - Supports both conditional loops and iteration over lists.
    - Allows parallel execution and flattening of iteration outputs.
    - Maintains runtime state, node outputs, and loop variables for downstream nodes.
    - Enhances workflow flexibility for complex, repeated operations.

  - feat(workflow): support cycle nodes in workflow config validation and enhance node logging

  - feat(workflow): support cycle nodes in workflow config validation and enhance node logging

  - fix(workflow): fix compatibility with some legacy node configurations

Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/80
This commit is contained in:
孟永豪
2025-12-30 03:40:38 +00:00
committed by 孙科
parent b376c3d648
commit 9bedcadca4
13 changed files with 221 additions and 114 deletions

View File

@@ -83,3 +83,5 @@ class AssignerNode(BaseNode):
operator.remove_last() operator.remove_last()
case _: case _:
raise ValueError(f"Invalid Operator: {assignment.operation}") raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -78,6 +78,7 @@ class BaseNode(ABC):
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.node_id = node_config["id"] self.node_id = node_config["id"]
self.node_type = node_config["type"] self.node_type = node_config["type"]
self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id) self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}

View File

@@ -29,4 +29,5 @@ class BreakNode(BaseNode):
Optional dictionary indicating the loop has been stopped. Optional dictionary indicating the loop has been stopped.
""" """
state["looping"] = False state["looping"] = False
logger.info(f"run break node, looping={state['looping']}") logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import copy import copy
import logging
import re import re
from typing import Any from typing import Any
@@ -9,6 +10,8 @@ from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class IterationRuntime: class IterationRuntime:
""" """
@@ -127,12 +130,15 @@ class IterationRuntime:
# Execute iterations in parallel batches # Execute iterations in parallel batches
while idx < len(array_obj) and self.looping: while idx < len(array_obj) and self.looping:
tasks = self._create_iteration_tasks(array_obj, idx) tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count idx += self.typed_config.parallel_count
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result return self.result
else: else:
# Execute iterations sequentially # Execute iterations sequentially
while idx < len(array_obj) and self.looping: while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx] item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value) output = VariablePool(result).get(self.output_value)
@@ -143,4 +149,6 @@ class IterationRuntime:
if not result["looping"]: if not result["looping"]:
self.looping = False self.looping = False
idx += 1 idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result return self.result

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any from typing import Any
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
@@ -8,6 +9,8 @@ from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class LoopRuntime: class LoopRuntime:
""" """
@@ -119,6 +122,9 @@ class LoopRuntime:
node_outputs=loop_variable_pool.get_all_node_outputs(), node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(), system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0: ) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate) await self.graph.ainvoke(loopstate)
loop_time -= 1 loop_time -= 1
logger.info(f"loop node {self.node_id}: execution completed")
return loopstate["runtime_vars"][self.node_id] return loopstate["runtime_vars"][self.node_id]

View File

@@ -65,7 +65,10 @@ class CycleGraphNode(BaseNode):
# Raise error if cycle nodes are connected with external nodes # Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in: if source_in ^ target_in:
raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}") raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in: if source_in and target_in:
cycle_edges.append(edge) cycle_edges.append(edge)
@@ -220,4 +223,4 @@ class CycleGraphNode(BaseNode):
config=self.config, config=self.config,
state=state, state=state,
).run() ).run()
raise RuntimeError("未知循环节点类型") raise RuntimeError("Unknown cycle node type")

View File

@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
**self._build_content(state) **self._build_content(state)
) )
resp.raise_for_status() resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")
return HttpRequestNodeOutput( return HttpRequestNodeOutput(
body=resp.text, body=resp.text,
status_code=resp.status_code, status_code=resp.status_code,
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
else: else:
match self.typed_config.error_handle.method: match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE: case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput( return HttpRequestNodeOutput(
body="", body="",
status_code=resp.status_code, status_code=resp.status_code,
headers=resp.headers, headers=resp.headers,
).model_dump() ).model_dump()
case HttpErrorHandle.DEFAULT: case HttpErrorHandle.DEFAULT:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result"
)
return self.typed_config.error_handle.default.model_dump() return self.typed_config.error_handle.default.model_dump()
case HttpErrorHandle.BRANCH: case HttpErrorHandle.BRANCH:
logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR" return "ERROR"

View File

@@ -94,5 +94,6 @@ class IfElseNode(BaseNode):
for i in range(len(expressions)): for i in range(len(expressions)):
logger.info(expressions[i]) logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state): if self._evaluate_condition(expressions[i], state):
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}' return f'CASE{i + 1}'
return f'CASE{len(expressions)}' return f'CASE{len(expressions)}'

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any from typing import Any
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.template_renderer import TemplateRenderer from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
res = render.env.from_string(self.typed_config.template).render(**context) res = render.env.from_string(self.typed_config.template).render(**context)
except Exception as e: except Exception as e:
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
return res return res

View File

@@ -190,12 +190,12 @@ class KnowledgeRetrievalNode(BaseNode):
match kb_config.retrieve_type: match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE: case RetrieveType.PARTICIPLE:
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
score_threshold=kb_config.similarity_threshold)) score_threshold=kb_config.similarity_threshold))
case RetrieveType.SEMANTIC: case RetrieveType.SEMANTIC:
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
score_threshold=kb_config.vector_similarity_weight)) score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.HYBRID: case RetrieveType.HYBRID:
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
@@ -209,5 +209,9 @@ class KnowledgeRetrievalNode(BaseNode):
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k) vector_service.reranker = self.get_reranker_model()
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
return [chunk.model_dump() for chunk in final_rs] return [chunk.model_dump() for chunk in final_rs]

View File

@@ -163,6 +163,6 @@ class ParameterExtractorNode(BaseNode):
model_resp = await llm.ainvoke(messages) model_resp = await llm.ainvoke(messages)
result = json_repair.repair_json(model_resp.content, return_objects=True) result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"get prarms:{result}") logger.info(f"node: {self.node_id} get params:{result}")
return result return result

View File

@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
continue continue
if value is not None: if value is not None:
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
return value return value
logger.info("No variable found in non-group mode; returning empty string.") logger.info("No variable found in non-group mode; returning empty string.")
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
else: else:
result[group_name] = "" result[group_name] = ""
logger.info(f"No variable found for group '{group_name}'; set empty string.") logger.info(f"No variable found for group '{group_name}'; set empty string.")
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
return result return result

View File

@@ -7,14 +7,87 @@
import logging import logging
from typing import Any, Union from typing import Any, Union
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowValidator: class WorkflowValidator:
"""工作流配置验证器""" """工作流配置验证器"""
@staticmethod @classmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
"""
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
# Select all nodes that belong to the current cycle
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
cycle_node_ids = {node.get("id") for node in cycle_nodes}
cycle_edges = []
remain_edges = []
for edge in edges:
source_in = edge.get("source") in cycle_node_ids
target_in = edge.get("target") in cycle_node_ids
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
cycle_edges.append(edge)
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != node_id
]
workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
@classmethod
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
if not isinstance(workflow_config, dict):
workflow_config = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
}
cycle_nodes = [
node.get("id")
for node in workflow_config.get("nodes", [])
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
]
graphs = []
for cycle_node in cycle_nodes:
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
graphs.append({
"nodes": nodes,
"edges": edges,
})
graphs.append(workflow_config)
return graphs
@classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置 """验证工作流配置
Args: Args:
@@ -39,80 +112,75 @@ class WorkflowValidator:
""" """
errors = [] errors = []
# 支持字典和 Pydantic 模型 graphs = cls.get_subgraph(workflow_config)
if isinstance(workflow_config, dict): logger.info(graphs)
nodes = workflow_config.get("nodes", []) for graph in graphs:
edges = workflow_config.get("edges", []) nodes = graph.get("nodes", [])
variables = workflow_config.get("variables", []) edges = graph.get("edges", [])
else: variables = graph.get("variables", [])
# Pydantic 模型 # 1. 验证 start 节点(有且只有一个)
nodes = getattr(workflow_config, "nodes", []) start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
edges = getattr(workflow_config, "edges", []) if len(start_nodes) == 0:
variables = getattr(workflow_config, "variables", []) errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 1. 验证 start 节点(有且只有一个) # 2. 验证 end 节点(至少一个)
start_nodes = [n for n in nodes if n.get("type") == "start"] end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
if len(start_nodes) == 0: if len(end_nodes) == 0:
errors.append("工作流必须有一个 start 节点") errors.append("工作流必须至少有一个 end 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个) # 3. 验证节点 ID 唯一性
end_nodes = [n for n in nodes if n.get("type") == "end"] node_ids = [n.get("id") for n in nodes]
if len(end_nodes) == 0: if len(node_ids) != len(set(node_ids)):
errors.append("工作流必须至少有一个 end 节点") duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 3. 验证节点 ID 唯一性 # 4. 验证节点必须有 id 和 type
node_ids = [n.get("id") for n in nodes] for i, node in enumerate(nodes):
if len(node_ids) != len(set(node_ids)): if not node.get("id"):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] errors.append(f"节点 #{i} 缺少 id 字段")
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 4. 验证节点必须有 id 和 type # 5. 验证边的有效性
for i, node in enumerate(nodes): node_id_set = set(node_ids)
if not node.get("id"): for i, edge in enumerate(edges):
errors.append(f"节点 #{i} 缺少 id 字段") source = edge.get("source")
if not node.get("type"): target = edge.get("target")
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性 if not source:
node_id_set = set(node_ids) errors.append(f"边 #{i} 缺少 source 字段")
for i, edge in enumerate(edges): elif source not in node_id_set:
source = edge.get("source") errors.append(f"边 #{i} 的 source 节点不存在: {source}")
target = edge.get("target")
if not source: if not target:
errors.append(f"边 #{i} 缺少 source 字段") errors.append(f"边 #{i} 缺少 target 字段")
elif source not in node_id_set: elif target not in node_id_set:
errors.append(f"边 #{i}source 节点不存在: {source}") errors.append(f"边 #{i}target 节点不存在: {target}")
if not target: # 6. 验证所有节点可达(从 start 节点出发)
errors.append(f"边 #{i} 缺少 target 字段") if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
elif target not in node_id_set: reachable = WorkflowValidator._get_reachable_nodes(
errors.append(f"边 #{i} 的 target 节点不存在: {target}") start_nodes[0]["id"],
edges
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
) )
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 8. 验证变量名 # 7. 检测循环依赖(非 loop 节点)
from app.core.workflow.expression_evaluator import ExpressionEvaluator if not errors: # 只有在前面验证通过时才检查循环
var_errors = ExpressionEvaluator.validate_variable_names(variables) has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
errors.extend(var_errors) if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors return len(errors) == 0, errors
@@ -230,7 +298,7 @@ class WorkflowValidator:
# 1. 验证所有节点都有名称 # 1. 验证所有节点都有名称
for node in nodes: for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"): if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
errors.append( errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)" f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
) )
@@ -238,7 +306,7 @@ class WorkflowValidator:
# 2. 验证所有非 start/end 节点都有配置 # 2. 验证所有非 start/end 节点都有配置
for node in nodes: for node in nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type not in ["start", "end"]: if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
config = node.get("config") config = node.get("config")
if not config or not isinstance(config, dict): if not config or not isinstance(config, dict):
errors.append( errors.append(
@@ -259,8 +327,8 @@ class WorkflowValidator:
def validate_workflow_config( def validate_workflow_config(
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
for_publish: bool = False for_publish: bool = False
) -> tuple[bool, list[str]]: ) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数) """验证工作流配置(便捷函数)