Merge pull request #682 from SuanmoSuanyangTechnology/pref/workflow-engine
pref(workflow): optimize workflow execution performance and reduce logging noise
This commit is contained in:
@@ -1099,7 +1099,6 @@ class ExtractionOrchestrator:
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
chunk_nodes.append(chunk_node)
|
||||
logger.error(f"chunk file: {chunk.files}")
|
||||
|
||||
for p, file_type in chunk.files:
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition
|
||||
from app.schemas.workflow_schema import (
|
||||
EdgeDefinition,
|
||||
NodeDefinition,
|
||||
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowImportResult(BaseModel):
|
||||
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
|
||||
@@ -9,9 +9,9 @@ from urllib.parse import quote
|
||||
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import (
|
||||
UnsupportVariableType,
|
||||
UnknowModelWarning,
|
||||
ExceptionDefineition,
|
||||
UnsupportedVariableType,
|
||||
UnknownModelWarning,
|
||||
ExceptionDefinition,
|
||||
ExceptionType
|
||||
)
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
|
||||
HttpFormData,
|
||||
HttpTimeOutConfig,
|
||||
HttpRetryConfig,
|
||||
HttpErrorDefaultTamplete,
|
||||
HttpErrorDefaultTemplate,
|
||||
HttpErrorHandleConfig
|
||||
)
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
|
||||
try:
|
||||
return config.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
|
||||
var_selector = mapping.get(var_selector, var_selector)
|
||||
return var_selector
|
||||
|
||||
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
|
||||
def _process_list_variable_literal(self, variable_selector: list) -> str | None:
|
||||
if not self.process_var_selector(".".join(variable_selector)):
|
||||
return None
|
||||
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
||||
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
|
||||
var_type = self.variable_type_map(var["type"])
|
||||
if not var_type:
|
||||
self.errors.append(
|
||||
UnsupportVariableType(
|
||||
UnsupportedVariableType(
|
||||
scope=node["id"],
|
||||
name=var["variable"],
|
||||
var_type=var["type"],
|
||||
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
if var_type in ["file", "array[file]"]:
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
|
||||
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
UnknownModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
|
||||
result = QuestionClassifierNodeConfig.model_construct(
|
||||
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
|
||||
input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
|
||||
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
||||
categories=categories,
|
||||
).model_dump()
|
||||
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
|
||||
def convert_llm_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
UnknownModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
|
||||
context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
|
||||
memory = MemoryWindowSetting(
|
||||
enable=bool(node_data.get("memory")),
|
||||
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
||||
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
)
|
||||
vision = node_data["vision"]["enabled"]
|
||||
vision_input = self._process_list_variable_litearl(
|
||||
vision_input = self._process_list_variable_literal(
|
||||
node_data["vision"]["configs"]["variable_selector"]
|
||||
) if vision else None
|
||||
result = LLMNodeConfig.model_construct(
|
||||
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
|
||||
conditions.append(
|
||||
LoopConditionDetail.model_construct(
|
||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
||||
left=self._process_list_variable_literal(condition["variable_selector"]),
|
||||
right=self.trans_variable_format(
|
||||
right_value
|
||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
|
||||
right_input_type = variable["value_type"]
|
||||
right_value_type = self.variable_type_map(variable["var_type"])
|
||||
if right_input_type == ValueInputType.VARIABLE:
|
||||
right_value = self._process_list_variable_litearl(variable.get("value", ""))
|
||||
right_value = self._process_list_variable_literal(variable.get("value", ""))
|
||||
else:
|
||||
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
||||
loop_variables.append(
|
||||
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
|
||||
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
result = IterationNodeConfig.model_construct(
|
||||
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
||||
input=self._process_list_variable_literal(node_data["iterator_selector"]),
|
||||
parallel=node_data["is_parallel"],
|
||||
parallel_count=node_data["parallel_nums"],
|
||||
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
||||
output=self._process_list_variable_literal(node_data["output_selector"]),
|
||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||
flatten=node_data["flatten_output"],
|
||||
).model_dump()
|
||||
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
|
||||
continue
|
||||
assignments.append(
|
||||
AssignmentItem(
|
||||
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
|
||||
value=self._process_list_variable_litearl(
|
||||
variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
|
||||
value=self._process_list_variable_literal(
|
||||
assignment["value"]
|
||||
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
||||
operation=self.convert_assignment_operator(assignment["operation"])
|
||||
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
|
||||
input_variables.append(
|
||||
InputVariable.model_construct(
|
||||
name=input_variable["variable"],
|
||||
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
||||
variable=self._process_list_variable_literal(input_variable["value_selector"]),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
|
||||
else:
|
||||
if node_data["body"]["data"]:
|
||||
body_content = (node_data["body"]["data"][0].get("value") or
|
||||
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
|
||||
self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
|
||||
else:
|
||||
body_content = ""
|
||||
|
||||
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
|
||||
default_header = var["value"]
|
||||
elif var["key"] == "status_code":
|
||||
default_status_code = var["value"]
|
||||
default_value = HttpErrorDefaultTamplete(
|
||||
default_value = HttpErrorDefaultTemplate(
|
||||
body=default_body,
|
||||
headers=default_header,
|
||||
status_code=default_status_code,
|
||||
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
|
||||
for variable in node_data["variables"]:
|
||||
mapping.append(VariablesMappingConfig.model_construct(
|
||||
name=variable["variable"],
|
||||
value=self._process_list_variable_litearl(variable["value_selector"])
|
||||
value=self._process_list_variable_literal(variable["value_selector"])
|
||||
))
|
||||
result = JinjaRenderNodeConfig.model_construct(
|
||||
template=node_data["template"],
|
||||
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||
))
|
||||
result = KnowledgeRetrievalNodeConfig.model_construct(
|
||||
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||
query=self._process_list_variable_literal(node_data["query_variable_selector"]),
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
||||
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
|
||||
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
UnknownModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
)
|
||||
result = ParameterExtractorNodeConfig.model_construct(
|
||||
text=self._process_list_variable_litearl(node_data["query"]),
|
||||
text=self._process_list_variable_literal(node_data["query"]),
|
||||
params=params,
|
||||
prompt=node_data.get("instruction")
|
||||
).model_dump()
|
||||
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
|
||||
group_type = {}
|
||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||
group_variables = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
self._process_list_variable_literal(variable)
|
||||
for variable in node_data["variables"]
|
||||
]
|
||||
group_type["output"] = node_data["output_type"]
|
||||
else:
|
||||
for group in advanced_settings["groups"]:
|
||||
group_variables[group["group_name"]] = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
self._process_list_variable_literal(variable)
|
||||
for variable in group["variables"]
|
||||
]
|
||||
group_type[group["group_name"]] = group["output_type"]
|
||||
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
def convert_tool_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.core.workflow.adapters.dify.converter import DifyConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import (
|
||||
NodeDefinition,
|
||||
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
if self.config.get("app", {}).get("mode") == "workflow":
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.PLATFORM,
|
||||
detail="workflow mode is not supported"
|
||||
))
|
||||
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
edge = self._convert_edge(edge)
|
||||
if edge:
|
||||
self.edges.append(edge)
|
||||
#
|
||||
|
||||
for variable in self.config.get("workflow").get("conversation_variables"):
|
||||
con_var = self._convert_variable(variable)
|
||||
if variable:
|
||||
self.conv_variables.append(con_var)
|
||||
#
|
||||
|
||||
# for variables in config.get("workflow").get("environment_variables"):
|
||||
# variable = self._convert_variable(variables)
|
||||
# conv_variables.append(variable)
|
||||
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"y": node["position"]["y"] + position["y"]
|
||||
}
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
detail="parent cycle node not found"
|
||||
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
node_data = node["data"]
|
||||
converter = self.get_node_convert(node_type)
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
))
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
||||
try:
|
||||
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
label = None
|
||||
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
label=label,
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}",
|
||||
))
|
||||
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
description=variable.get("description")
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}",
|
||||
|
||||
@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ExceptionDefineition(BaseModel):
|
||||
class ExceptionDefinition(BaseModel):
|
||||
type: ExceptionType
|
||||
detail: str
|
||||
|
||||
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class UnknowModelWarning(ExceptionDefineition):
|
||||
class UnknownModelWarning(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id, node_name, model_name):
|
||||
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
|
||||
)
|
||||
|
||||
|
||||
class UnknowError(ExceptionDefineition):
|
||||
class UnknownError(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.UNKNOWN
|
||||
|
||||
def __init__(self, detail: str, **kwargs):
|
||||
super().__init__(detail=detail, **kwargs)
|
||||
|
||||
|
||||
class UnsupportPlatform(ExceptionDefineition):
|
||||
class UnsupportedPlatform(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.PLATFORM
|
||||
|
||||
def __init__(self, platform: str):
|
||||
super().__init__(detail=f"Unsupport platform {platform}")
|
||||
super().__init__(detail=f"Unsupported platform {platform}")
|
||||
|
||||
|
||||
class UnsupportVariableType(ExceptionDefineition):
|
||||
class UnsupportedVariableType(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.VARIABLE
|
||||
|
||||
def __init__(self, scope, name, var_type: str, **kwargs):
|
||||
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
|
||||
super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
|
||||
|
||||
|
||||
class InvalidConfiguration(ExceptionDefineition):
|
||||
class InvalidConfiguration(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.CONFIG
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(detail="Invalid workflow configuration format")
|
||||
|
||||
|
||||
class UnsupportNodeType(ExceptionDefineition):
|
||||
class UnsupportedNodeType(ExceptionDefinition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id: str, node_type: str):
|
||||
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
|
||||
super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")
|
||||
|
||||
@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
try:
|
||||
node_type = self.map_node_type(node["type"])
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(UnsupportNodeType(
|
||||
self.errors.append(UnsupportedNodeType(
|
||||
node_id=node_id,
|
||||
node_type=node["type"]
|
||||
))
|
||||
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
|
||||
return NodeDefinition(**node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||
try:
|
||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||
))
|
||||
return None
|
||||
return EdgeDefinition(**edge)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}"
|
||||
))
|
||||
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
try:
|
||||
return VariableDefinition(**variable)
|
||||
except Exception as e:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
self.warnings.append(ExceptionDefinition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
|
||||
try:
|
||||
return config_cls.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
self.errors.append(ExceptionDefinition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any, Iterable
|
||||
from typing import Any, Iterable, Callable
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
@@ -41,48 +41,31 @@ class GraphBuilder:
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
stream: bool = False,
|
||||
subgraph: bool = False,
|
||||
cycle: str = '',
|
||||
variable_pool: VariablePool | None = None
|
||||
):
|
||||
self.workflow_config = workflow_config
|
||||
|
||||
self.stream = stream
|
||||
self.subgraph = subgraph
|
||||
self.cycle = cycle
|
||||
|
||||
self.start_node_id: str | None = None
|
||||
|
||||
self.node_map = {node["id"]: node for node in self.nodes}
|
||||
self.node_map: dict[str, dict] = {}
|
||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||
self._find_upstream_activation_dep = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_activation_dep)
|
||||
self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
|
||||
if variable_pool:
|
||||
self.variable_pool = variable_pool
|
||||
else:
|
||||
self.variable_pool = VariablePool()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self.add_edges()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
self.graph: StateGraph | None = None
|
||||
self.nodes: list = []
|
||||
self.edges: list = []
|
||||
self.reachable_nodes: set[str] | None = None
|
||||
self.end_nodes: list[dict] = []
|
||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||
self._build_reverse_adj()
|
||||
self._analyze_end_node_output()
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("nodes", [])
|
||||
|
||||
@property
|
||||
def edges(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("edges", [])
|
||||
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
def get_node_type(self, node_id: str) -> str:
|
||||
"""Retrieve the type of node given its ID.
|
||||
@@ -108,13 +91,14 @@ class GraphBuilder:
|
||||
result[node[0]].append(node[1])
|
||||
return result
|
||||
|
||||
def _build_reverse_adj(self):
|
||||
def _build_adj(self):
|
||||
for edge in self.edges:
|
||||
if edge["source"] not in self.reachable_nodes:
|
||||
continue
|
||||
self._reverse_adj[edge.get("target")].append({
|
||||
"id": edge["source"], "branch": edge.get("label")
|
||||
})
|
||||
self._adj[edge.get("source")].append(edge["target"])
|
||||
|
||||
def _find_upstream_activation_dep(
|
||||
self,
|
||||
@@ -302,22 +286,13 @@ class GraphBuilder:
|
||||
"""
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type == NodeType.NOTES:
|
||||
continue
|
||||
node_id = node.get("id")
|
||||
cycle_node = node.get("cycle")
|
||||
if cycle_node:
|
||||
# Nodes within a loop subgraph are constructed by CycleGraphNode
|
||||
if not self.subgraph:
|
||||
continue
|
||||
|
||||
# Record start and end node IDs
|
||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||
self.start_node_id = node_id
|
||||
if node_id not in self.reachable_nodes:
|
||||
continue
|
||||
|
||||
# Create node instance (start and end nodes are also created)
|
||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
|
||||
|
||||
if node_type in BRANCH_NODES:
|
||||
|
||||
@@ -413,11 +388,12 @@ class GraphBuilder:
|
||||
# Add conditional edges
|
||||
for source_node, branches in conditional_edges.items():
|
||||
def make_router(src, branch_list):
|
||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||
"""Create a router function for each source node that routes to a NOP node for later merging."""
|
||||
|
||||
def make_branch_node(node_name, targets):
|
||||
def node(s):
|
||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||
# NOTE: NOP NODE USED FOR ROUTING ONLY.
|
||||
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
|
||||
return {
|
||||
"activate": {
|
||||
node_id: s["activate"][node_name]
|
||||
@@ -504,14 +480,52 @@ class GraphBuilder:
|
||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||
|
||||
# Connect End nodes to the global END node
|
||||
for end_node in self.end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
if end_node_id:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||
for node in self.reachable_nodes:
|
||||
if not self._adj[node]:
|
||||
self.graph.add_edge(node, END)
|
||||
return
|
||||
|
||||
def build(self) -> CompiledStateGraph:
|
||||
nodes = self.workflow_config.get("nodes", [])
|
||||
edges = self.workflow_config.get("edges", [])
|
||||
|
||||
for node in nodes:
|
||||
if (node.get("cycle") or '') == self.cycle:
|
||||
node_type = node.get("type")
|
||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||
self.start_node_id = node.get("id")
|
||||
elif node_type == NodeType.NOTES:
|
||||
continue
|
||||
self.nodes.append(node)
|
||||
self.node_map[node.get("id")] = node
|
||||
|
||||
for edge in edges:
|
||||
source_in = edge.get("source") in self.node_map
|
||||
target_in = edge.get("target") in self.node_map
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
self.edges.append(edge)
|
||||
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self._build_adj()
|
||||
self._find_upstream_activation_dep: Callable = lru_cache(
|
||||
maxsize=len(self.nodes)*2
|
||||
)(self._find_upstream_activation_dep)
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges()
|
||||
|
||||
self._analyze_end_node_output()
|
||||
checkpointer = InMemorySaver()
|
||||
self.graph = self.graph.compile(checkpointer=checkpointer)
|
||||
return self.graph
|
||||
return self.graph.compile(checkpointer=checkpointer)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
|
||||
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
|
||||
def build_final_output(
|
||||
self,
|
||||
result: dict,
|
||||
execution_context: ExecutionContext,
|
||||
variable_pool: VariablePool,
|
||||
elapsed_time: float,
|
||||
final_output: str,
|
||||
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
execution_context (ExecutionContext): The execution context containing metadata like
|
||||
execution ID, workspace ID, and user ID.)
|
||||
variable_pool (VariablePool): Variable Pool
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self.aggregate_token_usage(node_outputs)
|
||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||
conversation_vars = {}
|
||||
sys_vars = {}
|
||||
|
||||
if variable_pool:
|
||||
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||
sys_vars = variable_pool.get_all_system_vars()
|
||||
|
||||
return {
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": variable_pool.get_all_conversation_vars(),
|
||||
"sys": variable_pool.get_all_system_vars()
|
||||
"conv": conversation_vars,
|
||||
"sys": sys_vars
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"conversation_id": execution_context.conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
|
||||
@@ -12,6 +12,7 @@ class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
conversation_id: str
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
@@ -22,6 +23,7 @@ class ExecutionContext(BaseModel):
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
@@ -29,6 +31,7 @@ class ExecutionContext(BaseModel):
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 13:51
|
||||
import datetime
|
||||
import time
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -82,13 +83,15 @@ class WorkflowExecutor:
|
||||
CompiledStateGraph: The compiled and ready-to-run state graph.
|
||||
"""
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
||||
start_time = time.time()
|
||||
builder = GraphBuilder(
|
||||
self.workflow_config,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
self.graph = builder.build()
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.variable_pool = builder.variable_pool
|
||||
self.graph = builder.build()
|
||||
|
||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||
self.event_handler = EventStreamHandler(
|
||||
@@ -96,7 +99,8 @@ class WorkflowExecutor:
|
||||
variable_pool=self.variable_pool,
|
||||
execution_id=self.execution_context.execution_id
|
||||
)
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
|
||||
f"cost: {time.time() - start_time:.4f}s")
|
||||
|
||||
return self.graph
|
||||
|
||||
@@ -134,94 +138,12 @@ class WorkflowExecutor:
|
||||
return event.get("data")
|
||||
return self.result_builder.build_final_output(
|
||||
{"error": "Workflow execution did not end as expected"},
|
||||
self.execution_context,
|
||||
self.variable_pool,
|
||||
(datetime.datetime.now() - start).total_seconds(),
|
||||
"",
|
||||
success=False
|
||||
)
|
||||
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
#
|
||||
# start_time = datetime.datetime.now()
|
||||
#
|
||||
# # Execute the workflow
|
||||
# try:
|
||||
# # Build the workflow graph
|
||||
# graph = self.build_graph()
|
||||
#
|
||||
# # Initialize the variable pool with input data
|
||||
# await self.variable_initializer.initialize(
|
||||
# variable_pool=self.variable_pool,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context
|
||||
# )
|
||||
# initial_state = self.state_manager.create_initial_state(
|
||||
# workflow_config=self.workflow_config,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context,
|
||||
# start_node_id=self.start_node_id
|
||||
# )
|
||||
#
|
||||
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
#
|
||||
# # Aggregate output from all End nodes
|
||||
# full_content = ''
|
||||
# for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
#
|
||||
# # Append messages for user and assistant
|
||||
# if input_data.get("files"):
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("files")
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# else:
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# # Calculate elapsed time
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.info(
|
||||
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
#
|
||||
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
#
|
||||
# except Exception as e:
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
# exc_info=True)
|
||||
# return {
|
||||
# "status": "failed",
|
||||
# "error": str(e),
|
||||
# "output": None,
|
||||
# "node_outputs": {},
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "token_usage": None
|
||||
# }
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
@@ -255,7 +177,7 @@ class WorkflowExecutor:
|
||||
"data": {
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"workspace_id": self.execution_context.workspace_id,
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
"conversation_id": self.execution_context.conversation_id,
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
}
|
||||
@@ -376,6 +298,7 @@ class WorkflowExecutor:
|
||||
"event": "workflow_end",
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.execution_context,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
@@ -396,6 +319,7 @@ class WorkflowExecutor:
|
||||
"event": "workflow_end",
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.execution_context,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
@@ -432,6 +356,7 @@ async def execute_workflow(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
conversation_id=input_data.get("conversation_id"),
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
@@ -471,6 +396,7 @@ async def execute_workflow_stream(
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
conversation_id=input_data.get("conversation_id"),
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
|
||||
@@ -64,9 +64,7 @@ class AgentNode(BaseNode):
|
||||
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
|
||||
|
||||
return release, message
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
|
||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.variable_updater = True
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class BaseNode(ABC):
|
||||
All node types should inherit from this class and implement the `execute` method.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
"""Initialize the node.
|
||||
|
||||
Args:
|
||||
@@ -41,6 +41,7 @@ class BaseNode(ABC):
|
||||
self.node_type = node_config["type"]
|
||||
self.cycle = node_config.get("cycle")
|
||||
self.node_name = node_config.get("name", self.node_id)
|
||||
self.down_stream_nodes = down_stream_nodes
|
||||
# 使用 or 运算符处理 None 值
|
||||
self.config = node_config.get("config") or {}
|
||||
self.error_handling = node_config.get("error_handling") or {}
|
||||
@@ -93,18 +94,16 @@ class BaseNode(ABC):
|
||||
dict: A dict with a single key 'activate', mapping node IDs to
|
||||
their activation status (True/False).
|
||||
"""
|
||||
edges = self.workflow_config.get("edges")
|
||||
under_stream_nodes = [
|
||||
edge.get("target")
|
||||
for edge in edges
|
||||
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
|
||||
]
|
||||
return {
|
||||
"activate": {
|
||||
node_id: self.check_activate(state)
|
||||
for node_id in under_stream_nodes
|
||||
} | {self.node_id: self.check_activate(state)}
|
||||
}
|
||||
activate_flag = self.check_activate(state)
|
||||
|
||||
if self.node_type not in BRANCH_NODES:
|
||||
activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
|
||||
else:
|
||||
activate = {}
|
||||
|
||||
activate[self.node_id] = activate_flag
|
||||
|
||||
return {"activate": activate}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
@@ -315,8 +314,8 @@ class BaseNode(ABC):
|
||||
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||
logger.debug(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
@@ -428,8 +427,8 @@ class BaseNode(ABC):
|
||||
when an error edge exists. If no error edge exists, this method
|
||||
raises an exception to stop the workflow.
|
||||
"""
|
||||
# Check if the node has an error edge defined
|
||||
error_edge = self._find_error_edge()
|
||||
# # Check if the node has an error edge defined
|
||||
# error_edge = self._find_error_edge()
|
||||
|
||||
# Extract input data (for logging or audit purposes)
|
||||
input_data = self._extract_input(state, variable_pool)
|
||||
@@ -447,27 +446,26 @@ class BaseNode(ABC):
|
||||
"error": error_message
|
||||
}
|
||||
|
||||
if error_edge:
|
||||
# If an error edge exists, log a warning and continue to error node
|
||||
logger.warning(
|
||||
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||
)
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: node_output
|
||||
},
|
||||
"error": error_message,
|
||||
"error_node": self.node_id
|
||||
}
|
||||
else:
|
||||
# If no error edge, send the error via stream writer and stop the workflow
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "node_error",
|
||||
**node_output
|
||||
})
|
||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||
# if error_edge:
|
||||
# # If an error edge exists, log a warning and continue to error node
|
||||
# logger.warning(
|
||||
# f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||
# )
|
||||
# return {
|
||||
# "node_outputs": {
|
||||
# self.node_id: node_output
|
||||
# },
|
||||
# "error": error_message,
|
||||
# "error_node": self.node_id
|
||||
# }
|
||||
# else:
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "node_error",
|
||||
**node_output
|
||||
})
|
||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""Extracts the input data for this node (used for logging or audit).
|
||||
@@ -644,7 +642,7 @@ class BaseNode(ABC):
|
||||
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
multimodal_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
@@ -653,7 +651,7 @@ class BaseNode(ABC):
|
||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
message = await multimodal_service.process_files(
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
@@ -661,7 +659,7 @@ class BaseNode(ABC):
|
||||
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||
return message
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
raise TypeError(f'Unexpected input value type - {type(content)}')
|
||||
|
||||
@staticmethod
|
||||
def process_model_output(content) -> str:
|
||||
|
||||
@@ -51,8 +51,8 @@ console.log(result)
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: CodeNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
|
||||
It acts as a container and execution controller for a subgraph.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.child_variable_pool: VariablePool | None = None
|
||||
self.build_graph()
|
||||
self.iteration_flag = True
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
|
||||
else:
|
||||
remain_edges.append(edge)
|
||||
|
||||
# Update workflow_config by removing cycle nodes and internal edges
|
||||
self.workflow_config["nodes"] = [
|
||||
node for node in nodes if node.get("cycle") != self.node_id
|
||||
]
|
||||
self.workflow_config["edges"] = remain_edges
|
||||
# # Update workflow_config by removing cycle nodes and internal edges
|
||||
# self.workflow_config["nodes"] = [
|
||||
# node for node in nodes if node.get("cycle") != self.node_id
|
||||
# ]
|
||||
# self.workflow_config["edges"] = remain_edges
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
|
||||
3. Compile the graph for runtime execution
|
||||
"""
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
"nodes": self.cycle_nodes,
|
||||
"edges": self.cycle_edges,
|
||||
},
|
||||
subgraph=True,
|
||||
variable_pool=self.child_variable_pool
|
||||
variable_pool=self.child_variable_pool,
|
||||
cycle=self.node_id
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.graph = builder.build()
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.child_variable_pool = builder.variable_pool
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If the node type is unsupported.
|
||||
"""
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
yield {
|
||||
"__final__": True,
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""End 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class EndNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -36,8 +36,6 @@ class EndNode(BaseNode):
|
||||
Returns:
|
||||
最终输出字符串
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
|
||||
output = self._render_template(output_template, variable_pool, strict=False)
|
||||
else:
|
||||
output = ""
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
return output
|
||||
|
||||
@@ -28,7 +28,7 @@ class NodeType(StrEnum):
|
||||
NOTES = "notes"
|
||||
|
||||
|
||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
|
||||
@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class HttpErrorDefaultTamplete(BaseModel):
|
||||
class HttpErrorDefaultTemplate(BaseModel):
|
||||
body: str = Field(
|
||||
default="",
|
||||
description="Default body returned on HTTP error",
|
||||
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
|
||||
description="Error handling strategy: 'none', 'default', or 'branch'",
|
||||
)
|
||||
|
||||
default: HttpErrorDefaultTamplete | None = Field(
|
||||
default: HttpErrorDefaultTemplate | None = Field(
|
||||
default=None,
|
||||
description="Default response template for error handling",
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
||||
from app.core.workflow.utils.file_processor import mime_to_file_type
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.schemas import FileType, TransferMethod
|
||||
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
|
||||
or a branch identifier string when error branching is enabled.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: IfElseNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
|
||||
@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: LLMNodeConfig | None = None
|
||||
self.messages = []
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ from app.tasks import write_message_task
|
||||
|
||||
|
||||
class MemoryReadNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: MemoryReadNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
@@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode):
|
||||
|
||||
|
||||
class MemoryWriteNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -104,13 +104,15 @@ class NodeFactory:
|
||||
def create_node(
|
||||
cls,
|
||||
node_config: dict[str, Any],
|
||||
workflow_config: dict[str, Any]
|
||||
workflow_config: dict[str, Any],
|
||||
down_stream_nodes: list[str]
|
||||
) -> WorkflowNode | None:
|
||||
"""创建节点实例
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
down_stream_nodes: 下游节点
|
||||
|
||||
Returns:
|
||||
节点实例或 None(对于不支持的节点类型)
|
||||
@@ -127,7 +129,7 @@ class NodeFactory:
|
||||
|
||||
# 创建节点实例
|
||||
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
||||
return node_class(node_config, workflow_config)
|
||||
return node_class(node_config, workflow_config, down_stream_nodes)
|
||||
|
||||
@classmethod
|
||||
def get_supported_types(cls) -> list[str]:
|
||||
|
||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
self.response_metadata = {}
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
"""问题分类器节点"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
self.response_metadata = {}
|
||||
|
||||
@@ -27,14 +27,8 @@ class StartNode(BaseNode):
|
||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化 Start 节点
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
"""
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config: StartNodeConfig | None = None
|
||||
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
"""
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 处理自定义变量(传入 pool 避免重复创建)
|
||||
custom_vars = self._process_custom_variables(variable_pool)
|
||||
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
|
||||
**custom_vars # 自定义变量作为节点输出的一部分
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"节点 {self.node_id} (Start) 执行完成,"
|
||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
||||
logger.debug(
|
||||
f"Node {self.node_id} (Start) execution completed, "
|
||||
f"outputting {len(custom_vars)} custom variables"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
|
||||
@@ -153,7 +153,8 @@ class TemplateRenderer:
|
||||
|
||||
|
||||
# 全局渲染器实例(严格模式)
|
||||
_default_renderer = TemplateRenderer(strict=True)
|
||||
_strict_renderer = TemplateRenderer(strict=True)
|
||||
_lenient_renderer = TemplateRenderer(strict=False)
|
||||
|
||||
|
||||
def render_template(
|
||||
@@ -184,7 +185,7 @@ def render_template(
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
renderer = TemplateRenderer(strict=strict)
|
||||
renderer = _strict_renderer if strict else _lenient_renderer
|
||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||
|
||||
|
||||
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
|
||||
Returns:
|
||||
错误列表
|
||||
"""
|
||||
return _default_renderer.validate(template)
|
||||
return _strict_renderer.validate(template)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Union, TYPE_CHECKING
|
||||
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
@@ -119,7 +120,6 @@ class WorkflowValidator:
|
||||
errors = []
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for index, graph in enumerate(graphs):
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
@@ -183,7 +183,7 @@ class WorkflowValidator:
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
@@ -204,18 +204,18 @@ class WorkflowValidator:
|
||||
Returns:
|
||||
可达节点 ID 集合
|
||||
"""
|
||||
adj = defaultdict(list)
|
||||
for edge in edges:
|
||||
adj[edge["source"]].append(edge["target"])
|
||||
|
||||
reachable = {start_id}
|
||||
queue = [start_id]
|
||||
|
||||
queue = deque([start_id])
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for edge in edges:
|
||||
if edge.get("source") == current:
|
||||
target = edge.get("target")
|
||||
if target and target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
|
||||
current = queue.popleft()
|
||||
for target in adj[current]:
|
||||
if target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
return reachable
|
||||
|
||||
@staticmethod
|
||||
@@ -229,10 +229,6 @@ class WorkflowValidator:
|
||||
Returns:
|
||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||
"""
|
||||
# 排除 loop 类型的节点
|
||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||
|
||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||
graph: dict[str, list[str]] = {}
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
@@ -243,10 +239,6 @@ class WorkflowValidator:
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
# 如果涉及 loop 节点,跳过
|
||||
if source in loop_nodes or target in loop_nodes:
|
||||
continue
|
||||
|
||||
if source and target:
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
|
||||
@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
|
||||
|
||||
def valid_value(self, value) -> dict:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
|
||||
from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration
|
||||
from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration
|
||||
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.schemas import AppCreate
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
@@ -46,7 +46,7 @@ class WorkflowImportService:
|
||||
success=False,
|
||||
temp_id=None,
|
||||
workflow_id=None,
|
||||
errors=[UnsupportPlatform(platform=platform)]
|
||||
errors=[UnsupportedPlatform(platform=platform)]
|
||||
)
|
||||
|
||||
adapter = self.registry.get_adapter(platform, config)
|
||||
|
||||
Reference in New Issue
Block a user