Merge pull request #682 from SuanmoSuanyangTechnology/pref/workflow-engine

pref(workflow): optimize workflow execution performance and reduce logging noise
This commit is contained in:
Mark
2026-03-25 18:59:27 +08:00
committed by GitHub
37 changed files with 268 additions and 347 deletions

View File

@@ -1099,7 +1099,6 @@ class ExtractionOrchestrator:
metadata=chunk.metadata, metadata=chunk.metadata,
) )
chunk_nodes.append(chunk_node) chunk_nodes.append(chunk_node)
logger.error(f"chunk file: {chunk.files}")
for p, file_type in chunk.files: for p, file_type in chunk.files:

View File

@@ -9,7 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.workflow.adapters.errors import ExceptionDefineition from app.core.workflow.adapters.errors import ExceptionDefinition
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
EdgeDefinition, EdgeDefinition,
NodeDefinition, NodeDefinition,
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list) warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list)
class WorkflowImportResult(BaseModel): class WorkflowImportResult(BaseModel):
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list) warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list)
class BasePlatformAdapter(ABC): class BasePlatformAdapter(ABC):

View File

@@ -9,9 +9,9 @@ from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ( from app.core.workflow.adapters.errors import (
UnsupportVariableType, UnsupportedVariableType,
UnknowModelWarning, UnknownModelWarning,
ExceptionDefineition, ExceptionDefinition,
ExceptionType ExceptionType
) )
from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.assigner.config import AssignmentItem
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
HttpFormData, HttpFormData,
HttpTimeOutConfig, HttpTimeOutConfig,
HttpRetryConfig, HttpRetryConfig,
HttpErrorDefaultTamplete, HttpErrorDefaultTemplate,
HttpErrorHandleConfig HttpErrorHandleConfig
) )
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
try: try:
return config.model_validate(value) return config.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
var_selector = mapping.get(var_selector, var_selector) var_selector = mapping.get(var_selector, var_selector)
return var_selector return var_selector
def _process_list_variable_litearl(self, variable_selector: list) -> str | None: def _process_list_variable_literal(self, variable_selector: list) -> str | None:
if not self.process_var_selector(".".join(variable_selector)): if not self.process_var_selector(".".join(variable_selector)):
return None return None
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
var_type = self.variable_type_map(var["type"]) var_type = self.variable_type_map(var["type"])
if not var_type: if not var_type:
self.errors.append( self.errors.append(
UnsupportVariableType( UnsupportedVariableType(
scope=node["id"], scope=node["id"],
name=var["variable"], name=var["variable"],
var_type=var["type"], var_type=var["type"],
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
if var_type in ["file", "array[file]"]: if var_type in ["file", "array[file]"]:
self.errors.append( self.errors.append(
ExceptionDefineition( ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
def convert_question_classifier_node_config(self, node: dict) -> dict: def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
) )
result = QuestionClassifierNodeConfig.model_construct( result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")), input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories, categories=categories,
).model_dump() ).model_dump()
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
def convert_llm_node_config(self, node: dict) -> dict: def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
) )
) )
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
memory = MemoryWindowSetting( memory = MemoryWindowSetting(
enable=bool(node_data.get("memory")), enable=bool(node_data.get("memory")),
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
) )
) )
vision = node_data["vision"]["enabled"] vision = node_data["vision"]["enabled"]
vision_input = self._process_list_variable_litearl( vision_input = self._process_list_variable_literal(
node_data["vision"]["configs"]["variable_selector"] node_data["vision"]["configs"]["variable_selector"]
) if vision else None ) if vision else None
result = LLMNodeConfig.model_construct( result = LLMNodeConfig.model_construct(
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
conditions.append( conditions.append(
LoopConditionDetail.model_construct( LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]), operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]), left=self._process_list_variable_literal(condition["variable_selector"]),
right=self.trans_variable_format( right=self.trans_variable_format(
right_value right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"] right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"]) right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE: if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable.get("value", "")) right_value = self._process_list_variable_literal(variable.get("value", ""))
else: else:
right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append( loop_variables.append(
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
def convert_iteration_node_config(self, node: dict) -> dict: def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
result = IterationNodeConfig.model_construct( result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]), input=self._process_list_variable_literal(node_data["iterator_selector"]),
parallel=node_data["is_parallel"], parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"], parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]), output=self._process_list_variable_literal(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")), output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"], flatten=node_data["flatten_output"],
).model_dump() ).model_dump()
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
continue continue
assignments.append( assignments.append(
AssignmentItem( AssignmentItem(
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
value=self._process_list_variable_litearl( value=self._process_list_variable_literal(
assignment["value"] assignment["value"]
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
operation=self.convert_assignment_operator(assignment["operation"]) operation=self.convert_assignment_operator(assignment["operation"])
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
input_variables.append( input_variables.append(
InputVariable.model_construct( InputVariable.model_construct(
name=input_variable["variable"], name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]), variable=self._process_list_variable_literal(input_variable["value_selector"]),
) )
) )
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
else: else:
if node_data["body"]["data"]: if node_data["body"]["data"]:
body_content = (node_data["body"]["data"][0].get("value") or body_content = (node_data["body"]["data"][0].get("value") or
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
else: else:
body_content = "" body_content = ""
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
default_header = var["value"] default_header = var["value"]
elif var["key"] == "status_code": elif var["key"] == "status_code":
default_status_code = var["value"] default_status_code = var["value"]
default_value = HttpErrorDefaultTamplete( default_value = HttpErrorDefaultTemplate(
body=default_body, body=default_body,
headers=default_header, headers=default_header,
status_code=default_status_code, status_code=default_status_code,
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
for variable in node_data["variables"]: for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig.model_construct( mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"], name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"]) value=self._process_list_variable_literal(variable["value_selector"])
)) ))
result = JinjaRenderNodeConfig.model_construct( result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"], template=node_data["template"],
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
def convert_knowledge_node_config(self, node: dict) -> dict: def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.", detail=f"Please reconfigure the Knowledge Retrieval node.",
)) ))
result = KnowledgeRetrievalNodeConfig.model_construct( result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]), query=self._process_list_variable_literal(node_data["query_variable_selector"]),
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
def convert_parameter_extractor_node_config(self, node: dict) -> dict: def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
) )
) )
result = ParameterExtractorNodeConfig.model_construct( result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]), text=self._process_list_variable_literal(node_data["query"]),
params=params, params=params,
prompt=node_data.get("instruction") prompt=node_data.get("instruction")
).model_dump() ).model_dump()
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
group_type = {} group_type = {}
if not advanced_settings or not advanced_settings["group_enabled"]: if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables = [ group_variables = [
self._process_list_variable_litearl(variable) self._process_list_variable_literal(variable)
for variable in node_data["variables"] for variable in node_data["variables"]
] ]
group_type["output"] = node_data["output_type"] group_type["output"] = node_data["output_type"]
else: else:
for group in advanced_settings["groups"]: for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [ group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable) self._process_list_variable_literal(variable)
for variable in group["variables"] for variable in group["variables"]
] ]
group_type[group["group_name"]] = group["output_type"] group_type[group["group_name"]] = group["output_type"]
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
def convert_tool_node_config(self, node: dict) -> dict: def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.dify.converter import DifyConverter from app.core.workflow.adapters.dify.converter import DifyConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
NodeDefinition, NodeDefinition,
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False
if self.config.get("app", {}).get("mode") == "workflow": if self.config.get("app", {}).get("mode") == "workflow":
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.PLATFORM, type=ExceptionType.PLATFORM,
detail="workflow mode is not supported" detail="workflow mode is not supported"
)) ))
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
edge = self._convert_edge(edge) edge = self._convert_edge(edge)
if edge: if edge:
self.edges.append(edge) self.edges.append(edge)
#
for variable in self.config.get("workflow").get("conversation_variables"): for variable in self.config.get("workflow").get("conversation_variables"):
con_var = self._convert_variable(variable) con_var = self._convert_variable(variable)
if variable: if variable:
self.conv_variables.append(con_var) self.conv_variables.append(con_var)
#
# for variables in config.get("workflow").get("environment_variables"): # for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables) # variable = self._convert_variable(variables)
# conv_variables.append(variable) # conv_variables.append(variable)
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"y": node["position"]["y"] + position["y"] "y": node["position"]["y"] + position["y"]
} }
self.errors.append( self.errors.append(
ExceptionDefineition( ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
detail="parent cycle node not found" detail="parent cycle node not found"
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_data = node["data"] node_data = node["data"]
converter = self.get_node_convert(node_type) converter = self.get_node_convert(node_type)
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
)) ))
return converter(node) return converter(node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
try: try:
source = edge["source"] source = edge["source"]
target = edge["target"] target = edge["target"]
label = None label = None
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
label=label, label=label,
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}", detail=f"convert edge error - {e}",
)) ))
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
description=variable.get("description") description=variable.get("description")
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}", detail=f"convert variable error - {e}",

View File

@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
UNKNOWN = "unknown" UNKNOWN = "unknown"
class ExceptionDefineition(BaseModel): class ExceptionDefinition(BaseModel):
type: ExceptionType type: ExceptionType
detail: str detail: str
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
name: str | None = None name: str | None = None
class UnknowModelWarning(ExceptionDefineition): class UnknownModelWarning(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id, node_name, model_name): def __init__(self, node_id, node_name, model_name):
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
) )
class UnknowError(ExceptionDefineition): class UnknownError(ExceptionDefinition):
type: ExceptionType = ExceptionType.UNKNOWN type: ExceptionType = ExceptionType.UNKNOWN
def __init__(self, detail: str, **kwargs): def __init__(self, detail: str, **kwargs):
super().__init__(detail=detail, **kwargs) super().__init__(detail=detail, **kwargs)
class UnsupportPlatform(ExceptionDefineition): class UnsupportedPlatform(ExceptionDefinition):
type: ExceptionType = ExceptionType.PLATFORM type: ExceptionType = ExceptionType.PLATFORM
def __init__(self, platform: str): def __init__(self, platform: str):
super().__init__(detail=f"Unsupport platform {platform}") super().__init__(detail=f"Unsupported platform {platform}")
class UnsupportVariableType(ExceptionDefineition): class UnsupportedVariableType(ExceptionDefinition):
type: ExceptionType = ExceptionType.VARIABLE type: ExceptionType = ExceptionType.VARIABLE
def __init__(self, scope, name, var_type: str, **kwargs): def __init__(self, scope, name, var_type: str, **kwargs):
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type[{var_type}]", **kwargs) super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
class InvalidConfiguration(ExceptionDefineition): class InvalidConfiguration(ExceptionDefinition):
type: ExceptionType = ExceptionType.CONFIG type: ExceptionType = ExceptionType.CONFIG
def __init__(self): def __init__(self):
super().__init__(detail="Invalid workflow configuration format") super().__init__(detail="Invalid workflow configuration format")
class UnsupportNodeType(ExceptionDefineition): class UnsupportedNodeType(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id: str, node_type: str): def __init__(self, node_id: str, node_type: str):
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")

View File

@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
BasePlatformAdapter, BasePlatformAdapter,
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
node_type = self.map_node_type(node["type"]) node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportNodeType( self.errors.append(UnsupportedNodeType(
node_id=node_id, node_id=node_id,
node_type=node["type"] node_type=node["type"]
)) ))
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
return NodeDefinition(**node) return NodeDefinition(**node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try: try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found" detail=f"edge {edge.get('id')} skipped: source or target node not found"
)) ))
return None return None
return EdgeDefinition(**edge) return EdgeDefinition(**edge)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}" detail=f"convert edge error - {e}"
)) ))
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
return VariableDefinition(**variable) return VariableDefinition(**variable)
except Exception as e: except Exception as e:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}" detail=f"convert variable error - {e}"

View File

@@ -1,6 +1,6 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import ( from app.core.workflow.nodes.configs import (
StartNodeConfig, StartNodeConfig,
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
try: try:
return config_cls.model_validate(value) return config_cls.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,

View File

@@ -7,7 +7,7 @@ import re
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Any, Iterable from typing import Any, Iterable, Callable
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END from langgraph.graph import START, END
@@ -41,48 +41,31 @@ class GraphBuilder:
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
stream: bool = False, stream: bool = False,
subgraph: bool = False, cycle: str = '',
variable_pool: VariablePool | None = None variable_pool: VariablePool | None = None
): ):
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.stream = stream self.stream = stream
self.subgraph = subgraph self.cycle = cycle
self.start_node_id: str | None = None self.start_node_id: str | None = None
self.node_map = {node["id"]: node for node in self.nodes} self.node_map: dict[str, dict] = {}
self.end_node_map: dict[str, StreamOutputConfig] = {} self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep = lru_cache( self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
maxsize=len(self.nodes) * 2
)(self._find_upstream_activation_dep)
if variable_pool: if variable_pool:
self.variable_pool = variable_pool self.variable_pool = variable_pool
else: else:
self.variable_pool = VariablePool() self.variable_pool = VariablePool()
self.graph = StateGraph(WorkflowState) self.graph: StateGraph | None = None
self.add_nodes() self.nodes: list = []
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) self.edges: list = []
self.end_nodes = [ self.reachable_nodes: set[str] | None = None
node self.end_nodes: list[dict] = []
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list) self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj() self._adj: dict[str, list[str]] = defaultdict(list)
self._analyze_end_node_output()
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def get_node_type(self, node_id: str) -> str: def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID. """Retrieve the type of node given its ID.
@@ -108,13 +91,14 @@ class GraphBuilder:
result[node[0]].append(node[1]) result[node[0]].append(node[1])
return result return result
def _build_reverse_adj(self): def _build_adj(self):
for edge in self.edges: for edge in self.edges:
if edge["source"] not in self.reachable_nodes: if edge["source"] not in self.reachable_nodes:
continue continue
self._reverse_adj[edge.get("target")].append({ self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label") "id": edge["source"], "branch": edge.get("label")
}) })
self._adj[edge.get("source")].append(edge["target"])
def _find_upstream_activation_dep( def _find_upstream_activation_dep(
self, self,
@@ -302,22 +286,13 @@ class GraphBuilder:
""" """
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id") node_id = node.get("id")
cycle_node = node.get("cycle") if node_id not in self.reachable_nodes:
if cycle_node: continue
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
# Create node instance (start and end nodes are also created) # Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
if node_type in BRANCH_NODES: if node_type in BRANCH_NODES:
@@ -413,11 +388,12 @@ class GraphBuilder:
# Add conditional edges # Add conditional edges
for source_node, branches in conditional_edges.items(): for source_node, branches in conditional_edges.items():
def make_router(src, branch_list): def make_router(src, branch_list):
"""reate a router function for each source node that routes to a NOP node for later merging.""" """Create a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets): def make_branch_node(node_name, targets):
def node(s): def node(s):
# NOTE: NOP NODE MUST NOT MODIFY STATE # NOTE: NOP NODE USED FOR ROUTING ONLY.
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
return { return {
"activate": { "activate": {
node_id: s["activate"][node_name] node_id: s["activate"][node_name]
@@ -504,14 +480,52 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}") logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node # Connect End nodes to the global END node
for end_node in self.end_nodes: for node in self.reachable_nodes:
end_node_id = end_node.get("id") if not self._adj[node]:
if end_node_id: self.graph.add_edge(node, END)
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
for node in nodes:
if (node.get("cycle") or '') == self.cycle:
node_type = node.get("type")
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node.get("id")
elif node_type == NodeType.NOTES:
continue
self.nodes.append(node)
self.node_map[node.get("id")] = node
for edge in edges:
source_in = edge.get("source") in self.node_map
target_in = edge.get("target") in self.node_map
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
self.edges.append(edge)
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._build_adj()
self._find_upstream_activation_dep: Callable = lru_cache(
maxsize=len(self.nodes)*2
)(self._find_upstream_activation_dep)
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
self._analyze_end_node_output()
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
self.graph = self.graph.compile(checkpointer=checkpointer) return self.graph.compile(checkpointer=checkpointer)
return self.graph

View File

@@ -2,6 +2,7 @@
# Author: Eternity # Author: Eternity
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33 # @Time : 2026/2/10 13:33
from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
def build_final_output( def build_final_output(
self, self,
result: dict, result: dict,
execution_context: ExecutionContext,
variable_pool: VariablePool, variable_pool: VariablePool,
elapsed_time: float, elapsed_time: float,
final_output: str, final_output: str,
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
- "node_outputs" (dict): Outputs of executed nodes. - "node_outputs" (dict): Outputs of executed nodes.
- "messages" (list): Conversation messages exchanged during execution. - "messages" (list): Conversation messages exchanged during execution.
- "error" (str, optional): Error message if any node failed. - "error" (str, optional): Error message if any node failed.
execution_context (ExecutionContext): The execution context containing metadata like
execution ID, workspace ID, and user ID.)
variable_pool (VariablePool): Variable Pool variable_pool (VariablePool): Variable Pool
elapsed_time (float): Total execution time in seconds. elapsed_time (float): Total execution time in seconds.
final_output (Any): The aggregated or final output content of the workflow final_output (Any): The aggregated or final output content of the workflow
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
""" """
node_outputs = result.get("node_outputs", {}) node_outputs = result.get("node_outputs", {})
token_usage = self.aggregate_token_usage(node_outputs) token_usage = self.aggregate_token_usage(node_outputs)
conversation_id = variable_pool.get_value("sys.conversation_id") conversation_vars = {}
sys_vars = {}
if variable_pool:
conversation_vars = variable_pool.get_all_conversation_vars()
sys_vars = variable_pool.get_all_system_vars()
return { return {
"status": "completed" if success else "failed", "status": "completed" if success else "failed",
"output": final_output, "output": final_output,
"variables": { "variables": {
"conv": variable_pool.get_all_conversation_vars(), "conv": conversation_vars,
"sys": variable_pool.get_all_system_vars() "sys": sys_vars
}, },
"node_outputs": node_outputs, "node_outputs": node_outputs,
"messages": result.get("messages", []), "messages": result.get("messages", []),
"conversation_id": conversation_id, "conversation_id": execution_context.conversation_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": token_usage, "token_usage": token_usage,
"error": result.get("error"), "error": result.get("error"),

View File

@@ -12,6 +12,7 @@ class ExecutionContext(BaseModel):
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
conversation_id: str
memory_storage_type: str memory_storage_type: str
user_rag_memory_id: str user_rag_memory_id: str
checkpoint_config: RunnableConfig checkpoint_config: RunnableConfig
@@ -22,6 +23,7 @@ class ExecutionContext(BaseModel):
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str, user_id: str,
conversation_id: str,
memory_storage_type: str, memory_storage_type: str,
user_rag_memory_id: str user_rag_memory_id: str
): ):
@@ -29,6 +31,7 @@ class ExecutionContext(BaseModel):
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
conversation_id=conversation_id,
memory_storage_type=memory_storage_type, memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,

View File

@@ -3,6 +3,7 @@
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/9 13:51 # @Time : 2026/2/9 13:51
import datetime import datetime
import time
import logging import logging
from typing import Any from typing import Any
@@ -82,13 +83,15 @@ class WorkflowExecutor:
CompiledStateGraph: The compiled and ready-to-run state graph. CompiledStateGraph: The compiled and ready-to-run state graph.
""" """
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}") logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
start_time = time.time()
builder = GraphBuilder( builder = GraphBuilder(
self.workflow_config, self.workflow_config,
stream=stream, stream=stream,
) )
self.graph = builder.build()
self.start_node_id = builder.start_node_id self.start_node_id = builder.start_node_id
self.variable_pool = builder.variable_pool self.variable_pool = builder.variable_pool
self.graph = builder.build()
self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
self.event_handler = EventStreamHandler( self.event_handler = EventStreamHandler(
@@ -96,7 +99,8 @@ class WorkflowExecutor:
variable_pool=self.variable_pool, variable_pool=self.variable_pool,
execution_id=self.execution_context.execution_id execution_id=self.execution_context.execution_id
) )
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}") logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
f"cost: {time.time() - start_time:.4f}s")
return self.graph return self.graph
@@ -134,94 +138,12 @@ class WorkflowExecutor:
return event.get("data") return event.get("data")
return self.result_builder.build_final_output( return self.result_builder.build_final_output(
{"error": "Workflow execution did not end as expected"}, {"error": "Workflow execution did not end as expected"},
self.execution_context,
self.variable_pool, self.variable_pool,
(datetime.datetime.now() - start).total_seconds(), (datetime.datetime.now() - start).total_seconds(),
"", "",
success=False success=False
) )
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
#
# start_time = datetime.datetime.now()
#
# # Execute the workflow
# try:
# # Build the workflow graph
# graph = self.build_graph()
#
# # Initialize the variable pool with input data
# await self.variable_initializer.initialize(
# variable_pool=self.variable_pool,
# input_data=input_data,
# execution_context=self.execution_context
# )
# initial_state = self.state_manager.create_initial_state(
# workflow_config=self.workflow_config,
# input_data=input_data,
# execution_context=self.execution_context,
# start_node_id=self.start_node_id
# )
#
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
#
# # Aggregate output from all End nodes
# full_content = ''
# for end_id in self.stream_coordinator.end_outputs.keys():
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
#
# # Append messages for user and assistant
# if input_data.get("files"):
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "user",
# "content": input_data.get("files")
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# else:
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# # Calculate elapsed time
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.info(
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
#
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
#
# except Exception as e:
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
# exc_info=True)
# return {
# "status": "failed",
# "error": str(e),
# "output": None,
# "node_outputs": {},
# "elapsed_time": elapsed_time,
# "token_usage": None
# }
async def execute_stream( async def execute_stream(
self, self,
@@ -255,7 +177,7 @@ class WorkflowExecutor:
"data": { "data": {
"execution_id": self.execution_context.execution_id, "execution_id": self.execution_context.execution_id,
"workspace_id": self.execution_context.workspace_id, "workspace_id": self.execution_context.workspace_id,
"conversation_id": input_data.get("conversation_id"), "conversation_id": self.execution_context.conversation_id,
"timestamp": int(start_time.timestamp() * 1000) "timestamp": int(start_time.timestamp() * 1000)
} }
} }
@@ -376,6 +298,7 @@ class WorkflowExecutor:
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(
result, result,
self.execution_context,
self.variable_pool, self.variable_pool,
elapsed_time, elapsed_time,
full_content, full_content,
@@ -396,6 +319,7 @@ class WorkflowExecutor:
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(
result, result,
self.execution_context,
self.variable_pool, self.variable_pool,
elapsed_time, elapsed_time,
full_content, full_content,
@@ -432,6 +356,7 @@ async def execute_workflow(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
conversation_id=input_data.get("conversation_id"),
memory_storage_type=memory_storage_type, memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
) )
@@ -471,6 +396,7 @@ async def execute_workflow_stream(
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
memory_storage_type=memory_storage_type, memory_storage_type=memory_storage_type,
conversation_id=input_data.get("conversation_id"),
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(

View File

@@ -65,8 +65,6 @@ class AgentNode(BaseNode):
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
return release, message return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode): class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None

View File

@@ -28,7 +28,7 @@ class BaseNode(ABC):
All node types should inherit from this class and implement the `execute` method. All node types should inherit from this class and implement the `execute` method.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""Initialize the node. """Initialize the node.
Args: Args:
@@ -41,6 +41,7 @@ class BaseNode(ABC):
self.node_type = node_config["type"] self.node_type = node_config["type"]
self.cycle = node_config.get("cycle") self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id) self.node_name = node_config.get("name", self.node_id)
self.down_stream_nodes = down_stream_nodes
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {} self.error_handling = node_config.get("error_handling") or {}
@@ -93,18 +94,16 @@ class BaseNode(ABC):
dict: A dict with a single key 'activate', mapping node IDs to dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False). their activation status (True/False).
""" """
edges = self.workflow_config.get("edges") activate_flag = self.check_activate(state)
under_stream_nodes = [
edge.get("target") if self.node_type not in BRANCH_NODES:
for edge in edges activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES else:
] activate = {}
return {
"activate": { activate[self.node_id] = activate_flag
node_id: self.check_activate(state)
for node_id in under_stream_nodes return {"activate": activate}
} | {self.node_id: self.check_activate(state)}
}
@abstractmethod @abstractmethod
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -315,8 +314,8 @@ class BaseNode(ABC):
elapsed_time = (time.time() - start_time) * 1000 elapsed_time = (time.time() - start_time) * 1000
logger.info(f"Node {self.node_id} streaming execution finished, " logger.debug(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output) # Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result) extracted_output = self._extract_output(final_result)
@@ -428,8 +427,8 @@ class BaseNode(ABC):
when an error edge exists. If no error edge exists, this method when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow. raises an exception to stop the workflow.
""" """
# Check if the node has an error edge defined # # Check if the node has an error edge defined
error_edge = self._find_error_edge() # error_edge = self._find_error_edge()
# Extract input data (for logging or audit purposes) # Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool) input_data = self._extract_input(state, variable_pool)
@@ -447,27 +446,26 @@ class BaseNode(ABC):
"error": error_message "error": error_message
} }
if error_edge: # if error_edge:
# If an error edge exists, log a warning and continue to error node # # If an error edge exists, log a warning and continue to error node
logger.warning( # logger.warning(
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
) # )
return { # return {
"node_outputs": { # "node_outputs": {
self.node_id: node_output # self.node_id: node_output
}, # },
"error": error_message, # "error": error_message,
"error_node": self.node_id # "error_node": self.node_id
} # }
else: # else:
# If no error edge, send the error via stream writer and stop the workflow writer = get_stream_writer()
writer = get_stream_writer() writer({
writer({ "type": "node_error",
"type": "node_error", **node_output
**node_output })
}) logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") raise Exception(f"Node {self.node_id} execution failed: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit). """Extracts the input data for this node (used for logging or audit).
@@ -644,7 +642,7 @@ class BaseNode(ABC):
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, api_config=api_config) multimodal_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput( file_obj = FileInput(
type=content.type, type=content.type,
url=content.url, url=content.url,
@@ -653,7 +651,7 @@ class BaseNode(ABC):
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
) )
file_obj.set_content(content.get_content()) file_obj.set_content(content.get_content())
message = await multimodel_service.process_files( message = await multimodal_service.process_files(
[file_obj], [file_obj],
) )
content.set_content(file_obj.get_content()) content.set_content(file_obj.get_content())
@@ -661,7 +659,7 @@ class BaseNode(ABC):
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
return message return message
return None return None
raise TypeError(f'Unexpect input value type - {type(content)}') raise TypeError(f'Unexpected input value type - {type(content)}')
@staticmethod @staticmethod
def process_model_output(content) -> str: def process_model_output(content) -> str:

View File

@@ -51,8 +51,8 @@ console.log(result)
class CodeNode(BaseNode): class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: CodeNodeConfig | None = None self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
It acts as a container and execution controller for a subgraph. It acts as a container and execution controller for a subgraph.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT} outputs = {"__child_state": VariableType.ARRAY_OBJECT}
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
else: else:
remain_edges.append(edge) remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges # # Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [ # self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id # node for node in nodes if node.get("cycle") != self.node_id
] # ]
self.workflow_config["edges"] = remain_edges # self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges return cycle_nodes, cycle_edges
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
3. Compile the graph for runtime execution 3. Compile the graph for runtime execution
""" """
from app.core.workflow.engine.graph_builder import GraphBuilder from app.core.workflow.engine.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool() self.child_variable_pool = VariablePool()
builder = GraphBuilder( builder = GraphBuilder(
{ {
"nodes": self.cycle_nodes, "nodes": self.cycle_nodes,
"edges": self.cycle_edges, "edges": self.cycle_edges,
}, },
subgraph=True, variable_pool=self.child_variable_pool,
variable_pool=self.child_variable_pool cycle=self.node_id
) )
self.start_node_id = builder.start_node_id
self.graph = builder.build() self.graph = builder.build()
self.start_node_id = builder.start_node_id
self.child_variable_pool = builder.variable_pool self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
Raises: Raises:
RuntimeError: If the node type is unsupported. RuntimeError: If the node type is unsupported.
""" """
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
return await LoopRuntime( return await LoopRuntime(
start_id=self.start_node_id, start_id=self.start_node_id,
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
raise RuntimeError("Unknown cycle node type") raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
yield { yield {
"__final__": True, "__final__": True,

View File

@@ -1,9 +1,7 @@
"""End 节点配置""" """End 节点配置"""
from pydantic import Field from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
class EndNodeConfig(BaseNodeConfig): class EndNodeConfig(BaseNodeConfig):

View File

@@ -36,8 +36,6 @@ class EndNode(BaseNode):
Returns: Returns:
最终输出字符串 最终输出字符串
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
output = self._render_template(output_template, variable_pool, strict=False) output = self._render_template(output_template, variable_pool, strict=False)
else: else:
output = "" output = ""
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output return output

View File

@@ -28,7 +28,7 @@ class NodeType(StrEnum):
NOTES = "notes" NOTES = "notes"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
) )
class HttpErrorDefaultTamplete(BaseModel): class HttpErrorDefaultTemplate(BaseModel):
body: str = Field( body: str = Field(
default="", default="",
description="Default body returned on HTTP error", description="Default body returned on HTTP error",
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
description="Error handling strategy: 'none', 'default', or 'branch'", description="Error handling strategy: 'none', 'default', or 'branch'",
) )
default: HttpErrorDefaultTamplete | None = Field( default: HttpErrorDefaultTemplate | None = Field(
default=None, default=None,
description="Default response template for error handling", description="Default response template for error handling",
) )

View File

@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.utils.file_processer import mime_to_file_type from app.core.workflow.utils.file_processor import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod from app.schemas import FileType, TransferMethod
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
or a branch identifier string when error branching is enabled. or a branch identifier string when error branching is enabled.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: IfElseNodeConfig | None = None self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: JinjaRenderNodeConfig | None = None self.typed_config: JinjaRenderNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode): class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None self.vector_service: ElasticSearchVector | None = None

View File

@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: LLMNodeConfig | None = None self.typed_config: LLMNodeConfig | None = None
self.messages = [] self.messages = []

View File

@@ -14,8 +14,8 @@ from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryReadNodeConfig | None = None self.typed_config: MemoryReadNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
@@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode):
class MemoryWriteNode(BaseNode): class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryWriteNodeConfig | None = None self.typed_config: MemoryWriteNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -104,13 +104,15 @@ class NodeFactory:
def create_node( def create_node(
cls, cls,
node_config: dict[str, Any], node_config: dict[str, Any],
workflow_config: dict[str, Any] workflow_config: dict[str, Any],
down_stream_nodes: list[str]
) -> WorkflowNode | None: ) -> WorkflowNode | None:
"""创建节点实例 """创建节点实例
Args: Args:
node_config: 节点配置 node_config: 节点配置
workflow_config: 工作流配置 workflow_config: 工作流配置
down_stream_nodes: 下游节点
Returns: Returns:
节点实例或 None对于不支持的节点类型 节点实例或 None对于不支持的节点类型
@@ -127,7 +129,7 @@ class NodeFactory:
# 创建节点实例 # 创建节点实例
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config) return node_class(node_config, workflow_config, down_stream_nodes)
@classmethod @classmethod
def get_supported_types(cls) -> list[str]: def get_supported_types(cls) -> list[str]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode): class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ParameterExtractorNodeConfig | None = None self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {} self.response_metadata = {}

View File

@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: QuestionClassifierNodeConfig | None = None self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {} self.category_to_case_map = {}
self.response_metadata = {} self.response_metadata = {}

View File

@@ -27,14 +27,8 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。 注意:变量的验证和默认值处理由 Executor 在初始化时完成。
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""初始化 Start 节点 super().__init__(node_config, workflow_config, down_stream_nodes)
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置 # 解析并验证配置
self.typed_config: StartNodeConfig | None = None self.typed_config: StartNodeConfig | None = None
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
包含系统参数、会话变量和自定义变量的字典 包含系统参数、会话变量和自定义变量的字典
""" """
self.typed_config = StartNodeConfig(**self.config) self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 处理自定义变量(传入 pool 避免重复创建) # 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(variable_pool) custom_vars = self._process_custom_variables(variable_pool)
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
**custom_vars # 自定义变量作为节点输出的一部分 **custom_vars # 自定义变量作为节点输出的一部分
} }
logger.info( logger.debug(
f"节点 {self.node_id} (Start) 执行完成," f"Node {self.node_id} (Start) execution completed, "
f"输出了 {len(custom_vars)} 个自定义变量" f"outputting {len(custom_vars)} custom variables"
) )
return result return result

View File

@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
class ToolNode(BaseNode): class ToolNode(BaseNode):
"""工具节点""" """工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ToolNodeConfig | None = None self.typed_config: ToolNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode): class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: VariableAggregatorNodeConfig | None = None self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -153,7 +153,8 @@ class TemplateRenderer:
# 全局渲染器实例(严格模式) # 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True) _strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template( def render_template(
@@ -184,7 +185,7 @@ def render_template(
... ) ... )
'请分析: 这是一段文本' '请分析: 这是一段文本'
""" """
renderer = TemplateRenderer(strict=strict) renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars) return renderer.render(template, conv_vars, node_outputs, system_vars)
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
Returns: Returns:
错误列表 错误列表
""" """
return _default_renderer.validate(template) return _strict_renderer.validate(template)

View File

@@ -6,6 +6,7 @@
import copy import copy
import logging import logging
from collections import defaultdict, deque
from typing import Any, Union, TYPE_CHECKING from typing import Any, Union, TYPE_CHECKING
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
@@ -119,7 +120,6 @@ class WorkflowValidator:
errors = [] errors = []
graphs = cls.get_subgraph(workflow_config) graphs = cls.get_subgraph(workflow_config)
logger.info(graphs)
for index, graph in enumerate(graphs): for index, graph in enumerate(graphs):
nodes = graph.get("nodes", []) nodes = graph.get("nodes", [])
edges = graph.get("edges", []) edges = graph.get("edges", [])
@@ -183,7 +183,7 @@ class WorkflowValidator:
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle: if has_cycle:
errors.append( errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
) )
# 8. 验证变量名 # 8. 验证变量名
@@ -204,18 +204,18 @@ class WorkflowValidator:
Returns: Returns:
可达节点 ID 集合 可达节点 ID 集合
""" """
adj = defaultdict(list)
for edge in edges:
adj[edge["source"]].append(edge["target"])
reachable = {start_id} reachable = {start_id}
queue = [start_id] queue = deque([start_id])
while queue: while queue:
current = queue.pop(0) current = queue.popleft()
for edge in edges: for target in adj[current]:
if edge.get("source") == current: if target not in reachable:
target = edge.get("target") reachable.add(target)
if target and target not in reachable: queue.append(target)
reachable.add(target)
queue.append(target)
return reachable return reachable
@staticmethod @staticmethod
@@ -229,10 +229,6 @@ class WorkflowValidator:
Returns: Returns:
(has_cycle, cycle_path): 是否有循环和循环路径 (has_cycle, cycle_path): 是否有循环和循环路径
""" """
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {} graph: dict[str, list[str]] = {}
for edge in edges: for edge in edges:
source = edge.get("source") source = edge.get("source")
@@ -243,10 +239,6 @@ class WorkflowValidator:
if edge_type == "error": if edge_type == "error":
continue continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target: if source and target:
if source not in graph: if source not in graph:
graph[source] = [] graph[source] = []

View File

@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
def valid_value(self, value) -> dict: def valid_value(self, value) -> dict:
if not isinstance(value, dict): if not isinstance(value, dict):
raise TypeError(f"Value must be a dict - {type(value)}:{value}") raise TypeError(f"Value must be a dict - {type(value)}:{value}")
return value return value
def to_literal(self) -> str: def to_literal(self) -> str:

View File

@@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration
from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.schemas import AppCreate from app.schemas import AppCreate
from app.schemas.workflow_schema import WorkflowConfigCreate from app.schemas.workflow_schema import WorkflowConfigCreate
@@ -46,7 +46,7 @@ class WorkflowImportService:
success=False, success=False,
temp_id=None, temp_id=None,
workflow_id=None, workflow_id=None,
errors=[UnsupportPlatform(platform=platform)] errors=[UnsupportedPlatform(platform=platform)]
) )
adapter = self.registry.get_adapter(platform, config) adapter = self.registry.get_adapter(platform, config)