Merge pull request #957 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
feat(workflow)
This commit is contained in:
@@ -81,6 +81,7 @@ class DifyConverter(BaseConverter):
|
|||||||
NodeType.START: self.convert_start_node_config,
|
NodeType.START: self.convert_start_node_config,
|
||||||
NodeType.LLM: self.convert_llm_node_config,
|
NodeType.LLM: self.convert_llm_node_config,
|
||||||
NodeType.END: self.convert_end_node_config,
|
NodeType.END: self.convert_end_node_config,
|
||||||
|
NodeType.OUTPUT: self.convert_output_node_config,
|
||||||
NodeType.IF_ELSE: self.convert_if_else_node_config,
|
NodeType.IF_ELSE: self.convert_if_else_node_config,
|
||||||
NodeType.LOOP: self.convert_loop_node_config,
|
NodeType.LOOP: self.convert_loop_node_config,
|
||||||
NodeType.ITERATION: self.convert_iteration_node_config,
|
NodeType.ITERATION: self.convert_iteration_node_config,
|
||||||
@@ -174,12 +175,20 @@ class DifyConverter(BaseConverter):
|
|||||||
"file": VariableType.FILE,
|
"file": VariableType.FILE,
|
||||||
"paragraph": VariableType.STRING,
|
"paragraph": VariableType.STRING,
|
||||||
"text-input": VariableType.STRING,
|
"text-input": VariableType.STRING,
|
||||||
|
"string": VariableType.STRING,
|
||||||
"number": VariableType.NUMBER,
|
"number": VariableType.NUMBER,
|
||||||
"checkbox": VariableType.BOOLEAN,
|
|
||||||
"file-list": VariableType.ARRAY_FILE,
|
|
||||||
"select": VariableType.STRING,
|
|
||||||
"integer": VariableType.NUMBER,
|
"integer": VariableType.NUMBER,
|
||||||
"float": VariableType.NUMBER,
|
"float": VariableType.NUMBER,
|
||||||
|
"checkbox": VariableType.BOOLEAN,
|
||||||
|
"boolean": VariableType.BOOLEAN,
|
||||||
|
"object": VariableType.OBJECT,
|
||||||
|
"file-list": VariableType.ARRAY_FILE,
|
||||||
|
"array[string]": VariableType.ARRAY_STRING,
|
||||||
|
"array[number]": VariableType.ARRAY_NUMBER,
|
||||||
|
"array[boolean]": VariableType.ARRAY_BOOLEAN,
|
||||||
|
"array[object]": VariableType.ARRAY_OBJECT,
|
||||||
|
"array[file]": VariableType.ARRAY_FILE,
|
||||||
|
"select": VariableType.STRING,
|
||||||
}
|
}
|
||||||
var_type = type_map.get(source_type, source_type)
|
var_type = type_map.get(source_type, source_type)
|
||||||
return var_type
|
return var_type
|
||||||
@@ -274,7 +283,18 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_start_node_config(self, node: dict) -> dict:
|
def convert_start_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
start_vars = []
|
start_vars = []
|
||||||
for var in node_data["variables"]:
|
# workflow mode 用 user_input_form,advanced-chat 用 variables
|
||||||
|
raw_vars = node_data.get("variables") or []
|
||||||
|
if not raw_vars:
|
||||||
|
for form_item in node_data.get("user_input_form") or []:
|
||||||
|
# 每个 form_item 是 {"text-input": {...}} 或 {"paragraph": {...}} 等
|
||||||
|
for input_type, var in form_item.items():
|
||||||
|
var["type"] = input_type
|
||||||
|
var.setdefault("variable", var.get("variable", ""))
|
||||||
|
var.setdefault("required", var.get("required", False))
|
||||||
|
var.setdefault("label", var.get("label", ""))
|
||||||
|
raw_vars.append(var)
|
||||||
|
for var in raw_vars:
|
||||||
var_type = self.variable_type_map(var["type"])
|
var_type = self.variable_type_map(var["type"])
|
||||||
if not var_type:
|
if not var_type:
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
@@ -404,6 +424,19 @@ class DifyConverter(BaseConverter):
|
|||||||
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
|
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def convert_output_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
outputs = []
|
||||||
|
for item in node_data.get("outputs", []):
|
||||||
|
value_selector = item.get("value_selector") or []
|
||||||
|
var_type = self.variable_type_map(item.get("value_type", "string")) or VariableType.STRING
|
||||||
|
outputs.append({
|
||||||
|
"name": item.get("variable") or item.get("name", ""),
|
||||||
|
"type": var_type,
|
||||||
|
"value": self._process_list_variable_literal(value_selector) or "",
|
||||||
|
})
|
||||||
|
return {"outputs": outputs}
|
||||||
|
|
||||||
def convert_if_else_node_config(self, node: dict) -> dict:
|
def convert_if_else_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
cases = []
|
cases = []
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
"start": NodeType.START,
|
"start": NodeType.START,
|
||||||
"llm": NodeType.LLM,
|
"llm": NodeType.LLM,
|
||||||
"answer": NodeType.END,
|
"answer": NodeType.END,
|
||||||
|
"end": NodeType.OUTPUT,
|
||||||
"if-else": NodeType.IF_ELSE,
|
"if-else": NodeType.IF_ELSE,
|
||||||
"loop-start": NodeType.CYCLE_START,
|
"loop-start": NodeType.CYCLE_START,
|
||||||
"iteration-start": NodeType.CYCLE_START,
|
"iteration-start": NodeType.CYCLE_START,
|
||||||
@@ -86,13 +87,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||||
if not all(field in self.config for field in require_fields):
|
if not all(field in self.config for field in require_fields):
|
||||||
return False
|
return False
|
||||||
if self.config.get("app", {}).get("mode") == "workflow":
|
|
||||||
self.errors.append(ExceptionDefinition(
|
|
||||||
type=ExceptionType.PLATFORM,
|
|
||||||
detail="workflow mode is not supported"
|
|
||||||
))
|
|
||||||
return False
|
|
||||||
|
|
||||||
for node in self.origin_nodes:
|
for node in self.origin_nodes:
|
||||||
if not self._valid_nodes(node):
|
if not self._valid_nodes(node):
|
||||||
return False
|
return False
|
||||||
@@ -114,7 +108,11 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
if edge:
|
if edge:
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
|
|
||||||
for variable in self.config.get("workflow").get("conversation_variables"):
|
mode = self.config.get("app", {}).get("mode", "advanced-chat")
|
||||||
|
conv_variables = self.config.get("workflow").get("conversation_variables") or []
|
||||||
|
if mode == "workflow":
|
||||||
|
conv_variables = []
|
||||||
|
for variable in conv_variables:
|
||||||
con_var = self._convert_variable(variable)
|
con_var = self._convert_variable(variable)
|
||||||
if variable:
|
if variable:
|
||||||
self.conv_variables.append(con_var)
|
self.conv_variables.append(con_var)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from app.core.workflow.nodes.configs import (
|
|||||||
NoteNodeConfig,
|
NoteNodeConfig,
|
||||||
ListOperatorNodeConfig,
|
ListOperatorNodeConfig,
|
||||||
DocExtractorNodeConfig,
|
DocExtractorNodeConfig,
|
||||||
|
OutputNodeConfig,
|
||||||
)
|
)
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
@@ -36,6 +37,7 @@ class MemoryBearConverter(BaseConverter):
|
|||||||
NodeType.START: StartNodeConfig,
|
NodeType.START: StartNodeConfig,
|
||||||
NodeType.END: EndNodeConfig,
|
NodeType.END: EndNodeConfig,
|
||||||
NodeType.ANSWER: EndNodeConfig,
|
NodeType.ANSWER: EndNodeConfig,
|
||||||
|
NodeType.OUTPUT: OutputNodeConfig,
|
||||||
NodeType.LLM: LLMNodeConfig,
|
NodeType.LLM: LLMNodeConfig,
|
||||||
NodeType.AGENT: AgentNodeConfig,
|
NodeType.AGENT: AgentNodeConfig,
|
||||||
NodeType.IF_ELSE: IfElseNodeConfig,
|
NodeType.IF_ELSE: IfElseNodeConfig,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.core.workflow.nodes import NodeFactory
|
|||||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||||
from app.core.workflow.validator import WorkflowValidator
|
from app.core.workflow.validator import WorkflowValidator
|
||||||
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -144,7 +145,7 @@ class GraphBuilder:
|
|||||||
(node_info["id"], node_info["branch"])
|
(node_info["id"], node_info["branch"])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.get_node_type(node_info["id"]) == NodeType.END:
|
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
||||||
output_nodes.append(node_info["id"])
|
output_nodes.append(node_info["id"])
|
||||||
non_branch_nodes.append(node_info["id"])
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
@@ -187,7 +188,17 @@ class GraphBuilder:
|
|||||||
for end_node in self.end_nodes:
|
for end_node in self.end_nodes:
|
||||||
end_node_id = end_node.get("id")
|
end_node_id = end_node.get("id")
|
||||||
config = end_node.get("config", {})
|
config = end_node.get("config", {})
|
||||||
output = config.get("output")
|
node_type = end_node.get("type")
|
||||||
|
|
||||||
|
# Output node: STRING type items participate in streaming text output
|
||||||
|
if node_type == NodeType.OUTPUT:
|
||||||
|
outputs_list = config.get("outputs", [])
|
||||||
|
output = "\n".join(
|
||||||
|
item.get("value", "") for item in outputs_list
|
||||||
|
if item.get("value") and item.get("type", VariableType.STRING) == VariableType.STRING
|
||||||
|
) or None
|
||||||
|
else:
|
||||||
|
output = config.get("output")
|
||||||
|
|
||||||
# Skip End nodes without output configuration
|
# Skip End nodes without output configuration
|
||||||
if not output:
|
if not output:
|
||||||
@@ -515,7 +526,7 @@ class GraphBuilder:
|
|||||||
self.end_nodes = [
|
self.end_nodes = [
|
||||||
node
|
node
|
||||||
for node in self.nodes
|
for node in self.nodes
|
||||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
if node.get("type") in ("end", "output") and node.get("id") in self.reachable_nodes
|
||||||
]
|
]
|
||||||
self._build_adj()
|
self._build_adj()
|
||||||
self._find_upstream_activation_dep: Callable = lru_cache(
|
self._find_upstream_activation_dep: Callable = lru_cache(
|
||||||
|
|||||||
@@ -258,6 +258,21 @@ class WorkflowExecutor:
|
|||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
# For output nodes, collect structured results from variable_pool and serialize to JSON
|
||||||
|
output_node_ids = [
|
||||||
|
node["id"] for node in self.workflow_config.get("nodes", [])
|
||||||
|
if node.get("type") == "output"
|
||||||
|
]
|
||||||
|
if output_node_ids:
|
||||||
|
structured_output = {}
|
||||||
|
for node_id in output_node_ids:
|
||||||
|
node_output = self.variable_pool.get_node_output(node_id, default=None, strict=False)
|
||||||
|
if node_output:
|
||||||
|
structured_output.update(node_output)
|
||||||
|
final_output = structured_output if structured_output else full_content
|
||||||
|
else:
|
||||||
|
final_output = full_content
|
||||||
|
|
||||||
# Append messages for user and assistant
|
# Append messages for user and assistant
|
||||||
if input_data.get("files"):
|
if input_data.get("files"):
|
||||||
result["messages"].extend(
|
result["messages"].extend(
|
||||||
@@ -301,7 +316,7 @@ class WorkflowExecutor:
|
|||||||
self.execution_context,
|
self.execution_context,
|
||||||
self.variable_pool,
|
self.variable_pool,
|
||||||
elapsed_time,
|
elapsed_time,
|
||||||
full_content,
|
final_output,
|
||||||
success=True)
|
success=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
|
|||||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||||
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
|
from app.core.workflow.nodes.list_operator.config import ListOperatorNodeConfig
|
||||||
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
||||||
|
from app.core.workflow.nodes.output.config import OutputNodeConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 基础类
|
# 基础类
|
||||||
@@ -54,4 +55,5 @@ __all__ = [
|
|||||||
"NoteNodeConfig",
|
"NoteNodeConfig",
|
||||||
"ListOperatorNodeConfig",
|
"ListOperatorNodeConfig",
|
||||||
"DocExtractorNodeConfig",
|
"DocExtractorNodeConfig",
|
||||||
|
"OutputNodeConfig"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ class NodeType(StrEnum):
|
|||||||
MEMORY_WRITE = "memory-write"
|
MEMORY_WRITE = "memory-write"
|
||||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||||
LIST_OPERATOR = "list-operator"
|
LIST_OPERATOR = "list-operator"
|
||||||
|
OUTPUT = "output"
|
||||||
|
|
||||||
UNKNOWN = "unknown"
|
UNKNOWN = "unknown"
|
||||||
NOTES = "notes"
|
NOTES = "notes"
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ LLM 节点实现
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
@@ -81,7 +80,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||||
return re.sub(r"{{context}}", context, message)
|
return message.replace("{{context}}", context)
|
||||||
|
|
||||||
async def _prepare_llm(
|
async def _prepare_llm(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from app.core.workflow.nodes.breaker import BreakNode
|
|||||||
from app.core.workflow.nodes.tool import ToolNode
|
from app.core.workflow.nodes.tool import ToolNode
|
||||||
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
||||||
from app.core.workflow.nodes.list_operator import ListOperatorNode
|
from app.core.workflow.nodes.list_operator import ListOperatorNode
|
||||||
|
from app.core.workflow.nodes.output import OutputNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -53,7 +54,8 @@ WorkflowNode = Union[
|
|||||||
MemoryWriteNode,
|
MemoryWriteNode,
|
||||||
CodeNode,
|
CodeNode,
|
||||||
DocExtractorNode,
|
DocExtractorNode,
|
||||||
ListOperatorNode
|
ListOperatorNode,
|
||||||
|
OutputNode
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -86,7 +88,8 @@ class NodeFactory:
|
|||||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||||
NodeType.CODE: CodeNode,
|
NodeType.CODE: CodeNode,
|
||||||
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
|
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode,
|
||||||
NodeType.LIST_OPERATOR: ListOperatorNode
|
NodeType.LIST_OPERATOR: ListOperatorNode,
|
||||||
|
NodeType.OUTPUT: OutputNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
4
api/app/core/workflow/nodes/output/__init__.py
Normal file
4
api/app/core/workflow/nodes/output/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.core.workflow.nodes.output.node import OutputNode
|
||||||
|
from app.core.workflow.nodes.output.config import OutputNodeConfig
|
||||||
|
|
||||||
|
__all__ = ["OutputNode", "OutputNodeConfig"]
|
||||||
14
api/app/core/workflow/nodes/output/config.py
Normal file
14
api/app/core/workflow/nodes/output/config.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from typing import Any
|
||||||
|
from pydantic import Field
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class OutputItemConfig(BaseNodeConfig):
|
||||||
|
name: str
|
||||||
|
type: VariableType = VariableType.STRING
|
||||||
|
value: Any = ""
|
||||||
|
|
||||||
|
|
||||||
|
class OutputNodeConfig(BaseNodeConfig):
|
||||||
|
outputs: list[OutputItemConfig] = Field(default_factory=list)
|
||||||
49
api/app/core/workflow/nodes/output/node.py
Normal file
49
api/app/core/workflow/nodes/output/node.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
Output 节点实现
|
||||||
|
|
||||||
|
工作流的输出节点(类似 Dify workflow 的 end 节点),
|
||||||
|
用于定义工作流的最终输出变量,不产生流式输出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputNode(BaseNode):
|
||||||
|
"""
|
||||||
|
Output 节点
|
||||||
|
|
||||||
|
工作流的输出节点,收集并输出指定变量的值。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
outputs = self.config.get("outputs", [])
|
||||||
|
return {
|
||||||
|
item["name"]: VariableType(item.get("type", VariableType.STRING))
|
||||||
|
for item in outputs if item.get("name")
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
outputs = self.config.get("outputs", [])
|
||||||
|
result = {}
|
||||||
|
for item in outputs:
|
||||||
|
name = item.get("name")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
var_type = VariableType(item.get("type", VariableType.STRING))
|
||||||
|
value = item.get("value", "")
|
||||||
|
if var_type == VariableType.STRING:
|
||||||
|
result[name] = self._render_template(str(value), variable_pool, strict=False)
|
||||||
|
elif isinstance(value, str) and value.strip().startswith("{{") and value.strip().endswith("}}"):
|
||||||
|
selector = value.strip()[2:-2].strip()
|
||||||
|
result[name] = variable_pool.get_value(selector, default=None, strict=False)
|
||||||
|
else:
|
||||||
|
result[name] = value
|
||||||
|
return result
|
||||||
@@ -132,10 +132,10 @@ class WorkflowValidator:
|
|||||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||||
|
|
||||||
if index == len(graphs) - 1:
|
if index == len(graphs) - 1:
|
||||||
# 2. 验证 主图end 节点(至少一个)
|
# 2. 验证 主图end 节点(至少一个,output 节点也可作为终止节点)
|
||||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
end_nodes = [n for n in nodes if n.get("type") in [NodeType.END, NodeType.OUTPUT]]
|
||||||
if len(end_nodes) == 0:
|
if len(end_nodes) == 0:
|
||||||
errors.append("工作流必须至少有一个 end 节点")
|
errors.append("工作流必须至少有一个 end 节点 或 output 节点")
|
||||||
|
|
||||||
# 3. 验证节点 ID 唯一性
|
# 3. 验证节点 ID 唯一性
|
||||||
node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES]
|
node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES]
|
||||||
|
|||||||
Reference in New Issue
Block a user