Merge #80 into develop from feature/20251219_myh
feat(workflow): support cycle nodes in workflow config validation and enhance node logging
* feature/20251219_myh: (11 commits squashed)
- feat(workflow): update reranker model configuration for knowledge base retrieval
- fix(workflow): fix output issue in parameter extraction node
- fix(workflow): fix output issue in parameter extraction node
- feat(workflow): add user prompt to parameter extraction node
- perf(workflow): change grouped variable input to key-value format in variable aggregator
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
- Merge branch 'develop' into feature/20251219_myh
# Conflicts:
#	api/app/core/workflow/nodes/configs.py
#	api/app/core/workflow/nodes/node_factory.py
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
- feat(workflow): support cycle nodes in workflow config validation and enhance node logging
- feat(workflow): support cycle nodes in workflow config validation and enhance node logging
- fix(workflow): fix compatibility with some legacy node configurations
Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/80
This commit is contained in:
@@ -83,3 +83,5 @@ class AssignerNode(BaseNode):
|
||||
operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ class BaseNode(ABC):
|
||||
self.workflow_config = workflow_config
|
||||
self.node_id = node_config["id"]
|
||||
self.node_type = node_config["type"]
|
||||
self.cycle = node_config.get("cycle")
|
||||
self.node_name = node_config.get("name", self.node_id)
|
||||
# 使用 or 运算符处理 None 值
|
||||
self.config = node_config.get("config") or {}
|
||||
|
||||
@@ -29,4 +29,5 @@ class BreakNode(BaseNode):
|
||||
Optional dictionary indicating the loop has been stopped.
|
||||
"""
|
||||
state["looping"] = False
|
||||
logger.info(f"run break node, looping={state['looping']}")
|
||||
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
@@ -9,6 +10,8 @@ from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IterationRuntime:
|
||||
"""
|
||||
@@ -127,12 +130,15 @@ class IterationRuntime:
|
||||
# Execute iterations in parallel batches
|
||||
while idx < len(array_obj) and self.looping:
|
||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
@@ -143,4 +149,6 @@ class IterationRuntime:
|
||||
if not result["looping"]:
|
||||
self.looping = False
|
||||
idx += 1
|
||||
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
@@ -8,6 +9,8 @@ from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopRuntime:
|
||||
"""
|
||||
@@ -119,6 +122,9 @@ class LoopRuntime:
|
||||
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
||||
system_vars=loop_variable_pool.get_all_system_vars(),
|
||||
) and loopstate["looping"] and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
await self.graph.ainvoke(loopstate)
|
||||
loop_time -= 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return loopstate["runtime_vars"][self.node_id]
|
||||
|
||||
@@ -65,7 +65,10 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
# Raise error if cycle nodes are connected with external nodes
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}")
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
cycle_edges.append(edge)
|
||||
@@ -220,4 +223,4 @@ class CycleGraphNode(BaseNode):
|
||||
config=self.config,
|
||||
state=state,
|
||||
).run()
|
||||
raise RuntimeError("未知循环节点类型")
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
|
||||
**self._build_content(state)
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
return HttpRequestNodeOutput(
|
||||
body=resp.text,
|
||||
status_code=resp.status_code,
|
||||
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
|
||||
else:
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
||||
)
|
||||
return HttpRequestNodeOutput(
|
||||
body="",
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
).model_dump()
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||
)
|
||||
return self.typed_config.error_handle.default.model_dump()
|
||||
case HttpErrorHandle.BRANCH:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
|
||||
@@ -94,5 +94,6 @@ class IfElseNode(BaseNode):
|
||||
for i in range(len(expressions)):
|
||||
logger.info(expressions[i])
|
||||
if self._evaluate_condition(expressions[i], state):
|
||||
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||
return f'CASE{i + 1}'
|
||||
return f'CASE{len(expressions)}'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
|
||||
res = render.env.from_string(self.typed_config.template).render(**context)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
|
||||
|
||||
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
|
||||
return res
|
||||
|
||||
@@ -209,5 +209,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k)
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.model_dump() for chunk in final_rs]
|
||||
|
||||
@@ -163,6 +163,6 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
logger.info(f"get prarms:{result}")
|
||||
logger.info(f"node: {self.node_id} get params:{result}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
continue
|
||||
|
||||
if value is not None:
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
|
||||
return value
|
||||
|
||||
logger.info("No variable found in non-group mode; returning empty string.")
|
||||
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
|
||||
else:
|
||||
result[group_name] = ""
|
||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
||||
return result
|
||||
|
||||
@@ -7,14 +7,87 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""工作流配置验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
@classmethod
|
||||
def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
|
||||
"""
|
||||
Extract cycle nodes and internal edges from the workflow configuration,
|
||||
removing them from the global workflow.
|
||||
|
||||
Raises:
|
||||
ValueError: If cycle nodes are connected to external nodes improperly.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- cycle_nodes: List of removed nodes
|
||||
- cycle_edges: List of removed edges
|
||||
"""
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
|
||||
# Select all nodes that belong to the current cycle
|
||||
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
|
||||
cycle_node_ids = {node.get("id") for node in cycle_nodes}
|
||||
|
||||
cycle_edges = []
|
||||
remain_edges = []
|
||||
|
||||
for edge in edges:
|
||||
source_in = edge.get("source") in cycle_node_ids
|
||||
target_in = edge.get("target") in cycle_node_ids
|
||||
|
||||
# Raise error if cycle nodes are connected with external nodes
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
cycle_edges.append(edge)
|
||||
else:
|
||||
remain_edges.append(edge)
|
||||
|
||||
# Update workflow_config by removing cycle nodes and internal edges
|
||||
workflow_config["nodes"] = [
|
||||
node for node in nodes if node.get("cycle") != node_id
|
||||
]
|
||||
workflow_config["edges"] = remain_edges
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
@classmethod
|
||||
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
|
||||
if not isinstance(workflow_config, dict):
|
||||
workflow_config = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
}
|
||||
cycle_nodes = [
|
||||
node.get("id")
|
||||
for node in workflow_config.get("nodes", [])
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
]
|
||||
graphs = []
|
||||
for cycle_node in cycle_nodes:
|
||||
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
|
||||
graphs.append({
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
})
|
||||
graphs.append(workflow_config)
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
@@ -39,26 +112,21 @@ class WorkflowValidator:
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 支持字典和 Pydantic 模型
|
||||
if isinstance(workflow_config, dict):
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
variables = workflow_config.get("variables", [])
|
||||
else:
|
||||
# Pydantic 模型
|
||||
nodes = getattr(workflow_config, "nodes", [])
|
||||
edges = getattr(workflow_config, "edges", [])
|
||||
variables = getattr(workflow_config, "variables", [])
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for graph in graphs:
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
variables = graph.get("variables", [])
|
||||
# 1. 验证 start 节点(有且只有一个)
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
|
||||
if len(start_nodes) == 0:
|
||||
errors.append("工作流必须有一个 start 节点")
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
@@ -230,7 +298,7 @@ class WorkflowValidator:
|
||||
|
||||
# 1. 验证所有节点都有名称
|
||||
for node in nodes:
|
||||
if node.get("type") not in ["start", "end"] and not node.get("name"):
|
||||
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
@@ -238,7 +306,7 @@ class WorkflowValidator:
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
for node in nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type not in ["start", "end"]:
|
||||
if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
|
||||
Reference in New Issue
Block a user