Merge pull request #527 from SuanmoSuanyangTechnology/feature/agent-tool_xjn

fix(app)
This commit is contained in:
Mark
2026-03-10 16:22:06 +08:00
committed by GitHub
4 changed files with 221 additions and 26 deletions

View File

@@ -4,65 +4,145 @@
# @Time : 2026/2/25 14:11
from typing import Any
from app.core.logging_config import get_logger
from app.core.workflow.adapters.base_adapter import (
PlatformMetadata,
PlatformType,
BasePlatformAdapter,
WorkflowParserResult
)
from app.schemas.workflow_schema import ExecutionConfig
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
logger = get_logger()
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
class MemoryBearAdapter(BasePlatformAdapter):
NODE_TYPE_MAPPING = {}
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
def __init__(self, config: dict[str, Any]):
MemoryBearConverter.__init__(self)
BasePlatformAdapter.__init__(self, config)
@property
def origin_nodes(self):
return self.config.get("workflow").get("nodes")
return self.config.get("workflow").get("nodes") or []
@property
def origin_edges(self):
return self.config.get("workflow").get("edges")
return self.config.get("workflow").get("edges") or []
@property
def origin_variables(self):
return self.config.get("workflow").get("variables")
return self.config.get("workflow").get("variables") or []
def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata(
platform_name=PlatformType.MEMORY_BEAR,
version="0.2.5",
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
support_node_types=list(VALID_NODE_TYPES)
)
def map_node_type(self, platform_node_type) -> str:
return platform_node_type
def map_node_type(self, platform_node_type: str) -> NodeType:
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@staticmethod
def _valid_nodes(node: dict[str, Any]):
if "type" not in node["data"]:
return False
def _valid_node(node: dict[str, Any]) -> bool:
if "id" not in node or "type" not in node:
return False
if not isinstance(node.get("config"), dict):
return False
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'workflow'})
if not all(field in self.config for field in require_fields):
return False
for node in self.origin_nodes:
if not self._valid_nodes(node):
if not self._valid_node(node):
return False
return True
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_id = node.get("id")
node_name = node.get("name")
try:
node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportNodeType(
node_id=node_id,
node_type=node["type"]
))
return None
config = node.get("config") or {}
converter = self.get_node_convert(node_type)
converter(node_id, node_name, config) # validates and appends errors if invalid
return NodeDefinition(**node)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node_id,
node_name=node_name,
detail=f"convert node error - {e}"
))
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
return None
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found"
))
return None
return EdgeDefinition(**edge)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}"
))
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
return None
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
try:
return VariableDefinition(**variable)
except Exception as e:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}"
))
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
return None
def parse_workflow(self) -> WorkflowParserResult:
self.nodes = self.origin_nodes
self.edges = self.origin_edges
self.conv_variables = self.origin_variables
for node in self.origin_nodes:
converted = self._convert_node(node)
if converted:
self.nodes.append(converted)
valid_node_ids = {n.id for n in self.nodes}
for edge in self.origin_edges:
converted = self._convert_edge(edge, valid_node_ids)
if converted:
self.edges.append(converted)
for variable in self.origin_variables:
converted = self._convert_variable(variable)
if converted:
self.conv_variables.append(converted)
return WorkflowParserResult(
success=True,
success=not self.errors and not self.warnings,
platform=self.get_metadata(),
execution_config=ExecutionConfig(),
origin_config=self.config,
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
variables=self.conv_variables,
warnings=self.warnings,
errors=self.errors,
)

View File

@@ -0,0 +1,85 @@
# -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import (
StartNodeConfig,
EndNodeConfig,
LLMNodeConfig,
AgentNodeConfig,
IfElseNodeConfig,
KnowledgeRetrievalNodeConfig,
AssignerNodeConfig,
CodeNodeConfig,
HttpRequestNodeConfig,
JinjaRenderNodeConfig,
VariableAggregatorNodeConfig,
ParameterExtractorNodeConfig,
LoopNodeConfig,
IterationNodeConfig,
QuestionClassifierNodeConfig,
ToolNodeConfig,
MemoryReadNodeConfig,
MemoryWriteNodeConfig,
NoteNodeConfig,
)
from app.core.workflow.nodes.enums import NodeType
class MemoryBearConverter(BaseConverter):
errors: list
warnings: list
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
NodeType.START: StartNodeConfig,
NodeType.END: EndNodeConfig,
NodeType.ANSWER: EndNodeConfig,
NodeType.LLM: LLMNodeConfig,
NodeType.AGENT: AgentNodeConfig,
NodeType.IF_ELSE: IfElseNodeConfig,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
NodeType.ASSIGNER: AssignerNodeConfig,
NodeType.CODE: CodeNodeConfig,
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
NodeType.JINJARENDER: JinjaRenderNodeConfig,
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
NodeType.LOOP: LoopNodeConfig,
NodeType.ITERATION: IterationNodeConfig,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
NodeType.TOOL: ToolNodeConfig,
NodeType.MEMORY_READ: MemoryReadNodeConfig,
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
NodeType.NOTES: NoteNodeConfig,
}
@staticmethod
def _convert_file(var):
return None
@staticmethod
def _convert_array_file(var):
return []
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
try:
return config_cls.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
def get_node_convert(self, node_type: NodeType):
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
if not config_cls:
return lambda node_id, node_name, config: config
def validate(node_id: str, node_name: str, config: dict):
self.config_validate(node_id, node_name, config_cls, config)
return config
return validate

View File

@@ -17,6 +17,7 @@ from app.models.models_model import ModelConfig
from app.models.tool_model import ToolConfig as ToolConfigModel
from app.models.workflow_model import WorkflowConfig
from app.services.workflow_service import WorkflowService
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
class AppDslService:
@@ -243,7 +244,7 @@ class AppDslService:
model_parameters=cfg.get("model_parameters"),
default_model_config_id=self._resolve_model(cfg.get("default_model_config_ref"), tenant_id, warnings),
knowledge_retrieval=self._resolve_knowledge_retrieval(cfg.get("knowledge_retrieval"), workspace_id, warnings),
memory=cfg.get("memory"),
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
variables=cfg.get("variables", []),
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
skills=cfg.get("skills", {}),
@@ -272,12 +273,20 @@ class AppDslService:
))
elif app_type == AppType.WORKFLOW:
adapter = MemoryBearAdapter(dsl)
if not adapter.validate_config():
raise BusinessException("工作流配置格式无效", BizCode.BAD_REQUEST)
result = adapter.parse_workflow()
for e in result.errors:
warnings.append(f"[节点错误] {e.node_name or e.node_id}: {e.detail}")
for w in result.warnings:
warnings.append(f"[节点警告] {w.node_name or w.node_id}: {w.detail}")
wf = dsl.get("workflow") or {}
WorkflowService(self.db).create_workflow_config(
app_id=new_app.id,
nodes=wf.get("nodes", []),
edges=wf.get("edges", []),
variables=wf.get("variables", []),
nodes=[n.model_dump() for n in result.nodes],
edges=[e.model_dump() for e in result.edges],
variables=[v.model_dump() for v in result.variables],
execution_config=wf.get("execution_config", {}),
triggers=wf.get("triggers", []),
validate=False,
@@ -376,15 +385,37 @@ class AppDslService:
for kb in kr.get("knowledge_bases", []):
ref = kb.get("_ref") or ({"name": kb.get("kb_id")} if kb.get("kb_id") else None)
entry = {k: v for k, v in kb.items() if k != "_ref"}
entry["kb_id"] = self._resolve_kb(ref, workspace_id, warnings)
resolved_id = self._resolve_kb(ref, workspace_id, warnings)
if resolved_id is None:
continue
entry["kb_id"] = resolved_id
resolved_kbs.append(entry)
return {k: v for k, v in kr.items() if k != "knowledge_bases"} | {"knowledge_bases": resolved_kbs}
def _resolve_memory(self, memory: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[dict]:
if not memory:
return memory
config_id = memory.get("memory_config_id") or memory.get("memory_content")
if not config_id:
return memory
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
exists = self.db.query(MemoryConfigModel).filter(
MemoryConfigModel.config_id == config_id,
MemoryConfigModel.workspace_id == workspace_id
).first()
if not exists:
warnings.append(f"记忆配置 '{config_id}' 未匹配,已置空,请导入后手动配置")
return {**memory, "memory_config_id": None, "enabled": False}
return memory
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
result = []
for t in (tools or []):
ref = t.get("_ref") or ({"name": t.get("tool_id")} if t.get("tool_id") else None)
entry = {k: v for k, v in t.items() if k != "_ref"}
entry["tool_id"] = self._resolve_tool(ref, tenant_id, warnings)
resolved_id = self._resolve_tool(ref, tenant_id, warnings)
if resolved_id is None:
continue
entry["tool_id"] = resolved_id
result.append(entry)
return result

View File

@@ -75,7 +75,7 @@ REFRESH_TOKEN_EXPIRE_DAYS=7
ENABLE_SINGLE_SESSION=
# File Upload
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
MAX_FILE_SIZE=52428800 # 50MB:50 * 1024 * 1024
FILE_PATH=/files
FILE_LOCAL_SERVER_URL="http://localhost:8000/api"