Merge #80 into develop from feature/20251219_myh
feat(workflow): support cycle nodes in workflow config validation and enhance node logging
* feature/20251219_myh: (11 commits squashed)
- feat(workflow): update reranker model configuration for knowledge base retrieval
- fix(workflow): fix output issue in parameter extraction node
- fix(workflow): fix output issue in parameter extraction node
- feat(workflow): add user prompt to parameter extraction node
- perf(workflow): change grouped variable input to key-value format in variable aggregator
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
- Merge branch 'develop' into feature/20251219_myh
# Conflicts:
#	api/app/core/workflow/nodes/configs.py
#	api/app/core/workflow/nodes/node_factory.py
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
- feat(workflow): support cycle nodes in workflow config validation and enhance node logging
- feat(workflow): support cycle nodes in workflow config validation and enhance node logging
- fix(workflow): fix compatibility with some legacy node configurations
Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/80
This commit is contained in:
@@ -83,3 +83,5 @@ class AssignerNode(BaseNode):
|
||||
operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ class BaseNode(ABC):
|
||||
self.workflow_config = workflow_config
|
||||
self.node_id = node_config["id"]
|
||||
self.node_type = node_config["type"]
|
||||
self.cycle = node_config.get("cycle")
|
||||
self.node_name = node_config.get("name", self.node_id)
|
||||
# 使用 or 运算符处理 None 值
|
||||
self.config = node_config.get("config") or {}
|
||||
|
||||
@@ -29,4 +29,5 @@ class BreakNode(BaseNode):
|
||||
Optional dictionary indicating the loop has been stopped.
|
||||
"""
|
||||
state["looping"] = False
|
||||
logger.info(f"run break node, looping={state['looping']}")
|
||||
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
@@ -9,6 +10,8 @@ from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IterationRuntime:
|
||||
"""
|
||||
@@ -127,12 +130,15 @@ class IterationRuntime:
|
||||
# Execute iterations in parallel batches
|
||||
while idx < len(array_obj) and self.looping:
|
||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
@@ -143,4 +149,6 @@ class IterationRuntime:
|
||||
if not result["looping"]:
|
||||
self.looping = False
|
||||
idx += 1
|
||||
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
@@ -8,6 +9,8 @@ from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopRuntime:
|
||||
"""
|
||||
@@ -119,6 +122,9 @@ class LoopRuntime:
|
||||
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
||||
system_vars=loop_variable_pool.get_all_system_vars(),
|
||||
) and loopstate["looping"] and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
await self.graph.ainvoke(loopstate)
|
||||
loop_time -= 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return loopstate["runtime_vars"][self.node_id]
|
||||
|
||||
@@ -65,7 +65,10 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
# Raise error if cycle nodes are connected with external nodes
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}")
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
cycle_edges.append(edge)
|
||||
@@ -220,4 +223,4 @@ class CycleGraphNode(BaseNode):
|
||||
config=self.config,
|
||||
state=state,
|
||||
).run()
|
||||
raise RuntimeError("未知循环节点类型")
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
|
||||
**self._build_content(state)
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
return HttpRequestNodeOutput(
|
||||
body=resp.text,
|
||||
status_code=resp.status_code,
|
||||
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
|
||||
else:
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
||||
)
|
||||
return HttpRequestNodeOutput(
|
||||
body="",
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
).model_dump()
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||
)
|
||||
return self.typed_config.error_handle.default.model_dump()
|
||||
case HttpErrorHandle.BRANCH:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
|
||||
@@ -94,5 +94,6 @@ class IfElseNode(BaseNode):
|
||||
for i in range(len(expressions)):
|
||||
logger.info(expressions[i])
|
||||
if self._evaluate_condition(expressions[i], state):
|
||||
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||
return f'CASE{i + 1}'
|
||||
return f'CASE{len(expressions)}'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
|
||||
res = render.env.from_string(self.typed_config.template).render(**context)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
|
||||
|
||||
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
|
||||
return res
|
||||
|
||||
@@ -190,12 +190,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
@@ -209,5 +209,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k)
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.model_dump() for chunk in final_rs]
|
||||
|
||||
@@ -163,6 +163,6 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
logger.info(f"get prarms:{result}")
|
||||
logger.info(f"node: {self.node_id} get params:{result}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
continue
|
||||
|
||||
if value is not None:
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
|
||||
return value
|
||||
|
||||
logger.info("No variable found in non-group mode; returning empty string.")
|
||||
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
|
||||
else:
|
||||
result[group_name] = ""
|
||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
||||
return result
|
||||
|
||||
@@ -7,14 +7,87 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""工作流配置验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
|
||||
@classmethod
|
||||
def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
|
||||
"""
|
||||
Extract cycle nodes and internal edges from the workflow configuration,
|
||||
removing them from the global workflow.
|
||||
|
||||
Raises:
|
||||
ValueError: If cycle nodes are connected to external nodes improperly.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- cycle_nodes: List of removed nodes
|
||||
- cycle_edges: List of removed edges
|
||||
"""
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
|
||||
# Select all nodes that belong to the current cycle
|
||||
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
|
||||
cycle_node_ids = {node.get("id") for node in cycle_nodes}
|
||||
|
||||
cycle_edges = []
|
||||
remain_edges = []
|
||||
|
||||
for edge in edges:
|
||||
source_in = edge.get("source") in cycle_node_ids
|
||||
target_in = edge.get("target") in cycle_node_ids
|
||||
|
||||
# Raise error if cycle nodes are connected with external nodes
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
cycle_edges.append(edge)
|
||||
else:
|
||||
remain_edges.append(edge)
|
||||
|
||||
# Update workflow_config by removing cycle nodes and internal edges
|
||||
workflow_config["nodes"] = [
|
||||
node for node in nodes if node.get("cycle") != node_id
|
||||
]
|
||||
workflow_config["edges"] = remain_edges
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
@classmethod
|
||||
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
|
||||
if not isinstance(workflow_config, dict):
|
||||
workflow_config = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
}
|
||||
cycle_nodes = [
|
||||
node.get("id")
|
||||
for node in workflow_config.get("nodes", [])
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
]
|
||||
graphs = []
|
||||
for cycle_node in cycle_nodes:
|
||||
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
|
||||
graphs.append({
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
})
|
||||
graphs.append(workflow_config)
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
@@ -38,84 +111,79 @@ class WorkflowValidator:
|
||||
True
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 支持字典和 Pydantic 模型
|
||||
if isinstance(workflow_config, dict):
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
variables = workflow_config.get("variables", [])
|
||||
else:
|
||||
# Pydantic 模型
|
||||
nodes = getattr(workflow_config, "nodes", [])
|
||||
edges = getattr(workflow_config, "edges", [])
|
||||
variables = getattr(workflow_config, "variables", [])
|
||||
|
||||
# 1. 验证 start 节点(有且只有一个)
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
if len(start_nodes) == 0:
|
||||
errors.append("工作流必须有一个 start 节点")
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
# 4. 验证节点必须有 id 和 type
|
||||
for i, node in enumerate(nodes):
|
||||
if not node.get("id"):
|
||||
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||
if not node.get("type"):
|
||||
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||
|
||||
# 5. 验证边的有效性
|
||||
node_id_set = set(node_ids)
|
||||
for i, edge in enumerate(edges):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
if not source:
|
||||
errors.append(f"边 #{i} 缺少 source 字段")
|
||||
elif source not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||
|
||||
if not target:
|
||||
errors.append(f"边 #{i} 缺少 target 字段")
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for graph in graphs:
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
variables = graph.get("variables", [])
|
||||
# 1. 验证 start 节点(有且只有一个)
|
||||
start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
|
||||
if len(start_nodes) == 0:
|
||||
errors.append("工作流必须有一个 start 节点")
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
# 4. 验证节点必须有 id 和 type
|
||||
for i, node in enumerate(nodes):
|
||||
if not node.get("id"):
|
||||
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||
if not node.get("type"):
|
||||
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||
|
||||
# 5. 验证边的有效性
|
||||
node_id_set = set(node_ids)
|
||||
for i, edge in enumerate(edges):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
if not source:
|
||||
errors.append(f"边 #{i} 缺少 source 字段")
|
||||
elif source not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||
|
||||
if not target:
|
||||
errors.append(f"边 #{i} 缺少 target 字段")
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
"""获取从 start 节点可达的所有节点
|
||||
@@ -129,7 +197,7 @@ class WorkflowValidator:
|
||||
"""
|
||||
reachable = {start_id}
|
||||
queue = [start_id]
|
||||
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for edge in edges:
|
||||
@@ -138,9 +206,9 @@ class WorkflowValidator:
|
||||
if target and target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
|
||||
|
||||
return reachable
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
|
||||
"""检测是否存在循环依赖(DFS)
|
||||
@@ -154,39 +222,39 @@ class WorkflowValidator:
|
||||
"""
|
||||
# 排除 loop 类型的节点
|
||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||
|
||||
|
||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||
graph: dict[str, list[str]] = {}
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
|
||||
|
||||
# 跳过错误边
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
|
||||
# 如果涉及 loop 节点,跳过
|
||||
if source in loop_nodes or target in loop_nodes:
|
||||
continue
|
||||
|
||||
|
||||
if source and target:
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
graph[source].append(target)
|
||||
|
||||
|
||||
# DFS 检测环
|
||||
visited = set()
|
||||
rec_stack = set()
|
||||
path = []
|
||||
cycle_path = []
|
||||
|
||||
|
||||
def dfs(node: str) -> bool:
|
||||
"""DFS 检测环,返回是否找到环"""
|
||||
visited.add(node)
|
||||
rec_stack.add(node)
|
||||
path.append(node)
|
||||
|
||||
|
||||
for neighbor in graph.get(node, []):
|
||||
if neighbor not in visited:
|
||||
if dfs(neighbor):
|
||||
@@ -196,19 +264,19 @@ class WorkflowValidator:
|
||||
cycle_start = path.index(neighbor)
|
||||
cycle_path.extend([*path[cycle_start:], neighbor])
|
||||
return True
|
||||
|
||||
|
||||
rec_stack.remove(node)
|
||||
path.pop()
|
||||
return False
|
||||
|
||||
|
||||
# 检查所有节点
|
||||
for node_id in graph:
|
||||
if node_id not in visited:
|
||||
if dfs(node_id):
|
||||
return True, cycle_path
|
||||
|
||||
|
||||
return False, []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置是否可以发布(更严格的验证)
|
||||
@@ -221,30 +289,30 @@ class WorkflowValidator:
|
||||
"""
|
||||
# 先执行基础验证
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||
|
||||
|
||||
if not is_valid:
|
||||
return False, errors
|
||||
|
||||
|
||||
# 额外的发布验证
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
|
||||
|
||||
# 1. 验证所有节点都有名称
|
||||
for node in nodes:
|
||||
if node.get("type") not in ["start", "end"] and not node.get("name"):
|
||||
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
for node in nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type not in ["start", "end"]:
|
||||
if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
|
||||
# 3. 验证必填变量
|
||||
variables = workflow_config.get("variables", [])
|
||||
required_vars = [v for v in variables if v.get("required")]
|
||||
@@ -254,13 +322,13 @@ class WorkflowValidator:
|
||||
f"工作流包含 {len(required_vars)} 个必填变量: "
|
||||
f"{[v.get('name') for v in required_vars]}"
|
||||
)
|
||||
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def validate_workflow_config(
|
||||
workflow_config: dict[str, Any],
|
||||
for_publish: bool = False
|
||||
workflow_config: dict[str, Any],
|
||||
for_publish: bool = False
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置(便捷函数)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user