feat(workflow): add Dify workflow import adapter and related APIs
This commit is contained in:
@@ -10,7 +10,6 @@ from app.core.config import settings
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 创建连接池
|
||||
pool = ConnectionPool.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||
@@ -21,6 +20,7 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
try:
|
||||
@@ -29,7 +29,8 @@ async def get_redis_connection():
|
||||
logger.error(f"Redis连接失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
|
||||
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
"""设置Redis键值
|
||||
|
||||
Args:
|
||||
@@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
@@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
|
||||
async def aio_redis_get(key: str):
|
||||
"""获取Redis键值"""
|
||||
try:
|
||||
@@ -58,6 +60,7 @@ async def aio_redis_get(key: str):
|
||||
logger.error(f"Redis get错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_delete(key: str):
|
||||
"""删除Redis键"""
|
||||
try:
|
||||
@@ -66,6 +69,7 @@ async def aio_redis_delete(key: str):
|
||||
logger.error(f"Redis delete错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
"""发布消息到Redis频道"""
|
||||
try:
|
||||
@@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
logger.error(f"Redis发布错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
"""Redis订阅器"""
|
||||
|
||||
|
||||
def __init__(self, channel: str):
|
||||
self.channel = channel
|
||||
self.conn = None
|
||||
@@ -88,25 +93,25 @@ class RedisSubscriber:
|
||||
self.is_closed = False
|
||||
self._queue = asyncio.Queue()
|
||||
self._task = None
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""开始订阅"""
|
||||
if self.is_closed or self._task:
|
||||
return
|
||||
|
||||
|
||||
self._task = asyncio.create_task(self._receive_messages())
|
||||
logger.info(f"开始订阅: {self.channel}")
|
||||
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""接收消息"""
|
||||
try:
|
||||
self.conn = await get_redis_connection()
|
||||
if not self.conn:
|
||||
return
|
||||
|
||||
|
||||
self.pubsub = self.conn.pubsub()
|
||||
await self.pubsub.subscribe(self.channel)
|
||||
|
||||
|
||||
while not self.is_closed:
|
||||
try:
|
||||
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
|
||||
@@ -127,7 +132,7 @@ class RedisSubscriber:
|
||||
finally:
|
||||
await self._queue.put(None)
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.pubsub:
|
||||
@@ -141,7 +146,7 @@ class RedisSubscriber:
|
||||
await self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def get_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取消息"""
|
||||
if self.is_closed:
|
||||
@@ -153,7 +158,7 @@ class RedisSubscriber:
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""关闭订阅器"""
|
||||
if self.is_closed:
|
||||
@@ -163,32 +168,33 @@ class RedisSubscriber:
|
||||
self._task.cancel()
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
"""Redis发布订阅管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.subscribers = {}
|
||||
|
||||
|
||||
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
|
||||
return await aio_redis_publish(channel, message)
|
||||
|
||||
|
||||
def get_subscriber(self, channel: str) -> RedisSubscriber:
|
||||
if channel in self.subscribers:
|
||||
subscriber = self.subscribers[channel]
|
||||
if not subscriber.is_closed:
|
||||
return subscriber
|
||||
|
||||
|
||||
subscriber = RedisSubscriber(channel)
|
||||
self.subscribers[channel] = subscriber
|
||||
return subscriber
|
||||
|
||||
|
||||
def cancel_subscription(self, channel: str) -> bool:
|
||||
if channel in self.subscribers:
|
||||
asyncio.create_task(self.subscribers[channel].close())
|
||||
del self.subscribers[channel]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cancel_all_subscriptions(self) -> int:
|
||||
count = len(self.subscribers)
|
||||
for subscriber in self.subscribers.values():
|
||||
@@ -196,6 +202,6 @@ class RedisPubSubManager:
|
||||
self.subscribers.clear()
|
||||
return count
|
||||
|
||||
|
||||
# 全局实例
|
||||
pubsub_manager = RedisPubSubManager()
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import uuid
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -17,12 +18,13 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -65,7 +67,7 @@ def list_apps(
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
@@ -879,6 +881,60 @@ async def update_workflow_config(
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/workflow/export")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
"""导出工作流配置为YAML文件"""
|
||||
workflow_service = WorkflowService(db)
|
||||
|
||||
return success(data={
|
||||
"content": workflow_service.export_workflow_dsl(app_id=app_id),
|
||||
})
|
||||
|
||||
|
||||
@router.post("/workflow/import")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_workflow_config(
|
||||
file: UploadFile = File(...),
|
||||
platform: str = Form(...),
|
||||
app_id: str = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
|
||||
):
|
||||
"""从YAML内容导入工作流配置"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw_text = (await file.read()).decode("utf-8")
|
||||
import_service = WorkflowImportService(db)
|
||||
config = yaml.safe_load(raw_text)
|
||||
result = await import_service.upload_config(platform, config)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.post("/workflow/import/save")
|
||||
@cur_workspace_access_guard()
|
||||
async def save_workflow_import(
|
||||
data: WorkflowImportSave,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
import_service = WorkflowImportService(db)
|
||||
app = await import_service.save_workflow(
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
temp_id=data.temp_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
)
|
||||
return success(data=app_schema.App.model_validate(app))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
@@ -889,12 +945,14 @@ def get_app_statistics(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
@@ -931,6 +989,8 @@ def get_workspace_api_statistics(
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
|
||||
@@ -16,18 +16,18 @@ class Settings:
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -57,7 +57,6 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
@@ -91,7 +90,7 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
@@ -130,7 +129,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
@@ -225,6 +224,7 @@ class Settings:
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
@@ -232,20 +232,20 @@ class Settings:
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
|
||||
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
|
||||
8
api/app/core/workflow/adapters/__init__.py
Normal file
8
api/app/core/workflow/adapters/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/24 15:54
|
||||
from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||
|
||||
__all__ = ["DifyAdapter", "MemoryBearAdapter"]
|
||||
88
api/app/core/workflow/adapters/base_adapter.py
Normal file
88
api/app/core/workflow/adapters/base_adapter.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/24 15:58
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition
|
||||
from app.schemas.workflow_schema import (
|
||||
EdgeDefinition,
|
||||
NodeDefinition,
|
||||
VariableDefinition,
|
||||
ExecutionConfig,
|
||||
TriggerConfig
|
||||
)
|
||||
|
||||
|
||||
class PlatformType(StrEnum):
|
||||
MEMORY_BEAR = "memory_bear"
|
||||
DIFY = "dify"
|
||||
COZE = "coze"
|
||||
|
||||
|
||||
class PlatformMetadata(BaseModel):
|
||||
platform_name: str
|
||||
version: str
|
||||
support_node_types: list[str]
|
||||
|
||||
|
||||
class WorkflowParserResult(BaseModel):
|
||||
success: bool
|
||||
platform: PlatformMetadata
|
||||
execution_config: ExecutionConfig
|
||||
origin_config: dict[str, Any]
|
||||
trigger: TriggerConfig | None
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowImportResult(BaseModel):
|
||||
success: bool
|
||||
temp_id: str | None = Field(..., description="cache id")
|
||||
workflow_id: str | None = Field(..., description="workflow id")
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
self.config = config
|
||||
self.nodes: list[NodeDefinition] = []
|
||||
self.edges: list[EdgeDefinition] = []
|
||||
self.conv_variables: list[VariableDefinition] = []
|
||||
|
||||
self.errors = []
|
||||
self.warnings = []
|
||||
|
||||
self.branch_node_cache = defaultdict(list)
|
||||
self.error_branch_node_cache = []
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
"""get platform metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_config(self) -> bool:
|
||||
"""platform configuration validate"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
"""parse platform configuration to local config"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_node_type(self, platform_node_type: str) -> str:
|
||||
pass
|
||||
75
api/app/core/workflow/adapters/base_converter.py
Normal file
75
api/app/core/workflow/adapters/base_converter.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/26 14:32
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType
|
||||
|
||||
|
||||
class BaseConverter(ABC):
|
||||
@staticmethod
|
||||
def _convert_string(var):
|
||||
try:
|
||||
return str(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.STRING)
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean(var):
|
||||
try:
|
||||
return bool(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.BOOLEAN)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(var):
|
||||
try:
|
||||
return float(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.NUMBER)
|
||||
|
||||
@staticmethod
|
||||
def _convert_object(var):
|
||||
try:
|
||||
return dict(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.OBJECT)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _convert_file(var):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_string(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_STRING)
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_number(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_NUMBER)
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_boolean(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN)
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_object(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_OBJECT)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _convert_array_file(var):
|
||||
pass
|
||||
4
api/app/core/workflow/adapters/dify/__init__.py
Normal file
4
api/app/core/workflow/adapters/dify/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 18:20
|
||||
659
api/app/core/workflow/adapters/dify/converter.py
Normal file
659
api/app/core/workflow/adapters/dify/converter.py
Normal file
@@ -0,0 +1,659 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 18:21
|
||||
import base64
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
||||
ExceptionType
|
||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition
|
||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
||||
CycleVariable
|
||||
from app.core.workflow.nodes.end import EndNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
||||
HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
||||
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
||||
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
||||
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
|
||||
class DifyConverter(BaseConverter):
|
||||
errors: list
|
||||
warnings: list
|
||||
branch_node_cache: dict
|
||||
error_branch_node_cache: list
|
||||
|
||||
def __init__(self):
|
||||
self.CONFIG_CONVERT_MAP = {
|
||||
"start": self.convert_start_node_config,
|
||||
"llm": self.convert_llm_node_config,
|
||||
"answer": self.convert_end_node_config,
|
||||
"if-else": self.convert_if_else_node_config,
|
||||
"loop": self.convert_loop_node_config,
|
||||
"iteration": self.convert_iteration_node_config,
|
||||
"assigner": self.convert_assigner_node_config,
|
||||
"code": self.convert_code_node_config,
|
||||
"http-request": self.convert_http_node_config,
|
||||
"template-transform": self.convert_jinja_render_node_config,
|
||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
||||
"question-classifier": self.convert_question_classifier_node_config,
|
||||
"variable-aggregator": self.convert_variable_aggregator,
|
||||
"loop-start": lambda x: {},
|
||||
"iteration-start": lambda x: {},
|
||||
"loop-end": lambda x: {},
|
||||
}
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
func = self.CONFIG_CONVERT_MAP.get(node_type, None)
|
||||
return func
|
||||
|
||||
@staticmethod
|
||||
def is_variable(expression) -> bool:
|
||||
return bool(re.match(r"\{\{#(.*?)#}}", expression))
|
||||
|
||||
@staticmethod
|
||||
def process_var_selector(var_selector):
|
||||
if not var_selector:
|
||||
return ""
|
||||
selector = var_selector.split('.')
|
||||
if len(selector) != 2:
|
||||
raise Exception(f"invalid variable selector: {var_selector}")
|
||||
if selector[0] == "conversation":
|
||||
selector[0] = "conv"
|
||||
var_selector = ".".join(selector)
|
||||
mapping = {
|
||||
"sys.query": "sys.message"
|
||||
}
|
||||
|
||||
var_selector = mapping.get(var_selector, var_selector)
|
||||
return var_selector
|
||||
|
||||
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
|
||||
if not self.process_var_selector(".".join(variable_selector)):
|
||||
return None
|
||||
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
||||
|
||||
def trans_variable_format(self, content):
|
||||
pattern = re.compile(r"\{\{#(.*?)#}}")
|
||||
|
||||
def replacer(match: re.Match) -> str:
|
||||
raw_name = match.group(1)
|
||||
new_name = self.process_var_selector(raw_name)
|
||||
return f"{{{{{new_name}}}}}"
|
||||
|
||||
return pattern.sub(replacer, content)
|
||||
|
||||
@staticmethod
|
||||
def _convert_file(var):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_file(var):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def variable_type_map(source_type) -> VariableType | None:
|
||||
type_map = {
|
||||
"file": VariableType.FILE,
|
||||
"paragraph": VariableType.STRING,
|
||||
"text-input": VariableType.STRING,
|
||||
"number": VariableType.NUMBER,
|
||||
"checkbox": VariableType.BOOLEAN,
|
||||
"file-list": VariableType.ARRAY_FILE,
|
||||
"select": VariableType.STRING,
|
||||
}
|
||||
var_type = type_map.get(source_type, source_type)
|
||||
return var_type
|
||||
|
||||
def convert_variable_type(self, target_type: VariableType, origin_value: Any):
|
||||
if not origin_value:
|
||||
return DEFAULT_VALUE(target_type)
|
||||
try:
|
||||
match target_type:
|
||||
case VariableType.STRING:
|
||||
return self._convert_string(origin_value)
|
||||
case VariableType.NUMBER:
|
||||
return self._convert_number(origin_value)
|
||||
case VariableType.BOOLEAN:
|
||||
return self._convert_boolean(origin_value)
|
||||
case VariableType.FILE:
|
||||
return self._convert_file(origin_value)
|
||||
case VariableType.ARRAY_FILE:
|
||||
return self._convert_array_file(origin_value)
|
||||
case _:
|
||||
return origin_value
|
||||
except:
|
||||
raise Exception(f"convert variable failed: {target_type}")
|
||||
|
||||
@staticmethod
|
||||
def convert_compare_operator(operator):
|
||||
operator_map = {
|
||||
"is": ComparisonOperator.EQ,
|
||||
"is not": ComparisonOperator.NE,
|
||||
"=": ComparisonOperator.EQ,
|
||||
"≠": ComparisonOperator.NE,
|
||||
">": ComparisonOperator.GT,
|
||||
"<": ComparisonOperator.LT,
|
||||
"≥": ComparisonOperator.GE,
|
||||
"≤": ComparisonOperator.LE,
|
||||
"not empty": ComparisonOperator.NOT_EMPTY,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@staticmethod
|
||||
def convert_assignment_operator(operator):
|
||||
operator_map = {
|
||||
"+=": AssignmentOperator.ADD,
|
||||
"-=": AssignmentOperator.SUBTRACT,
|
||||
"*=": AssignmentOperator.MULTIPLY,
|
||||
"/=": AssignmentOperator.DIVIDE,
|
||||
"over-write": AssignmentOperator.COVER,
|
||||
"remove-last": AssignmentOperator.REMOVE_LAST,
|
||||
"remove-first": AssignmentOperator.REMOVE_FIRST,
|
||||
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@staticmethod
|
||||
def convert_http_auth_type(auth_type):
|
||||
auth_type_map = {
|
||||
"no-auth": HttpAuthType.NONE,
|
||||
"bearer": HttpAuthType.BEARER,
|
||||
"basic": HttpAuthType.BASIC,
|
||||
"custom": HttpAuthType.CUSTOM,
|
||||
}
|
||||
return auth_type_map.get(auth_type, auth_type)
|
||||
|
||||
@staticmethod
|
||||
def convert_http_content_type(content_type):
|
||||
content_type_map = {
|
||||
"none": HttpContentType.NONE,
|
||||
"form-data": HttpContentType.FROM_DATA,
|
||||
"x-www-form-urlencoded": HttpContentType.WWW_FORM,
|
||||
"json": HttpContentType.JSON,
|
||||
"raw-text": HttpContentType.RAW,
|
||||
"binary": HttpContentType.BINARY,
|
||||
}
|
||||
return content_type_map.get(content_type, content_type)
|
||||
|
||||
@staticmethod
|
||||
def convert_http_error_handle_type(handle_type):
|
||||
handle_type_map = {
|
||||
"none": HttpErrorHandle.NONE,
|
||||
"fail-branch": HttpErrorHandle.BRANCH,
|
||||
"default-value": HttpErrorHandle.DEFAULT,
|
||||
}
|
||||
return handle_type_map.get(handle_type, handle_type)
|
||||
|
||||
def convert_start_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
start_vars = []
|
||||
for var in node_data["variables"]:
|
||||
var_type = self.variable_type_map(var["type"])
|
||||
if not var_type:
|
||||
self.errors.append(
|
||||
UnsupportVariableType(
|
||||
scope=node["id"],
|
||||
name=var["variable"],
|
||||
var_type=var["type"],
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"]
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if var_type in ["file", "array[file]"]:
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
name=var["variable"],
|
||||
detail=f"Unsupport Variable type for start node: {var_type}"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
var_def = VariableDefinition(
|
||||
name=var["variable"],
|
||||
type=var_type,
|
||||
required=var["required"],
|
||||
default=self.convert_variable_type(
|
||||
var_type, var["default"]
|
||||
),
|
||||
description=var["label"],
|
||||
max_length=var.get("max_length"),
|
||||
)
|
||||
start_vars.append(var_def)
|
||||
return StartNodeConfig(
|
||||
variables=start_vars
|
||||
).model_dump()
|
||||
|
||||
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
categories = []
|
||||
for category in node_data["classes"]:
|
||||
self.branch_node_cache[node["id"]].append(category["id"])
|
||||
categories.append(
|
||||
ClassifierConfig(
|
||||
class_name=category["name"],
|
||||
)
|
||||
)
|
||||
|
||||
return QuestionClassifierNodeConfig.model_construct(
|
||||
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]),
|
||||
categories=categories,
|
||||
).model_dump()
|
||||
|
||||
def convert_llm_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
|
||||
memory = MemoryWindowSetting(
|
||||
enable=bool(node_data.get("memory")),
|
||||
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
||||
window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20))
|
||||
)
|
||||
messages = []
|
||||
for message in node_data["prompt_template"]:
|
||||
messages.append(
|
||||
MessageConfig(
|
||||
role=message["role"],
|
||||
content=self.trans_variable_format(message["text"])
|
||||
)
|
||||
)
|
||||
if memory.enable:
|
||||
messages.append(
|
||||
MessageConfig(
|
||||
role="user",
|
||||
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"])
|
||||
)
|
||||
)
|
||||
vision = node_data["vision"]["enabled"]
|
||||
vision_input = self._process_list_variable_litearl(
|
||||
node_data["vision"]["configs"]["variable_selector"]
|
||||
) if vision else None
|
||||
return LLMNodeConfig.model_construct(
|
||||
model_id=None,
|
||||
context=context,
|
||||
memory=memory,
|
||||
vision=vision,
|
||||
vision_input=vision_input,
|
||||
messages=messages
|
||||
).model_dump()
|
||||
|
||||
def convert_end_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
return EndNodeConfig(
|
||||
output=self.trans_variable_format(node_data["answer"]),
|
||||
).model_dump()
|
||||
|
||||
def convert_if_else_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
cases = []
|
||||
for case in node_data["cases"]:
|
||||
case_id = case["id"]
|
||||
logical_operator = case["logical_operator"]
|
||||
conditions = []
|
||||
for condition in case["conditions"]:
|
||||
right_value = condition["value"]
|
||||
condition_detail = ConditionDetail(
|
||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||
left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}",
|
||||
right=self.trans_variable_format(
|
||||
right_value
|
||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||
self.variable_type_map(condition["varType"]),
|
||||
condition["value"]
|
||||
),
|
||||
input_type=ValueInputType.VARIABLE
|
||||
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
|
||||
)
|
||||
conditions.append(condition_detail)
|
||||
cases.append(
|
||||
ConditionBranchConfig(
|
||||
logical_operator=logical_operator,
|
||||
expressions=conditions
|
||||
)
|
||||
)
|
||||
self.branch_node_cache[node["id"]].append(case_id)
|
||||
return IfElseNodeConfig(
|
||||
cases=cases
|
||||
).model_dump()
|
||||
|
||||
def convert_loop_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
logical_operator = node_data["logical_operator"]
|
||||
conditions = []
|
||||
for condition in node_data["break_conditions"]:
|
||||
right_value = condition["value"]
|
||||
conditions.append(
|
||||
LoopConditionDetail(
|
||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
||||
right=self.trans_variable_format(
|
||||
right_value
|
||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||
self.variable_type_map(condition["varType"]),
|
||||
condition["value"]
|
||||
),
|
||||
input_type=ValueInputType.VARIABLE
|
||||
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
|
||||
)
|
||||
)
|
||||
condition_config = ConditionsConfig(
|
||||
logical_operator=logical_operator,
|
||||
expressions=conditions
|
||||
)
|
||||
loop_variables = []
|
||||
for variable in node_data["loop_variables"]:
|
||||
right_input_type = variable["value_type"]
|
||||
right_value_type = self.variable_type_map(variable["var_type"])
|
||||
if right_input_type == ValueInputType.VARIABLE:
|
||||
right_value = self._process_list_variable_litearl(variable["value"])
|
||||
else:
|
||||
right_value = self.convert_variable_type(right_value_type, variable["value"])
|
||||
loop_variables.append(
|
||||
CycleVariable(
|
||||
name=variable["label"],
|
||||
type=right_value_type,
|
||||
value=right_value,
|
||||
input_type=right_input_type
|
||||
)
|
||||
)
|
||||
return LoopNodeConfig(
|
||||
condition=condition_config,
|
||||
cycle_vars=loop_variables,
|
||||
max_loop=node_data["loop_count"]
|
||||
).model_dump()
|
||||
|
||||
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
return IterationNodeConfig(
|
||||
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
||||
parallel=node_data["is_parallel"],
|
||||
parallel_count=node_data["parallel_nums"],
|
||||
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
||||
output_type=self.variable_type_map(node_data["output_type"]),
|
||||
flatten=node_data["flatten_output"],
|
||||
).model_dump()
|
||||
|
||||
def convert_assigner_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
assignments = []
|
||||
for assignment in node_data["items"]:
|
||||
if assignment.get("operation") is None or assignment.get("value") is None:
|
||||
continue
|
||||
assignments.append(
|
||||
AssignmentItem(
|
||||
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
|
||||
value=self._process_list_variable_litearl(
|
||||
assignment["value"]
|
||||
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
||||
operation=self.convert_assignment_operator(assignment["operation"])
|
||||
)
|
||||
)
|
||||
return AssignerNodeConfig(
|
||||
assignments=assignments
|
||||
).model_dump()
|
||||
|
||||
def convert_code_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
input_variables = []
|
||||
for input_variable in node_data["variables"]:
|
||||
input_variables.append(
|
||||
InputVariable(
|
||||
name=input_variable["variable"],
|
||||
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
||||
)
|
||||
)
|
||||
|
||||
output_variables = []
|
||||
for output_variable in node_data["outputs"]:
|
||||
output_variables.append(
|
||||
OutputVariable(
|
||||
name=output_variable,
|
||||
type=node_data["outputs"][output_variable]["type"],
|
||||
)
|
||||
)
|
||||
|
||||
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
|
||||
|
||||
return CodeNodeConfig(
|
||||
input_variables=input_variables,
|
||||
language=node_data["code_language"],
|
||||
output_variables=output_variables,
|
||||
code=code
|
||||
).model_dump()
|
||||
|
||||
def convert_http_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
if node_data["authorization"] != 'no-auth':
|
||||
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
|
||||
auth_config = HttpAuthConfig(
|
||||
auth_type=auth_type,
|
||||
header=node_data["authorization"]["config"].get("header"),
|
||||
api_key=node_data["authorization"]["config"].get("api_key"),
|
||||
)
|
||||
else:
|
||||
auth_config = HttpAuthConfig()
|
||||
|
||||
content_type = self.convert_http_content_type(node_data["body"]["type"])
|
||||
if content_type == HttpContentType.FROM_DATA:
|
||||
body_content = []
|
||||
for content in node_data["body"]["data"]:
|
||||
body_content.append(
|
||||
HttpFormData(
|
||||
key=self.trans_variable_format(content["key"]),
|
||||
type=content["type"],
|
||||
value=self.trans_variable_format(content["value"]),
|
||||
)
|
||||
)
|
||||
elif content_type == HttpContentType.WWW_FORM:
|
||||
body_content = {}
|
||||
for content in node_data["body"]["data"]:
|
||||
body_content[
|
||||
self.trans_variable_format(content["key"])
|
||||
] = self.trans_variable_format(content["value"])
|
||||
else:
|
||||
if node_data["body"]["data"]:
|
||||
body_content = node_data["body"]["data"][0]["value"]
|
||||
else:
|
||||
body_content = ""
|
||||
|
||||
headers = {}
|
||||
for header in node_data["headers"].split("\n"):
|
||||
if not header:
|
||||
continue
|
||||
|
||||
key_value = header.split(":")
|
||||
if len(key_value) == 2:
|
||||
headers[
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
detail=f"Invalid header/param - {header}",
|
||||
))
|
||||
|
||||
params = {}
|
||||
for param in node_data["params"].split("\n"):
|
||||
if not param:
|
||||
continue
|
||||
|
||||
key_value = param.split(":")
|
||||
if len(key_value) == 2:
|
||||
params[
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
detail=f"Invalid header/param - {param}",
|
||||
))
|
||||
|
||||
error_handle_type = self.convert_http_error_handle_type(
|
||||
node_data.get("error_strategy", "none")
|
||||
)
|
||||
default_value = None
|
||||
if error_handle_type == HttpErrorHandle.DEFAULT:
|
||||
default_body = ""
|
||||
default_header = {}
|
||||
default_status_code = 0
|
||||
for var in node_data["default_value"]:
|
||||
if var["key"] == "body":
|
||||
default_body = var["value"]
|
||||
elif var["key"] == "header":
|
||||
default_header = var["value"]
|
||||
elif var["key"] == "status_code":
|
||||
default_status_code = var["value"]
|
||||
default_value = HttpErrorDefaultTamplete(
|
||||
body=default_body,
|
||||
headers=default_header,
|
||||
status_code=default_status_code,
|
||||
)
|
||||
|
||||
self.error_branch_node_cache.append(node['id'])
|
||||
return HttpRequestNodeConfig(
|
||||
method=node_data["method"].upper(),
|
||||
url=node_data["url"],
|
||||
auth=auth_config,
|
||||
body=HttpContentTypeConfig(
|
||||
content_type=self.convert_http_content_type(node_data["body"]["type"]),
|
||||
data=body_content,
|
||||
),
|
||||
headers=headers,
|
||||
params=params,
|
||||
verify_ssl=node_data["ssl_verify"],
|
||||
timeouts=HttpTimeOutConfig(
|
||||
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
|
||||
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
|
||||
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
|
||||
),
|
||||
retry=HttpRetryConfig(
|
||||
enable=node_data["retry_config"]["retry_enabled"],
|
||||
max_attempts=node_data["retry_config"]["max_retries"],
|
||||
retry_interval=node_data["retry_config"]["retry_interval"],
|
||||
),
|
||||
error_handle=HttpErrorHandleConfig(
|
||||
method=error_handle_type,
|
||||
default=default_value,
|
||||
)
|
||||
).model_dump()
|
||||
|
||||
def convert_jinja_render_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
mapping = []
|
||||
for variable in node_data["variables"]:
|
||||
mapping.append(VariablesMappingConfig(
|
||||
name=variable["variable"],
|
||||
value=self._process_list_variable_litearl(variable["value_selector"])
|
||||
))
|
||||
return JinjaRenderNodeConfig(
|
||||
template=node_data["template"],
|
||||
mapping=mapping,
|
||||
).model_dump()
|
||||
|
||||
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||
))
|
||||
return KnowledgeRetrievalNodeConfig.model_construct(
|
||||
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||
).model_dump()
|
||||
|
||||
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
params = []
|
||||
for param in node_data["parameters"]:
|
||||
params.append(
|
||||
ParamsConfig(
|
||||
name=param["name"],
|
||||
desc=param["description"],
|
||||
required=param["required"],
|
||||
type=param["type"],
|
||||
)
|
||||
)
|
||||
return ParameterExtractorNodeConfig.model_construct(
|
||||
text=self._process_list_variable_litearl(node_data["query"]),
|
||||
params=params,
|
||||
prompt=node_data["instruction"]
|
||||
).model_dump()
|
||||
|
||||
def convert_variable_aggregator(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
group_enable = node_data["advanced_settings"]["group_enabled"]
|
||||
group_variables = {}
|
||||
group_type = {}
|
||||
if not group_enable:
|
||||
group_variables["output"] = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
for variable in node_data["variables"]
|
||||
]
|
||||
group_type["output"] = node_data["output_type"]
|
||||
else:
|
||||
for group in node_data["advanced_settings"]["groups"]:
|
||||
group_variables[group["group_name"]] = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
for variable in group["variables"]
|
||||
]
|
||||
group_type[group["group_name"]] = group["output_type"]
|
||||
|
||||
return VariableAggregatorNodeConfig(
|
||||
group=group_enable,
|
||||
group_variables=group_variables,
|
||||
group_type=group_type,
|
||||
).model_dump()
|
||||
239
api/app/core/workflow/adapters/dify/dify_adapter.py
Normal file
239
api/app/core/workflow/adapters/dify/dify_adapter.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/24 16:05
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
BasePlatformAdapter,
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.core.workflow.adapters.dify.converter import DifyConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import (
|
||||
NodeDefinition,
|
||||
EdgeDefinition,
|
||||
VariableDefinition,
|
||||
TriggerConfig,
|
||||
ExecutionConfig
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
NODE_TYPE_MAPPING = {
|
||||
"start": NodeType.START,
|
||||
"llm": NodeType.LLM,
|
||||
"answer": NodeType.END,
|
||||
"if-else": NodeType.IF_ELSE,
|
||||
"loop-start": NodeType.CYCLE_START,
|
||||
"iteration-start": NodeType.CYCLE_START,
|
||||
"assigner": NodeType.ASSIGNER,
|
||||
"loop": NodeType.LOOP,
|
||||
"iteration": NodeType.ITERATION,
|
||||
"loop-end": NodeType.BREAK,
|
||||
"code": NodeType.CODE,
|
||||
"http-request": NodeType.HTTP_REQUEST,
|
||||
"template-transform": NodeType.JINJARENDER,
|
||||
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR
|
||||
}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
DifyConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.DIFY,
|
||||
version="0.5.0",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type)
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("graph").get("nodes")
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("graph").get("edges")
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "data" not in node:
|
||||
return False
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
for node in self.origin_nodes:
|
||||
node = self._convert_node(node)
|
||||
if node:
|
||||
self.nodes.append(node)
|
||||
nodes_id = [node.id for node in self.nodes]
|
||||
for edge in self.origin_edges:
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
if source not in nodes_id or target not in nodes_id:
|
||||
continue
|
||||
edge = self._convert_edge(edge)
|
||||
if edge:
|
||||
self.edges.append(edge)
|
||||
#
|
||||
for variable in self.config.get("workflow").get("conversation_variables"):
|
||||
con_var = self._convert_variable(variable)
|
||||
if variable:
|
||||
self.conv_variables.append(con_var)
|
||||
#
|
||||
# for variables in config.get("workflow").get("environment_variables"):
|
||||
# variable = self._convert_variable(variables)
|
||||
# conv_variables.append(variable)
|
||||
|
||||
trigger = self._convert_trigger({})
|
||||
execution_config = self._convert_execution({})
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=not self.errors and not self.warnings,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=execution_config,
|
||||
origin_config=self.config,
|
||||
trigger=trigger,
|
||||
edges=self.edges,
|
||||
nodes=self.nodes,
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors
|
||||
)
|
||||
|
||||
def _convert_cycle_node_position(self, node_id: str, position: dict):
|
||||
for node in self.origin_nodes:
|
||||
if node["id"] == node_id:
|
||||
return {
|
||||
"x": node["position"]["x"] + position["x"],
|
||||
"y": node["position"]["y"] + position["y"]
|
||||
}
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
detail="parent cycle node not found"
|
||||
)
|
||||
)
|
||||
raise Exception("parent cycle node not found")
|
||||
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_data = node["data"]
|
||||
try:
|
||||
return NodeDefinition(
|
||||
id=node["id"],
|
||||
type=self.map_node_type(node_data["type"]),
|
||||
name=node_data.get("title"),
|
||||
cycle=node.get("parentId"),
|
||||
description=None,
|
||||
config=self._convert_node_config(node),
|
||||
position={
|
||||
"x": node["position"]["x"],
|
||||
"y": node["position"]["y"]
|
||||
} if node.get("parentId") is None else self._convert_cycle_node_position(
|
||||
node["parentId"],
|
||||
node["position"]
|
||||
),
|
||||
error_handling=None,
|
||||
cache=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||
|
||||
def _convert_node_config(self, node: dict):
|
||||
node_data = node["data"]
|
||||
node_type = node_data["type"]
|
||||
try:
|
||||
converter = self.get_node_convert(node_type)
|
||||
if converter is None:
|
||||
raise Exception(f"node type not supported - {node_type}")
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
detail=f"convert node error - {e}",
|
||||
))
|
||||
raise e
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
||||
try:
|
||||
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
edge_id = edge["id"]
|
||||
label = None
|
||||
if source in self.branch_node_cache:
|
||||
case_id = "-".join(edge_id.split("-")[1:-2])
|
||||
if case_id == "false":
|
||||
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
||||
else:
|
||||
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||
if source in self.error_branch_node_cache:
|
||||
case_id = "-".join(edge_id.split("-")[1:-2])
|
||||
if case_id == "source":
|
||||
label = "SUCCESS"
|
||||
else:
|
||||
label = "ERROR"
|
||||
return EdgeDefinition(
|
||||
id=edge["id"],
|
||||
source=source,
|
||||
target=target,
|
||||
label=label,
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}",
|
||||
))
|
||||
return None
|
||||
|
||||
def _convert_variable(self, variable) -> VariableDefinition | None:
|
||||
try:
|
||||
return VariableDefinition(
|
||||
name=variable["name"],
|
||||
default=variable["value"],
|
||||
type=variable["value_type"],
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}",
|
||||
))
|
||||
|
||||
def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None:
|
||||
pass
|
||||
|
||||
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||
return ExecutionConfig()
|
||||
|
||||
|
||||
75
api/app/core/workflow/adapters/errors.py
Normal file
75
api/app/core/workflow/adapters/errors.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/26 11:29
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExceptionType(StrEnum):
|
||||
NODE = "node"
|
||||
EDGE = "edge"
|
||||
VARIABLE = "variable"
|
||||
TRIGGER = "trigger"
|
||||
EXECUTION = "execution"
|
||||
CONFIG = "config"
|
||||
PLATFORM = "platform"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ExceptionDefineition(BaseModel):
|
||||
type: ExceptionType
|
||||
detail: str
|
||||
|
||||
node_id: str | None = None
|
||||
node_name: str | None = None
|
||||
|
||||
scope: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class UnknowModelWarning(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id, node_name, model_name):
|
||||
super().__init__(
|
||||
detail=f"Please specify the model mapping manually for model: {model_name}",
|
||||
node_id=node_id,
|
||||
node_name=node_name
|
||||
)
|
||||
|
||||
|
||||
class UnknowError(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.UNKNOWN
|
||||
|
||||
def __init__(self, detail: str, **kwargs):
|
||||
super().__init__(detail=detail, **kwargs)
|
||||
|
||||
|
||||
class UnsupportPlatform(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.PLATFORM
|
||||
|
||||
def __init__(self, platform: str):
|
||||
super().__init__(detail=f"Unsupport platform {platform}")
|
||||
|
||||
|
||||
class UnsupportVariableType(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.VARIABLE
|
||||
|
||||
def __init__(self, scope, name, var_type: str, **kwargs):
|
||||
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
|
||||
|
||||
|
||||
class InvalidConfiguration(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.CONFIG
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(detail="Invalid workflow configuration format")
|
||||
|
||||
|
||||
class UnsupportNodeType(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id: str, node_type: str):
|
||||
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
|
||||
4
api/app/core/workflow/adapters/memory_bear/__init__.py
Normal file
4
api/app/core/workflow/adapters/memory_bear/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/26 11:30
|
||||
@@ -0,0 +1,76 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 14:11
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.schemas.workflow_schema import ExecutionConfig
|
||||
|
||||
|
||||
class MemoryBearAdapter(BasePlatformAdapter):
|
||||
NODE_TYPE_MAPPING = {}
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("nodes")
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("edges")
|
||||
|
||||
@property
|
||||
def origin_variables(self):
|
||||
return self.config.get("workflow").get("variables")
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.MEMORY_BEAR,
|
||||
version="0.2.5",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return platform_node_type
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self.nodes = self.origin_nodes
|
||||
self.edges = self.origin_edges
|
||||
self.conv_variables = self.origin_variables
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=True,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=ExecutionConfig(),
|
||||
origin_config=self.config,
|
||||
trigger=None,
|
||||
edges=self.edges,
|
||||
nodes=self.nodes,
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors,
|
||||
|
||||
)
|
||||
34
api/app/core/workflow/adapters/registry.py
Normal file
34
api/app/core/workflow/adapters/registry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 14:19
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter
|
||||
from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType
|
||||
|
||||
|
||||
class PlatformAdapterRegistry:
|
||||
_adapters: dict[str, type[BasePlatformAdapter]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, platform: str, adapter: type[BasePlatformAdapter]):
|
||||
cls._adapters[platform] = adapter
|
||||
|
||||
@classmethod
|
||||
def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter:
|
||||
if platform not in cls._adapters:
|
||||
raise ValueError(f"Unsupported platform: {platform}")
|
||||
return cls._adapters.get(platform)(config)
|
||||
|
||||
@classmethod
|
||||
def list_platforms(cls) -> list[str]:
|
||||
return list(cls._adapters.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, platform: str) -> bool:
|
||||
return platform in cls._adapters
|
||||
|
||||
|
||||
PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter)
|
||||
PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter)
|
||||
@@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -88,6 +88,8 @@ class AssignerNode(BaseNode):
|
||||
await operator.remove_first()
|
||||
case AssignmentOperator.REMOVE_LAST:
|
||||
await operator.remove_last()
|
||||
case AssignmentOperator.EXTEND:
|
||||
await operator.extend()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
@@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig):
|
||||
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="工作流的最终输出"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
# # 输出变量定义
|
||||
# output_variables: list[VariableDefinition] = Field(
|
||||
# default_factory=lambda: [
|
||||
# VariableDefinition(
|
||||
# name="output",
|
||||
# type=VariableType.STRING,
|
||||
# description="工作流的最终输出"
|
||||
# )
|
||||
# ],
|
||||
# description="输出变量定义(自动生成,通常不需要修改)"
|
||||
# )
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
|
||||
@@ -61,6 +61,7 @@ class AssignmentOperator(StrEnum):
|
||||
APPEND = "append"
|
||||
REMOVE_LAST = "remove_last"
|
||||
REMOVE_FIRST = "remove_first"
|
||||
EXTEND = "extend"
|
||||
|
||||
|
||||
class HttpRequestMethod(StrEnum):
|
||||
|
||||
@@ -236,5 +236,5 @@ class HttpRequestNode(BaseNode):
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
return {"output": "ERROR"}
|
||||
raise RuntimeError("http request failed")
|
||||
|
||||
@@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
)
|
||||
|
||||
knowledge_bases: list[KnowledgeBaseConfig] = Field(
|
||||
...,
|
||||
default_factory=list,
|
||||
description="Knowledge base config"
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class StartNodeConfig(BaseNodeConfig):
|
||||
@@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig):
|
||||
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="message",
|
||||
type=VariableType.STRING,
|
||||
description="用户输入的消息"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_vars",
|
||||
type=VariableType.OBJECT,
|
||||
description="会话级变量"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="execution_id",
|
||||
type=VariableType.STRING,
|
||||
description="执行 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_id",
|
||||
type=VariableType.STRING,
|
||||
description="会话 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="workspace_id",
|
||||
type=VariableType.STRING,
|
||||
description="工作空间 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="user_id",
|
||||
type=VariableType.STRING,
|
||||
description="用户 ID"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
# # 输出变量定义
|
||||
# output_variables: list[VariableDefinition] = Field(
|
||||
# default_factory=lambda: [
|
||||
# VariableDefinition(
|
||||
# name="message",
|
||||
# type=VariableType.STRING,
|
||||
# description="用户输入的消息"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="conversation_vars",
|
||||
# type=VariableType.OBJECT,
|
||||
# description="会话级变量"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="execution_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="执行 ID"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="conversation_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="会话 ID"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="workspace_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="工作空间 ID"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="user_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="用户 ID"
|
||||
# )
|
||||
# ],
|
||||
# description="输出变量定义(自动生成,通常不需要修改)"
|
||||
# )
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
|
||||
@@ -5,6 +5,8 @@ from enum import Enum, StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
|
||||
|
||||
# ---------- Multimodal File Support ----------
|
||||
|
||||
@@ -196,6 +198,8 @@ class AppCreate(BaseModel):
|
||||
# only for type=multi_agent
|
||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
workflow_config: Optional[WorkflowConfigCreate] = None
|
||||
|
||||
|
||||
class AppUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
|
||||
@@ -18,7 +18,10 @@ class NodeConfig(BaseModel):
|
||||
class NodeDefinition(BaseModel):
|
||||
"""节点定义"""
|
||||
id: str = Field(..., description="节点唯一标识")
|
||||
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
|
||||
type: str = Field(
|
||||
...,
|
||||
description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code"
|
||||
)
|
||||
name: str | None = Field(None, description="节点名称")
|
||||
cycle: str | None = Field(None, description="父循环节点id")
|
||||
description: str | None = Field(None, description="节点描述")
|
||||
@@ -30,12 +33,12 @@ class NodeDefinition(BaseModel):
|
||||
|
||||
class EdgeDefinition(BaseModel):
|
||||
"""边定义"""
|
||||
id: str | None = Field(None, description="边唯一标识(可选)")
|
||||
id: str | None = Field(default=None, description="边唯一标识(可选)")
|
||||
source: str = Field(..., description="源节点 ID")
|
||||
target: str = Field(..., description="目标节点 ID")
|
||||
type: str | None = Field(None, description="边类型: normal, error")
|
||||
condition: str | None = Field(None, description="条件表达式(条件边)")
|
||||
label: str | None = Field(None, description="边标签")
|
||||
type: str | None = Field(default=None, description="边类型: normal, error")
|
||||
condition: str | None = Field(default=None, description="条件表达式(条件边)")
|
||||
label: str | None = Field(default=None, description="边标签")
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
@@ -44,7 +47,7 @@ class VariableDefinition(BaseModel):
|
||||
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
|
||||
required: bool = Field(default=False, description="是否必填")
|
||||
default: Any = Field(None, description="默认值")
|
||||
description: str | None = Field(None, description="变量描述")
|
||||
description: str | None = Field(default=None, description="变量描述")
|
||||
|
||||
|
||||
class ExecutionConfig(BaseModel):
|
||||
@@ -61,6 +64,13 @@ class TriggerConfig(BaseModel):
|
||||
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
|
||||
|
||||
|
||||
class WorkflowImportSave(BaseModel):
|
||||
"""工作流导入请求"""
|
||||
temp_id: str
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
|
||||
class WorkflowConfigCreate(BaseModel):
|
||||
@@ -84,7 +94,7 @@ class WorkflowConfigUpdate(BaseModel):
|
||||
class WorkflowConfig(BaseModel):
|
||||
"""工作流配置输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
nodes: list[dict[str, Any]]
|
||||
@@ -95,11 +105,11 @@ class WorkflowConfig(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -123,7 +133,8 @@ class WorkflowExecutionResponse(BaseModel):
|
||||
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
|
||||
error_message: str | None = Field(None, description="错误信息")
|
||||
elapsed_time: float | None = Field(None, description="耗时(秒)")
|
||||
token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
|
||||
token_usage: dict[str, Any] | None = Field(None,
|
||||
description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
|
||||
|
||||
|
||||
class WorkflowExecutionStreamChunk(BaseModel):
|
||||
@@ -136,7 +147,7 @@ class WorkflowExecutionStreamChunk(BaseModel):
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""工作流执行记录输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
workflow_config_id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
@@ -156,15 +167,15 @@ class WorkflowExecution(BaseModel):
|
||||
token_usage: dict[str, Any] | None
|
||||
meta_data: dict[str, Any]
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
@field_serializer("started_at", when_used="json")
|
||||
def _serialize_started_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("completed_at", when_used="json")
|
||||
def _serialize_completed_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -173,7 +184,7 @@ class WorkflowExecution(BaseModel):
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
"""工作流节点执行记录输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
execution_id: uuid.UUID
|
||||
node_id: str
|
||||
@@ -193,15 +204,15 @@ class WorkflowNodeExecution(BaseModel):
|
||||
cache_key: str | None
|
||||
meta_data: dict[str, Any]
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
@field_serializer("started_at", when_used="json")
|
||||
def _serialize_started_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("completed_at", when_used="json")
|
||||
def _serialize_completed_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@@ -321,6 +321,26 @@ class AppService:
|
||||
self.db.add(agent_cfg)
|
||||
logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)})
|
||||
|
||||
def _create_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
data: app_schema.WorkflowConfigCreate,
|
||||
now: datetime.datetime
|
||||
):
|
||||
workflow_cfg = WorkflowConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=app_id,
|
||||
nodes=[node.model_dump() for node in data.nodes] if data.nodes else [],
|
||||
edges=[edge.model_dump() for edge in data.edges] if data.edges else [],
|
||||
variables=[var.model_dump() for var in data.variables] if data.variables else [],
|
||||
execution_config=data.execution_config.model_dump() if data.execution_config else {},
|
||||
triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
self.db.add(workflow_cfg)
|
||||
|
||||
def _create_multi_agent_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -532,6 +552,9 @@ class AppService:
|
||||
if app.type == "multi_agent" and data.multi_agent_config:
|
||||
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
|
||||
|
||||
if app.type == "workflow" and data.workflow_config:
|
||||
self._create_workflow_config(app.id, data.workflow_config, now)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(app)
|
||||
|
||||
@@ -968,7 +991,7 @@ class AppService:
|
||||
config = self.db.scalars(stmt).first()
|
||||
|
||||
try:
|
||||
config_memory=config.memory
|
||||
config_memory = config.memory
|
||||
if 'memory_content' in config_memory:
|
||||
config.memory['memory_config_id'] = config.memory.pop('memory_content')
|
||||
except:
|
||||
@@ -1189,9 +1212,9 @@ class AppService:
|
||||
# ==================== 记忆配置提取方法 ====================
|
||||
|
||||
def _extract_memory_config_id(
|
||||
self,
|
||||
app_type: str,
|
||||
config: Dict[str, Any]
|
||||
self,
|
||||
app_type: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Tuple[Optional[uuid.UUID], bool]:
|
||||
"""从发布配置中提取 memory_config_id(委托给 MemoryConfigService)
|
||||
|
||||
@@ -1205,13 +1228,13 @@ class AppService:
|
||||
- is_legacy_int: 是否检测到旧格式 int 数据,需要回退到工作空间默认配置
|
||||
"""
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
service = MemoryConfigService(self.db)
|
||||
return service.extract_memory_config_id(app_type, config)
|
||||
|
||||
def _get_workspace_default_memory_config_id(
|
||||
self,
|
||||
workspace_id: uuid.UUID
|
||||
self,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Optional[uuid.UUID]:
|
||||
"""获取工作空间的默认记忆配置ID
|
||||
|
||||
@@ -1222,22 +1245,22 @@ class AppService:
|
||||
Optional[uuid.UUID]: 默认记忆配置ID,如果不存在则返回 None
|
||||
"""
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
service = MemoryConfigService(self.db)
|
||||
config = service.get_workspace_default_config(workspace_id)
|
||||
|
||||
|
||||
if not config:
|
||||
logger.warning(
|
||||
f"工作空间没有可用的记忆配置: workspace_id={workspace_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
return config.config_id
|
||||
|
||||
def _update_endusers_memory_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
@@ -1249,13 +1272,13 @@ class AppService:
|
||||
int: 更新的终端用户数量
|
||||
"""
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
|
||||
repo = EndUserRepository(self.db)
|
||||
updated_count = repo.batch_update_memory_config_id(
|
||||
app_id=app_id,
|
||||
memory_config_id=memory_config_id
|
||||
)
|
||||
|
||||
|
||||
return updated_count
|
||||
|
||||
# ==================== 应用发布管理 ====================
|
||||
@@ -1403,7 +1426,7 @@ class AppService:
|
||||
|
||||
# 提取记忆配置ID并更新终端用户
|
||||
memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config)
|
||||
|
||||
|
||||
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
|
||||
if is_legacy_int and not memory_config_id:
|
||||
memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id)
|
||||
@@ -1412,7 +1435,7 @@ class AppService:
|
||||
f"发布时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, "
|
||||
f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}"
|
||||
)
|
||||
|
||||
|
||||
if memory_config_id:
|
||||
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
|
||||
logger.info(
|
||||
@@ -1537,7 +1560,7 @@ class AppService:
|
||||
|
||||
# 提取记忆配置ID并更新终端用户
|
||||
memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config)
|
||||
|
||||
|
||||
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
|
||||
if is_legacy_int and not memory_config_id:
|
||||
memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id)
|
||||
@@ -1546,7 +1569,7 @@ class AppService:
|
||||
f"回滚时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, "
|
||||
f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}"
|
||||
)
|
||||
|
||||
|
||||
if memory_config_id:
|
||||
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
|
||||
logger.info(
|
||||
|
||||
102
api/app/services/workflow_import_service.py
Normal file
102
api/app/services/workflow_import_service.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 14:39
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.aioRedis import aio_redis_set, aio_redis_get
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
|
||||
from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration
|
||||
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.schemas import AppCreate
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
from app.services.app_service import AppService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class WorkflowImportService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.registry = PlatformAdapterRegistry
|
||||
self.cache_timeout = settings.WORKFLOW_IMPORT_CACHE_TIMEOUT
|
||||
|
||||
self.app_service = AppService(db)
|
||||
self.workflow_service = WorkflowService(db)
|
||||
|
||||
async def flush_config(self, temp_id: str, config: WorkflowParserResult):
|
||||
config_cache = await aio_redis_get(temp_id)
|
||||
if not config_cache:
|
||||
raise BusinessException("Workflow configuration has expired. Please re-upload it.")
|
||||
await aio_redis_set(temp_id, config.model_dump_json(), expire=self.cache_timeout)
|
||||
|
||||
async def upload_config(
|
||||
self,
|
||||
platform: str,
|
||||
config: dict[str, Any],
|
||||
):
|
||||
|
||||
if not self.registry.is_supported(platform):
|
||||
return WorkflowImportResult(
|
||||
success=False,
|
||||
temp_id=None,
|
||||
workflow_id=None,
|
||||
errors=[UnsupportPlatform(platform=platform)]
|
||||
)
|
||||
|
||||
adapter = self.registry.get_adapter(platform, config)
|
||||
|
||||
if not adapter.validate_config():
|
||||
return WorkflowImportResult(
|
||||
success=False,
|
||||
temp_id=None,
|
||||
workflow_id=None,
|
||||
errors=[InvalidConfiguration()]
|
||||
)
|
||||
|
||||
workflow_config = adapter.parse_workflow()
|
||||
temp_id = uuid.uuid4().hex
|
||||
await aio_redis_set(temp_id, workflow_config.model_dump(), expire=self.cache_timeout)
|
||||
return WorkflowImportResult(
|
||||
success=True,
|
||||
temp_id=temp_id,
|
||||
workflow_id=None,
|
||||
edges=workflow_config.edges,
|
||||
nodes=workflow_config.nodes,
|
||||
variables=workflow_config.variables,
|
||||
warnings=workflow_config.warnings,
|
||||
errors=workflow_config.errors
|
||||
)
|
||||
|
||||
async def save_workflow(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
temp_id: str,
|
||||
name: str,
|
||||
description: str | None,
|
||||
):
|
||||
config = await aio_redis_get(temp_id)
|
||||
if config is None:
|
||||
raise BusinessException("Configuration import timed out. Please try again.")
|
||||
config = json.loads(config)
|
||||
app = self.app_service.create_app(
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
data=AppCreate(
|
||||
name=name,
|
||||
description=description,
|
||||
type="workflow",
|
||||
workflow_config=WorkflowConfigCreate(
|
||||
nodes=config["nodes"],
|
||||
edges=config["edges"],
|
||||
variables=config["variables"]
|
||||
)
|
||||
)
|
||||
)
|
||||
return app
|
||||
@@ -6,13 +6,16 @@ import logging
|
||||
import uuid
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
import yaml
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
@@ -38,6 +41,8 @@ class WorkflowService:
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.multimodal_service = MultimodalService(db)
|
||||
|
||||
self.registry = PlatformAdapterRegistry
|
||||
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
def create_workflow_config(
|
||||
@@ -200,6 +205,32 @@ class WorkflowService:
|
||||
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||
return True
|
||||
|
||||
def export_workflow_dsl(self, app_id: uuid.UUID):
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
app: App = config.app
|
||||
dsl_info = {
|
||||
"app": {
|
||||
"name": app.name,
|
||||
"description": app.description,
|
||||
"icon": app.icon,
|
||||
"icon_type": app.icon_type
|
||||
},
|
||||
"workflow": {
|
||||
"variables": config.variables,
|
||||
"edges": config.edges,
|
||||
"nodes": config.nodes,
|
||||
"execution_config": config.execution_config,
|
||||
"triggers": config.triggers
|
||||
}
|
||||
}
|
||||
return yaml.dump(dsl_info, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
|
||||
"""检查工作流配置的完整性
|
||||
|
||||
|
||||
Reference in New Issue
Block a user