diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index c729a3dc..f758dd15 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -10,7 +10,6 @@ from app.core.config import settings # 设置日志记录器 logger = logging.getLogger(__name__) - # 创建连接池 pool = ConnectionPool.from_url( f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}", @@ -21,6 +20,7 @@ pool = ConnectionPool.from_url( ) aio_redis = redis.StrictRedis(connection_pool=pool) + async def get_redis_connection(): """获取Redis连接""" try: @@ -29,7 +29,8 @@ async def get_redis_connection(): logger.error(f"Redis连接失败: {str(e)}") return None -async def aio_redis_set(key: str, val: str|dict, expire: int = None): + +async def aio_redis_set(key: str, val: str | dict, expire: int = None): """设置Redis键值 Args: @@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None): try: if isinstance(val, dict): val = json.dumps(val, ensure_ascii=False) - + if expire is not None: # 设置带过期时间的键值 await aio_redis.set(key, val, ex=expire) @@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None): except Exception as e: logger.error(f"Redis set错误: {str(e)}") + async def aio_redis_get(key: str): """获取Redis键值""" try: @@ -58,6 +60,7 @@ async def aio_redis_get(key: str): logger.error(f"Redis get错误: {str(e)}") return None + async def aio_redis_delete(key: str): """删除Redis键""" try: @@ -66,6 +69,7 @@ async def aio_redis_delete(key: str): logger.error(f"Redis delete错误: {str(e)}") return None + async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool: """发布消息到Redis频道""" try: @@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool: logger.error(f"Redis发布错误: {str(e)}") return False + class RedisSubscriber: """Redis订阅器""" - + def __init__(self, channel: str): self.channel = channel self.conn = None @@ -88,25 +93,25 @@ class RedisSubscriber: self.is_closed = False self._queue = asyncio.Queue() self._task = None - + async def start(self): """开始订阅""" if self.is_closed or self._task: return - + self._task = asyncio.create_task(self._receive_messages()) logger.info(f"开始订阅: {self.channel}") - + async def _receive_messages(self): """接收消息""" try: self.conn = await get_redis_connection() if not self.conn: return - + self.pubsub = self.conn.pubsub() await self.pubsub.subscribe(self.channel) - + while not self.is_closed: try: message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01) @@ -127,7 +132,7 @@ class RedisSubscriber: finally: await self._queue.put(None) await self._cleanup() - + async def _cleanup(self): """清理资源""" if self.pubsub: @@ -141,7 +146,7 @@ class RedisSubscriber: await self.conn.close() except Exception: pass - + async def get_message(self) -> Optional[Dict[str, Any]]: """获取消息""" if self.is_closed: @@ -153,7 +158,7 @@ class RedisSubscriber: except Exception as e: logger.error(f"获取消息错误: {str(e)}") return None - + async def close(self): """关闭订阅器""" if self.is_closed: @@ -163,32 +168,33 @@ class RedisSubscriber: self._task.cancel() await self._cleanup() + class RedisPubSubManager: """Redis发布订阅管理器""" - + def __init__(self): self.subscribers = {} - + async def publish(self, channel: str, message: Dict[str, Any]) -> bool: return await aio_redis_publish(channel, message) - + def get_subscriber(self, channel: str) -> RedisSubscriber: if channel in self.subscribers: subscriber = self.subscribers[channel] if not subscriber.is_closed: return subscriber - + subscriber = RedisSubscriber(channel) self.subscribers[channel] = subscriber return subscriber - + def cancel_subscription(self, channel: str) -> bool: if channel in self.subscribers: asyncio.create_task(self.subscribers[channel].close()) del self.subscribers[channel] return True return False - + def cancel_all_subscriptions(self) -> int: count = len(self.subscribers) for subscriber in self.subscribers.values(): @@ -196,6 +202,6 @@ class RedisPubSubManager: self.subscribers.clear() return count + # 全局实例 pubsub_manager = RedisPubSubManager() - diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index f1508114..e2849ad6 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -1,7 +1,8 @@ import uuid from typing import Optional, Annotated -from fastapi import APIRouter, Depends, Path +import yaml +from fastapi import APIRouter, Depends, Path, Form, UploadFile, File from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -17,12 +18,13 @@ from app.repositories.end_user_repository import EndUserRepository from app.schemas import app_schema from app.schemas.response_schema import PageData, PageMeta from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema -from app.schemas.workflow_schema import WorkflowConfigUpdate +from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave from app.services import app_service, workspace_service from app.services.agent_config_helper import enrich_agent_config from app.services.app_service import AppService -from app.services.workflow_service import WorkflowService, get_workflow_service from app.services.app_statistics_service import AppStatisticsService +from app.services.workflow_import_service import WorkflowImportService +from app.services.workflow_service import WorkflowService, get_workflow_service router = APIRouter(prefix="/apps", tags=["Apps"]) logger = get_business_logger() @@ -65,7 +67,7 @@ def list_apps( # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: - app_ids = [id.strip() for id in ids.split(',') if id.strip()] + app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] return success(data=items) @@ -879,6 +881,60 @@ async def update_workflow_config( return success(data=WorkflowConfigSchema.model_validate(cfg)) +@router.get("/{app_id}/workflow/export") +@cur_workspace_access_guard() +async def export_workflow_config( + app_id: uuid.UUID, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)] +): + """导出工作流配置为YAML文件""" + workflow_service = WorkflowService(db) + + return success(data={ + "content": workflow_service.export_workflow_dsl(app_id=app_id), + }) + + +@router.post("/workflow/import") +@cur_workspace_access_guard() +async def import_workflow_config( + file: UploadFile = File(...), + platform: str = Form(...), + app_id: str = Form(None), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) + +): + """从YAML内容导入工作流配置""" + if not file.filename.lower().endswith((".yaml", ".yml")): + return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST) + + raw_text = (await file.read()).decode("utf-8") + import_service = WorkflowImportService(db) + config = yaml.safe_load(raw_text) + result = await import_service.upload_config(platform, config) + return success(data=result) + + +@router.post("/workflow/import/save") +@cur_workspace_access_guard() +async def save_workflow_import( + data: WorkflowImportSave, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + import_service = WorkflowImportService(db) + app = await import_service.save_workflow( + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, + temp_id=data.temp_id, + name=data.name, + description=data.description, + ) + return success(data=app_schema.App.model_validate(app)) + + @router.get("/{app_id}/statistics", summary="应用统计数据") @cur_workspace_access_guard() def get_app_statistics( @@ -889,12 +945,14 @@ def get_app_statistics( current_user=Depends(get_current_user), ): """获取应用统计数据 - + Args: app_id: 应用ID start_date: 开始时间戳(毫秒) end_date: 结束时间戳(毫秒) - + db: 数据库连接 + current_user: 当前用户 + Returns: - daily_conversations: 每日会话数统计 - total_conversations: 总会话数 @@ -931,6 +989,8 @@ def get_workspace_api_statistics( Args: start_date: 开始时间戳(毫秒) end_date: 结束时间戳(毫秒) + db: 数据库连接 + current_user: 当前用户 Returns: 每日统计数据列表,每项包含: diff --git a/api/app/core/config.py b/api/app/core/config.py index 3a0c97b4..0962b545 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -16,18 +16,18 @@ class Settings: # cloud: SaaS 云服务版(全功能,按量计费) # enterprise: 企业私有化版(License 控制) DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community") - + # License 配置(企业版) LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json") LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com") - + # 计费服务配置(SaaS 版) BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "") - + # 基础 URL(用于 SSO 回调等) BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000") FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000") - + ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" # API Keys Configuration OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") @@ -57,7 +57,6 @@ class Settings: REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") - # ElasticSearch configuration ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") @@ -91,7 +90,7 @@ class Settings: # Single Sign-On configuration ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true" - + # SSO 免登配置 SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300")) SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}") @@ -130,7 +129,7 @@ class Settings: # Server Configuration SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1") - FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api") + FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api") # ======================================================================== # Internal Configuration (not in .env, used by application code) @@ -225,6 +224,7 @@ class Settings: LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true" # workflow config + WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800)) WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) # ======================================================================== @@ -232,20 +232,20 @@ class Settings: # ======================================================================== # 通用本体文件路径列表(逗号分隔) GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl") - + # 是否启用通用本体类型功能 ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true" - + # Prompt 中最大类型数量 MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50")) - + # 核心通用类型列表(逗号分隔) CORE_GENERAL_TYPES: str = os.getenv( "CORE_GENERAL_TYPES", "Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building," "Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject" ) - + # 实验模式开关(允许通过 API 动态切换本体配置) ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true" diff --git a/api/app/core/workflow/adapters/__init__.py b/api/app/core/workflow/adapters/__init__.py new file mode 100644 index 00000000..141aa4ab --- /dev/null +++ b/api/app/core/workflow/adapters/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 15:54 +from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter +from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter + +__all__ = ["DifyAdapter", "MemoryBearAdapter"] diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py new file mode 100644 index 00000000..601c8ff2 --- /dev/null +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -0,0 +1,88 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 15:58 +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from app.core.workflow.adapters.errors import ExceptionDefineition +from app.schemas.workflow_schema import ( + EdgeDefinition, + NodeDefinition, + VariableDefinition, + ExecutionConfig, + TriggerConfig +) + + +class PlatformType(StrEnum): + MEMORY_BEAR = "memory_bear" + DIFY = "dify" + COZE = "coze" + + +class PlatformMetadata(BaseModel): + platform_name: str + version: str + support_node_types: list[str] + + +class WorkflowParserResult(BaseModel): + success: bool + platform: PlatformMetadata + execution_config: ExecutionConfig + origin_config: dict[str, Any] + trigger: TriggerConfig | None + edges: list[EdgeDefinition] = Field(default_factory=list) + nodes: list[NodeDefinition] = Field(default_factory=list) + variables: list[VariableDefinition] = Field(default_factory=list) + warnings: list[ExceptionDefineition] = Field(default_factory=list) + errors: list[ExceptionDefineition] = Field(default_factory=list) + + +class WorkflowImportResult(BaseModel): + success: bool + temp_id: str | None = Field(..., description="cache id") + workflow_id: str | None = Field(..., description="workflow id") + edges: list[EdgeDefinition] = Field(default_factory=list) + nodes: list[NodeDefinition] = Field(default_factory=list) + variables: list[VariableDefinition] = Field(default_factory=list) + warnings: list[ExceptionDefineition] = Field(default_factory=list) + errors: list[ExceptionDefineition] = Field(default_factory=list) + + +class BasePlatformAdapter(ABC): + def __init__(self, config: dict[str, Any]): + self.config = config + self.nodes: list[NodeDefinition] = [] + self.edges: list[EdgeDefinition] = [] + self.conv_variables: list[VariableDefinition] = [] + + self.errors = [] + self.warnings = [] + + self.branch_node_cache = defaultdict(list) + self.error_branch_node_cache = [] + + @abstractmethod + def get_metadata(self) -> PlatformMetadata: + """get platform metadata""" + pass + + @abstractmethod + def validate_config(self) -> bool: + """platform configuration validate""" + pass + + @abstractmethod + def parse_workflow(self) -> WorkflowParserResult: + """parse platform configuration to local config""" + pass + + @abstractmethod + def map_node_type(self, platform_node_type: str) -> str: + pass diff --git a/api/app/core/workflow/adapters/base_converter.py b/api/app/core/workflow/adapters/base_converter.py new file mode 100644 index 00000000..eebde971 --- /dev/null +++ b/api/app/core/workflow/adapters/base_converter.py @@ -0,0 +1,75 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 14:32 +from abc import ABC, abstractmethod + +from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType + + +class BaseConverter(ABC): + @staticmethod + def _convert_string(var): + try: + return str(var) + except: + return DEFAULT_VALUE(VariableType.STRING) + + @staticmethod + def _convert_boolean(var): + try: + return bool(var) + except: + return DEFAULT_VALUE(VariableType.BOOLEAN) + + @staticmethod + def _convert_number(var): + try: + return float(var) + except: + return DEFAULT_VALUE(VariableType.NUMBER) + + @staticmethod + def _convert_object(var): + try: + return dict(var) + except: + return DEFAULT_VALUE(VariableType.OBJECT) + + @staticmethod + @abstractmethod + def _convert_file(var): + pass + + @staticmethod + def _convert_array_string(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_STRING) + + @staticmethod + def _convert_array_number(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_NUMBER) + + @staticmethod + def _convert_array_boolean(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN) + + @staticmethod + def _convert_array_object(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_OBJECT) + + @staticmethod + @abstractmethod + def _convert_array_file(var): + pass diff --git a/api/app/core/workflow/adapters/dify/__init__.py b/api/app/core/workflow/adapters/dify/__init__.py new file mode 100644 index 00000000..7774dcaa --- /dev/null +++ b/api/app/core/workflow/adapters/dify/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 18:20 diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py new file mode 100644 index 00000000..0e92b2c7 --- /dev/null +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -0,0 +1,659 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 18:21 +import base64 +import re +from typing import Any +from urllib.parse import quote + +from app.core.workflow.adapters.base_converter import BaseConverter +from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \ + ExceptionType +from app.core.workflow.nodes.assigner import AssignerNodeConfig +from app.core.workflow.nodes.assigner.config import AssignmentItem +from app.core.workflow.nodes.base_config import VariableDefinition +from app.core.workflow.nodes.code import CodeNodeConfig +from app.core.workflow.nodes.code.config import InputVariable, OutputVariable +from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig +from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig +from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \ + CycleVariable +from app.core.workflow.nodes.end import EndNodeConfig +from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \ + HttpContentType, HttpErrorHandle +from app.core.workflow.nodes.http_request import HttpRequestNodeConfig +from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \ + HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig +from app.core.workflow.nodes.if_else import IfElseNodeConfig +from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig +from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig +from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig +from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig +from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig +from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig +from app.core.workflow.nodes.question_classifier.config import ClassifierConfig +from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE + + +class DifyConverter(BaseConverter): + errors: list + warnings: list + branch_node_cache: dict + error_branch_node_cache: list + + def __init__(self): + self.CONFIG_CONVERT_MAP = { + "start": self.convert_start_node_config, + "llm": self.convert_llm_node_config, + "answer": self.convert_end_node_config, + "if-else": self.convert_if_else_node_config, + "loop": self.convert_loop_node_config, + "iteration": self.convert_iteration_node_config, + "assigner": self.convert_assigner_node_config, + "code": self.convert_code_node_config, + "http-request": self.convert_http_node_config, + "template-transform": self.convert_jinja_render_node_config, + "knowledge-retrieval": self.convert_knowledge_node_config, + "parameter-extractor": self.convert_parameter_extractor_node_config, + "question-classifier": self.convert_question_classifier_node_config, + "variable-aggregator": self.convert_variable_aggregator, + "loop-start": lambda x: {}, + "iteration-start": lambda x: {}, + "loop-end": lambda x: {}, + } + + def get_node_convert(self, node_type): + func = self.CONFIG_CONVERT_MAP.get(node_type, None) + return func + + @staticmethod + def is_variable(expression) -> bool: + return bool(re.match(r"\{\{#(.*?)#}}", expression)) + + @staticmethod + def process_var_selector(var_selector): + if not var_selector: + return "" + selector = var_selector.split('.') + if len(selector) != 2: + raise Exception(f"invalid variable selector: {var_selector}") + if selector[0] == "conversation": + selector[0] = "conv" + var_selector = ".".join(selector) + mapping = { + "sys.query": "sys.message" + } + + var_selector = mapping.get(var_selector, var_selector) + return var_selector + + def _process_list_variable_litearl(self, variable_selector: list) -> str | None: + if not self.process_var_selector(".".join(variable_selector)): + return None + return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" + + def trans_variable_format(self, content): + pattern = re.compile(r"\{\{#(.*?)#}}") + + def replacer(match: re.Match) -> str: + raw_name = match.group(1) + new_name = self.process_var_selector(raw_name) + return f"{{{{{new_name}}}}}" + + return pattern.sub(replacer, content) + + @staticmethod + def _convert_file(var): + pass + + @staticmethod + def _convert_array_file(var): + pass + + @staticmethod + def variable_type_map(source_type) -> VariableType | None: + type_map = { + "file": VariableType.FILE, + "paragraph": VariableType.STRING, + "text-input": VariableType.STRING, + "number": VariableType.NUMBER, + "checkbox": VariableType.BOOLEAN, + "file-list": VariableType.ARRAY_FILE, + "select": VariableType.STRING, + } + var_type = type_map.get(source_type, source_type) + return var_type + + def convert_variable_type(self, target_type: VariableType, origin_value: Any): + if not origin_value: + return DEFAULT_VALUE(target_type) + try: + match target_type: + case VariableType.STRING: + return self._convert_string(origin_value) + case VariableType.NUMBER: + return self._convert_number(origin_value) + case VariableType.BOOLEAN: + return self._convert_boolean(origin_value) + case VariableType.FILE: + return self._convert_file(origin_value) + case VariableType.ARRAY_FILE: + return self._convert_array_file(origin_value) + case _: + return origin_value + except: + raise Exception(f"convert variable failed: {target_type}") + + @staticmethod + def convert_compare_operator(operator): + operator_map = { + "is": ComparisonOperator.EQ, + "is not": ComparisonOperator.NE, + "=": ComparisonOperator.EQ, + "≠": ComparisonOperator.NE, + ">": ComparisonOperator.GT, + "<": ComparisonOperator.LT, + "≥": ComparisonOperator.GE, + "≤": ComparisonOperator.LE, + "not empty": ComparisonOperator.NOT_EMPTY, + } + return operator_map.get(operator, operator) + + @staticmethod + def convert_assignment_operator(operator): + operator_map = { + "+=": AssignmentOperator.ADD, + "-=": AssignmentOperator.SUBTRACT, + "*=": AssignmentOperator.MULTIPLY, + "/=": AssignmentOperator.DIVIDE, + "over-write": AssignmentOperator.COVER, + "remove-last": AssignmentOperator.REMOVE_LAST, + "remove-first": AssignmentOperator.REMOVE_FIRST, + + } + return operator_map.get(operator, operator) + + @staticmethod + def convert_http_auth_type(auth_type): + auth_type_map = { + "no-auth": HttpAuthType.NONE, + "bearer": HttpAuthType.BEARER, + "basic": HttpAuthType.BASIC, + "custom": HttpAuthType.CUSTOM, + } + return auth_type_map.get(auth_type, auth_type) + + @staticmethod + def convert_http_content_type(content_type): + content_type_map = { + "none": HttpContentType.NONE, + "form-data": HttpContentType.FROM_DATA, + "x-www-form-urlencoded": HttpContentType.WWW_FORM, + "json": HttpContentType.JSON, + "raw-text": HttpContentType.RAW, + "binary": HttpContentType.BINARY, + } + return content_type_map.get(content_type, content_type) + + @staticmethod + def convert_http_error_handle_type(handle_type): + handle_type_map = { + "none": HttpErrorHandle.NONE, + "fail-branch": HttpErrorHandle.BRANCH, + "default-value": HttpErrorHandle.DEFAULT, + } + return handle_type_map.get(handle_type, handle_type) + + def convert_start_node_config(self, node: dict) -> dict: + node_data = node["data"] + start_vars = [] + for var in node_data["variables"]: + var_type = self.variable_type_map(var["type"]) + if not var_type: + self.errors.append( + UnsupportVariableType( + scope=node["id"], + name=var["variable"], + var_type=var["type"], + node_id=node["id"], + node_name=node_data["title"] + ) + ) + continue + + if var_type in ["file", "array[file]"]: + self.errors.append( + ExceptionDefineition( + type=ExceptionType.VARIABLE, + node_id=node["id"], + node_name=node_data["title"], + name=var["variable"], + detail=f"Unsupport Variable type for start node: {var_type}" + ) + ) + continue + + var_def = VariableDefinition( + name=var["variable"], + type=var_type, + required=var["required"], + default=self.convert_variable_type( + var_type, var["default"] + ), + description=var["label"], + max_length=var.get("max_length"), + ) + start_vars.append(var_def) + return StartNodeConfig( + variables=start_vars + ).model_dump() + + def convert_question_classifier_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + categories = [] + for category in node_data["classes"]: + self.branch_node_cache[node["id"]].append(category["id"]) + categories.append( + ClassifierConfig( + class_name=category["name"], + ) + ) + + return QuestionClassifierNodeConfig.model_construct( + input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]), + user_supplement_prompt=self.trans_variable_format(node_data["instructions"]), + categories=categories, + ).model_dump() + + def convert_llm_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) + memory = MemoryWindowSetting( + enable=bool(node_data.get("memory")), + enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), + window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20)) + ) + messages = [] + for message in node_data["prompt_template"]: + messages.append( + MessageConfig( + role=message["role"], + content=self.trans_variable_format(message["text"]) + ) + ) + if memory.enable: + messages.append( + MessageConfig( + role="user", + content=self.trans_variable_format(node_data["memory"]["query_prompt_template"]) + ) + ) + vision = node_data["vision"]["enabled"] + vision_input = self._process_list_variable_litearl( + node_data["vision"]["configs"]["variable_selector"] + ) if vision else None + return LLMNodeConfig.model_construct( + model_id=None, + context=context, + memory=memory, + vision=vision, + vision_input=vision_input, + messages=messages + ).model_dump() + + def convert_end_node_config(self, node: dict) -> dict: + node_data = node["data"] + return EndNodeConfig( + output=self.trans_variable_format(node_data["answer"]), + ).model_dump() + + def convert_if_else_node_config(self, node: dict) -> dict: + node_data = node["data"] + cases = [] + for case in node_data["cases"]: + case_id = case["id"] + logical_operator = case["logical_operator"] + conditions = [] + for condition in case["conditions"]: + right_value = condition["value"] + condition_detail = ConditionDetail( + operator=self.convert_compare_operator(condition["comparison_operator"]), + left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}", + right=self.trans_variable_format( + right_value + ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( + self.variable_type_map(condition["varType"]), + condition["value"] + ), + input_type=ValueInputType.VARIABLE + if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT, + ) + conditions.append(condition_detail) + cases.append( + ConditionBranchConfig( + logical_operator=logical_operator, + expressions=conditions + ) + ) + self.branch_node_cache[node["id"]].append(case_id) + return IfElseNodeConfig( + cases=cases + ).model_dump() + + def convert_loop_node_config(self, node: dict) -> dict: + node_data = node["data"] + logical_operator = node_data["logical_operator"] + conditions = [] + for condition in node_data["break_conditions"]: + right_value = condition["value"] + conditions.append( + LoopConditionDetail( + operator=self.convert_compare_operator(condition["comparison_operator"]), + left=self._process_list_variable_litearl(condition["variable_selector"]), + right=self.trans_variable_format( + right_value + ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( + self.variable_type_map(condition["varType"]), + condition["value"] + ), + input_type=ValueInputType.VARIABLE + if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT, + ) + ) + condition_config = ConditionsConfig( + logical_operator=logical_operator, + expressions=conditions + ) + loop_variables = [] + for variable in node_data["loop_variables"]: + right_input_type = variable["value_type"] + right_value_type = self.variable_type_map(variable["var_type"]) + if right_input_type == ValueInputType.VARIABLE: + right_value = self._process_list_variable_litearl(variable["value"]) + else: + right_value = self.convert_variable_type(right_value_type, variable["value"]) + loop_variables.append( + CycleVariable( + name=variable["label"], + type=right_value_type, + value=right_value, + input_type=right_input_type + ) + ) + return LoopNodeConfig( + condition=condition_config, + cycle_vars=loop_variables, + max_loop=node_data["loop_count"] + ).model_dump() + + def convert_iteration_node_config(self, node: dict) -> dict: + node_data = node["data"] + return IterationNodeConfig( + input=self._process_list_variable_litearl(node_data["iterator_selector"]), + parallel=node_data["is_parallel"], + parallel_count=node_data["parallel_nums"], + output=self._process_list_variable_litearl(node_data["output_selector"]), + output_type=self.variable_type_map(node_data["output_type"]), + flatten=node_data["flatten_output"], + ).model_dump() + + def convert_assigner_node_config(self, node: dict) -> dict: + node_data = node["data"] + assignments = [] + for assignment in node_data["items"]: + if assignment.get("operation") is None or assignment.get("value") is None: + continue + assignments.append( + AssignmentItem( + variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), + value=self._process_list_variable_litearl( + assignment["value"] + ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], + operation=self.convert_assignment_operator(assignment["operation"]) + ) + ) + return AssignerNodeConfig( + assignments=assignments + ).model_dump() + + def convert_code_node_config(self, node: dict) -> dict: + node_data = node["data"] + input_variables = [] + for input_variable in node_data["variables"]: + input_variables.append( + InputVariable( + name=input_variable["variable"], + variable=self._process_list_variable_litearl(input_variable["value_selector"]), + ) + ) + + output_variables = [] + for output_variable in node_data["outputs"]: + output_variables.append( + OutputVariable( + name=output_variable, + type=node_data["outputs"][output_variable]["type"], + ) + ) + + code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8") + + return CodeNodeConfig( + input_variables=input_variables, + language=node_data["code_language"], + output_variables=output_variables, + code=code + ).model_dump() + + def convert_http_node_config(self, node: dict) -> dict: + node_data = node["data"] + if node_data["authorization"] != 'no-auth': + auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"]) + auth_config = HttpAuthConfig( + auth_type=auth_type, + header=node_data["authorization"]["config"].get("header"), + api_key=node_data["authorization"]["config"].get("api_key"), + ) + else: + auth_config = HttpAuthConfig() + + content_type = self.convert_http_content_type(node_data["body"]["type"]) + if content_type == HttpContentType.FROM_DATA: + body_content = [] + for content in node_data["body"]["data"]: + body_content.append( + HttpFormData( + key=self.trans_variable_format(content["key"]), + type=content["type"], + value=self.trans_variable_format(content["value"]), + ) + ) + elif content_type == HttpContentType.WWW_FORM: + body_content = {} + for content in node_data["body"]["data"]: + body_content[ + self.trans_variable_format(content["key"]) + ] = self.trans_variable_format(content["value"]) + else: + if node_data["body"]["data"]: + body_content = node_data["body"]["data"][0]["value"] + else: + body_content = "" + + headers = {} + for header in node_data["headers"].split("\n"): + if not header: + continue + + key_value = header.split(":") + if len(key_value) == 2: + headers[ + self.trans_variable_format(key_value[0]) + ] = self.trans_variable_format(key_value[1]) + else: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node["id"], + node_name=node_data["title"], + detail=f"Invalid header/param - {header}", + )) + + params = {} + for param in node_data["params"].split("\n"): + if not param: + continue + + key_value = param.split(":") + if len(key_value) == 2: + params[ + self.trans_variable_format(key_value[0]) + ] = self.trans_variable_format(key_value[1]) + else: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node["id"], + node_name=node_data["title"], + detail=f"Invalid header/param - {param}", + )) + + error_handle_type = self.convert_http_error_handle_type( + node_data.get("error_strategy", "none") + ) + default_value = None + if error_handle_type == HttpErrorHandle.DEFAULT: + default_body = "" + default_header = {} + default_status_code = 0 + for var in node_data["default_value"]: + if var["key"] == "body": + default_body = var["value"] + elif var["key"] == "header": + default_header = var["value"] + elif var["key"] == "status_code": + default_status_code = var["value"] + default_value = HttpErrorDefaultTamplete( + body=default_body, + headers=default_header, + status_code=default_status_code, + ) + + self.error_branch_node_cache.append(node['id']) + return HttpRequestNodeConfig( + method=node_data["method"].upper(), + url=node_data["url"], + auth=auth_config, + body=HttpContentTypeConfig( + content_type=self.convert_http_content_type(node_data["body"]["type"]), + data=body_content, + ), + headers=headers, + params=params, + verify_ssl=node_data["ssl_verify"], + timeouts=HttpTimeOutConfig( + connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5, + read_timeout=node_data["timeout"]["max_read_timeout"] or 5, + write_timeout=node_data["timeout"]["max_write_timeout"] or 5, + ), + retry=HttpRetryConfig( + enable=node_data["retry_config"]["retry_enabled"], + max_attempts=node_data["retry_config"]["max_retries"], + retry_interval=node_data["retry_config"]["retry_interval"], + ), + error_handle=HttpErrorHandleConfig( + method=error_handle_type, + default=default_value, + ) + ).model_dump() + + def convert_jinja_render_node_config(self, node: dict) -> dict: + node_data = node["data"] + mapping = [] + for variable in node_data["variables"]: + mapping.append(VariablesMappingConfig( + name=variable["variable"], + value=self._process_list_variable_litearl(variable["value_selector"]) + )) + return JinjaRenderNodeConfig( + template=node_data["template"], + mapping=mapping, + ).model_dump() + + def convert_knowledge_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append(ExceptionDefineition( + node_id=node["id"], + node_name=node_data["title"], + type=ExceptionType.CONFIG, + detail=f"Please reconfigure the Knowledge Retrieval node.", + )) + return KnowledgeRetrievalNodeConfig.model_construct( + query=self._process_list_variable_litearl(node_data["query_variable_selector"]), + ).model_dump() + + def convert_parameter_extractor_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + params = [] + for param in node_data["parameters"]: + params.append( + ParamsConfig( + name=param["name"], + desc=param["description"], + required=param["required"], + type=param["type"], + ) + ) + return ParameterExtractorNodeConfig.model_construct( + text=self._process_list_variable_litearl(node_data["query"]), + params=params, + prompt=node_data["instruction"] + ).model_dump() + + def convert_variable_aggregator(self, node: dict) -> dict: + node_data = node["data"] + group_enable = node_data["advanced_settings"]["group_enabled"] + group_variables = {} + group_type = {} + if not group_enable: + group_variables["output"] = [ + self._process_list_variable_litearl(variable) + for variable in node_data["variables"] + ] + group_type["output"] = node_data["output_type"] + else: + for group in node_data["advanced_settings"]["groups"]: + group_variables[group["group_name"]] = [ + self._process_list_variable_litearl(variable) + for variable in group["variables"] + ] + group_type[group["group_name"]] = group["output_type"] + + return VariableAggregatorNodeConfig( + group=group_enable, + group_variables=group_variables, + group_type=group_type, + ).model_dump() diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py new file mode 100644 index 00000000..48a0cbd6 --- /dev/null +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -0,0 +1,239 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 16:05 +from typing import Any + +from app.core.logging_config import get_logger +from app.core.workflow.adapters.base_adapter import ( + BasePlatformAdapter, + PlatformMetadata, + PlatformType, + WorkflowParserResult +) +from app.core.workflow.adapters.dify.converter import DifyConverter +from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.nodes.enums import NodeType +from app.schemas.workflow_schema import ( + NodeDefinition, + EdgeDefinition, + VariableDefinition, + TriggerConfig, + ExecutionConfig +) + +logger = get_logger() + + +class DifyAdapter(BasePlatformAdapter, DifyConverter): + NODE_TYPE_MAPPING = { + "start": NodeType.START, + "llm": NodeType.LLM, + "answer": NodeType.END, + "if-else": NodeType.IF_ELSE, + "loop-start": NodeType.CYCLE_START, + "iteration-start": NodeType.CYCLE_START, + "assigner": NodeType.ASSIGNER, + "loop": NodeType.LOOP, + "iteration": NodeType.ITERATION, + "loop-end": NodeType.BREAK, + "code": NodeType.CODE, + "http-request": NodeType.HTTP_REQUEST, + "template-transform": NodeType.JINJARENDER, + "knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL, + "parameter-extractor": NodeType.PARAMETER_EXTRACTOR, + "question-classifier": NodeType.QUESTION_CLASSIFIER, + "variable-aggregator": NodeType.VAR_AGGREGATOR + } + + def __init__(self, config: dict[str, Any]): + DifyConverter.__init__(self) + BasePlatformAdapter.__init__(self, config) + + def get_metadata(self) -> PlatformMetadata: + return PlatformMetadata( + platform_name=PlatformType.DIFY, + version="0.5.0", + support_node_types=list(self.NODE_TYPE_MAPPING.keys()) + ) + + def map_node_type(self, platform_node_type) -> str: + return self.NODE_TYPE_MAPPING.get(platform_node_type) + + @property + def origin_nodes(self): + return self.config.get("workflow").get("graph").get("nodes") + + @property + def origin_edges(self): + return self.config.get("workflow").get("graph").get("edges") + + @staticmethod + def _valid_nodes(node: dict[str, Any]): + if "data" not in node: + return False + if "type" not in node["data"]: + return False + if "id" not in node or "type" not in node: + return False + return True + + def validate_config(self) -> bool: + require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'}) + if not all(field in self.config for field in require_fields): + return False + + for node in self.origin_nodes: + if not self._valid_nodes(node): + return False + return True + + def parse_workflow(self) -> WorkflowParserResult: + for node in self.origin_nodes: + node = self._convert_node(node) + if node: + self.nodes.append(node) + nodes_id = [node.id for node in self.nodes] + for edge in self.origin_edges: + source = edge["source"] + target = edge["target"] + if source not in nodes_id or target not in nodes_id: + continue + edge = self._convert_edge(edge) + if edge: + self.edges.append(edge) + # + for variable in self.config.get("workflow").get("conversation_variables"): + con_var = self._convert_variable(variable) + if variable: + self.conv_variables.append(con_var) + # + # for variables in config.get("workflow").get("environment_variables"): + # variable = self._convert_variable(variables) + # conv_variables.append(variable) + + trigger = self._convert_trigger({}) + execution_config = self._convert_execution({}) + + return WorkflowParserResult( + success=not self.errors and not self.warnings, + platform=self.get_metadata(), + execution_config=execution_config, + origin_config=self.config, + trigger=trigger, + edges=self.edges, + nodes=self.nodes, + variables=self.conv_variables, + warnings=self.warnings, + errors=self.errors + ) + + def _convert_cycle_node_position(self, node_id: str, position: dict): + for node in self.origin_nodes: + if node["id"] == node_id: + return { + "x": node["position"]["x"] + position["x"], + "y": node["position"]["y"] + position["y"] + } + self.errors.append( + ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node_id, + detail="parent cycle node not found" + ) + ) + raise Exception("parent cycle node not found") + + def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None: + node_data = node["data"] + try: + return NodeDefinition( + id=node["id"], + type=self.map_node_type(node_data["type"]), + name=node_data.get("title"), + cycle=node.get("parentId"), + description=None, + config=self._convert_node_config(node), + position={ + "x": node["position"]["x"], + "y": node["position"]["y"] + } if node.get("parentId") is None else self._convert_cycle_node_position( + node["parentId"], + node["position"] + ), + error_handling=None, + cache=None + ) + except Exception as e: + logger.debug(f"convert node error - {e}", exc_info=True) + + def _convert_node_config(self, node: dict): + node_data = node["data"] + node_type = node_data["type"] + try: + converter = self.get_node_convert(node_type) + if converter is None: + raise Exception(f"node type not supported - {node_type}") + return converter(node) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node["id"], + node_name=node["data"]["title"], + detail=f"convert node error - {e}", + )) + raise e + + def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: + try: + + source = edge["source"] + target = edge["target"] + edge_id = edge["id"] + label = None + if source in self.branch_node_cache: + case_id = "-".join(edge_id.split("-")[1:-2]) + if case_id == "false": + label = f'CASE{len(self.branch_node_cache[source])+1}' + else: + label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}' + if source in self.error_branch_node_cache: + case_id = "-".join(edge_id.split("-")[1:-2]) + if case_id == "source": + label = "SUCCESS" + else: + label = "ERROR" + return EdgeDefinition( + id=edge["id"], + source=source, + target=target, + label=label, + ) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.EDGE, + detail=f"convert edge error - {e}", + )) + return None + + def _convert_variable(self, variable) -> VariableDefinition | None: + try: + return VariableDefinition( + name=variable["name"], + default=variable["value"], + type=variable["value_type"], + ) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.VARIABLE, + name=variable.get("name"), + detail=f"convert variable error - {e}", + )) + + def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None: + pass + + def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig: + return ExecutionConfig() + + diff --git a/api/app/core/workflow/adapters/errors.py b/api/app/core/workflow/adapters/errors.py new file mode 100644 index 00000000..c0340a5e --- /dev/null +++ b/api/app/core/workflow/adapters/errors.py @@ -0,0 +1,75 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 11:29 +from enum import StrEnum + +from pydantic import BaseModel + + +class ExceptionType(StrEnum): + NODE = "node" + EDGE = "edge" + VARIABLE = "variable" + TRIGGER = "trigger" + EXECUTION = "execution" + CONFIG = "config" + PLATFORM = "platform" + UNKNOWN = "unknown" + + +class ExceptionDefineition(BaseModel): + type: ExceptionType + detail: str + + node_id: str | None = None + node_name: str | None = None + + scope: str | None = None + name: str | None = None + + +class UnknowModelWarning(ExceptionDefineition): + type: ExceptionType = ExceptionType.NODE + + def __init__(self, node_id, node_name, model_name): + super().__init__( + detail=f"Please specify the model mapping manually for model: {model_name}", + node_id=node_id, + node_name=node_name + ) + + +class UnknowError(ExceptionDefineition): + type: ExceptionType = ExceptionType.UNKNOWN + + def __init__(self, detail: str, **kwargs): + super().__init__(detail=detail, **kwargs) + + +class UnsupportPlatform(ExceptionDefineition): + type: ExceptionType = ExceptionType.PLATFORM + + def __init__(self, platform: str): + super().__init__(detail=f"Unsupport platform {platform}") + + +class UnsupportVariableType(ExceptionDefineition): + type: ExceptionType = ExceptionType.VARIABLE + + def __init__(self, scope, name, var_type: str, **kwargs): + super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs) + + +class InvalidConfiguration(ExceptionDefineition): + type: ExceptionType = ExceptionType.CONFIG + + def __init__(self): + super().__init__(detail="Invalid workflow configuration format") + + +class UnsupportNodeType(ExceptionDefineition): + type: ExceptionType = ExceptionType.NODE + + def __init__(self, node_id: str, node_type: str): + super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") diff --git a/api/app/core/workflow/adapters/memory_bear/__init__.py b/api/app/core/workflow/adapters/memory_bear/__init__.py new file mode 100644 index 00000000..f314662f --- /dev/null +++ b/api/app/core/workflow/adapters/memory_bear/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 11:30 diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py new file mode 100644 index 00000000..0e3f459f --- /dev/null +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py @@ -0,0 +1,76 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:11 +from typing import Any + +from app.core.workflow.adapters.base_adapter import ( + PlatformMetadata, + PlatformType, + BasePlatformAdapter, + WorkflowParserResult +) +from app.schemas.workflow_schema import ExecutionConfig + + +class MemoryBearAdapter(BasePlatformAdapter): + NODE_TYPE_MAPPING = {} + + @property + def origin_nodes(self): + return self.config.get("workflow").get("nodes") + + @property + def origin_edges(self): + return self.config.get("workflow").get("edges") + + @property + def origin_variables(self): + return self.config.get("workflow").get("variables") + + def get_metadata(self) -> PlatformMetadata: + return PlatformMetadata( + platform_name=PlatformType.MEMORY_BEAR, + version="0.2.5", + support_node_types=list(self.NODE_TYPE_MAPPING.keys()) + ) + + def map_node_type(self, platform_node_type) -> str: + return platform_node_type + + @staticmethod + def _valid_nodes(node: dict[str, Any]): + if "type" not in node["data"]: + return False + if "id" not in node or "type" not in node: + return False + return True + + def validate_config(self) -> bool: + require_fields = frozenset({'app', 'workflow'}) + if not all(field in self.config for field in require_fields): + return False + + for node in self.origin_nodes: + if not self._valid_nodes(node): + return False + return True + + def parse_workflow(self) -> WorkflowParserResult: + self.nodes = self.origin_nodes + self.edges = self.origin_edges + self.conv_variables = self.origin_variables + + return WorkflowParserResult( + success=True, + platform=self.get_metadata(), + execution_config=ExecutionConfig(), + origin_config=self.config, + trigger=None, + edges=self.edges, + nodes=self.nodes, + variables=self.conv_variables, + warnings=self.warnings, + errors=self.errors, + + ) diff --git a/api/app/core/workflow/adapters/registry.py b/api/app/core/workflow/adapters/registry.py new file mode 100644 index 00000000..10012676 --- /dev/null +++ b/api/app/core/workflow/adapters/registry.py @@ -0,0 +1,34 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:19 +from typing import Any + +from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter +from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType + + +class PlatformAdapterRegistry: + _adapters: dict[str, type[BasePlatformAdapter]] = {} + + @classmethod + def register(cls, platform: str, adapter: type[BasePlatformAdapter]): + cls._adapters[platform] = adapter + + @classmethod + def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter: + if platform not in cls._adapters: + raise ValueError(f"Unsupported platform: {platform}") + return cls._adapters.get(platform)(config) + + @classmethod + def list_platforms(cls) -> list[str]: + return list(cls._adapters.keys()) + + @classmethod + def is_supported(cls, platform: str) -> bool: + return platform in cls._adapters + + +PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter) +PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter) diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index 5155a76f..ba6af156 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool logger = get_logger(__name__) SCOPE_PATTERN = re.compile( - r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}" + r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}" ) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index be51f81d..4c897d5a 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -88,6 +88,8 @@ class AssignerNode(BaseNode): await operator.remove_first() case AssignmentOperator.REMOVE_LAST: await operator.remove_last() + case AssignmentOperator.EXTEND: + await operator.extend() case _: raise ValueError(f"Invalid Operator: {assignment.operation}") logger.info(f"Node {self.node_id}: execution completed") diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index f534dfb5..5c2a6c2a 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig): description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}" ) - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="output", - type=VariableType.STRING, - description="工作流的最终输出" - ) - ], - description="输出变量定义(自动生成,通常不需要修改)" - ) + # # 输出变量定义 + # output_variables: list[VariableDefinition] = Field( + # default_factory=lambda: [ + # VariableDefinition( + # name="output", + # type=VariableType.STRING, + # description="工作流的最终输出" + # ) + # ], + # description="输出变量定义(自动生成,通常不需要修改)" + # ) class Config: json_schema_extra = { diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 6ad1c6a8..0579bdf5 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -61,6 +61,7 @@ class AssignmentOperator(StrEnum): APPEND = "append" REMOVE_LAST = "remove_last" REMOVE_FIRST = "remove_first" + EXTEND = "extend" class HttpRequestMethod(StrEnum): diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index cdb34b57..df899940 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -236,5 +236,5 @@ class HttpRequestNode(BaseNode): logger.warning( f"Node {self.node_id}: HTTP request failed, switching to error handling branch" ) - return "ERROR" + return {"output": "ERROR"} raise RuntimeError("http request failed") diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py index 5475636e..56afe004 100644 --- a/api/app/core/workflow/nodes/knowledge/config.py +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig): ) knowledge_bases: list[KnowledgeBaseConfig] = Field( - ..., + default_factory=list, description="Knowledge base config" ) diff --git a/api/app/core/workflow/nodes/start/config.py b/api/app/core/workflow/nodes/start/config.py index 98390bf7..3f795f1e 100644 --- a/api/app/core/workflow/nodes/start/config.py +++ b/api/app/core/workflow/nodes/start/config.py @@ -3,7 +3,6 @@ from pydantic import Field from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition -from app.core.workflow.variable.base_variable import VariableType class StartNodeConfig(BaseNodeConfig): @@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig): description="自定义输入变量列表,这些变量会作为 Start 节点的输出" ) - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="message", - type=VariableType.STRING, - description="用户输入的消息" - ), - VariableDefinition( - name="conversation_vars", - type=VariableType.OBJECT, - description="会话级变量" - ), - VariableDefinition( - name="execution_id", - type=VariableType.STRING, - description="执行 ID" - ), - VariableDefinition( - name="conversation_id", - type=VariableType.STRING, - description="会话 ID" - ), - VariableDefinition( - name="workspace_id", - type=VariableType.STRING, - description="工作空间 ID" - ), - VariableDefinition( - name="user_id", - type=VariableType.STRING, - description="用户 ID" - ) - ], - description="输出变量定义(自动生成,通常不需要修改)" - ) + # # 输出变量定义 + # output_variables: list[VariableDefinition] = Field( + # default_factory=lambda: [ + # VariableDefinition( + # name="message", + # type=VariableType.STRING, + # description="用户输入的消息" + # ), + # VariableDefinition( + # name="conversation_vars", + # type=VariableType.OBJECT, + # description="会话级变量" + # ), + # VariableDefinition( + # name="execution_id", + # type=VariableType.STRING, + # description="执行 ID" + # ), + # VariableDefinition( + # name="conversation_id", + # type=VariableType.STRING, + # description="会话 ID" + # ), + # VariableDefinition( + # name="workspace_id", + # type=VariableType.STRING, + # description="工作空间 ID" + # ), + # VariableDefinition( + # name="user_id", + # type=VariableType.STRING, + # description="用户 ID" + # ) + # ], + # description="输出变量定义(自动生成,通常不需要修改)" + # ) class Config: json_schema_extra = { diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 8cf81b92..eeb73a01 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -5,6 +5,8 @@ from enum import Enum, StrEnum from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator +from app.schemas.workflow_schema import WorkflowConfigCreate + # ---------- Multimodal File Support ---------- @@ -196,6 +198,8 @@ class AppCreate(BaseModel): # only for type=multi_agent multi_agent_config: Optional[Dict[str, Any]] = None + workflow_config: Optional[WorkflowConfigCreate] = None + class AppUpdate(BaseModel): name: Optional[str] = None diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index bdef825e..9e15f227 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -18,7 +18,10 @@ class NodeConfig(BaseModel): class NodeDefinition(BaseModel): """节点定义""" id: str = Field(..., description="节点唯一标识") - type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code") + type: str = Field( + ..., + description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code" + ) name: str | None = Field(None, description="节点名称") cycle: str | None = Field(None, description="父循环节点id") description: str | None = Field(None, description="节点描述") @@ -30,12 +33,12 @@ class NodeDefinition(BaseModel): class EdgeDefinition(BaseModel): """边定义""" - id: str | None = Field(None, description="边唯一标识(可选)") + id: str | None = Field(default=None, description="边唯一标识(可选)") source: str = Field(..., description="源节点 ID") target: str = Field(..., description="目标节点 ID") - type: str | None = Field(None, description="边类型: normal, error") - condition: str | None = Field(None, description="条件表达式(条件边)") - label: str | None = Field(None, description="边标签") + type: str | None = Field(default=None, description="边类型: normal, error") + condition: str | None = Field(default=None, description="条件表达式(条件边)") + label: str | None = Field(default=None, description="边标签") class VariableDefinition(BaseModel): @@ -44,7 +47,7 @@ class VariableDefinition(BaseModel): type: str = Field(default="string", description="变量类型: string, number, boolean, object, array") required: bool = Field(default=False, description="是否必填") default: Any = Field(None, description="默认值") - description: str | None = Field(None, description="变量描述") + description: str | None = Field(default=None, description="变量描述") class ExecutionConfig(BaseModel): @@ -61,6 +64,13 @@ class TriggerConfig(BaseModel): config: dict[str, Any] = Field(default_factory=dict, description="触发器配置") +class WorkflowImportSave(BaseModel): + """工作流导入请求""" + temp_id: str + name: str + description: str + + # ==================== 工作流配置 ==================== class WorkflowConfigCreate(BaseModel): @@ -84,7 +94,7 @@ class WorkflowConfigUpdate(BaseModel): class WorkflowConfig(BaseModel): """工作流配置输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID app_id: uuid.UUID nodes: list[dict[str, Any]] @@ -95,11 +105,11 @@ class WorkflowConfig(BaseModel): is_active: bool created_at: datetime.datetime updated_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -123,7 +133,8 @@ class WorkflowExecutionResponse(BaseModel): output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据") error_message: str | None = Field(None, description="错误信息") elapsed_time: float | None = Field(None, description="耗时(秒)") - token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}") + token_usage: dict[str, Any] | None = Field(None, + description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}") class WorkflowExecutionStreamChunk(BaseModel): @@ -136,7 +147,7 @@ class WorkflowExecutionStreamChunk(BaseModel): class WorkflowExecution(BaseModel): """工作流执行记录输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID workflow_config_id: uuid.UUID app_id: uuid.UUID @@ -156,15 +167,15 @@ class WorkflowExecution(BaseModel): token_usage: dict[str, Any] | None meta_data: dict[str, Any] created_at: datetime.datetime - + @field_serializer("started_at", when_used="json") def _serialize_started_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("completed_at", when_used="json") def _serialize_completed_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -173,7 +184,7 @@ class WorkflowExecution(BaseModel): class WorkflowNodeExecution(BaseModel): """工作流节点执行记录输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID execution_id: uuid.UUID node_id: str @@ -193,15 +204,15 @@ class WorkflowNodeExecution(BaseModel): cache_key: str | None meta_data: dict[str, Any] created_at: datetime.datetime - + @field_serializer("started_at", when_used="json") def _serialize_started_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("completed_at", when_used="json") def _serialize_completed_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index f3c6260a..6e6e0ecb 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -321,6 +321,26 @@ class AppService: self.db.add(agent_cfg) logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)}) + def _create_workflow_config( + self, + app_id: uuid.UUID, + data: app_schema.WorkflowConfigCreate, + now: datetime.datetime + ): + workflow_cfg = WorkflowConfig( + id=uuid.uuid4(), + app_id=app_id, + nodes=[node.model_dump() for node in data.nodes] if data.nodes else [], + edges=[edge.model_dump() for edge in data.edges] if data.edges else [], + variables=[var.model_dump() for var in data.variables] if data.variables else [], + execution_config=data.execution_config.model_dump() if data.execution_config else {}, + triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [], + is_active=True, + created_at=now, + updated_at=now + ) + self.db.add(workflow_cfg) + def _create_multi_agent_config( self, app_id: uuid.UUID, @@ -532,6 +552,9 @@ class AppService: if app.type == "multi_agent" and data.multi_agent_config: self._create_multi_agent_config(app.id, data.multi_agent_config, now) + if app.type == "workflow" and data.workflow_config: + self._create_workflow_config(app.id, data.workflow_config, now) + self.db.commit() self.db.refresh(app) @@ -968,7 +991,7 @@ class AppService: config = self.db.scalars(stmt).first() try: - config_memory=config.memory + config_memory = config.memory if 'memory_content' in config_memory: config.memory['memory_config_id'] = config.memory.pop('memory_content') except: @@ -1189,9 +1212,9 @@ class AppService: # ==================== 记忆配置提取方法 ==================== def _extract_memory_config_id( - self, - app_type: str, - config: Dict[str, Any] + self, + app_type: str, + config: Dict[str, Any] ) -> Tuple[Optional[uuid.UUID], bool]: """从发布配置中提取 memory_config_id(委托给 MemoryConfigService) @@ -1205,13 +1228,13 @@ class AppService: - is_legacy_int: 是否检测到旧格式 int 数据,需要回退到工作空间默认配置 """ from app.services.memory_config_service import MemoryConfigService - + service = MemoryConfigService(self.db) return service.extract_memory_config_id(app_type, config) def _get_workspace_default_memory_config_id( - self, - workspace_id: uuid.UUID + self, + workspace_id: uuid.UUID ) -> Optional[uuid.UUID]: """获取工作空间的默认记忆配置ID @@ -1222,22 +1245,22 @@ class AppService: Optional[uuid.UUID]: 默认记忆配置ID,如果不存在则返回 None """ from app.services.memory_config_service import MemoryConfigService - + service = MemoryConfigService(self.db) config = service.get_workspace_default_config(workspace_id) - + if not config: logger.warning( f"工作空间没有可用的记忆配置: workspace_id={workspace_id}" ) return None - + return config.config_id def _update_endusers_memory_config( - self, - app_id: uuid.UUID, - memory_config_id: uuid.UUID + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID ) -> int: """批量更新应用下所有终端用户的 memory_config_id @@ -1249,13 +1272,13 @@ class AppService: int: 更新的终端用户数量 """ from app.repositories.end_user_repository import EndUserRepository - + repo = EndUserRepository(self.db) updated_count = repo.batch_update_memory_config_id( app_id=app_id, memory_config_id=memory_config_id ) - + return updated_count # ==================== 应用发布管理 ==================== @@ -1403,7 +1426,7 @@ class AppService: # 提取记忆配置ID并更新终端用户 memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config) - + # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id) @@ -1412,7 +1435,7 @@ class AppService: f"发布时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, " f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}" ) - + if memory_config_id: updated_count = self._update_endusers_memory_config(app_id, memory_config_id) logger.info( @@ -1537,7 +1560,7 @@ class AppService: # 提取记忆配置ID并更新终端用户 memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config) - + # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id) @@ -1546,7 +1569,7 @@ class AppService: f"回滚时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, " f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}" ) - + if memory_config_id: updated_count = self._update_endusers_memory_config(app_id, memory_config_id) logger.info( diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py new file mode 100644 index 00000000..2e17f404 --- /dev/null +++ b/api/app/services/workflow_import_service.py @@ -0,0 +1,102 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:39 +import json +import uuid +from typing import Any + +from sqlalchemy.orm import Session + +from app.aioRedis import aio_redis_set, aio_redis_get +from app.core.config import settings +from app.core.exceptions import BusinessException +from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult +from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration +from app.core.workflow.adapters.registry import PlatformAdapterRegistry +from app.schemas import AppCreate +from app.schemas.workflow_schema import WorkflowConfigCreate +from app.services.app_service import AppService +from app.services.workflow_service import WorkflowService + + +class WorkflowImportService: + def __init__(self, db: Session): + self.db = db + self.registry = PlatformAdapterRegistry + self.cache_timeout = settings.WORKFLOW_IMPORT_CACHE_TIMEOUT + + self.app_service = AppService(db) + self.workflow_service = WorkflowService(db) + + async def flush_config(self, temp_id: str, config: WorkflowParserResult): + config_cache = await aio_redis_get(temp_id) + if not config_cache: + raise BusinessException("Workflow configuration has expired. Please re-upload it.") + await aio_redis_set(temp_id, config.model_dump_json(), expire=self.cache_timeout) + + async def upload_config( + self, + platform: str, + config: dict[str, Any], + ): + + if not self.registry.is_supported(platform): + return WorkflowImportResult( + success=False, + temp_id=None, + workflow_id=None, + errors=[UnsupportPlatform(platform=platform)] + ) + + adapter = self.registry.get_adapter(platform, config) + + if not adapter.validate_config(): + return WorkflowImportResult( + success=False, + temp_id=None, + workflow_id=None, + errors=[InvalidConfiguration()] + ) + + workflow_config = adapter.parse_workflow() + temp_id = uuid.uuid4().hex + await aio_redis_set(temp_id, workflow_config.model_dump(), expire=self.cache_timeout) + return WorkflowImportResult( + success=True, + temp_id=temp_id, + workflow_id=None, + edges=workflow_config.edges, + nodes=workflow_config.nodes, + variables=workflow_config.variables, + warnings=workflow_config.warnings, + errors=workflow_config.errors + ) + + async def save_workflow( + self, + user_id: uuid.UUID, + workspace_id: uuid.UUID, + temp_id: str, + name: str, + description: str | None, + ): + config = await aio_redis_get(temp_id) + if config is None: + raise BusinessException("Configuration import timed out. Please try again.") + config = json.loads(config) + app = self.app_service.create_app( + user_id=user_id, + workspace_id=workspace_id, + data=AppCreate( + name=name, + description=description, + type="workflow", + workflow_config=WorkflowConfigCreate( + nodes=config["nodes"], + edges=config["edges"], + variables=config["variables"] + ) + ) + ) + return app diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index d06a05d7..188ef6cd 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -6,13 +6,16 @@ import logging import uuid from typing import Any, Annotated, Optional +import yaml from fastapi import Depends from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException +from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.validator import validate_workflow_config from app.db import get_db +from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.repositories.workflow_repository import ( WorkflowConfigRepository, @@ -38,6 +41,8 @@ class WorkflowService: self.conversation_service = ConversationService(db) self.multimodal_service = MultimodalService(db) + self.registry = PlatformAdapterRegistry + # ==================== 配置管理 ==================== def create_workflow_config( @@ -200,6 +205,32 @@ class WorkflowService: logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}") return True + def export_workflow_dsl(self, app_id: uuid.UUID): + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + code=BizCode.NOT_FOUND, + message=f"工作流配置不存在: app_id={app_id}" + ) + + app: App = config.app + dsl_info = { + "app": { + "name": app.name, + "description": app.description, + "icon": app.icon, + "icon_type": app.icon_type + }, + "workflow": { + "variables": config.variables, + "edges": config.edges, + "nodes": config.nodes, + "execution_config": config.execution_config, + "triggers": config.triggers + } + } + return yaml.dump(dsl_info, default_flow_style=False, allow_unicode=True) + def check_config(self, app_id: uuid.UUID) -> WorkflowConfig: """检查工作流配置的完整性