Merge pull request #1018 from SuanmoSuanyangTechnology/feat/wxy-dev
feat(workflow): incorporate model references and streamline parsing logic
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 匹配模板变量 {{xxx}} 的正则
|
||||||
|
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||||
|
|
||||||
|
|
||||||
class NodeExecutionError(Exception):
|
class NodeExecutionError(Exception):
|
||||||
"""节点执行失败异常。
|
"""节点执行失败异常。
|
||||||
@@ -503,10 +507,29 @@ class BaseNode(ABC):
|
|||||||
variable_pool: The variable pool used for reading and writing variables.
|
variable_pool: The variable pool used for reading and writing variables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the node's input data.
|
A dictionary containing the node's input data with all template
|
||||||
|
variables resolved to their actual runtime values.
|
||||||
"""
|
"""
|
||||||
# Default implementation returns the node configuration
|
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||||
return {"config": self.config}
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
|
||||||
|
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 节点的原始配置(可能包含模板变量)。
|
||||||
|
variable_pool: 变量池,用于解析模板变量。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
|
||||||
|
"""
|
||||||
|
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
|
||||||
|
return BaseNode._render_template(config, variable_pool, strict=False)
|
||||||
|
elif isinstance(config, dict):
|
||||||
|
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
|
||||||
|
elif isinstance(config, list):
|
||||||
|
return [BaseNode._resolve_config(item, variable_pool) for item in config]
|
||||||
|
return config
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
"""Extracts the actual output from the business result.
|
"""Extracts the actual output from the business result.
|
||||||
|
|||||||
@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
|
|||||||
return business_result
|
return business_result
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
return {"file_selector": self.config.get("file_selector")}
|
file_selector = self.config.get("file_selector", "")
|
||||||
|
# 将变量选择器(如 sys.files)解析为实际值
|
||||||
|
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
|
||||||
|
return {"file_selector": resolved}
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
config = DocExtractorNodeConfig(**self.config)
|
config = DocExtractorNodeConfig(**self.config)
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ class AppDslService:
|
|||||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||||
]
|
]
|
||||||
return enriched
|
return enriched
|
||||||
|
if app_type == AppType.WORKFLOW:
|
||||||
|
enriched = {**cfg}
|
||||||
|
if "nodes" in cfg:
|
||||||
|
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
|
||||||
|
return enriched
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||||
@@ -110,7 +115,7 @@ class AppDslService:
|
|||||||
config_data = {
|
config_data = {
|
||||||
"variables": config.variables if config else [],
|
"variables": config.variables if config else [],
|
||||||
"edges": config.edges if config else [],
|
"edges": config.edges if config else [],
|
||||||
"nodes": config.nodes if config else [],
|
"nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
|
||||||
"features": config.features if config else {},
|
"features": config.features if config else {},
|
||||||
"execution_config": config.execution_config if config else {},
|
"execution_config": config.execution_config if config else {},
|
||||||
"triggers": config.triggers if config else [],
|
"triggers": config.triggers if config else [],
|
||||||
@@ -190,6 +195,23 @@ class AppDslService:
|
|||||||
def _enrich_tools(self, tools: list) -> list:
|
def _enrich_tools(self, tools: list) -> list:
|
||||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||||
|
|
||||||
|
def _enrich_workflow_nodes(self, nodes: list) -> list:
|
||||||
|
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
enriched_nodes = []
|
||||||
|
for node in (nodes or []):
|
||||||
|
node_type = node.get("type")
|
||||||
|
config = dict(node.get("config") or {})
|
||||||
|
|
||||||
|
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||||
|
model_id = config.get("model_id")
|
||||||
|
if model_id:
|
||||||
|
config["model_ref"] = self._model_ref(model_id)
|
||||||
|
del config["model_id"]
|
||||||
|
|
||||||
|
enriched_nodes.append({**node, "config": config})
|
||||||
|
return enriched_nodes
|
||||||
|
|
||||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||||
if not skill_id:
|
if not skill_id:
|
||||||
return None
|
return None
|
||||||
@@ -620,16 +642,16 @@ class AppDslService:
|
|||||||
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
||||||
config["knowledge_bases"] = resolved_kbs
|
config["knowledge_bases"] = resolved_kbs
|
||||||
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||||
model_ref = config.get("model_id")
|
model_ref = config.get("model_ref") or config.get("model_id")
|
||||||
if model_ref:
|
if model_ref:
|
||||||
ref_dict = None
|
ref_dict = None
|
||||||
if isinstance(model_ref, dict):
|
if isinstance(model_ref, dict):
|
||||||
ref_id = model_ref.get("id")
|
ref_dict = {
|
||||||
ref_name = model_ref.get("name")
|
"id": model_ref.get("id"),
|
||||||
if ref_id:
|
"name": model_ref.get("name"),
|
||||||
ref_dict = {"id": ref_id}
|
"provider": model_ref.get("provider"),
|
||||||
elif ref_name is not None:
|
"type": model_ref.get("type")
|
||||||
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
|
}
|
||||||
elif isinstance(model_ref, str):
|
elif isinstance(model_ref, str):
|
||||||
try:
|
try:
|
||||||
uuid.UUID(model_ref)
|
uuid.UUID(model_ref)
|
||||||
@@ -640,12 +662,18 @@ class AppDslService:
|
|||||||
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
||||||
if resolved_model_id:
|
if resolved_model_id:
|
||||||
config["model_id"] = resolved_model_id
|
config["model_id"] = resolved_model_id
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
resolved_nodes.append({**node, "config": config})
|
resolved_nodes.append({**node, "config": config})
|
||||||
return resolved_nodes
|
return resolved_nodes
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user