Merge pull request #415 from SuanmoSuanyangTechnology/feature/workflow-adapter-dify
feat(workflow): add Dify workflow import adapter and related APIs
This commit is contained in:
@@ -10,7 +10,6 @@ from app.core.config import settings
|
|||||||
# 设置日志记录器
|
# 设置日志记录器
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# 创建连接池
|
# 创建连接池
|
||||||
pool = ConnectionPool.from_url(
|
pool = ConnectionPool.from_url(
|
||||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||||
@@ -21,6 +20,7 @@ pool = ConnectionPool.from_url(
|
|||||||
)
|
)
|
||||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||||
|
|
||||||
|
|
||||||
async def get_redis_connection():
|
async def get_redis_connection():
|
||||||
"""获取Redis连接"""
|
"""获取Redis连接"""
|
||||||
try:
|
try:
|
||||||
@@ -29,7 +29,8 @@ async def get_redis_connection():
|
|||||||
logger.error(f"Redis连接失败: {str(e)}")
|
logger.error(f"Redis连接失败: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
|
||||||
|
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||||
"""设置Redis键值
|
"""设置Redis键值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Redis set错误: {str(e)}")
|
logger.error(f"Redis set错误: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def aio_redis_get(key: str):
|
async def aio_redis_get(key: str):
|
||||||
"""获取Redis键值"""
|
"""获取Redis键值"""
|
||||||
try:
|
try:
|
||||||
@@ -58,6 +60,7 @@ async def aio_redis_get(key: str):
|
|||||||
logger.error(f"Redis get错误: {str(e)}")
|
logger.error(f"Redis get错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def aio_redis_delete(key: str):
|
async def aio_redis_delete(key: str):
|
||||||
"""删除Redis键"""
|
"""删除Redis键"""
|
||||||
try:
|
try:
|
||||||
@@ -66,6 +69,7 @@ async def aio_redis_delete(key: str):
|
|||||||
logger.error(f"Redis delete错误: {str(e)}")
|
logger.error(f"Redis delete错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||||
"""发布消息到Redis频道"""
|
"""发布消息到Redis频道"""
|
||||||
try:
|
try:
|
||||||
@@ -78,6 +82,7 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
|||||||
logger.error(f"Redis发布错误: {str(e)}")
|
logger.error(f"Redis发布错误: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class RedisSubscriber:
|
class RedisSubscriber:
|
||||||
"""Redis订阅器"""
|
"""Redis订阅器"""
|
||||||
|
|
||||||
@@ -163,6 +168,7 @@ class RedisSubscriber:
|
|||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
await self._cleanup()
|
await self._cleanup()
|
||||||
|
|
||||||
|
|
||||||
class RedisPubSubManager:
|
class RedisPubSubManager:
|
||||||
"""Redis发布订阅管理器"""
|
"""Redis发布订阅管理器"""
|
||||||
|
|
||||||
@@ -196,6 +202,6 @@ class RedisPubSubManager:
|
|||||||
self.subscribers.clear()
|
self.subscribers.clear()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
# 全局实例
|
||||||
pubsub_manager = RedisPubSubManager()
|
pubsub_manager = RedisPubSubManager()
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Annotated
|
from typing import Optional, Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Path
|
import yaml
|
||||||
|
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -17,12 +18,13 @@ from app.repositories.end_user_repository import EndUserRepository
|
|||||||
from app.schemas import app_schema
|
from app.schemas import app_schema
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
|
||||||
from app.services import app_service, workspace_service
|
from app.services import app_service, workspace_service
|
||||||
from app.services.agent_config_helper import enrich_agent_config
|
from app.services.agent_config_helper import enrich_agent_config
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
|
||||||
from app.services.app_statistics_service import AppStatisticsService
|
from app.services.app_statistics_service import AppStatisticsService
|
||||||
|
from app.services.workflow_import_service import WorkflowImportService
|
||||||
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -65,7 +67,7 @@ def list_apps(
|
|||||||
|
|
||||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||||
return success(data=items)
|
return success(data=items)
|
||||||
@@ -879,6 +881,60 @@ async def update_workflow_config(
|
|||||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/workflow/export")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def export_workflow_config(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
"""导出工作流配置为YAML文件"""
|
||||||
|
workflow_service = WorkflowService(db)
|
||||||
|
|
||||||
|
return success(data={
|
||||||
|
"content": workflow_service.export_workflow_dsl(app_id=app_id),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/workflow/import")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def import_workflow_config(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
platform: str = Form(...),
|
||||||
|
app_id: str = Form(None),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
|
||||||
|
):
|
||||||
|
"""从YAML内容导入工作流配置"""
|
||||||
|
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||||
|
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
raw_text = (await file.read()).decode("utf-8")
|
||||||
|
import_service = WorkflowImportService(db)
|
||||||
|
config = yaml.safe_load(raw_text)
|
||||||
|
result = await import_service.upload_config(platform, config)
|
||||||
|
return success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/workflow/import/save")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def save_workflow_import(
|
||||||
|
data: WorkflowImportSave,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
import_service = WorkflowImportService(db)
|
||||||
|
app = await import_service.save_workflow(
|
||||||
|
user_id=current_user.id,
|
||||||
|
workspace_id=current_user.current_workspace_id,
|
||||||
|
temp_id=data.temp_id,
|
||||||
|
name=data.name,
|
||||||
|
description=data.description,
|
||||||
|
)
|
||||||
|
return success(data=app_schema.App.model_validate(app))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_app_statistics(
|
def get_app_statistics(
|
||||||
@@ -894,6 +950,8 @@ def get_app_statistics(
|
|||||||
app_id: 应用ID
|
app_id: 应用ID
|
||||||
start_date: 开始时间戳(毫秒)
|
start_date: 开始时间戳(毫秒)
|
||||||
end_date: 结束时间戳(毫秒)
|
end_date: 结束时间戳(毫秒)
|
||||||
|
db: 数据库连接
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- daily_conversations: 每日会话数统计
|
- daily_conversations: 每日会话数统计
|
||||||
@@ -931,6 +989,8 @@ def get_workspace_api_statistics(
|
|||||||
Args:
|
Args:
|
||||||
start_date: 开始时间戳(毫秒)
|
start_date: 开始时间戳(毫秒)
|
||||||
end_date: 结束时间戳(毫秒)
|
end_date: 结束时间戳(毫秒)
|
||||||
|
db: 数据库连接
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
每日统计数据列表,每项包含:
|
每日统计数据列表,每项包含:
|
||||||
|
|||||||
@@ -58,7 +58,6 @@ class Settings:
|
|||||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||||
|
|
||||||
|
|
||||||
# ElasticSearch configuration
|
# ElasticSearch configuration
|
||||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||||
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||||
@@ -130,7 +129,7 @@ class Settings:
|
|||||||
|
|
||||||
# Server Configuration
|
# Server Configuration
|
||||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# Internal Configuration (not in .env, used by application code)
|
# Internal Configuration (not in .env, used by application code)
|
||||||
@@ -225,6 +224,7 @@ class Settings:
|
|||||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||||
|
|
||||||
# workflow config
|
# workflow config
|
||||||
|
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
|
||||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
|
|||||||
8
api/app/core/workflow/adapters/__init__.py
Normal file
8
api/app/core/workflow/adapters/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/24 15:54
|
||||||
|
from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter
|
||||||
|
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||||
|
|
||||||
|
__all__ = ["DifyAdapter", "MemoryBearAdapter"]
|
||||||
88
api/app/core/workflow/adapters/base_adapter.py
Normal file
88
api/app/core/workflow/adapters/base_adapter.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/24 15:58
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.core.workflow.adapters.errors import ExceptionDefineition
|
||||||
|
from app.schemas.workflow_schema import (
|
||||||
|
EdgeDefinition,
|
||||||
|
NodeDefinition,
|
||||||
|
VariableDefinition,
|
||||||
|
ExecutionConfig,
|
||||||
|
TriggerConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlatformType(StrEnum):
|
||||||
|
MEMORY_BEAR = "memory_bear"
|
||||||
|
DIFY = "dify"
|
||||||
|
COZE = "coze"
|
||||||
|
|
||||||
|
|
||||||
|
class PlatformMetadata(BaseModel):
|
||||||
|
platform_name: str
|
||||||
|
version: str
|
||||||
|
support_node_types: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowParserResult(BaseModel):
|
||||||
|
success: bool
|
||||||
|
platform: PlatformMetadata
|
||||||
|
execution_config: ExecutionConfig
|
||||||
|
origin_config: dict[str, Any]
|
||||||
|
trigger: TriggerConfig | None
|
||||||
|
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||||
|
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||||
|
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||||
|
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||||
|
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowImportResult(BaseModel):
|
||||||
|
success: bool
|
||||||
|
temp_id: str | None = Field(..., description="cache id")
|
||||||
|
workflow_id: str | None = Field(..., description="workflow id")
|
||||||
|
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||||
|
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||||
|
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||||
|
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||||
|
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlatformAdapter(ABC):
|
||||||
|
def __init__(self, config: dict[str, Any]):
|
||||||
|
self.config = config
|
||||||
|
self.nodes: list[NodeDefinition] = []
|
||||||
|
self.edges: list[EdgeDefinition] = []
|
||||||
|
self.conv_variables: list[VariableDefinition] = []
|
||||||
|
|
||||||
|
self.errors = []
|
||||||
|
self.warnings = []
|
||||||
|
|
||||||
|
self.branch_node_cache = defaultdict(list)
|
||||||
|
self.error_branch_node_cache = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self) -> PlatformMetadata:
|
||||||
|
"""get platform metadata"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_config(self) -> bool:
|
||||||
|
"""platform configuration validate"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_workflow(self) -> WorkflowParserResult:
|
||||||
|
"""parse platform configuration to local config"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def map_node_type(self, platform_node_type: str) -> str:
|
||||||
|
pass
|
||||||
75
api/app/core/workflow/adapters/base_converter.py
Normal file
75
api/app/core/workflow/adapters/base_converter.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/26 14:32
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConverter(ABC):
|
||||||
|
@staticmethod
|
||||||
|
def _convert_string(var):
|
||||||
|
try:
|
||||||
|
return str(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.STRING)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_boolean(var):
|
||||||
|
try:
|
||||||
|
return bool(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.BOOLEAN)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_number(var):
|
||||||
|
try:
|
||||||
|
return float(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.NUMBER)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_object(var):
|
||||||
|
try:
|
||||||
|
return dict(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.OBJECT)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def _convert_file(var):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_array_string(var):
|
||||||
|
try:
|
||||||
|
return list(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.ARRAY_STRING)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_array_number(var):
|
||||||
|
try:
|
||||||
|
return list(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.ARRAY_NUMBER)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_array_boolean(var):
|
||||||
|
try:
|
||||||
|
return list(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_array_object(var):
|
||||||
|
try:
|
||||||
|
return list(var)
|
||||||
|
except:
|
||||||
|
return DEFAULT_VALUE(VariableType.ARRAY_OBJECT)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def _convert_array_file(var):
|
||||||
|
pass
|
||||||
4
api/app/core/workflow/adapters/dify/__init__.py
Normal file
4
api/app/core/workflow/adapters/dify/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/25 18:20
|
||||||
659
api/app/core/workflow/adapters/dify/converter.py
Normal file
659
api/app/core/workflow/adapters/dify/converter.py
Normal file
@@ -0,0 +1,659 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/25 18:21
|
||||||
|
import base64
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
|
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
||||||
|
ExceptionType
|
||||||
|
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
||||||
|
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||||
|
from app.core.workflow.nodes.base_config import VariableDefinition
|
||||||
|
from app.core.workflow.nodes.code import CodeNodeConfig
|
||||||
|
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||||
|
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
||||||
|
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||||
|
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
||||||
|
CycleVariable
|
||||||
|
from app.core.workflow.nodes.end import EndNodeConfig
|
||||||
|
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
||||||
|
HttpContentType, HttpErrorHandle
|
||||||
|
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
||||||
|
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
||||||
|
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
||||||
|
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||||
|
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||||
|
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
||||||
|
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||||
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||||
|
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||||
|
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
||||||
|
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||||
|
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
||||||
|
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||||
|
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
||||||
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
|
|
||||||
|
|
||||||
|
class DifyConverter(BaseConverter):
|
||||||
|
errors: list
|
||||||
|
warnings: list
|
||||||
|
branch_node_cache: dict
|
||||||
|
error_branch_node_cache: list
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.CONFIG_CONVERT_MAP = {
|
||||||
|
"start": self.convert_start_node_config,
|
||||||
|
"llm": self.convert_llm_node_config,
|
||||||
|
"answer": self.convert_end_node_config,
|
||||||
|
"if-else": self.convert_if_else_node_config,
|
||||||
|
"loop": self.convert_loop_node_config,
|
||||||
|
"iteration": self.convert_iteration_node_config,
|
||||||
|
"assigner": self.convert_assigner_node_config,
|
||||||
|
"code": self.convert_code_node_config,
|
||||||
|
"http-request": self.convert_http_node_config,
|
||||||
|
"template-transform": self.convert_jinja_render_node_config,
|
||||||
|
"knowledge-retrieval": self.convert_knowledge_node_config,
|
||||||
|
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
||||||
|
"question-classifier": self.convert_question_classifier_node_config,
|
||||||
|
"variable-aggregator": self.convert_variable_aggregator,
|
||||||
|
"loop-start": lambda x: {},
|
||||||
|
"iteration-start": lambda x: {},
|
||||||
|
"loop-end": lambda x: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_node_convert(self, node_type):
|
||||||
|
func = self.CONFIG_CONVERT_MAP.get(node_type, None)
|
||||||
|
return func
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_variable(expression) -> bool:
|
||||||
|
return bool(re.match(r"\{\{#(.*?)#}}", expression))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_var_selector(var_selector):
|
||||||
|
if not var_selector:
|
||||||
|
return ""
|
||||||
|
selector = var_selector.split('.')
|
||||||
|
if len(selector) != 2:
|
||||||
|
raise Exception(f"invalid variable selector: {var_selector}")
|
||||||
|
if selector[0] == "conversation":
|
||||||
|
selector[0] = "conv"
|
||||||
|
var_selector = ".".join(selector)
|
||||||
|
mapping = {
|
||||||
|
"sys.query": "sys.message"
|
||||||
|
}
|
||||||
|
|
||||||
|
var_selector = mapping.get(var_selector, var_selector)
|
||||||
|
return var_selector
|
||||||
|
|
||||||
|
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
|
||||||
|
if not self.process_var_selector(".".join(variable_selector)):
|
||||||
|
return None
|
||||||
|
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
||||||
|
|
||||||
|
def trans_variable_format(self, content):
|
||||||
|
pattern = re.compile(r"\{\{#(.*?)#}}")
|
||||||
|
|
||||||
|
def replacer(match: re.Match) -> str:
|
||||||
|
raw_name = match.group(1)
|
||||||
|
new_name = self.process_var_selector(raw_name)
|
||||||
|
return f"{{{{{new_name}}}}}"
|
||||||
|
|
||||||
|
return pattern.sub(replacer, content)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_file(var):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_array_file(var):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def variable_type_map(source_type) -> VariableType | None:
|
||||||
|
type_map = {
|
||||||
|
"file": VariableType.FILE,
|
||||||
|
"paragraph": VariableType.STRING,
|
||||||
|
"text-input": VariableType.STRING,
|
||||||
|
"number": VariableType.NUMBER,
|
||||||
|
"checkbox": VariableType.BOOLEAN,
|
||||||
|
"file-list": VariableType.ARRAY_FILE,
|
||||||
|
"select": VariableType.STRING,
|
||||||
|
}
|
||||||
|
var_type = type_map.get(source_type, source_type)
|
||||||
|
return var_type
|
||||||
|
|
||||||
|
def convert_variable_type(self, target_type: VariableType, origin_value: Any):
|
||||||
|
if not origin_value:
|
||||||
|
return DEFAULT_VALUE(target_type)
|
||||||
|
try:
|
||||||
|
match target_type:
|
||||||
|
case VariableType.STRING:
|
||||||
|
return self._convert_string(origin_value)
|
||||||
|
case VariableType.NUMBER:
|
||||||
|
return self._convert_number(origin_value)
|
||||||
|
case VariableType.BOOLEAN:
|
||||||
|
return self._convert_boolean(origin_value)
|
||||||
|
case VariableType.FILE:
|
||||||
|
return self._convert_file(origin_value)
|
||||||
|
case VariableType.ARRAY_FILE:
|
||||||
|
return self._convert_array_file(origin_value)
|
||||||
|
case _:
|
||||||
|
return origin_value
|
||||||
|
except:
|
||||||
|
raise Exception(f"convert variable failed: {target_type}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_compare_operator(operator):
|
||||||
|
operator_map = {
|
||||||
|
"is": ComparisonOperator.EQ,
|
||||||
|
"is not": ComparisonOperator.NE,
|
||||||
|
"=": ComparisonOperator.EQ,
|
||||||
|
"≠": ComparisonOperator.NE,
|
||||||
|
">": ComparisonOperator.GT,
|
||||||
|
"<": ComparisonOperator.LT,
|
||||||
|
"≥": ComparisonOperator.GE,
|
||||||
|
"≤": ComparisonOperator.LE,
|
||||||
|
"not empty": ComparisonOperator.NOT_EMPTY,
|
||||||
|
}
|
||||||
|
return operator_map.get(operator, operator)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_assignment_operator(operator):
|
||||||
|
operator_map = {
|
||||||
|
"+=": AssignmentOperator.ADD,
|
||||||
|
"-=": AssignmentOperator.SUBTRACT,
|
||||||
|
"*=": AssignmentOperator.MULTIPLY,
|
||||||
|
"/=": AssignmentOperator.DIVIDE,
|
||||||
|
"over-write": AssignmentOperator.COVER,
|
||||||
|
"remove-last": AssignmentOperator.REMOVE_LAST,
|
||||||
|
"remove-first": AssignmentOperator.REMOVE_FIRST,
|
||||||
|
|
||||||
|
}
|
||||||
|
return operator_map.get(operator, operator)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_http_auth_type(auth_type):
|
||||||
|
auth_type_map = {
|
||||||
|
"no-auth": HttpAuthType.NONE,
|
||||||
|
"bearer": HttpAuthType.BEARER,
|
||||||
|
"basic": HttpAuthType.BASIC,
|
||||||
|
"custom": HttpAuthType.CUSTOM,
|
||||||
|
}
|
||||||
|
return auth_type_map.get(auth_type, auth_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_http_content_type(content_type):
|
||||||
|
content_type_map = {
|
||||||
|
"none": HttpContentType.NONE,
|
||||||
|
"form-data": HttpContentType.FROM_DATA,
|
||||||
|
"x-www-form-urlencoded": HttpContentType.WWW_FORM,
|
||||||
|
"json": HttpContentType.JSON,
|
||||||
|
"raw-text": HttpContentType.RAW,
|
||||||
|
"binary": HttpContentType.BINARY,
|
||||||
|
}
|
||||||
|
return content_type_map.get(content_type, content_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_http_error_handle_type(handle_type):
|
||||||
|
handle_type_map = {
|
||||||
|
"none": HttpErrorHandle.NONE,
|
||||||
|
"fail-branch": HttpErrorHandle.BRANCH,
|
||||||
|
"default-value": HttpErrorHandle.DEFAULT,
|
||||||
|
}
|
||||||
|
return handle_type_map.get(handle_type, handle_type)
|
||||||
|
|
||||||
|
def convert_start_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
start_vars = []
|
||||||
|
for var in node_data["variables"]:
|
||||||
|
var_type = self.variable_type_map(var["type"])
|
||||||
|
if not var_type:
|
||||||
|
self.errors.append(
|
||||||
|
UnsupportVariableType(
|
||||||
|
scope=node["id"],
|
||||||
|
name=var["variable"],
|
||||||
|
var_type=var["type"],
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if var_type in ["file", "array[file]"]:
|
||||||
|
self.errors.append(
|
||||||
|
ExceptionDefineition(
|
||||||
|
type=ExceptionType.VARIABLE,
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"],
|
||||||
|
name=var["variable"],
|
||||||
|
detail=f"Unsupport Variable type for start node: {var_type}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
var_def = VariableDefinition(
|
||||||
|
name=var["variable"],
|
||||||
|
type=var_type,
|
||||||
|
required=var["required"],
|
||||||
|
default=self.convert_variable_type(
|
||||||
|
var_type, var["default"]
|
||||||
|
),
|
||||||
|
description=var["label"],
|
||||||
|
max_length=var.get("max_length"),
|
||||||
|
)
|
||||||
|
start_vars.append(var_def)
|
||||||
|
return StartNodeConfig(
|
||||||
|
variables=start_vars
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
self.warnings.append(
|
||||||
|
UnknowModelWarning(
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"],
|
||||||
|
model_name=node_data["model"].get("name")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
categories = []
|
||||||
|
for category in node_data["classes"]:
|
||||||
|
self.branch_node_cache[node["id"]].append(category["id"])
|
||||||
|
categories.append(
|
||||||
|
ClassifierConfig(
|
||||||
|
class_name=category["name"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return QuestionClassifierNodeConfig.model_construct(
|
||||||
|
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||||
|
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]),
|
||||||
|
categories=categories,
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_llm_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
self.warnings.append(
|
||||||
|
UnknowModelWarning(
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"],
|
||||||
|
model_name=node_data["model"].get("name")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
|
||||||
|
memory = MemoryWindowSetting(
|
||||||
|
enable=bool(node_data.get("memory")),
|
||||||
|
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
||||||
|
window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20))
|
||||||
|
)
|
||||||
|
messages = []
|
||||||
|
for message in node_data["prompt_template"]:
|
||||||
|
messages.append(
|
||||||
|
MessageConfig(
|
||||||
|
role=message["role"],
|
||||||
|
content=self.trans_variable_format(message["text"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if memory.enable:
|
||||||
|
messages.append(
|
||||||
|
MessageConfig(
|
||||||
|
role="user",
|
||||||
|
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
vision = node_data["vision"]["enabled"]
|
||||||
|
vision_input = self._process_list_variable_litearl(
|
||||||
|
node_data["vision"]["configs"]["variable_selector"]
|
||||||
|
) if vision else None
|
||||||
|
return LLMNodeConfig.model_construct(
|
||||||
|
model_id=None,
|
||||||
|
context=context,
|
||||||
|
memory=memory,
|
||||||
|
vision=vision,
|
||||||
|
vision_input=vision_input,
|
||||||
|
messages=messages
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_end_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
return EndNodeConfig(
|
||||||
|
output=self.trans_variable_format(node_data["answer"]),
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_if_else_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
cases = []
|
||||||
|
for case in node_data["cases"]:
|
||||||
|
case_id = case["id"]
|
||||||
|
logical_operator = case["logical_operator"]
|
||||||
|
conditions = []
|
||||||
|
for condition in case["conditions"]:
|
||||||
|
right_value = condition["value"]
|
||||||
|
condition_detail = ConditionDetail(
|
||||||
|
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||||
|
left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}",
|
||||||
|
right=self.trans_variable_format(
|
||||||
|
right_value
|
||||||
|
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||||
|
self.variable_type_map(condition["varType"]),
|
||||||
|
condition["value"]
|
||||||
|
),
|
||||||
|
input_type=ValueInputType.VARIABLE
|
||||||
|
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
|
||||||
|
)
|
||||||
|
conditions.append(condition_detail)
|
||||||
|
cases.append(
|
||||||
|
ConditionBranchConfig(
|
||||||
|
logical_operator=logical_operator,
|
||||||
|
expressions=conditions
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.branch_node_cache[node["id"]].append(case_id)
|
||||||
|
return IfElseNodeConfig(
|
||||||
|
cases=cases
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_loop_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
logical_operator = node_data["logical_operator"]
|
||||||
|
conditions = []
|
||||||
|
for condition in node_data["break_conditions"]:
|
||||||
|
right_value = condition["value"]
|
||||||
|
conditions.append(
|
||||||
|
LoopConditionDetail(
|
||||||
|
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||||
|
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
||||||
|
right=self.trans_variable_format(
|
||||||
|
right_value
|
||||||
|
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||||
|
self.variable_type_map(condition["varType"]),
|
||||||
|
condition["value"]
|
||||||
|
),
|
||||||
|
input_type=ValueInputType.VARIABLE
|
||||||
|
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
condition_config = ConditionsConfig(
|
||||||
|
logical_operator=logical_operator,
|
||||||
|
expressions=conditions
|
||||||
|
)
|
||||||
|
loop_variables = []
|
||||||
|
for variable in node_data["loop_variables"]:
|
||||||
|
right_input_type = variable["value_type"]
|
||||||
|
right_value_type = self.variable_type_map(variable["var_type"])
|
||||||
|
if right_input_type == ValueInputType.VARIABLE:
|
||||||
|
right_value = self._process_list_variable_litearl(variable["value"])
|
||||||
|
else:
|
||||||
|
right_value = self.convert_variable_type(right_value_type, variable["value"])
|
||||||
|
loop_variables.append(
|
||||||
|
CycleVariable(
|
||||||
|
name=variable["label"],
|
||||||
|
type=right_value_type,
|
||||||
|
value=right_value,
|
||||||
|
input_type=right_input_type
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return LoopNodeConfig(
|
||||||
|
condition=condition_config,
|
||||||
|
cycle_vars=loop_variables,
|
||||||
|
max_loop=node_data["loop_count"]
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
return IterationNodeConfig(
|
||||||
|
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
||||||
|
parallel=node_data["is_parallel"],
|
||||||
|
parallel_count=node_data["parallel_nums"],
|
||||||
|
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
||||||
|
output_type=self.variable_type_map(node_data["output_type"]),
|
||||||
|
flatten=node_data["flatten_output"],
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_assigner_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
assignments = []
|
||||||
|
for assignment in node_data["items"]:
|
||||||
|
if assignment.get("operation") is None or assignment.get("value") is None:
|
||||||
|
continue
|
||||||
|
assignments.append(
|
||||||
|
AssignmentItem(
|
||||||
|
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
|
||||||
|
value=self._process_list_variable_litearl(
|
||||||
|
assignment["value"]
|
||||||
|
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
||||||
|
operation=self.convert_assignment_operator(assignment["operation"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return AssignerNodeConfig(
|
||||||
|
assignments=assignments
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_code_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
input_variables = []
|
||||||
|
for input_variable in node_data["variables"]:
|
||||||
|
input_variables.append(
|
||||||
|
InputVariable(
|
||||||
|
name=input_variable["variable"],
|
||||||
|
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_variables = []
|
||||||
|
for output_variable in node_data["outputs"]:
|
||||||
|
output_variables.append(
|
||||||
|
OutputVariable(
|
||||||
|
name=output_variable,
|
||||||
|
type=node_data["outputs"][output_variable]["type"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
|
return CodeNodeConfig(
|
||||||
|
input_variables=input_variables,
|
||||||
|
language=node_data["code_language"],
|
||||||
|
output_variables=output_variables,
|
||||||
|
code=code
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_http_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
if node_data["authorization"] != 'no-auth':
|
||||||
|
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
|
||||||
|
auth_config = HttpAuthConfig(
|
||||||
|
auth_type=auth_type,
|
||||||
|
header=node_data["authorization"]["config"].get("header"),
|
||||||
|
api_key=node_data["authorization"]["config"].get("api_key"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
auth_config = HttpAuthConfig()
|
||||||
|
|
||||||
|
content_type = self.convert_http_content_type(node_data["body"]["type"])
|
||||||
|
if content_type == HttpContentType.FROM_DATA:
|
||||||
|
body_content = []
|
||||||
|
for content in node_data["body"]["data"]:
|
||||||
|
body_content.append(
|
||||||
|
HttpFormData(
|
||||||
|
key=self.trans_variable_format(content["key"]),
|
||||||
|
type=content["type"],
|
||||||
|
value=self.trans_variable_format(content["value"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif content_type == HttpContentType.WWW_FORM:
|
||||||
|
body_content = {}
|
||||||
|
for content in node_data["body"]["data"]:
|
||||||
|
body_content[
|
||||||
|
self.trans_variable_format(content["key"])
|
||||||
|
] = self.trans_variable_format(content["value"])
|
||||||
|
else:
|
||||||
|
if node_data["body"]["data"]:
|
||||||
|
body_content = node_data["body"]["data"][0]["value"]
|
||||||
|
else:
|
||||||
|
body_content = ""
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
for header in node_data["headers"].split("\n"):
|
||||||
|
if not header:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key_value = header.split(":")
|
||||||
|
if len(key_value) == 2:
|
||||||
|
headers[
|
||||||
|
self.trans_variable_format(key_value[0])
|
||||||
|
] = self.trans_variable_format(key_value[1])
|
||||||
|
else:
|
||||||
|
self.warnings.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.CONFIG,
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"],
|
||||||
|
detail=f"Invalid header/param - {header}",
|
||||||
|
))
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
for param in node_data["params"].split("\n"):
|
||||||
|
if not param:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key_value = param.split(":")
|
||||||
|
if len(key_value) == 2:
|
||||||
|
params[
|
||||||
|
self.trans_variable_format(key_value[0])
|
||||||
|
] = self.trans_variable_format(key_value[1])
|
||||||
|
else:
|
||||||
|
self.warnings.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.CONFIG,
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"],
|
||||||
|
detail=f"Invalid header/param - {param}",
|
||||||
|
))
|
||||||
|
|
||||||
|
error_handle_type = self.convert_http_error_handle_type(
|
||||||
|
node_data.get("error_strategy", "none")
|
||||||
|
)
|
||||||
|
default_value = None
|
||||||
|
if error_handle_type == HttpErrorHandle.DEFAULT:
|
||||||
|
default_body = ""
|
||||||
|
default_header = {}
|
||||||
|
default_status_code = 0
|
||||||
|
for var in node_data["default_value"]:
|
||||||
|
if var["key"] == "body":
|
||||||
|
default_body = var["value"]
|
||||||
|
elif var["key"] == "header":
|
||||||
|
default_header = var["value"]
|
||||||
|
elif var["key"] == "status_code":
|
||||||
|
default_status_code = var["value"]
|
||||||
|
default_value = HttpErrorDefaultTamplete(
|
||||||
|
body=default_body,
|
||||||
|
headers=default_header,
|
||||||
|
status_code=default_status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.error_branch_node_cache.append(node['id'])
|
||||||
|
return HttpRequestNodeConfig(
|
||||||
|
method=node_data["method"].upper(),
|
||||||
|
url=node_data["url"],
|
||||||
|
auth=auth_config,
|
||||||
|
body=HttpContentTypeConfig(
|
||||||
|
content_type=self.convert_http_content_type(node_data["body"]["type"]),
|
||||||
|
data=body_content,
|
||||||
|
),
|
||||||
|
headers=headers,
|
||||||
|
params=params,
|
||||||
|
verify_ssl=node_data["ssl_verify"],
|
||||||
|
timeouts=HttpTimeOutConfig(
|
||||||
|
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
|
||||||
|
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
|
||||||
|
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
|
||||||
|
),
|
||||||
|
retry=HttpRetryConfig(
|
||||||
|
enable=node_data["retry_config"]["retry_enabled"],
|
||||||
|
max_attempts=node_data["retry_config"]["max_retries"],
|
||||||
|
retry_interval=node_data["retry_config"]["retry_interval"],
|
||||||
|
),
|
||||||
|
error_handle=HttpErrorHandleConfig(
|
||||||
|
method=error_handle_type,
|
||||||
|
default=default_value,
|
||||||
|
)
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_jinja_render_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
mapping = []
|
||||||
|
for variable in node_data["variables"]:
|
||||||
|
mapping.append(VariablesMappingConfig(
|
||||||
|
name=variable["variable"],
|
||||||
|
value=self._process_list_variable_litearl(variable["value_selector"])
|
||||||
|
))
|
||||||
|
return JinjaRenderNodeConfig(
|
||||||
|
template=node_data["template"],
|
||||||
|
mapping=mapping,
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
self.warnings.append(ExceptionDefineition(
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"],
|
||||||
|
type=ExceptionType.CONFIG,
|
||||||
|
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||||
|
))
|
||||||
|
return KnowledgeRetrievalNodeConfig.model_construct(
|
||||||
|
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
self.warnings.append(
|
||||||
|
UnknowModelWarning(
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node_data["title"],
|
||||||
|
model_name=node_data["model"].get("name")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
params = []
|
||||||
|
for param in node_data["parameters"]:
|
||||||
|
params.append(
|
||||||
|
ParamsConfig(
|
||||||
|
name=param["name"],
|
||||||
|
desc=param["description"],
|
||||||
|
required=param["required"],
|
||||||
|
type=param["type"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ParameterExtractorNodeConfig.model_construct(
|
||||||
|
text=self._process_list_variable_litearl(node_data["query"]),
|
||||||
|
params=params,
|
||||||
|
prompt=node_data["instruction"]
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
|
def convert_variable_aggregator(self, node: dict) -> dict:
|
||||||
|
node_data = node["data"]
|
||||||
|
group_enable = node_data["advanced_settings"]["group_enabled"]
|
||||||
|
group_variables = {}
|
||||||
|
group_type = {}
|
||||||
|
if not group_enable:
|
||||||
|
group_variables["output"] = [
|
||||||
|
self._process_list_variable_litearl(variable)
|
||||||
|
for variable in node_data["variables"]
|
||||||
|
]
|
||||||
|
group_type["output"] = node_data["output_type"]
|
||||||
|
else:
|
||||||
|
for group in node_data["advanced_settings"]["groups"]:
|
||||||
|
group_variables[group["group_name"]] = [
|
||||||
|
self._process_list_variable_litearl(variable)
|
||||||
|
for variable in group["variables"]
|
||||||
|
]
|
||||||
|
group_type[group["group_name"]] = group["output_type"]
|
||||||
|
|
||||||
|
return VariableAggregatorNodeConfig(
|
||||||
|
group=group_enable,
|
||||||
|
group_variables=group_variables,
|
||||||
|
group_type=group_type,
|
||||||
|
).model_dump()
|
||||||
239
api/app/core/workflow/adapters/dify/dify_adapter.py
Normal file
239
api/app/core/workflow/adapters/dify/dify_adapter.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/24 16:05
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
from app.core.workflow.adapters.base_adapter import (
|
||||||
|
BasePlatformAdapter,
|
||||||
|
PlatformMetadata,
|
||||||
|
PlatformType,
|
||||||
|
WorkflowParserResult
|
||||||
|
)
|
||||||
|
from app.core.workflow.adapters.dify.converter import DifyConverter
|
||||||
|
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
from app.schemas.workflow_schema import (
|
||||||
|
NodeDefinition,
|
||||||
|
EdgeDefinition,
|
||||||
|
VariableDefinition,
|
||||||
|
TriggerConfig,
|
||||||
|
ExecutionConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||||
|
NODE_TYPE_MAPPING = {
|
||||||
|
"start": NodeType.START,
|
||||||
|
"llm": NodeType.LLM,
|
||||||
|
"answer": NodeType.END,
|
||||||
|
"if-else": NodeType.IF_ELSE,
|
||||||
|
"loop-start": NodeType.CYCLE_START,
|
||||||
|
"iteration-start": NodeType.CYCLE_START,
|
||||||
|
"assigner": NodeType.ASSIGNER,
|
||||||
|
"loop": NodeType.LOOP,
|
||||||
|
"iteration": NodeType.ITERATION,
|
||||||
|
"loop-end": NodeType.BREAK,
|
||||||
|
"code": NodeType.CODE,
|
||||||
|
"http-request": NodeType.HTTP_REQUEST,
|
||||||
|
"template-transform": NodeType.JINJARENDER,
|
||||||
|
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||||
|
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||||
|
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||||
|
"variable-aggregator": NodeType.VAR_AGGREGATOR
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict[str, Any]):
|
||||||
|
DifyConverter.__init__(self)
|
||||||
|
BasePlatformAdapter.__init__(self, config)
|
||||||
|
|
||||||
|
def get_metadata(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
platform_name=PlatformType.DIFY,
|
||||||
|
version="0.5.0",
|
||||||
|
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
def map_node_type(self, platform_node_type) -> str:
|
||||||
|
return self.NODE_TYPE_MAPPING.get(platform_node_type)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_nodes(self):
|
||||||
|
return self.config.get("workflow").get("graph").get("nodes")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_edges(self):
|
||||||
|
return self.config.get("workflow").get("graph").get("edges")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _valid_nodes(node: dict[str, Any]):
|
||||||
|
if "data" not in node:
|
||||||
|
return False
|
||||||
|
if "type" not in node["data"]:
|
||||||
|
return False
|
||||||
|
if "id" not in node or "type" not in node:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_config(self) -> bool:
|
||||||
|
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'})
|
||||||
|
if not all(field in self.config for field in require_fields):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for node in self.origin_nodes:
|
||||||
|
if not self._valid_nodes(node):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def parse_workflow(self) -> WorkflowParserResult:
|
||||||
|
for node in self.origin_nodes:
|
||||||
|
node = self._convert_node(node)
|
||||||
|
if node:
|
||||||
|
self.nodes.append(node)
|
||||||
|
nodes_id = [node.id for node in self.nodes]
|
||||||
|
for edge in self.origin_edges:
|
||||||
|
source = edge["source"]
|
||||||
|
target = edge["target"]
|
||||||
|
if source not in nodes_id or target not in nodes_id:
|
||||||
|
continue
|
||||||
|
edge = self._convert_edge(edge)
|
||||||
|
if edge:
|
||||||
|
self.edges.append(edge)
|
||||||
|
#
|
||||||
|
for variable in self.config.get("workflow").get("conversation_variables"):
|
||||||
|
con_var = self._convert_variable(variable)
|
||||||
|
if variable:
|
||||||
|
self.conv_variables.append(con_var)
|
||||||
|
#
|
||||||
|
# for variables in config.get("workflow").get("environment_variables"):
|
||||||
|
# variable = self._convert_variable(variables)
|
||||||
|
# conv_variables.append(variable)
|
||||||
|
|
||||||
|
trigger = self._convert_trigger({})
|
||||||
|
execution_config = self._convert_execution({})
|
||||||
|
|
||||||
|
return WorkflowParserResult(
|
||||||
|
success=not self.errors and not self.warnings,
|
||||||
|
platform=self.get_metadata(),
|
||||||
|
execution_config=execution_config,
|
||||||
|
origin_config=self.config,
|
||||||
|
trigger=trigger,
|
||||||
|
edges=self.edges,
|
||||||
|
nodes=self.nodes,
|
||||||
|
variables=self.conv_variables,
|
||||||
|
warnings=self.warnings,
|
||||||
|
errors=self.errors
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_cycle_node_position(self, node_id: str, position: dict):
|
||||||
|
for node in self.origin_nodes:
|
||||||
|
if node["id"] == node_id:
|
||||||
|
return {
|
||||||
|
"x": node["position"]["x"] + position["x"],
|
||||||
|
"y": node["position"]["y"] + position["y"]
|
||||||
|
}
|
||||||
|
self.errors.append(
|
||||||
|
ExceptionDefineition(
|
||||||
|
type=ExceptionType.NODE,
|
||||||
|
node_id=node_id,
|
||||||
|
detail="parent cycle node not found"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise Exception("parent cycle node not found")
|
||||||
|
|
||||||
|
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||||
|
node_data = node["data"]
|
||||||
|
try:
|
||||||
|
return NodeDefinition(
|
||||||
|
id=node["id"],
|
||||||
|
type=self.map_node_type(node_data["type"]),
|
||||||
|
name=node_data.get("title"),
|
||||||
|
cycle=node.get("parentId"),
|
||||||
|
description=None,
|
||||||
|
config=self._convert_node_config(node),
|
||||||
|
position={
|
||||||
|
"x": node["position"]["x"],
|
||||||
|
"y": node["position"]["y"]
|
||||||
|
} if node.get("parentId") is None else self._convert_cycle_node_position(
|
||||||
|
node["parentId"],
|
||||||
|
node["position"]
|
||||||
|
),
|
||||||
|
error_handling=None,
|
||||||
|
cache=None
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||||
|
|
||||||
|
def _convert_node_config(self, node: dict):
|
||||||
|
node_data = node["data"]
|
||||||
|
node_type = node_data["type"]
|
||||||
|
try:
|
||||||
|
converter = self.get_node_convert(node_type)
|
||||||
|
if converter is None:
|
||||||
|
raise Exception(f"node type not supported - {node_type}")
|
||||||
|
return converter(node)
|
||||||
|
except Exception as e:
|
||||||
|
self.errors.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.NODE,
|
||||||
|
node_id=node["id"],
|
||||||
|
node_name=node["data"]["title"],
|
||||||
|
detail=f"convert node error - {e}",
|
||||||
|
))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
||||||
|
try:
|
||||||
|
|
||||||
|
source = edge["source"]
|
||||||
|
target = edge["target"]
|
||||||
|
edge_id = edge["id"]
|
||||||
|
label = None
|
||||||
|
if source in self.branch_node_cache:
|
||||||
|
case_id = "-".join(edge_id.split("-")[1:-2])
|
||||||
|
if case_id == "false":
|
||||||
|
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
||||||
|
else:
|
||||||
|
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||||
|
if source in self.error_branch_node_cache:
|
||||||
|
case_id = "-".join(edge_id.split("-")[1:-2])
|
||||||
|
if case_id == "source":
|
||||||
|
label = "SUCCESS"
|
||||||
|
else:
|
||||||
|
label = "ERROR"
|
||||||
|
return EdgeDefinition(
|
||||||
|
id=edge["id"],
|
||||||
|
source=source,
|
||||||
|
target=target,
|
||||||
|
label=label,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.errors.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.EDGE,
|
||||||
|
detail=f"convert edge error - {e}",
|
||||||
|
))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_variable(self, variable) -> VariableDefinition | None:
|
||||||
|
try:
|
||||||
|
return VariableDefinition(
|
||||||
|
name=variable["name"],
|
||||||
|
default=variable["value"],
|
||||||
|
type=variable["value_type"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.errors.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.VARIABLE,
|
||||||
|
name=variable.get("name"),
|
||||||
|
detail=f"convert variable error - {e}",
|
||||||
|
))
|
||||||
|
|
||||||
|
def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||||
|
return ExecutionConfig()
|
||||||
|
|
||||||
|
|
||||||
75
api/app/core/workflow/adapters/errors.py
Normal file
75
api/app/core/workflow/adapters/errors.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/26 11:29
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionType(StrEnum):
|
||||||
|
NODE = "node"
|
||||||
|
EDGE = "edge"
|
||||||
|
VARIABLE = "variable"
|
||||||
|
TRIGGER = "trigger"
|
||||||
|
EXECUTION = "execution"
|
||||||
|
CONFIG = "config"
|
||||||
|
PLATFORM = "platform"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionDefineition(BaseModel):
|
||||||
|
type: ExceptionType
|
||||||
|
detail: str
|
||||||
|
|
||||||
|
node_id: str | None = None
|
||||||
|
node_name: str | None = None
|
||||||
|
|
||||||
|
scope: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UnknowModelWarning(ExceptionDefineition):
|
||||||
|
type: ExceptionType = ExceptionType.NODE
|
||||||
|
|
||||||
|
def __init__(self, node_id, node_name, model_name):
|
||||||
|
super().__init__(
|
||||||
|
detail=f"Please specify the model mapping manually for model: {model_name}",
|
||||||
|
node_id=node_id,
|
||||||
|
node_name=node_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnknowError(ExceptionDefineition):
|
||||||
|
type: ExceptionType = ExceptionType.UNKNOWN
|
||||||
|
|
||||||
|
def __init__(self, detail: str, **kwargs):
|
||||||
|
super().__init__(detail=detail, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportPlatform(ExceptionDefineition):
|
||||||
|
type: ExceptionType = ExceptionType.PLATFORM
|
||||||
|
|
||||||
|
def __init__(self, platform: str):
|
||||||
|
super().__init__(detail=f"Unsupport platform {platform}")
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportVariableType(ExceptionDefineition):
|
||||||
|
type: ExceptionType = ExceptionType.VARIABLE
|
||||||
|
|
||||||
|
def __init__(self, scope, name, var_type: str, **kwargs):
|
||||||
|
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidConfiguration(ExceptionDefineition):
|
||||||
|
type: ExceptionType = ExceptionType.CONFIG
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(detail="Invalid workflow configuration format")
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportNodeType(ExceptionDefineition):
|
||||||
|
type: ExceptionType = ExceptionType.NODE
|
||||||
|
|
||||||
|
def __init__(self, node_id: str, node_type: str):
|
||||||
|
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
|
||||||
4
api/app/core/workflow/adapters/memory_bear/__init__.py
Normal file
4
api/app/core/workflow/adapters/memory_bear/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/26 11:30
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/25 14:11
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.adapters.base_adapter import (
|
||||||
|
PlatformMetadata,
|
||||||
|
PlatformType,
|
||||||
|
BasePlatformAdapter,
|
||||||
|
WorkflowParserResult
|
||||||
|
)
|
||||||
|
from app.schemas.workflow_schema import ExecutionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBearAdapter(BasePlatformAdapter):
|
||||||
|
NODE_TYPE_MAPPING = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_nodes(self):
|
||||||
|
return self.config.get("workflow").get("nodes")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_edges(self):
|
||||||
|
return self.config.get("workflow").get("edges")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_variables(self):
|
||||||
|
return self.config.get("workflow").get("variables")
|
||||||
|
|
||||||
|
def get_metadata(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
platform_name=PlatformType.MEMORY_BEAR,
|
||||||
|
version="0.2.5",
|
||||||
|
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
def map_node_type(self, platform_node_type) -> str:
|
||||||
|
return platform_node_type
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _valid_nodes(node: dict[str, Any]):
|
||||||
|
if "type" not in node["data"]:
|
||||||
|
return False
|
||||||
|
if "id" not in node or "type" not in node:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_config(self) -> bool:
|
||||||
|
require_fields = frozenset({'app', 'workflow'})
|
||||||
|
if not all(field in self.config for field in require_fields):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for node in self.origin_nodes:
|
||||||
|
if not self._valid_nodes(node):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def parse_workflow(self) -> WorkflowParserResult:
|
||||||
|
self.nodes = self.origin_nodes
|
||||||
|
self.edges = self.origin_edges
|
||||||
|
self.conv_variables = self.origin_variables
|
||||||
|
|
||||||
|
return WorkflowParserResult(
|
||||||
|
success=True,
|
||||||
|
platform=self.get_metadata(),
|
||||||
|
execution_config=ExecutionConfig(),
|
||||||
|
origin_config=self.config,
|
||||||
|
trigger=None,
|
||||||
|
edges=self.edges,
|
||||||
|
nodes=self.nodes,
|
||||||
|
variables=self.conv_variables,
|
||||||
|
warnings=self.warnings,
|
||||||
|
errors=self.errors,
|
||||||
|
|
||||||
|
)
|
||||||
34
api/app/core/workflow/adapters/registry.py
Normal file
34
api/app/core/workflow/adapters/registry.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/25 14:19
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter
|
||||||
|
from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType
|
||||||
|
|
||||||
|
|
||||||
|
class PlatformAdapterRegistry:
|
||||||
|
_adapters: dict[str, type[BasePlatformAdapter]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, platform: str, adapter: type[BasePlatformAdapter]):
|
||||||
|
cls._adapters[platform] = adapter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter:
|
||||||
|
if platform not in cls._adapters:
|
||||||
|
raise ValueError(f"Unsupported platform: {platform}")
|
||||||
|
return cls._adapters.get(platform)(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_platforms(cls) -> list[str]:
|
||||||
|
return list(cls._adapters.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(cls, platform: str) -> bool:
|
||||||
|
return platform in cls._adapters
|
||||||
|
|
||||||
|
|
||||||
|
PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter)
|
||||||
|
PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter)
|
||||||
@@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
SCOPE_PATTERN = re.compile(
|
SCOPE_PATTERN = re.compile(
|
||||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -88,6 +88,8 @@ class AssignerNode(BaseNode):
|
|||||||
await operator.remove_first()
|
await operator.remove_first()
|
||||||
case AssignmentOperator.REMOVE_LAST:
|
case AssignmentOperator.REMOVE_LAST:
|
||||||
await operator.remove_last()
|
await operator.remove_last()
|
||||||
|
case AssignmentOperator.EXTEND:
|
||||||
|
await operator.extend()
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||||
logger.info(f"Node {self.node_id}: execution completed")
|
logger.info(f"Node {self.node_id}: execution completed")
|
||||||
|
|||||||
@@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig):
|
|||||||
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
|
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 输出变量定义
|
# # 输出变量定义
|
||||||
output_variables: list[VariableDefinition] = Field(
|
# output_variables: list[VariableDefinition] = Field(
|
||||||
default_factory=lambda: [
|
# default_factory=lambda: [
|
||||||
VariableDefinition(
|
# VariableDefinition(
|
||||||
name="output",
|
# name="output",
|
||||||
type=VariableType.STRING,
|
# type=VariableType.STRING,
|
||||||
description="工作流的最终输出"
|
# description="工作流的最终输出"
|
||||||
)
|
# )
|
||||||
],
|
# ],
|
||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
# description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
# )
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_schema_extra = {
|
json_schema_extra = {
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ class AssignmentOperator(StrEnum):
|
|||||||
APPEND = "append"
|
APPEND = "append"
|
||||||
REMOVE_LAST = "remove_last"
|
REMOVE_LAST = "remove_last"
|
||||||
REMOVE_FIRST = "remove_first"
|
REMOVE_FIRST = "remove_first"
|
||||||
|
EXTEND = "extend"
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestMethod(StrEnum):
|
class HttpRequestMethod(StrEnum):
|
||||||
|
|||||||
@@ -236,5 +236,5 @@ class HttpRequestNode(BaseNode):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
)
|
)
|
||||||
return "ERROR"
|
return {"output": "ERROR"}
|
||||||
raise RuntimeError("http request failed")
|
raise RuntimeError("http request failed")
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
knowledge_bases: list[KnowledgeBaseConfig] = Field(
|
knowledge_bases: list[KnowledgeBaseConfig] = Field(
|
||||||
...,
|
default_factory=list,
|
||||||
description="Knowledge base config"
|
description="Knowledge base config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
|
||||||
|
|
||||||
|
|
||||||
class StartNodeConfig(BaseNodeConfig):
|
class StartNodeConfig(BaseNodeConfig):
|
||||||
@@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig):
|
|||||||
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
|
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 输出变量定义
|
# # 输出变量定义
|
||||||
output_variables: list[VariableDefinition] = Field(
|
# output_variables: list[VariableDefinition] = Field(
|
||||||
default_factory=lambda: [
|
# default_factory=lambda: [
|
||||||
VariableDefinition(
|
# VariableDefinition(
|
||||||
name="message",
|
# name="message",
|
||||||
type=VariableType.STRING,
|
# type=VariableType.STRING,
|
||||||
description="用户输入的消息"
|
# description="用户输入的消息"
|
||||||
),
|
# ),
|
||||||
VariableDefinition(
|
# VariableDefinition(
|
||||||
name="conversation_vars",
|
# name="conversation_vars",
|
||||||
type=VariableType.OBJECT,
|
# type=VariableType.OBJECT,
|
||||||
description="会话级变量"
|
# description="会话级变量"
|
||||||
),
|
# ),
|
||||||
VariableDefinition(
|
# VariableDefinition(
|
||||||
name="execution_id",
|
# name="execution_id",
|
||||||
type=VariableType.STRING,
|
# type=VariableType.STRING,
|
||||||
description="执行 ID"
|
# description="执行 ID"
|
||||||
),
|
# ),
|
||||||
VariableDefinition(
|
# VariableDefinition(
|
||||||
name="conversation_id",
|
# name="conversation_id",
|
||||||
type=VariableType.STRING,
|
# type=VariableType.STRING,
|
||||||
description="会话 ID"
|
# description="会话 ID"
|
||||||
),
|
# ),
|
||||||
VariableDefinition(
|
# VariableDefinition(
|
||||||
name="workspace_id",
|
# name="workspace_id",
|
||||||
type=VariableType.STRING,
|
# type=VariableType.STRING,
|
||||||
description="工作空间 ID"
|
# description="工作空间 ID"
|
||||||
),
|
# ),
|
||||||
VariableDefinition(
|
# VariableDefinition(
|
||||||
name="user_id",
|
# name="user_id",
|
||||||
type=VariableType.STRING,
|
# type=VariableType.STRING,
|
||||||
description="用户 ID"
|
# description="用户 ID"
|
||||||
)
|
# )
|
||||||
],
|
# ],
|
||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
# description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
# )
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_schema_extra = {
|
json_schema_extra = {
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from enum import Enum, StrEnum
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
|
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||||
|
|
||||||
|
|
||||||
# ---------- Multimodal File Support ----------
|
# ---------- Multimodal File Support ----------
|
||||||
|
|
||||||
@@ -196,6 +198,8 @@ class AppCreate(BaseModel):
|
|||||||
# only for type=multi_agent
|
# only for type=multi_agent
|
||||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
workflow_config: Optional[WorkflowConfigCreate] = None
|
||||||
|
|
||||||
|
|
||||||
class AppUpdate(BaseModel):
|
class AppUpdate(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ class NodeConfig(BaseModel):
|
|||||||
class NodeDefinition(BaseModel):
|
class NodeDefinition(BaseModel):
|
||||||
"""节点定义"""
|
"""节点定义"""
|
||||||
id: str = Field(..., description="节点唯一标识")
|
id: str = Field(..., description="节点唯一标识")
|
||||||
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
|
type: str = Field(
|
||||||
|
...,
|
||||||
|
description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code"
|
||||||
|
)
|
||||||
name: str | None = Field(None, description="节点名称")
|
name: str | None = Field(None, description="节点名称")
|
||||||
cycle: str | None = Field(None, description="父循环节点id")
|
cycle: str | None = Field(None, description="父循环节点id")
|
||||||
description: str | None = Field(None, description="节点描述")
|
description: str | None = Field(None, description="节点描述")
|
||||||
@@ -30,12 +33,12 @@ class NodeDefinition(BaseModel):
|
|||||||
|
|
||||||
class EdgeDefinition(BaseModel):
|
class EdgeDefinition(BaseModel):
|
||||||
"""边定义"""
|
"""边定义"""
|
||||||
id: str | None = Field(None, description="边唯一标识(可选)")
|
id: str | None = Field(default=None, description="边唯一标识(可选)")
|
||||||
source: str = Field(..., description="源节点 ID")
|
source: str = Field(..., description="源节点 ID")
|
||||||
target: str = Field(..., description="目标节点 ID")
|
target: str = Field(..., description="目标节点 ID")
|
||||||
type: str | None = Field(None, description="边类型: normal, error")
|
type: str | None = Field(default=None, description="边类型: normal, error")
|
||||||
condition: str | None = Field(None, description="条件表达式(条件边)")
|
condition: str | None = Field(default=None, description="条件表达式(条件边)")
|
||||||
label: str | None = Field(None, description="边标签")
|
label: str | None = Field(default=None, description="边标签")
|
||||||
|
|
||||||
|
|
||||||
class VariableDefinition(BaseModel):
|
class VariableDefinition(BaseModel):
|
||||||
@@ -44,7 +47,7 @@ class VariableDefinition(BaseModel):
|
|||||||
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
|
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
|
||||||
required: bool = Field(default=False, description="是否必填")
|
required: bool = Field(default=False, description="是否必填")
|
||||||
default: Any = Field(None, description="默认值")
|
default: Any = Field(None, description="默认值")
|
||||||
description: str | None = Field(None, description="变量描述")
|
description: str | None = Field(default=None, description="变量描述")
|
||||||
|
|
||||||
|
|
||||||
class ExecutionConfig(BaseModel):
|
class ExecutionConfig(BaseModel):
|
||||||
@@ -61,6 +64,13 @@ class TriggerConfig(BaseModel):
|
|||||||
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
|
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowImportSave(BaseModel):
|
||||||
|
"""工作流导入请求"""
|
||||||
|
temp_id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
# ==================== 工作流配置 ====================
|
# ==================== 工作流配置 ====================
|
||||||
|
|
||||||
class WorkflowConfigCreate(BaseModel):
|
class WorkflowConfigCreate(BaseModel):
|
||||||
@@ -123,7 +133,8 @@ class WorkflowExecutionResponse(BaseModel):
|
|||||||
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
|
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
|
||||||
error_message: str | None = Field(None, description="错误信息")
|
error_message: str | None = Field(None, description="错误信息")
|
||||||
elapsed_time: float | None = Field(None, description="耗时(秒)")
|
elapsed_time: float | None = Field(None, description="耗时(秒)")
|
||||||
token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
|
token_usage: dict[str, Any] | None = Field(None,
|
||||||
|
description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
|
||||||
|
|
||||||
|
|
||||||
class WorkflowExecutionStreamChunk(BaseModel):
|
class WorkflowExecutionStreamChunk(BaseModel):
|
||||||
|
|||||||
@@ -321,6 +321,26 @@ class AppService:
|
|||||||
self.db.add(agent_cfg)
|
self.db.add(agent_cfg)
|
||||||
logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)})
|
logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)})
|
||||||
|
|
||||||
|
def _create_workflow_config(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
data: app_schema.WorkflowConfigCreate,
|
||||||
|
now: datetime.datetime
|
||||||
|
):
|
||||||
|
workflow_cfg = WorkflowConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=[node.model_dump() for node in data.nodes] if data.nodes else [],
|
||||||
|
edges=[edge.model_dump() for edge in data.edges] if data.edges else [],
|
||||||
|
variables=[var.model_dump() for var in data.variables] if data.variables else [],
|
||||||
|
execution_config=data.execution_config.model_dump() if data.execution_config else {},
|
||||||
|
triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [],
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now
|
||||||
|
)
|
||||||
|
self.db.add(workflow_cfg)
|
||||||
|
|
||||||
def _create_multi_agent_config(
|
def _create_multi_agent_config(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
@@ -532,6 +552,9 @@ class AppService:
|
|||||||
if app.type == "multi_agent" and data.multi_agent_config:
|
if app.type == "multi_agent" and data.multi_agent_config:
|
||||||
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
||||||
|
|
||||||
|
if app.type == "workflow" and data.workflow_config:
|
||||||
|
self._create_workflow_config(app.id, data.workflow_config, now)
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(app)
|
self.db.refresh(app)
|
||||||
|
|
||||||
@@ -968,7 +991,7 @@ class AppService:
|
|||||||
config = self.db.scalars(stmt).first()
|
config = self.db.scalars(stmt).first()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_memory=config.memory
|
config_memory = config.memory
|
||||||
if 'memory_content' in config_memory:
|
if 'memory_content' in config_memory:
|
||||||
config.memory['memory_config_id'] = config.memory.pop('memory_content')
|
config.memory['memory_config_id'] = config.memory.pop('memory_content')
|
||||||
except:
|
except:
|
||||||
@@ -1189,9 +1212,9 @@ class AppService:
|
|||||||
# ==================== 记忆配置提取方法 ====================
|
# ==================== 记忆配置提取方法 ====================
|
||||||
|
|
||||||
def _extract_memory_config_id(
|
def _extract_memory_config_id(
|
||||||
self,
|
self,
|
||||||
app_type: str,
|
app_type: str,
|
||||||
config: Dict[str, Any]
|
config: Dict[str, Any]
|
||||||
) -> Tuple[Optional[uuid.UUID], bool]:
|
) -> Tuple[Optional[uuid.UUID], bool]:
|
||||||
"""从发布配置中提取 memory_config_id(委托给 MemoryConfigService)
|
"""从发布配置中提取 memory_config_id(委托给 MemoryConfigService)
|
||||||
|
|
||||||
@@ -1210,8 +1233,8 @@ class AppService:
|
|||||||
return service.extract_memory_config_id(app_type, config)
|
return service.extract_memory_config_id(app_type, config)
|
||||||
|
|
||||||
def _get_workspace_default_memory_config_id(
|
def _get_workspace_default_memory_config_id(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> Optional[uuid.UUID]:
|
) -> Optional[uuid.UUID]:
|
||||||
"""获取工作空间的默认记忆配置ID
|
"""获取工作空间的默认记忆配置ID
|
||||||
|
|
||||||
@@ -1235,9 +1258,9 @@ class AppService:
|
|||||||
return config.config_id
|
return config.config_id
|
||||||
|
|
||||||
def _update_endusers_memory_config(
|
def _update_endusers_memory_config(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
memory_config_id: uuid.UUID
|
memory_config_id: uuid.UUID
|
||||||
) -> int:
|
) -> int:
|
||||||
"""批量更新应用下所有终端用户的 memory_config_id
|
"""批量更新应用下所有终端用户的 memory_config_id
|
||||||
|
|
||||||
|
|||||||
102
api/app/services/workflow_import_service.py
Normal file
102
api/app/services/workflow_import_service.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/2/25 14:39
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.aioRedis import aio_redis_set, aio_redis_get
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
|
||||||
|
from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration
|
||||||
|
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||||
|
from app.schemas import AppCreate
|
||||||
|
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
from app.services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowImportService:
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
self.registry = PlatformAdapterRegistry
|
||||||
|
self.cache_timeout = settings.WORKFLOW_IMPORT_CACHE_TIMEOUT
|
||||||
|
|
||||||
|
self.app_service = AppService(db)
|
||||||
|
self.workflow_service = WorkflowService(db)
|
||||||
|
|
||||||
|
async def flush_config(self, temp_id: str, config: WorkflowParserResult):
|
||||||
|
config_cache = await aio_redis_get(temp_id)
|
||||||
|
if not config_cache:
|
||||||
|
raise BusinessException("Workflow configuration has expired. Please re-upload it.")
|
||||||
|
await aio_redis_set(temp_id, config.model_dump_json(), expire=self.cache_timeout)
|
||||||
|
|
||||||
|
async def upload_config(
|
||||||
|
self,
|
||||||
|
platform: str,
|
||||||
|
config: dict[str, Any],
|
||||||
|
):
|
||||||
|
|
||||||
|
if not self.registry.is_supported(platform):
|
||||||
|
return WorkflowImportResult(
|
||||||
|
success=False,
|
||||||
|
temp_id=None,
|
||||||
|
workflow_id=None,
|
||||||
|
errors=[UnsupportPlatform(platform=platform)]
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = self.registry.get_adapter(platform, config)
|
||||||
|
|
||||||
|
if not adapter.validate_config():
|
||||||
|
return WorkflowImportResult(
|
||||||
|
success=False,
|
||||||
|
temp_id=None,
|
||||||
|
workflow_id=None,
|
||||||
|
errors=[InvalidConfiguration()]
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_config = adapter.parse_workflow()
|
||||||
|
temp_id = uuid.uuid4().hex
|
||||||
|
await aio_redis_set(temp_id, workflow_config.model_dump(), expire=self.cache_timeout)
|
||||||
|
return WorkflowImportResult(
|
||||||
|
success=True,
|
||||||
|
temp_id=temp_id,
|
||||||
|
workflow_id=None,
|
||||||
|
edges=workflow_config.edges,
|
||||||
|
nodes=workflow_config.nodes,
|
||||||
|
variables=workflow_config.variables,
|
||||||
|
warnings=workflow_config.warnings,
|
||||||
|
errors=workflow_config.errors
|
||||||
|
)
|
||||||
|
|
||||||
|
async def save_workflow(
|
||||||
|
self,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
temp_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str | None,
|
||||||
|
):
|
||||||
|
config = await aio_redis_get(temp_id)
|
||||||
|
if config is None:
|
||||||
|
raise BusinessException("Configuration import timed out. Please try again.")
|
||||||
|
config = json.loads(config)
|
||||||
|
app = self.app_service.create_app(
|
||||||
|
user_id=user_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
data=AppCreate(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
type="workflow",
|
||||||
|
workflow_config=WorkflowConfigCreate(
|
||||||
|
nodes=config["nodes"],
|
||||||
|
edges=config["edges"],
|
||||||
|
variables=config["variables"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return app
|
||||||
@@ -6,13 +6,16 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Annotated, Optional
|
from typing import Any, Annotated, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
from app.models import App
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
@@ -38,6 +41,8 @@ class WorkflowService:
|
|||||||
self.conversation_service = ConversationService(db)
|
self.conversation_service = ConversationService(db)
|
||||||
self.multimodal_service = MultimodalService(db)
|
self.multimodal_service = MultimodalService(db)
|
||||||
|
|
||||||
|
self.registry = PlatformAdapterRegistry
|
||||||
|
|
||||||
# ==================== 配置管理 ====================
|
# ==================== 配置管理 ====================
|
||||||
|
|
||||||
def create_workflow_config(
|
def create_workflow_config(
|
||||||
@@ -200,6 +205,32 @@ class WorkflowService:
|
|||||||
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
|
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def export_workflow_dsl(self, app_id: uuid.UUID):
|
||||||
|
config = self.get_workflow_config(app_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
app: App = config.app
|
||||||
|
dsl_info = {
|
||||||
|
"app": {
|
||||||
|
"name": app.name,
|
||||||
|
"description": app.description,
|
||||||
|
"icon": app.icon,
|
||||||
|
"icon_type": app.icon_type
|
||||||
|
},
|
||||||
|
"workflow": {
|
||||||
|
"variables": config.variables,
|
||||||
|
"edges": config.edges,
|
||||||
|
"nodes": config.nodes,
|
||||||
|
"execution_config": config.execution_config,
|
||||||
|
"triggers": config.triggers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return yaml.dump(dsl_info, default_flow_style=False, allow_unicode=True)
|
||||||
|
|
||||||
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
|
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
|
||||||
"""检查工作流配置的完整性
|
"""检查工作流配置的完整性
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user