fix(app): Workflow import verification
This commit is contained in:
@@ -4,65 +4,145 @@
|
||||
# @Time : 2026/2/25 14:11
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.schemas.workflow_schema import ExecutionConfig
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
|
||||
|
||||
|
||||
class MemoryBearAdapter(BasePlatformAdapter):
|
||||
NODE_TYPE_MAPPING = {}
|
||||
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
MemoryBearConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("nodes")
|
||||
return self.config.get("workflow").get("nodes") or []
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("edges")
|
||||
return self.config.get("workflow").get("edges") or []
|
||||
|
||||
@property
|
||||
def origin_variables(self):
|
||||
return self.config.get("workflow").get("variables")
|
||||
return self.config.get("workflow").get("variables") or []
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.MEMORY_BEAR,
|
||||
version="0.2.5",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
support_node_types=list(VALID_NODE_TYPES)
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return platform_node_type
|
||||
def map_node_type(self, platform_node_type: str) -> NodeType:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
def _valid_node(node: dict[str, Any]) -> bool:
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
if not isinstance(node.get("config"), dict):
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
if not self._valid_node(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_id = node.get("id")
|
||||
node_name = node.get("name")
|
||||
try:
|
||||
node_type = self.map_node_type(node["type"])
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(UnsupportNodeType(
|
||||
node_id=node_id,
|
||||
node_type=node["type"]
|
||||
))
|
||||
return None
|
||||
|
||||
config = node.get("config") or {}
|
||||
converter = self.get_node_convert(node_type)
|
||||
converter(node_id, node_name, config) # validates and appends errors if invalid
|
||||
|
||||
return NodeDefinition(**node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=f"convert node error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||
try:
|
||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||
))
|
||||
return None
|
||||
return EdgeDefinition(**edge)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
|
||||
try:
|
||||
return VariableDefinition(**variable)
|
||||
except Exception as e:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self.nodes = self.origin_nodes
|
||||
self.edges = self.origin_edges
|
||||
self.conv_variables = self.origin_variables
|
||||
for node in self.origin_nodes:
|
||||
converted = self._convert_node(node)
|
||||
if converted:
|
||||
self.nodes.append(converted)
|
||||
|
||||
valid_node_ids = {n.id for n in self.nodes}
|
||||
|
||||
for edge in self.origin_edges:
|
||||
converted = self._convert_edge(edge, valid_node_ids)
|
||||
if converted:
|
||||
self.edges.append(converted)
|
||||
|
||||
for variable in self.origin_variables:
|
||||
converted = self._convert_variable(variable)
|
||||
if converted:
|
||||
self.conv_variables.append(converted)
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=True,
|
||||
success=not self.errors and not self.warnings,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=ExecutionConfig(),
|
||||
origin_config=self.config,
|
||||
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors,
|
||||
|
||||
)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
EndNodeConfig,
|
||||
LLMNodeConfig,
|
||||
AgentNodeConfig,
|
||||
IfElseNodeConfig,
|
||||
KnowledgeRetrievalNodeConfig,
|
||||
AssignerNodeConfig,
|
||||
CodeNodeConfig,
|
||||
HttpRequestNodeConfig,
|
||||
JinjaRenderNodeConfig,
|
||||
VariableAggregatorNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
LoopNodeConfig,
|
||||
IterationNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
ToolNodeConfig,
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class MemoryBearConverter(BaseConverter):
|
||||
errors: list
|
||||
warnings: list
|
||||
|
||||
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
|
||||
NodeType.START: StartNodeConfig,
|
||||
NodeType.END: EndNodeConfig,
|
||||
NodeType.ANSWER: EndNodeConfig,
|
||||
NodeType.LLM: LLMNodeConfig,
|
||||
NodeType.AGENT: AgentNodeConfig,
|
||||
NodeType.IF_ELSE: IfElseNodeConfig,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
|
||||
NodeType.ASSIGNER: AssignerNodeConfig,
|
||||
NodeType.CODE: CodeNodeConfig,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
|
||||
NodeType.JINJARENDER: JinjaRenderNodeConfig,
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
|
||||
NodeType.LOOP: LoopNodeConfig,
|
||||
NodeType.ITERATION: IterationNodeConfig,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
|
||||
NodeType.TOOL: ToolNodeConfig,
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _convert_file(var):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_file(var):
|
||||
return []
|
||||
|
||||
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
|
||||
try:
|
||||
return config_cls.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=str(e)
|
||||
))
|
||||
return None
|
||||
|
||||
def get_node_convert(self, node_type: NodeType):
|
||||
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
|
||||
if not config_cls:
|
||||
return lambda node_id, node_name, config: config
|
||||
|
||||
def validate(node_id: str, node_name: str, config: dict):
|
||||
self.config_validate(node_id, node_name, config_cls, config)
|
||||
return config
|
||||
|
||||
return validate
|
||||
@@ -17,6 +17,7 @@ from app.models.models_model import ModelConfig
|
||||
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||
from app.models.workflow_model import WorkflowConfig
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||
|
||||
|
||||
class AppDslService:
|
||||
@@ -243,7 +244,7 @@ class AppDslService:
|
||||
model_parameters=cfg.get("model_parameters"),
|
||||
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
|
||||
knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings),
|
||||
memory=cfg.get("memory"),
|
||||
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
|
||||
variables=cfg.get("variables", []),
|
||||
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||
skills=cfg.get("skills", {}),
|
||||
@@ -272,12 +273,20 @@ class AppDslService:
|
||||
))
|
||||
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
adapter = MemoryBearAdapter(dsl)
|
||||
if not adapter.validate_config():
|
||||
raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST)
|
||||
result = adapter.parse_workflow()
|
||||
for e in result.errors:
|
||||
warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}")
|
||||
for w in result.warnings:
|
||||
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
|
||||
wf = dsl.get("workflow") or {}
|
||||
WorkflowService(self.db).create_workflow_config(
|
||||
app_id=new_app.id,
|
||||
nodes=wf.get("nodes", []),
|
||||
edges=wf.get("edges", []),
|
||||
variables=wf.get("variables", []),
|
||||
nodes=[n.model_dump() for n in result.nodes],
|
||||
edges=[e.model_dump() for e in result.edges],
|
||||
variables=[v.model_dump() for v in result.variables],
|
||||
execution_config=wf.get("execution_config", {}),
|
||||
triggers=wf.get("triggers", []),
|
||||
validate=False,
|
||||
@@ -376,15 +385,37 @@ class AppDslService:
|
||||
for kb in kr.get("knowledge_bases", []):
|
||||
ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None)
|
||||
entry = {k: v for k, v in kb.items() if k != "_ref"}
|
||||
entry["kb_id"] = self._resolve_kb(ref, workspace_id, warnings)
|
||||
resolved_id = self._resolve_kb(ref, workspace_id, warnings)
|
||||
if resolved_id is None:
|
||||
continue
|
||||
entry["kb_id"] = resolved_id
|
||||
resolved_kbs.append(entry)
|
||||
return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs}
|
||||
|
||||
def _resolve_memory(self, memory: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
|
||||
if not memory:
|
||||
return memory
|
||||
config_id = memory.get("memory_config_id") or memory.get("memory_content")
|
||||
if not config_id:
|
||||
return memory
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
exists = self.db.query(MemoryConfigModel).filter(
|
||||
MemoryConfigModel.config_id == config_id,
|
||||
MemoryConfigModel.workspace_id == workspace_id
|
||||
).first()
|
||||
if not exists:
|
||||
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
|
||||
return {**memory, "memory_config_id": None, "enabled": False}
|
||||
return memory
|
||||
|
||||
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
||||
result = []
|
||||
for t in (tools or []):
|
||||
ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None)
|
||||
entry = {k: v for k, v in t.items() if k != "_ref"}
|
||||
entry["tool_id"] = self._resolve_tool(ref, tenant_id, warnings)
|
||||
resolved_id = self._resolve_tool(ref, tenant_id, warnings)
|
||||
if resolved_id is None:
|
||||
continue
|
||||
entry["tool_id"] = resolved_id
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
@@ -75,7 +75,7 @@ REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
ENABLE_SINGLE_SESSION=
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
|
||||
MAX_FILE_SIZE=52428800 # 50MB:50 * 1024 * 1024
|
||||
FILE_PATH=/files
|
||||
|
||||
FILE_LOCAL_SERVER_URL="http://localhost:8000/api"
|
||||
|
||||
Reference in New Issue
Block a user