diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py index 0e3f459f..3516cb58 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py @@ -4,65 +4,145 @@ # @Time : 2026/2/25 14:11 from typing import Any +from app.core.logging_config import get_logger from app.core.workflow.adapters.base_adapter import ( PlatformMetadata, PlatformType, BasePlatformAdapter, WorkflowParserResult ) -from app.schemas.workflow_schema import ExecutionConfig +from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType +from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter +from app.core.workflow.nodes.enums import NodeType +from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition + +logger = get_logger() + +VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN) -class MemoryBearAdapter(BasePlatformAdapter): - NODE_TYPE_MAPPING = {} +class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): + NODE_TYPE_MAPPING = {t.value: t for t in NodeType} + + def __init__(self, config: dict[str, Any]): + MemoryBearConverter.__init__(self) + BasePlatformAdapter.__init__(self, config) @property def origin_nodes(self): - return self.config.get("workflow").get("nodes") + return self.config.get("workflow").get("nodes") or [] @property def origin_edges(self): - return self.config.get("workflow").get("edges") + return self.config.get("workflow").get("edges") or [] @property def origin_variables(self): - return self.config.get("workflow").get("variables") + return self.config.get("workflow").get("variables") or [] def get_metadata(self) -> PlatformMetadata: return PlatformMetadata( platform_name=PlatformType.MEMORY_BEAR, version="0.2.5", - support_node_types=list(self.NODE_TYPE_MAPPING.keys()) + support_node_types=list(VALID_NODE_TYPES) ) - def map_node_type(self, platform_node_type) -> str: - return platform_node_type + def map_node_type(self, platform_node_type: str) -> NodeType: + return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN) @staticmethod - def _valid_nodes(node: dict[str, Any]): - if "type" not in node["data"]: - return False + def _valid_node(node: dict[str, Any]) -> bool: if "id" not in node or "type" not in node: return False + if not isinstance(node.get("config"), dict): + return False return True def validate_config(self) -> bool: require_fields = frozenset({'app', 'workflow'}) if not all(field in self.config for field in require_fields): return False - for node in self.origin_nodes: - if not self._valid_nodes(node): + if not self._valid_node(node): return False return True + def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None: + node_id = node.get("id") + node_name = node.get("name") + try: + node_type = self.map_node_type(node["type"]) + if node_type == NodeType.UNKNOWN: + self.errors.append(UnsupportNodeType( + node_id=node_id, + node_type=node["type"] + )) + return None + + config = node.get("config") or {} + converter = self.get_node_convert(node_type) + converter(node_id, node_name, config) # validates and appends errors if invalid + + return NodeDefinition(**node) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node_id, + node_name=node_name, + detail=f"convert node error - {e}" + )) + logger.debug(f"MemoryBear convert node error - {e}", exc_info=True) + return None + + def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: + try: + if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.EDGE, + detail=f"edge {edge.get('id')} skipped: source or target node not found" + )) + return None + return EdgeDefinition(**edge) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.EDGE, + detail=f"convert edge error - {e}" + )) + logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True) + return None + + def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None: + try: + return VariableDefinition(**variable) + except Exception as e: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.VARIABLE, + name=variable.get("name"), + detail=f"convert variable error - {e}" + )) + logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True) + return None + def parse_workflow(self) -> WorkflowParserResult: - self.nodes = self.origin_nodes - self.edges = self.origin_edges - self.conv_variables = self.origin_variables + for node in self.origin_nodes: + converted = self._convert_node(node) + if converted: + self.nodes.append(converted) + + valid_node_ids = {n.id for n in self.nodes} + + for edge in self.origin_edges: + converted = self._convert_edge(edge, valid_node_ids) + if converted: + self.edges.append(converted) + + for variable in self.origin_variables: + converted = self._convert_variable(variable) + if converted: + self.conv_variables.append(converted) return WorkflowParserResult( - success=True, + success=not self.errors and not self.warnings, platform=self.get_metadata(), execution_config=ExecutionConfig(), origin_config=self.config, @@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter): variables=self.conv_variables, warnings=self.warnings, errors=self.errors, - ) diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py new file mode 100644 index 00000000..031c7025 --- /dev/null +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py @@ -0,0 +1,85 @@ +# -*- coding: UTF-8 -*- +from app.core.workflow.adapters.base_converter import BaseConverter +from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.nodes.configs import ( + StartNodeConfig, + EndNodeConfig, + LLMNodeConfig, + AgentNodeConfig, + IfElseNodeConfig, + KnowledgeRetrievalNodeConfig, + AssignerNodeConfig, + CodeNodeConfig, + HttpRequestNodeConfig, + JinjaRenderNodeConfig, + VariableAggregatorNodeConfig, + ParameterExtractorNodeConfig, + LoopNodeConfig, + IterationNodeConfig, + QuestionClassifierNodeConfig, + ToolNodeConfig, + MemoryReadNodeConfig, + MemoryWriteNodeConfig, + NoteNodeConfig, +) +from app.core.workflow.nodes.enums import NodeType + + +class MemoryBearConverter(BaseConverter): + errors: list + warnings: list + + CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = { + NodeType.START: StartNodeConfig, + NodeType.END: EndNodeConfig, + NodeType.ANSWER: EndNodeConfig, + NodeType.LLM: LLMNodeConfig, + NodeType.AGENT: AgentNodeConfig, + NodeType.IF_ELSE: IfElseNodeConfig, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig, + NodeType.ASSIGNER: AssignerNodeConfig, + NodeType.CODE: CodeNodeConfig, + NodeType.HTTP_REQUEST: HttpRequestNodeConfig, + NodeType.JINJARENDER: JinjaRenderNodeConfig, + NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig, + NodeType.LOOP: LoopNodeConfig, + NodeType.ITERATION: IterationNodeConfig, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig, + NodeType.TOOL: ToolNodeConfig, + NodeType.MEMORY_READ: MemoryReadNodeConfig, + NodeType.MEMORY_WRITE: MemoryWriteNodeConfig, + NodeType.NOTES: NoteNodeConfig, + } + + @staticmethod + def _convert_file(var): + return None + + @staticmethod + def _convert_array_file(var): + return [] + + def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict): + try: + return config_cls.model_validate(value) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node_id, + node_name=node_name, + detail=str(e) + )) + return None + + def get_node_convert(self, node_type: NodeType): + config_cls = self.CONFIG_CLASS_MAP.get(node_type) + if not config_cls: + return lambda node_id, node_name, config: config + + def validate(node_id: str, node_name: str, config: dict): + self.config_validate(node_id, node_name, config_cls, config) + return config + + return validate diff --git a/api/app/services/app_dsl_service.py b/api/app/services/app_dsl_service.py index c120d98b..fc071177 100644 --- a/api/app/services/app_dsl_service.py +++ b/api/app/services/app_dsl_service.py @@ -17,6 +17,7 @@ from app.models.models_model import ModelConfig from app.models.tool_model import ToolConfig as ToolConfigModel from app.models.workflow_model import WorkflowConfig from app.services.workflow_service import WorkflowService +from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter class AppDslService: @@ -243,7 +244,7 @@ class AppDslService: model_parameters=cfg.get("model_parameters"), default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings), knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings), - memory=cfg.get("memory"), + memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings), variables=cfg.get("variables", []), tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings), skills=cfg.get("skills", {}), @@ -272,12 +273,20 @@ class AppDslService: )) elif app_type == AppType.WORKFLOW: + adapter = MemoryBearAdapter(dsl) + if not adapter.validate_config(): + raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST) + result = adapter.parse_workflow() + for e in result.errors: + warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}") + for w in result.warnings: + warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}") wf = dsl.get("workflow") or {} WorkflowService(self.db).create_workflow_config( app_id=new_app.id, - nodes=wf.get("nodes", []), - edges=wf.get("edges", []), - variables=wf.get("variables", []), + nodes=[n.model_dump() for n in result.nodes], + edges=[e.model_dump() for e in result.edges], + variables=[v.model_dump() for v in result.variables], execution_config=wf.get("execution_config", {}), triggers=wf.get("triggers", []), validate=False, @@ -376,15 +385,37 @@ class AppDslService: for kb in kr.get("knowledge_bases", []): ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None) entry = {k: v for k, v in kb.items() if k != "_ref"} - entry["kb_id"] = self._resolve_kb(ref, workspace_id, warnings) + resolved_id = self._resolve_kb(ref, workspace_id, warnings) + if resolved_id is None: + continue + entry["kb_id"] = resolved_id resolved_kbs.append(entry) return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs} + def _resolve_memory(self, memory: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]: + if not memory: + return memory + config_id = memory.get("memory_config_id") or memory.get("memory_content") + if not config_id: + return memory + from app.models.memory_config_model import MemoryConfig as MemoryConfigModel + exists = self.db.query(MemoryConfigModel).filter( + MemoryConfigModel.config_id == config_id, + MemoryConfigModel.workspace_id == workspace_id + ).first() + if not exists: + warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置") + return {**memory, "memory_config_id": None, "enabled": False} + return memory + def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list: result = [] for t in (tools or []): ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None) entry = {k: v for k, v in t.items() if k != "_ref"} - entry["tool_id"] = self._resolve_tool(ref, tenant_id, warnings) + resolved_id = self._resolve_tool(ref, tenant_id, warnings) + if resolved_id is None: + continue + entry["tool_id"] = resolved_id result.append(entry) return result diff --git a/api/env.example b/api/env.example index bd7f3dae..e324d1e5 100644 --- a/api/env.example +++ b/api/env.example @@ -75,7 +75,7 @@ REFRESH_TOKEN_EXPIRE_DAYS=7 ENABLE_SINGLE_SESSION= # File Upload -MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024 +MAX_FILE_SIZE=52428800 # 50MB:50 * 1024 * 1024 FILE_PATH=/files FILE_LOCAL_SERVER_URL="http://localhost:8000/api"