Merge #80 into develop from feature/20251219_myh
feat(workflow): support cycle nodes in workflow config validation and enhance node logging
* feature/20251219_myh: (11 commits squashed)
- feat(workflow): update reranker model configuration for knowledge base retrieval
- fix(workflow): fix output issue in parameter extraction node
- fix(workflow): fix output issue in parameter extraction node
- feat(workflow): add user prompt to parameter extraction node
- perf(workflow): change grouped variable input to key-value format in variable aggregator
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
- Merge branch 'develop' into feature/20251219_myh
# Conflicts:
#	api/app/core/workflow/nodes/configs.py
#	api/app/core/workflow/nodes/node_factory.py
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
- feat(workflow): support cycle nodes in workflow config validation and enhance node logging
- feat(workflow): support cycle nodes in workflow config validation and enhance node logging
- fix(workflow): fix compatibility with some legacy node configurations
Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/80
This commit is contained in:
@@ -83,3 +83,5 @@ class AssignerNode(BaseNode):
|
|||||||
operator.remove_last()
|
operator.remove_last()
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||||
|
logger.info(f"Node {self.node_id}: execution completed")
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class BaseNode(ABC):
|
|||||||
self.workflow_config = workflow_config
|
self.workflow_config = workflow_config
|
||||||
self.node_id = node_config["id"]
|
self.node_id = node_config["id"]
|
||||||
self.node_type = node_config["type"]
|
self.node_type = node_config["type"]
|
||||||
|
self.cycle = node_config.get("cycle")
|
||||||
self.node_name = node_config.get("name", self.node_id)
|
self.node_name = node_config.get("name", self.node_id)
|
||||||
# 使用 or 运算符处理 None 值
|
# 使用 or 运算符处理 None 值
|
||||||
self.config = node_config.get("config") or {}
|
self.config = node_config.get("config") or {}
|
||||||
|
|||||||
@@ -29,4 +29,5 @@ class BreakNode(BaseNode):
|
|||||||
Optional dictionary indicating the loop has been stopped.
|
Optional dictionary indicating the loop has been stopped.
|
||||||
"""
|
"""
|
||||||
state["looping"] = False
|
state["looping"] = False
|
||||||
logger.info(f"run break node, looping={state['looping']}")
|
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -9,6 +10,8 @@ from app.core.workflow.nodes import WorkflowState
|
|||||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class IterationRuntime:
|
class IterationRuntime:
|
||||||
"""
|
"""
|
||||||
@@ -127,12 +130,15 @@ class IterationRuntime:
|
|||||||
# Execute iterations in parallel batches
|
# Execute iterations in parallel batches
|
||||||
while idx < len(array_obj) and self.looping:
|
while idx < len(array_obj) and self.looping:
|
||||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||||
|
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||||
idx += self.typed_config.parallel_count
|
idx += self.typed_config.parallel_count
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||||
return self.result
|
return self.result
|
||||||
else:
|
else:
|
||||||
# Execute iterations sequentially
|
# Execute iterations sequentially
|
||||||
while idx < len(array_obj) and self.looping:
|
while idx < len(array_obj) and self.looping:
|
||||||
|
logger.info(f"Iteration node {self.node_id}: running")
|
||||||
item = array_obj[idx]
|
item = array_obj[idx]
|
||||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||||
output = VariablePool(result).get(self.output_value)
|
output = VariablePool(result).get(self.output_value)
|
||||||
@@ -143,4 +149,6 @@ class IterationRuntime:
|
|||||||
if not result["looping"]:
|
if not result["looping"]:
|
||||||
self.looping = False
|
self.looping = False
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||||
return self.result
|
return self.result
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
@@ -8,6 +9,8 @@ from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
|||||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoopRuntime:
|
class LoopRuntime:
|
||||||
"""
|
"""
|
||||||
@@ -119,6 +122,9 @@ class LoopRuntime:
|
|||||||
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
||||||
system_vars=loop_variable_pool.get_all_system_vars(),
|
system_vars=loop_variable_pool.get_all_system_vars(),
|
||||||
) and loopstate["looping"] and loop_time > 0:
|
) and loopstate["looping"] and loop_time > 0:
|
||||||
|
logger.info(f"loop node {self.node_id}: running")
|
||||||
await self.graph.ainvoke(loopstate)
|
await self.graph.ainvoke(loopstate)
|
||||||
loop_time -= 1
|
loop_time -= 1
|
||||||
|
|
||||||
|
logger.info(f"loop node {self.node_id}: execution completed")
|
||||||
return loopstate["runtime_vars"][self.node_id]
|
return loopstate["runtime_vars"][self.node_id]
|
||||||
|
|||||||
@@ -65,7 +65,10 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
# Raise error if cycle nodes are connected with external nodes
|
# Raise error if cycle nodes are connected with external nodes
|
||||||
if source_in ^ target_in:
|
if source_in ^ target_in:
|
||||||
raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}")
|
raise ValueError(
|
||||||
|
f"Cycle node is connected to external node, "
|
||||||
|
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||||
|
)
|
||||||
|
|
||||||
if source_in and target_in:
|
if source_in and target_in:
|
||||||
cycle_edges.append(edge)
|
cycle_edges.append(edge)
|
||||||
@@ -220,4 +223,4 @@ class CycleGraphNode(BaseNode):
|
|||||||
config=self.config,
|
config=self.config,
|
||||||
state=state,
|
state=state,
|
||||||
).run()
|
).run()
|
||||||
raise RuntimeError("未知循环节点类型")
|
raise RuntimeError("Unknown cycle node type")
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
**self._build_content(state)
|
**self._build_content(state)
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||||
return HttpRequestNodeOutput(
|
return HttpRequestNodeOutput(
|
||||||
body=resp.text,
|
body=resp.text,
|
||||||
status_code=resp.status_code,
|
status_code=resp.status_code,
|
||||||
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
match self.typed_config.error_handle.method:
|
match self.typed_config.error_handle.method:
|
||||||
case HttpErrorHandle.NONE:
|
case HttpErrorHandle.NONE:
|
||||||
|
logger.warning(
|
||||||
|
f"Node {self.node_id}: HTTP request failed, returning error response"
|
||||||
|
)
|
||||||
return HttpRequestNodeOutput(
|
return HttpRequestNodeOutput(
|
||||||
body="",
|
body="",
|
||||||
status_code=resp.status_code,
|
status_code=resp.status_code,
|
||||||
headers=resp.headers,
|
headers=resp.headers,
|
||||||
).model_dump()
|
).model_dump()
|
||||||
case HttpErrorHandle.DEFAULT:
|
case HttpErrorHandle.DEFAULT:
|
||||||
|
logger.warning(
|
||||||
|
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||||
|
)
|
||||||
return self.typed_config.error_handle.default.model_dump()
|
return self.typed_config.error_handle.default.model_dump()
|
||||||
case HttpErrorHandle.BRANCH:
|
case HttpErrorHandle.BRANCH:
|
||||||
|
logger.warning(
|
||||||
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
|
)
|
||||||
return "ERROR"
|
return "ERROR"
|
||||||
|
|||||||
@@ -94,5 +94,6 @@ class IfElseNode(BaseNode):
|
|||||||
for i in range(len(expressions)):
|
for i in range(len(expressions)):
|
||||||
logger.info(expressions[i])
|
logger.info(expressions[i])
|
||||||
if self._evaluate_condition(expressions[i], state):
|
if self._evaluate_condition(expressions[i], state):
|
||||||
|
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||||
return f'CASE{i + 1}'
|
return f'CASE{i + 1}'
|
||||||
return f'CASE{len(expressions)}'
|
return f'CASE{len(expressions)}'
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
|||||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||||
from app.core.workflow.template_renderer import TemplateRenderer
|
from app.core.workflow.template_renderer import TemplateRenderer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class JinjaRenderNode(BaseNode):
|
class JinjaRenderNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
|
|||||||
res = render.env.from_string(self.typed_config.template).render(**context)
|
res = render.env.from_string(self.typed_config.template).render(**context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
|
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
|
||||||
|
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
|
||||||
return res
|
return res
|
||||||
|
|||||||
@@ -190,12 +190,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
match kb_config.retrieve_type:
|
match kb_config.retrieve_type:
|
||||||
case RetrieveType.PARTICIPLE:
|
case RetrieveType.PARTICIPLE:
|
||||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=kb_config.similarity_threshold))
|
score_threshold=kb_config.similarity_threshold))
|
||||||
case RetrieveType.SEMANTIC:
|
case RetrieveType.SEMANTIC:
|
||||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=kb_config.vector_similarity_weight))
|
score_threshold=kb_config.vector_similarity_weight))
|
||||||
case RetrieveType.HYBRID:
|
case RetrieveType.HYBRID:
|
||||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
@@ -209,5 +209,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unknown retrieval type")
|
raise RuntimeError("Unknown retrieval type")
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k)
|
vector_service.reranker = self.get_reranker_model()
|
||||||
|
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
|
logger.info(
|
||||||
|
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||||
|
)
|
||||||
return [chunk.model_dump() for chunk in final_rs]
|
return [chunk.model_dump() for chunk in final_rs]
|
||||||
|
|||||||
@@ -163,6 +163,6 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
|
|
||||||
model_resp = await llm.ainvoke(messages)
|
model_resp = await llm.ainvoke(messages)
|
||||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||||
logger.info(f"get prarms:{result}")
|
logger.info(f"node: {self.node_id} get params:{result}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if value is not None:
|
if value is not None:
|
||||||
|
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
logger.info("No variable found in non-group mode; returning empty string.")
|
logger.info("No variable found in non-group mode; returning empty string.")
|
||||||
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
result[group_name] = ""
|
result[group_name] = ""
|
||||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||||
|
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -7,14 +7,87 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowValidator:
|
class WorkflowValidator:
|
||||||
"""工作流配置验证器"""
|
"""工作流配置验证器"""
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
|
||||||
|
"""
|
||||||
|
Extract cycle nodes and internal edges from the workflow configuration,
|
||||||
|
removing them from the global workflow.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If cycle nodes are connected to external nodes improperly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- cycle_nodes: List of removed nodes
|
||||||
|
- cycle_edges: List of removed edges
|
||||||
|
"""
|
||||||
|
nodes = workflow_config.get("nodes", [])
|
||||||
|
edges = workflow_config.get("edges", [])
|
||||||
|
|
||||||
|
# Select all nodes that belong to the current cycle
|
||||||
|
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
|
||||||
|
cycle_node_ids = {node.get("id") for node in cycle_nodes}
|
||||||
|
|
||||||
|
cycle_edges = []
|
||||||
|
remain_edges = []
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
source_in = edge.get("source") in cycle_node_ids
|
||||||
|
target_in = edge.get("target") in cycle_node_ids
|
||||||
|
|
||||||
|
# Raise error if cycle nodes are connected with external nodes
|
||||||
|
if source_in ^ target_in:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cycle node is connected to external node, "
|
||||||
|
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_in and target_in:
|
||||||
|
cycle_edges.append(edge)
|
||||||
|
else:
|
||||||
|
remain_edges.append(edge)
|
||||||
|
|
||||||
|
# Update workflow_config by removing cycle nodes and internal edges
|
||||||
|
workflow_config["nodes"] = [
|
||||||
|
node for node in nodes if node.get("cycle") != node_id
|
||||||
|
]
|
||||||
|
workflow_config["edges"] = remain_edges
|
||||||
|
|
||||||
|
return cycle_nodes, cycle_edges
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
|
||||||
|
if not isinstance(workflow_config, dict):
|
||||||
|
workflow_config = {
|
||||||
|
"nodes": workflow_config.nodes,
|
||||||
|
"edges": workflow_config.edges,
|
||||||
|
"variables": workflow_config.variables,
|
||||||
|
}
|
||||||
|
cycle_nodes = [
|
||||||
|
node.get("id")
|
||||||
|
for node in workflow_config.get("nodes", [])
|
||||||
|
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||||
|
]
|
||||||
|
graphs = []
|
||||||
|
for cycle_node in cycle_nodes:
|
||||||
|
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
|
||||||
|
graphs.append({
|
||||||
|
"nodes": nodes,
|
||||||
|
"edges": edges,
|
||||||
|
})
|
||||||
|
graphs.append(workflow_config)
|
||||||
|
return graphs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||||
"""验证工作流配置
|
"""验证工作流配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -38,84 +111,79 @@ class WorkflowValidator:
|
|||||||
True
|
True
|
||||||
"""
|
"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
# 支持字典和 Pydantic 模型
|
graphs = cls.get_subgraph(workflow_config)
|
||||||
if isinstance(workflow_config, dict):
|
logger.info(graphs)
|
||||||
nodes = workflow_config.get("nodes", [])
|
for graph in graphs:
|
||||||
edges = workflow_config.get("edges", [])
|
nodes = graph.get("nodes", [])
|
||||||
variables = workflow_config.get("variables", [])
|
edges = graph.get("edges", [])
|
||||||
else:
|
variables = graph.get("variables", [])
|
||||||
# Pydantic 模型
|
# 1. 验证 start 节点(有且只有一个)
|
||||||
nodes = getattr(workflow_config, "nodes", [])
|
start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
|
||||||
edges = getattr(workflow_config, "edges", [])
|
if len(start_nodes) == 0:
|
||||||
variables = getattr(workflow_config, "variables", [])
|
errors.append("工作流必须有一个 start 节点")
|
||||||
|
elif len(start_nodes) > 1:
|
||||||
# 1. 验证 start 节点(有且只有一个)
|
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
|
||||||
if len(start_nodes) == 0:
|
# 2. 验证 end 节点(至少一个)
|
||||||
errors.append("工作流必须有一个 start 节点")
|
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||||
elif len(start_nodes) > 1:
|
if len(end_nodes) == 0:
|
||||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
errors.append("工作流必须至少有一个 end 节点")
|
||||||
|
|
||||||
# 2. 验证 end 节点(至少一个)
|
# 3. 验证节点 ID 唯一性
|
||||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
node_ids = [n.get("id") for n in nodes]
|
||||||
if len(end_nodes) == 0:
|
if len(node_ids) != len(set(node_ids)):
|
||||||
errors.append("工作流必须至少有一个 end 节点")
|
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||||
|
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||||
# 3. 验证节点 ID 唯一性
|
|
||||||
node_ids = [n.get("id") for n in nodes]
|
# 4. 验证节点必须有 id 和 type
|
||||||
if len(node_ids) != len(set(node_ids)):
|
for i, node in enumerate(nodes):
|
||||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
if not node.get("id"):
|
||||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||||
|
if not node.get("type"):
|
||||||
# 4. 验证节点必须有 id 和 type
|
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||||
for i, node in enumerate(nodes):
|
|
||||||
if not node.get("id"):
|
# 5. 验证边的有效性
|
||||||
errors.append(f"节点 #{i} 缺少 id 字段")
|
node_id_set = set(node_ids)
|
||||||
if not node.get("type"):
|
for i, edge in enumerate(edges):
|
||||||
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
source = edge.get("source")
|
||||||
|
target = edge.get("target")
|
||||||
# 5. 验证边的有效性
|
|
||||||
node_id_set = set(node_ids)
|
if not source:
|
||||||
for i, edge in enumerate(edges):
|
errors.append(f"边 #{i} 缺少 source 字段")
|
||||||
source = edge.get("source")
|
elif source not in node_id_set:
|
||||||
target = edge.get("target")
|
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||||
|
|
||||||
if not source:
|
if not target:
|
||||||
errors.append(f"边 #{i} 缺少 source 字段")
|
errors.append(f"边 #{i} 缺少 target 字段")
|
||||||
elif source not in node_id_set:
|
elif target not in node_id_set:
|
||||||
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||||
|
|
||||||
if not target:
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
errors.append(f"边 #{i} 缺少 target 字段")
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
elif target not in node_id_set:
|
reachable = WorkflowValidator._get_reachable_nodes(
|
||||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
start_nodes[0]["id"],
|
||||||
|
edges
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
|
||||||
start_nodes[0]["id"],
|
|
||||||
edges
|
|
||||||
)
|
|
||||||
unreachable = node_id_set - reachable
|
|
||||||
if unreachable:
|
|
||||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
|
||||||
|
|
||||||
# 7. 检测循环依赖(非 loop 节点)
|
|
||||||
if not errors: # 只有在前面验证通过时才检查循环
|
|
||||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
|
||||||
if has_cycle:
|
|
||||||
errors.append(
|
|
||||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
|
||||||
)
|
)
|
||||||
|
unreachable = node_id_set - reachable
|
||||||
# 8. 验证变量名
|
if unreachable:
|
||||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
|
||||||
errors.extend(var_errors)
|
# 7. 检测循环依赖(非 loop 节点)
|
||||||
|
if not errors: # 只有在前面验证通过时才检查循环
|
||||||
|
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||||
|
if has_cycle:
|
||||||
|
errors.append(
|
||||||
|
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. 验证变量名
|
||||||
|
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||||
|
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||||
|
errors.extend(var_errors)
|
||||||
|
|
||||||
return len(errors) == 0, errors
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||||
"""获取从 start 节点可达的所有节点
|
"""获取从 start 节点可达的所有节点
|
||||||
@@ -129,7 +197,7 @@ class WorkflowValidator:
|
|||||||
"""
|
"""
|
||||||
reachable = {start_id}
|
reachable = {start_id}
|
||||||
queue = [start_id]
|
queue = [start_id]
|
||||||
|
|
||||||
while queue:
|
while queue:
|
||||||
current = queue.pop(0)
|
current = queue.pop(0)
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
@@ -138,9 +206,9 @@ class WorkflowValidator:
|
|||||||
if target and target not in reachable:
|
if target and target not in reachable:
|
||||||
reachable.add(target)
|
reachable.add(target)
|
||||||
queue.append(target)
|
queue.append(target)
|
||||||
|
|
||||||
return reachable
|
return reachable
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
|
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
|
||||||
"""检测是否存在循环依赖(DFS)
|
"""检测是否存在循环依赖(DFS)
|
||||||
@@ -154,39 +222,39 @@ class WorkflowValidator:
|
|||||||
"""
|
"""
|
||||||
# 排除 loop 类型的节点
|
# 排除 loop 类型的节点
|
||||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||||
|
|
||||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||||
graph: dict[str, list[str]] = {}
|
graph: dict[str, list[str]] = {}
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
source = edge.get("source")
|
source = edge.get("source")
|
||||||
target = edge.get("target")
|
target = edge.get("target")
|
||||||
edge_type = edge.get("type")
|
edge_type = edge.get("type")
|
||||||
|
|
||||||
# 跳过错误边
|
# 跳过错误边
|
||||||
if edge_type == "error":
|
if edge_type == "error":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果涉及 loop 节点,跳过
|
# 如果涉及 loop 节点,跳过
|
||||||
if source in loop_nodes or target in loop_nodes:
|
if source in loop_nodes or target in loop_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if source and target:
|
if source and target:
|
||||||
if source not in graph:
|
if source not in graph:
|
||||||
graph[source] = []
|
graph[source] = []
|
||||||
graph[source].append(target)
|
graph[source].append(target)
|
||||||
|
|
||||||
# DFS 检测环
|
# DFS 检测环
|
||||||
visited = set()
|
visited = set()
|
||||||
rec_stack = set()
|
rec_stack = set()
|
||||||
path = []
|
path = []
|
||||||
cycle_path = []
|
cycle_path = []
|
||||||
|
|
||||||
def dfs(node: str) -> bool:
|
def dfs(node: str) -> bool:
|
||||||
"""DFS 检测环,返回是否找到环"""
|
"""DFS 检测环,返回是否找到环"""
|
||||||
visited.add(node)
|
visited.add(node)
|
||||||
rec_stack.add(node)
|
rec_stack.add(node)
|
||||||
path.append(node)
|
path.append(node)
|
||||||
|
|
||||||
for neighbor in graph.get(node, []):
|
for neighbor in graph.get(node, []):
|
||||||
if neighbor not in visited:
|
if neighbor not in visited:
|
||||||
if dfs(neighbor):
|
if dfs(neighbor):
|
||||||
@@ -196,19 +264,19 @@ class WorkflowValidator:
|
|||||||
cycle_start = path.index(neighbor)
|
cycle_start = path.index(neighbor)
|
||||||
cycle_path.extend([*path[cycle_start:], neighbor])
|
cycle_path.extend([*path[cycle_start:], neighbor])
|
||||||
return True
|
return True
|
||||||
|
|
||||||
rec_stack.remove(node)
|
rec_stack.remove(node)
|
||||||
path.pop()
|
path.pop()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查所有节点
|
# 检查所有节点
|
||||||
for node_id in graph:
|
for node_id in graph:
|
||||||
if node_id not in visited:
|
if node_id not in visited:
|
||||||
if dfs(node_id):
|
if dfs(node_id):
|
||||||
return True, cycle_path
|
return True, cycle_path
|
||||||
|
|
||||||
return False, []
|
return False, []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
|
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
|
||||||
"""验证工作流配置是否可以发布(更严格的验证)
|
"""验证工作流配置是否可以发布(更严格的验证)
|
||||||
@@ -221,30 +289,30 @@ class WorkflowValidator:
|
|||||||
"""
|
"""
|
||||||
# 先执行基础验证
|
# 先执行基础验证
|
||||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||||
|
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return False, errors
|
return False, errors
|
||||||
|
|
||||||
# 额外的发布验证
|
# 额外的发布验证
|
||||||
nodes = workflow_config.get("nodes", [])
|
nodes = workflow_config.get("nodes", [])
|
||||||
|
|
||||||
# 1. 验证所有节点都有名称
|
# 1. 验证所有节点都有名称
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.get("type") not in ["start", "end"] and not node.get("name"):
|
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||||
errors.append(
|
errors.append(
|
||||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 验证所有非 start/end 节点都有配置
|
# 2. 验证所有非 start/end 节点都有配置
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
if node_type not in ["start", "end"]:
|
if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
|
||||||
config = node.get("config")
|
config = node.get("config")
|
||||||
if not config or not isinstance(config, dict):
|
if not config or not isinstance(config, dict):
|
||||||
errors.append(
|
errors.append(
|
||||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 验证必填变量
|
# 3. 验证必填变量
|
||||||
variables = workflow_config.get("variables", [])
|
variables = workflow_config.get("variables", [])
|
||||||
required_vars = [v for v in variables if v.get("required")]
|
required_vars = [v for v in variables if v.get("required")]
|
||||||
@@ -254,13 +322,13 @@ class WorkflowValidator:
|
|||||||
f"工作流包含 {len(required_vars)} 个必填变量: "
|
f"工作流包含 {len(required_vars)} 个必填变量: "
|
||||||
f"{[v.get('name') for v in required_vars]}"
|
f"{[v.get('name') for v in required_vars]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return len(errors) == 0, errors
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
|
|
||||||
def validate_workflow_config(
|
def validate_workflow_config(
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
for_publish: bool = False
|
for_publish: bool = False
|
||||||
) -> tuple[bool, list[str]]:
|
) -> tuple[bool, list[str]]:
|
||||||
"""验证工作流配置(便捷函数)
|
"""验证工作流配置(便捷函数)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user