feat(workflow): add Dify workflow import adapter and related APIs

This commit is contained in:
Eternity
2026-02-28 10:29:14 +08:00
parent e9ff742162
commit 9916cf3265
25 changed files with 1625 additions and 124 deletions

View File

@@ -10,7 +10,6 @@ from app.core.config import settings
# 设置日志记录器
logger = logging.getLogger(__name__)
# 创建连接池
pool = ConnectionPool.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
@@ -21,6 +20,7 @@ pool = ConnectionPool.from_url(
)
aio_redis = redis.StrictRedis(connection_pool=pool)
async def get_redis_connection():
"""获取Redis连接"""
try:
@@ -29,7 +29,8 @@ async def get_redis_connection():
logger.error(f"Redis连接失败: {str(e)}")
return None
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
"""设置Redis键值
Args:
@@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
try:
if isinstance(val, dict):
val = json.dumps(val, ensure_ascii=False)
if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire)
@@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
except Exception as e:
logger.error(f"Redis set错误: {str(e)}")
async def aio_redis_get(key: str):
"""获取Redis键值"""
try:
@@ -58,6 +60,7 @@ async def aio_redis_get(key: str):
logger.error(f"Redis get错误: {str(e)}")
return None
async def aio_redis_delete(key: str):
"""删除Redis键"""
try:
@@ -66,6 +69,7 @@ async def aio_redis_delete(key: str):
logger.error(f"Redis delete错误: {str(e)}")
return None
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
"""发布消息到Redis频道"""
try:
@@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
logger.error(f"Redis发布错误: {str(e)}")
return False
class RedisSubscriber:
"""Redis订阅器"""
def __init__(self, channel: str):
self.channel = channel
self.conn = None
@@ -88,25 +93,25 @@ class RedisSubscriber:
self.is_closed = False
self._queue = asyncio.Queue()
self._task = None
async def start(self):
"""开始订阅"""
if self.is_closed or self._task:
return
self._task = asyncio.create_task(self._receive_messages())
logger.info(f"开始订阅: {self.channel}")
async def _receive_messages(self):
"""接收消息"""
try:
self.conn = await get_redis_connection()
if not self.conn:
return
self.pubsub = self.conn.pubsub()
await self.pubsub.subscribe(self.channel)
while not self.is_closed:
try:
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
@@ -127,7 +132,7 @@ class RedisSubscriber:
finally:
await self._queue.put(None)
await self._cleanup()
async def _cleanup(self):
"""清理资源"""
if self.pubsub:
@@ -141,7 +146,7 @@ class RedisSubscriber:
await self.conn.close()
except Exception:
pass
async def get_message(self) -> Optional[Dict[str, Any]]:
"""获取消息"""
if self.is_closed:
@@ -153,7 +158,7 @@ class RedisSubscriber:
except Exception as e:
logger.error(f"获取消息错误: {str(e)}")
return None
async def close(self):
"""关闭订阅器"""
if self.is_closed:
@@ -163,32 +168,33 @@ class RedisSubscriber:
self._task.cancel()
await self._cleanup()
class RedisPubSubManager:
"""Redis发布订阅管理器"""
def __init__(self):
self.subscribers = {}
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
return await aio_redis_publish(channel, message)
def get_subscriber(self, channel: str) -> RedisSubscriber:
if channel in self.subscribers:
subscriber = self.subscribers[channel]
if not subscriber.is_closed:
return subscriber
subscriber = RedisSubscriber(channel)
self.subscribers[channel] = subscriber
return subscriber
def cancel_subscription(self, channel: str) -> bool:
if channel in self.subscribers:
asyncio.create_task(self.subscribers[channel].close())
del self.subscribers[channel]
return True
return False
def cancel_all_subscriptions(self) -> int:
count = len(self.subscribers)
for subscriber in self.subscribers.values():
@@ -196,6 +202,6 @@ class RedisPubSubManager:
self.subscribers.clear()
return count
# 全局实例
pubsub_manager = RedisPubSubManager()

View File

@@ -1,7 +1,8 @@
import uuid
from typing import Optional, Annotated
from fastapi import APIRouter, Depends, Path
import yaml
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
@@ -17,12 +18,13 @@ from app.repositories.end_user_repository import EndUserRepository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
from app.services import app_service, workspace_service
from app.services.agent_config_helper import enrich_agent_config
from app.services.app_service import AppService
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_statistics_service import AppStatisticsService
from app.services.workflow_import_service import WorkflowImportService
from app.services.workflow_service import WorkflowService, get_workflow_service
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@@ -65,7 +67,7 @@ def list_apps(
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
@@ -879,6 +881,60 @@ async def update_workflow_config(
return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.get("/{app_id}/workflow/export")
@cur_workspace_access_guard()
async def export_workflow_config(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
"""导出工作流配置为YAML文件"""
workflow_service = WorkflowService(db)
return success(data={
"content": workflow_service.export_workflow_dsl(app_id=app_id),
})
@router.post("/workflow/import")
@cur_workspace_access_guard()
async def import_workflow_config(
file: UploadFile = File(...),
platform: str = Form(...),
app_id: str = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从YAML内容导入工作流配置"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
raw_text = (await file.read()).decode("utf-8")
import_service = WorkflowImportService(db)
config = yaml.safe_load(raw_text)
result = await import_service.upload_config(platform, config)
return success(data=result)
@router.post("/workflow/import/save")
@cur_workspace_access_guard()
async def save_workflow_import(
data: WorkflowImportSave,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
import_service = WorkflowImportService(db)
app = await import_service.save_workflow(
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
temp_id=data.temp_id,
name=data.name,
description=data.description,
)
return success(data=app_schema.App.model_validate(app))
@router.get("/{app_id}/statistics", summary="应用统计数据")
@cur_workspace_access_guard()
def get_app_statistics(
@@ -889,12 +945,14 @@ def get_app_statistics(
current_user=Depends(get_current_user),
):
"""获取应用统计数据
Args:
app_id: 应用ID
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
- daily_conversations: 每日会话数统计
- total_conversations: 总会话数
@@ -931,6 +989,8 @@ def get_workspace_api_statistics(
Args:
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
每日统计数据列表,每项包含:

View File

@@ -16,18 +16,18 @@ class Settings:
# cloud: SaaS 云服务版(全功能,按量计费)
# enterprise: 企业私有化版License 控制)
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
# License 配置(企业版)
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
# 计费服务配置SaaS 版)
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
# 基础 URL用于 SSO 回调等)
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
@@ -57,7 +57,6 @@ class Settings:
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
# ElasticSearch configuration
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
@@ -91,7 +90,7 @@ class Settings:
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# SSO 免登配置
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
@@ -130,7 +129,7 @@ class Settings:
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
@@ -225,6 +224,7 @@ class Settings:
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
# workflow config
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
# ========================================================================
@@ -232,20 +232,20 @@ class Settings:
# ========================================================================
# 通用本体文件路径列表(逗号分隔)
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
# 是否启用通用本体类型功能
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
# Prompt 中最大类型数量
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
# 核心通用类型列表(逗号分隔)
CORE_GENERAL_TYPES: str = os.getenv(
"CORE_GENERAL_TYPES",
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
)
# 实验模式开关(允许通过 API 动态切换本体配置)
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"

View File

@@ -0,0 +1,8 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/24 15:54
from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
__all__ = ["DifyAdapter", "MemoryBearAdapter"]

View File

@@ -0,0 +1,88 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/24 15:58
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from app.core.workflow.adapters.errors import ExceptionDefineition
from app.schemas.workflow_schema import (
EdgeDefinition,
NodeDefinition,
VariableDefinition,
ExecutionConfig,
TriggerConfig
)
class PlatformType(StrEnum):
MEMORY_BEAR = "memory_bear"
DIFY = "dify"
COZE = "coze"
class PlatformMetadata(BaseModel):
platform_name: str
version: str
support_node_types: list[str]
class WorkflowParserResult(BaseModel):
success: bool
platform: PlatformMetadata
execution_config: ExecutionConfig
origin_config: dict[str, Any]
trigger: TriggerConfig | None
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list)
class WorkflowImportResult(BaseModel):
success: bool
temp_id: str | None = Field(..., description="cache id")
workflow_id: str | None = Field(..., description="workflow id")
edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list)
class BasePlatformAdapter(ABC):
def __init__(self, config: dict[str, Any]):
self.config = config
self.nodes: list[NodeDefinition] = []
self.edges: list[EdgeDefinition] = []
self.conv_variables: list[VariableDefinition] = []
self.errors = []
self.warnings = []
self.branch_node_cache = defaultdict(list)
self.error_branch_node_cache = []
@abstractmethod
def get_metadata(self) -> PlatformMetadata:
"""get platform metadata"""
pass
@abstractmethod
def validate_config(self) -> bool:
"""platform configuration validate"""
pass
@abstractmethod
def parse_workflow(self) -> WorkflowParserResult:
"""parse platform configuration to local config"""
pass
@abstractmethod
def map_node_type(self, platform_node_type: str) -> str:
pass

View File

@@ -0,0 +1,75 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/26 14:32
from abc import ABC, abstractmethod
from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType
class BaseConverter(ABC):
@staticmethod
def _convert_string(var):
try:
return str(var)
except:
return DEFAULT_VALUE(VariableType.STRING)
@staticmethod
def _convert_boolean(var):
try:
return bool(var)
except:
return DEFAULT_VALUE(VariableType.BOOLEAN)
@staticmethod
def _convert_number(var):
try:
return float(var)
except:
return DEFAULT_VALUE(VariableType.NUMBER)
@staticmethod
def _convert_object(var):
try:
return dict(var)
except:
return DEFAULT_VALUE(VariableType.OBJECT)
@staticmethod
@abstractmethod
def _convert_file(var):
pass
@staticmethod
def _convert_array_string(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_STRING)
@staticmethod
def _convert_array_number(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_NUMBER)
@staticmethod
def _convert_array_boolean(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN)
@staticmethod
def _convert_array_object(var):
try:
return list(var)
except:
return DEFAULT_VALUE(VariableType.ARRAY_OBJECT)
@staticmethod
@abstractmethod
def _convert_array_file(var):
pass

View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 18:20

View File

@@ -0,0 +1,659 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 18:21
import base64
import re
from typing import Any
from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
ExceptionType
from app.core.workflow.nodes.assigner import AssignerNodeConfig
from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition
from app.core.workflow.nodes.code import CodeNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
CycleVariable
from app.core.workflow.nodes.end import EndNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
HttpContentType, HttpErrorHandle
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
class DifyConverter(BaseConverter):
errors: list
warnings: list
branch_node_cache: dict
error_branch_node_cache: list
def __init__(self):
self.CONFIG_CONVERT_MAP = {
"start": self.convert_start_node_config,
"llm": self.convert_llm_node_config,
"answer": self.convert_end_node_config,
"if-else": self.convert_if_else_node_config,
"loop": self.convert_loop_node_config,
"iteration": self.convert_iteration_node_config,
"assigner": self.convert_assigner_node_config,
"code": self.convert_code_node_config,
"http-request": self.convert_http_node_config,
"template-transform": self.convert_jinja_render_node_config,
"knowledge-retrieval": self.convert_knowledge_node_config,
"parameter-extractor": self.convert_parameter_extractor_node_config,
"question-classifier": self.convert_question_classifier_node_config,
"variable-aggregator": self.convert_variable_aggregator,
"loop-start": lambda x: {},
"iteration-start": lambda x: {},
"loop-end": lambda x: {},
}
def get_node_convert(self, node_type):
func = self.CONFIG_CONVERT_MAP.get(node_type, None)
return func
@staticmethod
def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression))
@staticmethod
def process_var_selector(var_selector):
if not var_selector:
return ""
selector = var_selector.split('.')
if len(selector) != 2:
raise Exception(f"invalid variable selector: {var_selector}")
if selector[0] == "conversation":
selector[0] = "conv"
var_selector = ".".join(selector)
mapping = {
"sys.query": "sys.message"
}
var_selector = mapping.get(var_selector, var_selector)
return var_selector
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
if not self.process_var_selector(".".join(variable_selector)):
return None
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
def trans_variable_format(self, content):
pattern = re.compile(r"\{\{#(.*?)#}}")
def replacer(match: re.Match) -> str:
raw_name = match.group(1)
new_name = self.process_var_selector(raw_name)
return f"{{{{{new_name}}}}}"
return pattern.sub(replacer, content)
@staticmethod
def _convert_file(var):
pass
@staticmethod
def _convert_array_file(var):
pass
@staticmethod
def variable_type_map(source_type) -> VariableType | None:
type_map = {
"file": VariableType.FILE,
"paragraph": VariableType.STRING,
"text-input": VariableType.STRING,
"number": VariableType.NUMBER,
"checkbox": VariableType.BOOLEAN,
"file-list": VariableType.ARRAY_FILE,
"select": VariableType.STRING,
}
var_type = type_map.get(source_type, source_type)
return var_type
def convert_variable_type(self, target_type: VariableType, origin_value: Any):
if not origin_value:
return DEFAULT_VALUE(target_type)
try:
match target_type:
case VariableType.STRING:
return self._convert_string(origin_value)
case VariableType.NUMBER:
return self._convert_number(origin_value)
case VariableType.BOOLEAN:
return self._convert_boolean(origin_value)
case VariableType.FILE:
return self._convert_file(origin_value)
case VariableType.ARRAY_FILE:
return self._convert_array_file(origin_value)
case _:
return origin_value
except:
raise Exception(f"convert variable failed: {target_type}")
@staticmethod
def convert_compare_operator(operator):
operator_map = {
"is": ComparisonOperator.EQ,
"is not": ComparisonOperator.NE,
"=": ComparisonOperator.EQ,
"": ComparisonOperator.NE,
">": ComparisonOperator.GT,
"<": ComparisonOperator.LT,
"": ComparisonOperator.GE,
"": ComparisonOperator.LE,
"not empty": ComparisonOperator.NOT_EMPTY,
}
return operator_map.get(operator, operator)
@staticmethod
def convert_assignment_operator(operator):
operator_map = {
"+=": AssignmentOperator.ADD,
"-=": AssignmentOperator.SUBTRACT,
"*=": AssignmentOperator.MULTIPLY,
"/=": AssignmentOperator.DIVIDE,
"over-write": AssignmentOperator.COVER,
"remove-last": AssignmentOperator.REMOVE_LAST,
"remove-first": AssignmentOperator.REMOVE_FIRST,
}
return operator_map.get(operator, operator)
@staticmethod
def convert_http_auth_type(auth_type):
auth_type_map = {
"no-auth": HttpAuthType.NONE,
"bearer": HttpAuthType.BEARER,
"basic": HttpAuthType.BASIC,
"custom": HttpAuthType.CUSTOM,
}
return auth_type_map.get(auth_type, auth_type)
@staticmethod
def convert_http_content_type(content_type):
content_type_map = {
"none": HttpContentType.NONE,
"form-data": HttpContentType.FROM_DATA,
"x-www-form-urlencoded": HttpContentType.WWW_FORM,
"json": HttpContentType.JSON,
"raw-text": HttpContentType.RAW,
"binary": HttpContentType.BINARY,
}
return content_type_map.get(content_type, content_type)
@staticmethod
def convert_http_error_handle_type(handle_type):
handle_type_map = {
"none": HttpErrorHandle.NONE,
"fail-branch": HttpErrorHandle.BRANCH,
"default-value": HttpErrorHandle.DEFAULT,
}
return handle_type_map.get(handle_type, handle_type)
def convert_start_node_config(self, node: dict) -> dict:
node_data = node["data"]
start_vars = []
for var in node_data["variables"]:
var_type = self.variable_type_map(var["type"])
if not var_type:
self.errors.append(
UnsupportVariableType(
scope=node["id"],
name=var["variable"],
var_type=var["type"],
node_id=node["id"],
node_name=node_data["title"]
)
)
continue
if var_type in ["file", "array[file]"]:
self.errors.append(
ExceptionDefineition(
type=ExceptionType.VARIABLE,
node_id=node["id"],
node_name=node_data["title"],
name=var["variable"],
detail=f"Unsupport Variable type for start node: {var_type}"
)
)
continue
var_def = VariableDefinition(
name=var["variable"],
type=var_type,
required=var["required"],
default=self.convert_variable_type(
var_type, var["default"]
),
description=var["label"],
max_length=var.get("max_length"),
)
start_vars.append(var_def)
return StartNodeConfig(
variables=start_vars
).model_dump()
def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
UnknowModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
)
)
categories = []
for category in node_data["classes"]:
self.branch_node_cache[node["id"]].append(category["id"])
categories.append(
ClassifierConfig(
class_name=category["name"],
)
)
return QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]),
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]),
categories=categories,
).model_dump()
def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
UnknowModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
)
)
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
memory = MemoryWindowSetting(
enable=bool(node_data.get("memory")),
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20))
)
messages = []
for message in node_data["prompt_template"]:
messages.append(
MessageConfig(
role=message["role"],
content=self.trans_variable_format(message["text"])
)
)
if memory.enable:
messages.append(
MessageConfig(
role="user",
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"])
)
)
vision = node_data["vision"]["enabled"]
vision_input = self._process_list_variable_litearl(
node_data["vision"]["configs"]["variable_selector"]
) if vision else None
return LLMNodeConfig.model_construct(
model_id=None,
context=context,
memory=memory,
vision=vision,
vision_input=vision_input,
messages=messages
).model_dump()
def convert_end_node_config(self, node: dict) -> dict:
node_data = node["data"]
return EndNodeConfig(
output=self.trans_variable_format(node_data["answer"]),
).model_dump()
def convert_if_else_node_config(self, node: dict) -> dict:
node_data = node["data"]
cases = []
for case in node_data["cases"]:
case_id = case["id"]
logical_operator = case["logical_operator"]
conditions = []
for condition in case["conditions"]:
right_value = condition["value"]
condition_detail = ConditionDetail(
operator=self.convert_compare_operator(condition["comparison_operator"]),
left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}",
right=self.trans_variable_format(
right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
self.variable_type_map(condition["varType"]),
condition["value"]
),
input_type=ValueInputType.VARIABLE
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
)
conditions.append(condition_detail)
cases.append(
ConditionBranchConfig(
logical_operator=logical_operator,
expressions=conditions
)
)
self.branch_node_cache[node["id"]].append(case_id)
return IfElseNodeConfig(
cases=cases
).model_dump()
def convert_loop_node_config(self, node: dict) -> dict:
node_data = node["data"]
logical_operator = node_data["logical_operator"]
conditions = []
for condition in node_data["break_conditions"]:
right_value = condition["value"]
conditions.append(
LoopConditionDetail(
operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]),
right=self.trans_variable_format(
right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
self.variable_type_map(condition["varType"]),
condition["value"]
),
input_type=ValueInputType.VARIABLE
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
)
)
condition_config = ConditionsConfig(
logical_operator=logical_operator,
expressions=conditions
)
loop_variables = []
for variable in node_data["loop_variables"]:
right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable["value"])
else:
right_value = self.convert_variable_type(right_value_type, variable["value"])
loop_variables.append(
CycleVariable(
name=variable["label"],
type=right_value_type,
value=right_value,
input_type=right_input_type
)
)
return LoopNodeConfig(
condition=condition_config,
cycle_vars=loop_variables,
max_loop=node_data["loop_count"]
).model_dump()
def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"]
return IterationNodeConfig(
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]),
output_type=self.variable_type_map(node_data["output_type"]),
flatten=node_data["flatten_output"],
).model_dump()
def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"]
assignments = []
for assignment in node_data["items"]:
if assignment.get("operation") is None or assignment.get("value") is None:
continue
assignments.append(
AssignmentItem(
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
value=self._process_list_variable_litearl(
assignment["value"]
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
operation=self.convert_assignment_operator(assignment["operation"])
)
)
return AssignerNodeConfig(
assignments=assignments
).model_dump()
def convert_code_node_config(self, node: dict) -> dict:
node_data = node["data"]
input_variables = []
for input_variable in node_data["variables"]:
input_variables.append(
InputVariable(
name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
)
)
output_variables = []
for output_variable in node_data["outputs"]:
output_variables.append(
OutputVariable(
name=output_variable,
type=node_data["outputs"][output_variable]["type"],
)
)
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
return CodeNodeConfig(
input_variables=input_variables,
language=node_data["code_language"],
output_variables=output_variables,
code=code
).model_dump()
def convert_http_node_config(self, node: dict) -> dict:
node_data = node["data"]
if node_data["authorization"] != 'no-auth':
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
auth_config = HttpAuthConfig(
auth_type=auth_type,
header=node_data["authorization"]["config"].get("header"),
api_key=node_data["authorization"]["config"].get("api_key"),
)
else:
auth_config = HttpAuthConfig()
content_type = self.convert_http_content_type(node_data["body"]["type"])
if content_type == HttpContentType.FROM_DATA:
body_content = []
for content in node_data["body"]["data"]:
body_content.append(
HttpFormData(
key=self.trans_variable_format(content["key"]),
type=content["type"],
value=self.trans_variable_format(content["value"]),
)
)
elif content_type == HttpContentType.WWW_FORM:
body_content = {}
for content in node_data["body"]["data"]:
body_content[
self.trans_variable_format(content["key"])
] = self.trans_variable_format(content["value"])
else:
if node_data["body"]["data"]:
body_content = node_data["body"]["data"][0]["value"]
else:
body_content = ""
headers = {}
for header in node_data["headers"].split("\n"):
if not header:
continue
key_value = header.split(":")
if len(key_value) == 2:
headers[
self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1])
else:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node["id"],
node_name=node_data["title"],
detail=f"Invalid header/param - {header}",
))
params = {}
for param in node_data["params"].split("\n"):
if not param:
continue
key_value = param.split(":")
if len(key_value) == 2:
params[
self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1])
else:
self.warnings.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node["id"],
node_name=node_data["title"],
detail=f"Invalid header/param - {param}",
))
error_handle_type = self.convert_http_error_handle_type(
node_data.get("error_strategy", "none")
)
default_value = None
if error_handle_type == HttpErrorHandle.DEFAULT:
default_body = ""
default_header = {}
default_status_code = 0
for var in node_data["default_value"]:
if var["key"] == "body":
default_body = var["value"]
elif var["key"] == "header":
default_header = var["value"]
elif var["key"] == "status_code":
default_status_code = var["value"]
default_value = HttpErrorDefaultTamplete(
body=default_body,
headers=default_header,
status_code=default_status_code,
)
self.error_branch_node_cache.append(node['id'])
return HttpRequestNodeConfig(
method=node_data["method"].upper(),
url=node_data["url"],
auth=auth_config,
body=HttpContentTypeConfig(
content_type=self.convert_http_content_type(node_data["body"]["type"]),
data=body_content,
),
headers=headers,
params=params,
verify_ssl=node_data["ssl_verify"],
timeouts=HttpTimeOutConfig(
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
),
retry=HttpRetryConfig(
enable=node_data["retry_config"]["retry_enabled"],
max_attempts=node_data["retry_config"]["max_retries"],
retry_interval=node_data["retry_config"]["retry_interval"],
),
error_handle=HttpErrorHandleConfig(
method=error_handle_type,
default=default_value,
)
).model_dump()
def convert_jinja_render_node_config(self, node: dict) -> dict:
node_data = node["data"]
mapping = []
for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig(
name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"])
))
return JinjaRenderNodeConfig(
template=node_data["template"],
mapping=mapping,
).model_dump()
def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.",
))
return KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
).model_dump()
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
UnknowModelWarning(
node_id=node["id"],
node_name=node_data["title"],
model_name=node_data["model"].get("name")
)
)
params = []
for param in node_data["parameters"]:
params.append(
ParamsConfig(
name=param["name"],
desc=param["description"],
required=param["required"],
type=param["type"],
)
)
return ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]),
params=params,
prompt=node_data["instruction"]
).model_dump()
def convert_variable_aggregator(self, node: dict) -> dict:
node_data = node["data"]
group_enable = node_data["advanced_settings"]["group_enabled"]
group_variables = {}
group_type = {}
if not group_enable:
group_variables["output"] = [
self._process_list_variable_litearl(variable)
for variable in node_data["variables"]
]
group_type["output"] = node_data["output_type"]
else:
for group in node_data["advanced_settings"]["groups"]:
group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable)
for variable in group["variables"]
]
group_type[group["group_name"]] = group["output_type"]
return VariableAggregatorNodeConfig(
group=group_enable,
group_variables=group_variables,
group_type=group_type,
).model_dump()

View File

@@ -0,0 +1,239 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/24 16:05
from typing import Any
from app.core.logging_config import get_logger
from app.core.workflow.adapters.base_adapter import (
BasePlatformAdapter,
PlatformMetadata,
PlatformType,
WorkflowParserResult
)
from app.core.workflow.adapters.dify.converter import DifyConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import (
NodeDefinition,
EdgeDefinition,
VariableDefinition,
TriggerConfig,
ExecutionConfig
)
logger = get_logger()
class DifyAdapter(BasePlatformAdapter, DifyConverter):
NODE_TYPE_MAPPING = {
"start": NodeType.START,
"llm": NodeType.LLM,
"answer": NodeType.END,
"if-else": NodeType.IF_ELSE,
"loop-start": NodeType.CYCLE_START,
"iteration-start": NodeType.CYCLE_START,
"assigner": NodeType.ASSIGNER,
"loop": NodeType.LOOP,
"iteration": NodeType.ITERATION,
"loop-end": NodeType.BREAK,
"code": NodeType.CODE,
"http-request": NodeType.HTTP_REQUEST,
"template-transform": NodeType.JINJARENDER,
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR
}
def __init__(self, config: dict[str, Any]):
DifyConverter.__init__(self)
BasePlatformAdapter.__init__(self, config)
def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata(
platform_name=PlatformType.DIFY,
version="0.5.0",
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
)
def map_node_type(self, platform_node_type) -> str:
return self.NODE_TYPE_MAPPING.get(platform_node_type)
@property
def origin_nodes(self):
return self.config.get("workflow").get("graph").get("nodes")
@property
def origin_edges(self):
return self.config.get("workflow").get("graph").get("edges")
@staticmethod
def _valid_nodes(node: dict[str, Any]):
if "data" not in node:
return False
if "type" not in node["data"]:
return False
if "id" not in node or "type" not in node:
return False
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'})
if not all(field in self.config for field in require_fields):
return False
for node in self.origin_nodes:
if not self._valid_nodes(node):
return False
return True
def parse_workflow(self) -> WorkflowParserResult:
for node in self.origin_nodes:
node = self._convert_node(node)
if node:
self.nodes.append(node)
nodes_id = [node.id for node in self.nodes]
for edge in self.origin_edges:
source = edge["source"]
target = edge["target"]
if source not in nodes_id or target not in nodes_id:
continue
edge = self._convert_edge(edge)
if edge:
self.edges.append(edge)
#
for variable in self.config.get("workflow").get("conversation_variables"):
con_var = self._convert_variable(variable)
if variable:
self.conv_variables.append(con_var)
#
# for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables)
# conv_variables.append(variable)
trigger = self._convert_trigger({})
execution_config = self._convert_execution({})
return WorkflowParserResult(
success=not self.errors and not self.warnings,
platform=self.get_metadata(),
execution_config=execution_config,
origin_config=self.config,
trigger=trigger,
edges=self.edges,
nodes=self.nodes,
variables=self.conv_variables,
warnings=self.warnings,
errors=self.errors
)
def _convert_cycle_node_position(self, node_id: str, position: dict):
for node in self.origin_nodes:
if node["id"] == node_id:
return {
"x": node["position"]["x"] + position["x"],
"y": node["position"]["y"] + position["y"]
}
self.errors.append(
ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node_id,
detail="parent cycle node not found"
)
)
raise Exception("parent cycle node not found")
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
node_data = node["data"]
try:
return NodeDefinition(
id=node["id"],
type=self.map_node_type(node_data["type"]),
name=node_data.get("title"),
cycle=node.get("parentId"),
description=None,
config=self._convert_node_config(node),
position={
"x": node["position"]["x"],
"y": node["position"]["y"]
} if node.get("parentId") is None else self._convert_cycle_node_position(
node["parentId"],
node["position"]
),
error_handling=None,
cache=None
)
except Exception as e:
logger.debug(f"convert node error - {e}", exc_info=True)
def _convert_node_config(self, node: dict):
node_data = node["data"]
node_type = node_data["type"]
try:
converter = self.get_node_convert(node_type)
if converter is None:
raise Exception(f"node type not supported - {node_type}")
return converter(node)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
detail=f"convert node error - {e}",
))
raise e
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
try:
source = edge["source"]
target = edge["target"]
edge_id = edge["id"]
label = None
if source in self.branch_node_cache:
case_id = "-".join(edge_id.split("-")[1:-2])
if case_id == "false":
label = f'CASE{len(self.branch_node_cache[source])+1}'
else:
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
if source in self.error_branch_node_cache:
case_id = "-".join(edge_id.split("-")[1:-2])
if case_id == "source":
label = "SUCCESS"
else:
label = "ERROR"
return EdgeDefinition(
id=edge["id"],
source=source,
target=target,
label=label,
)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}",
))
return None
def _convert_variable(self, variable) -> VariableDefinition | None:
try:
return VariableDefinition(
name=variable["name"],
default=variable["value"],
type=variable["value_type"],
)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.VARIABLE,
name=variable.get("name"),
detail=f"convert variable error - {e}",
))
def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None:
pass
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
return ExecutionConfig()

View File

@@ -0,0 +1,75 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/26 11:29
from enum import StrEnum
from pydantic import BaseModel
class ExceptionType(StrEnum):
NODE = "node"
EDGE = "edge"
VARIABLE = "variable"
TRIGGER = "trigger"
EXECUTION = "execution"
CONFIG = "config"
PLATFORM = "platform"
UNKNOWN = "unknown"
class ExceptionDefineition(BaseModel):
type: ExceptionType
detail: str
node_id: str | None = None
node_name: str | None = None
scope: str | None = None
name: str | None = None
class UnknowModelWarning(ExceptionDefineition):
type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id, node_name, model_name):
super().__init__(
detail=f"Please specify the model mapping manually for model: {model_name}",
node_id=node_id,
node_name=node_name
)
class UnknowError(ExceptionDefineition):
type: ExceptionType = ExceptionType.UNKNOWN
def __init__(self, detail: str, **kwargs):
super().__init__(detail=detail, **kwargs)
class UnsupportPlatform(ExceptionDefineition):
type: ExceptionType = ExceptionType.PLATFORM
def __init__(self, platform: str):
super().__init__(detail=f"Unsupport platform {platform}")
class UnsupportVariableType(ExceptionDefineition):
type: ExceptionType = ExceptionType.VARIABLE
def __init__(self, scope, name, var_type: str, **kwargs):
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type[{var_type}]", **kwargs)
class InvalidConfiguration(ExceptionDefineition):
type: ExceptionType = ExceptionType.CONFIG
def __init__(self):
super().__init__(detail="Invalid workflow configuration format")
class UnsupportNodeType(ExceptionDefineition):
type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id: str, node_type: str):
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")

View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/26 11:30

View File

@@ -0,0 +1,76 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 14:11
from typing import Any
from app.core.workflow.adapters.base_adapter import (
PlatformMetadata,
PlatformType,
BasePlatformAdapter,
WorkflowParserResult
)
from app.schemas.workflow_schema import ExecutionConfig
class MemoryBearAdapter(BasePlatformAdapter):
NODE_TYPE_MAPPING = {}
@property
def origin_nodes(self):
return self.config.get("workflow").get("nodes")
@property
def origin_edges(self):
return self.config.get("workflow").get("edges")
@property
def origin_variables(self):
return self.config.get("workflow").get("variables")
def get_metadata(self) -> PlatformMetadata:
return PlatformMetadata(
platform_name=PlatformType.MEMORY_BEAR,
version="0.2.5",
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
)
def map_node_type(self, platform_node_type) -> str:
return platform_node_type
@staticmethod
def _valid_nodes(node: dict[str, Any]):
if "type" not in node["data"]:
return False
if "id" not in node or "type" not in node:
return False
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'workflow'})
if not all(field in self.config for field in require_fields):
return False
for node in self.origin_nodes:
if not self._valid_nodes(node):
return False
return True
def parse_workflow(self) -> WorkflowParserResult:
self.nodes = self.origin_nodes
self.edges = self.origin_edges
self.conv_variables = self.origin_variables
return WorkflowParserResult(
success=True,
platform=self.get_metadata(),
execution_config=ExecutionConfig(),
origin_config=self.config,
trigger=None,
edges=self.edges,
nodes=self.nodes,
variables=self.conv_variables,
warnings=self.warnings,
errors=self.errors,
)

View File

@@ -0,0 +1,34 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 14:19
from typing import Any
from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter
from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType
class PlatformAdapterRegistry:
_adapters: dict[str, type[BasePlatformAdapter]] = {}
@classmethod
def register(cls, platform: str, adapter: type[BasePlatformAdapter]):
cls._adapters[platform] = adapter
@classmethod
def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter:
if platform not in cls._adapters:
raise ValueError(f"Unsupported platform: {platform}")
return cls._adapters.get(platform)(config)
@classmethod
def list_platforms(cls) -> list[str]:
return list(cls._adapters.keys())
@classmethod
def is_supported(cls, platform: str) -> bool:
return platform in cls._adapters
PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter)
PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter)

View File

@@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
logger = get_logger(__name__)
SCOPE_PATTERN = re.compile(
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
)

View File

@@ -88,6 +88,8 @@ class AssignerNode(BaseNode):
await operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
await operator.remove_last()
case AssignmentOperator.EXTEND:
await operator.extend()
case _:
raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig):
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="工作流的最终输出"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
# # 输出变量定义
# output_variables: list[VariableDefinition] = Field(
# default_factory=lambda: [
# VariableDefinition(
# name="output",
# type=VariableType.STRING,
# description="工作流的最终输出"
# )
# ],
# description="输出变量定义(自动生成,通常不需要修改)"
# )
class Config:
json_schema_extra = {

View File

@@ -61,6 +61,7 @@ class AssignmentOperator(StrEnum):
APPEND = "append"
REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first"
EXTEND = "extend"
class HttpRequestMethod(StrEnum):

View File

@@ -236,5 +236,5 @@ class HttpRequestNode(BaseNode):
logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR"
return {"output": "ERROR"}
raise RuntimeError("http request failed")

View File

@@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
)
knowledge_bases: list[KnowledgeBaseConfig] = Field(
...,
default_factory=list,
description="Knowledge base config"
)

View File

@@ -3,7 +3,6 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class StartNodeConfig(BaseNodeConfig):
@@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig):
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="message",
type=VariableType.STRING,
description="用户输入的消息"
),
VariableDefinition(
name="conversation_vars",
type=VariableType.OBJECT,
description="会话级变量"
),
VariableDefinition(
name="execution_id",
type=VariableType.STRING,
description="执行 ID"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="workspace_id",
type=VariableType.STRING,
description="工作空间 ID"
),
VariableDefinition(
name="user_id",
type=VariableType.STRING,
description="用户 ID"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
# # 输出变量定义
# output_variables: list[VariableDefinition] = Field(
# default_factory=lambda: [
# VariableDefinition(
# name="message",
# type=VariableType.STRING,
# description="用户输入的消息"
# ),
# VariableDefinition(
# name="conversation_vars",
# type=VariableType.OBJECT,
# description="会话级变量"
# ),
# VariableDefinition(
# name="execution_id",
# type=VariableType.STRING,
# description="执行 ID"
# ),
# VariableDefinition(
# name="conversation_id",
# type=VariableType.STRING,
# description="会话 ID"
# ),
# VariableDefinition(
# name="workspace_id",
# type=VariableType.STRING,
# description="工作空间 ID"
# ),
# VariableDefinition(
# name="user_id",
# type=VariableType.STRING,
# description="用户 ID"
# )
# ],
# description="输出变量定义(自动生成,通常不需要修改)"
# )
class Config:
json_schema_extra = {

View File

@@ -5,6 +5,8 @@ from enum import Enum, StrEnum
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
from app.schemas.workflow_schema import WorkflowConfigCreate
# ---------- Multimodal File Support ----------
@@ -196,6 +198,8 @@ class AppCreate(BaseModel):
# only for type=multi_agent
multi_agent_config: Optional[Dict[str, Any]] = None
workflow_config: Optional[WorkflowConfigCreate] = None
class AppUpdate(BaseModel):
name: Optional[str] = None

View File

@@ -18,7 +18,10 @@ class NodeConfig(BaseModel):
class NodeDefinition(BaseModel):
"""节点定义"""
id: str = Field(..., description="节点唯一标识")
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
type: str = Field(
...,
description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code"
)
name: str | None = Field(None, description="节点名称")
cycle: str | None = Field(None, description="父循环节点id")
description: str | None = Field(None, description="节点描述")
@@ -30,12 +33,12 @@ class NodeDefinition(BaseModel):
class EdgeDefinition(BaseModel):
"""边定义"""
id: str | None = Field(None, description="边唯一标识(可选)")
id: str | None = Field(default=None, description="边唯一标识(可选)")
source: str = Field(..., description="源节点 ID")
target: str = Field(..., description="目标节点 ID")
type: str | None = Field(None, description="边类型: normal, error")
condition: str | None = Field(None, description="条件表达式(条件边)")
label: str | None = Field(None, description="边标签")
type: str | None = Field(default=None, description="边类型: normal, error")
condition: str | None = Field(default=None, description="条件表达式(条件边)")
label: str | None = Field(default=None, description="边标签")
class VariableDefinition(BaseModel):
@@ -44,7 +47,7 @@ class VariableDefinition(BaseModel):
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
required: bool = Field(default=False, description="是否必填")
default: Any = Field(None, description="默认值")
description: str | None = Field(None, description="变量描述")
description: str | None = Field(default=None, description="变量描述")
class ExecutionConfig(BaseModel):
@@ -61,6 +64,13 @@ class TriggerConfig(BaseModel):
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
class WorkflowImportSave(BaseModel):
"""工作流导入请求"""
temp_id: str
name: str
description: str
# ==================== 工作流配置 ====================
class WorkflowConfigCreate(BaseModel):
@@ -84,7 +94,7 @@ class WorkflowConfigUpdate(BaseModel):
class WorkflowConfig(BaseModel):
"""工作流配置输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
app_id: uuid.UUID
nodes: list[dict[str, Any]]
@@ -95,11 +105,11 @@ class WorkflowConfig(BaseModel):
is_active: bool
created_at: datetime.datetime
updated_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@@ -123,7 +133,8 @@ class WorkflowExecutionResponse(BaseModel):
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
error_message: str | None = Field(None, description="错误信息")
elapsed_time: float | None = Field(None, description="耗时(秒)")
token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
token_usage: dict[str, Any] | None = Field(None,
description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
class WorkflowExecutionStreamChunk(BaseModel):
@@ -136,7 +147,7 @@ class WorkflowExecutionStreamChunk(BaseModel):
class WorkflowExecution(BaseModel):
"""工作流执行记录输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
workflow_config_id: uuid.UUID
app_id: uuid.UUID
@@ -156,15 +167,15 @@ class WorkflowExecution(BaseModel):
token_usage: dict[str, Any] | None
meta_data: dict[str, Any]
created_at: datetime.datetime
@field_serializer("started_at", when_used="json")
def _serialize_started_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("completed_at", when_used="json")
def _serialize_completed_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@@ -173,7 +184,7 @@ class WorkflowExecution(BaseModel):
class WorkflowNodeExecution(BaseModel):
"""工作流节点执行记录输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
execution_id: uuid.UUID
node_id: str
@@ -193,15 +204,15 @@ class WorkflowNodeExecution(BaseModel):
cache_key: str | None
meta_data: dict[str, Any]
created_at: datetime.datetime
@field_serializer("started_at", when_used="json")
def _serialize_started_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("completed_at", when_used="json")
def _serialize_completed_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None

View File

@@ -321,6 +321,26 @@ class AppService:
self.db.add(agent_cfg)
logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)})
def _create_workflow_config(
self,
app_id: uuid.UUID,
data: app_schema.WorkflowConfigCreate,
now: datetime.datetime
):
workflow_cfg = WorkflowConfig(
id=uuid.uuid4(),
app_id=app_id,
nodes=[node.model_dump() for node in data.nodes] if data.nodes else [],
edges=[edge.model_dump() for edge in data.edges] if data.edges else [],
variables=[var.model_dump() for var in data.variables] if data.variables else [],
execution_config=data.execution_config.model_dump() if data.execution_config else {},
triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [],
is_active=True,
created_at=now,
updated_at=now
)
self.db.add(workflow_cfg)
def _create_multi_agent_config(
self,
app_id: uuid.UUID,
@@ -532,6 +552,9 @@ class AppService:
if app.type == "multi_agent" and data.multi_agent_config:
self._create_multi_agent_config(app.id, data.multi_agent_config, now)
if app.type == "workflow" and data.workflow_config:
self._create_workflow_config(app.id, data.workflow_config, now)
self.db.commit()
self.db.refresh(app)
@@ -968,7 +991,7 @@ class AppService:
config = self.db.scalars(stmt).first()
try:
config_memory=config.memory
config_memory = config.memory
if 'memory_content' in config_memory:
config.memory['memory_config_id'] = config.memory.pop('memory_content')
except:
@@ -1189,9 +1212,9 @@ class AppService:
# ==================== 记忆配置提取方法 ====================
def _extract_memory_config_id(
self,
app_type: str,
config: Dict[str, Any]
self,
app_type: str,
config: Dict[str, Any]
) -> Tuple[Optional[uuid.UUID], bool]:
"""从发布配置中提取 memory_config_id委托给 MemoryConfigService
@@ -1205,13 +1228,13 @@ class AppService:
- is_legacy_int: 是否检测到旧格式 int 数据,需要回退到工作空间默认配置
"""
from app.services.memory_config_service import MemoryConfigService
service = MemoryConfigService(self.db)
return service.extract_memory_config_id(app_type, config)
def _get_workspace_default_memory_config_id(
self,
workspace_id: uuid.UUID
self,
workspace_id: uuid.UUID
) -> Optional[uuid.UUID]:
"""获取工作空间的默认记忆配置ID
@@ -1222,22 +1245,22 @@ class AppService:
Optional[uuid.UUID]: 默认记忆配置ID如果不存在则返回 None
"""
from app.services.memory_config_service import MemoryConfigService
service = MemoryConfigService(self.db)
config = service.get_workspace_default_config(workspace_id)
if not config:
logger.warning(
f"工作空间没有可用的记忆配置: workspace_id={workspace_id}"
)
return None
return config.config_id
def _update_endusers_memory_config(
self,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
self,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
@@ -1249,13 +1272,13 @@ class AppService:
int: 更新的终端用户数量
"""
from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(self.db)
updated_count = repo.batch_update_memory_config_id(
app_id=app_id,
memory_config_id=memory_config_id
)
return updated_count
# ==================== 应用发布管理 ====================
@@ -1403,7 +1426,7 @@ class AppService:
# 提取记忆配置ID并更新终端用户
memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config)
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
if is_legacy_int and not memory_config_id:
memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id)
@@ -1412,7 +1435,7 @@ class AppService:
f"发布时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, "
f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}"
)
if memory_config_id:
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
logger.info(
@@ -1537,7 +1560,7 @@ class AppService:
# 提取记忆配置ID并更新终端用户
memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config)
# 如果检测到旧格式 int 数据,回退到工作空间默认配置
if is_legacy_int and not memory_config_id:
memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id)
@@ -1546,7 +1569,7 @@ class AppService:
f"回滚时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, "
f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}"
)
if memory_config_id:
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
logger.info(

View File

@@ -0,0 +1,102 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/25 14:39
import json
import uuid
from typing import Any
from sqlalchemy.orm import Session
from app.aioRedis import aio_redis_set, aio_redis_get
from app.core.config import settings
from app.core.exceptions import BusinessException
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.schemas import AppCreate
from app.schemas.workflow_schema import WorkflowConfigCreate
from app.services.app_service import AppService
from app.services.workflow_service import WorkflowService
class WorkflowImportService:
def __init__(self, db: Session):
self.db = db
self.registry = PlatformAdapterRegistry
self.cache_timeout = settings.WORKFLOW_IMPORT_CACHE_TIMEOUT
self.app_service = AppService(db)
self.workflow_service = WorkflowService(db)
async def flush_config(self, temp_id: str, config: WorkflowParserResult):
config_cache = await aio_redis_get(temp_id)
if not config_cache:
raise BusinessException("Workflow configuration has expired. Please re-upload it.")
await aio_redis_set(temp_id, config.model_dump_json(), expire=self.cache_timeout)
async def upload_config(
self,
platform: str,
config: dict[str, Any],
):
if not self.registry.is_supported(platform):
return WorkflowImportResult(
success=False,
temp_id=None,
workflow_id=None,
errors=[UnsupportPlatform(platform=platform)]
)
adapter = self.registry.get_adapter(platform, config)
if not adapter.validate_config():
return WorkflowImportResult(
success=False,
temp_id=None,
workflow_id=None,
errors=[InvalidConfiguration()]
)
workflow_config = adapter.parse_workflow()
temp_id = uuid.uuid4().hex
await aio_redis_set(temp_id, workflow_config.model_dump(), expire=self.cache_timeout)
return WorkflowImportResult(
success=True,
temp_id=temp_id,
workflow_id=None,
edges=workflow_config.edges,
nodes=workflow_config.nodes,
variables=workflow_config.variables,
warnings=workflow_config.warnings,
errors=workflow_config.errors
)
async def save_workflow(
self,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
temp_id: str,
name: str,
description: str | None,
):
config = await aio_redis_get(temp_id)
if config is None:
raise BusinessException("Configuration import timed out. Please try again.")
config = json.loads(config)
app = self.app_service.create_app(
user_id=user_id,
workspace_id=workspace_id,
data=AppCreate(
name=name,
description=description,
type="workflow",
workflow_config=WorkflowConfigCreate(
nodes=config["nodes"],
edges=config["edges"],
variables=config["variables"]
)
)
)
return app

View File

@@ -6,13 +6,16 @@ import logging
import uuid
from typing import Any, Annotated, Optional
import yaml
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
@@ -38,6 +41,8 @@ class WorkflowService:
self.conversation_service = ConversationService(db)
self.multimodal_service = MultimodalService(db)
self.registry = PlatformAdapterRegistry
# ==================== 配置管理 ====================
def create_workflow_config(
@@ -200,6 +205,32 @@ class WorkflowService:
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
return True
def export_workflow_dsl(self, app_id: uuid.UUID):
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
app: App = config.app
dsl_info = {
"app": {
"name": app.name,
"description": app.description,
"icon": app.icon,
"icon_type": app.icon_type
},
"workflow": {
"variables": config.variables,
"edges": config.edges,
"nodes": config.nodes,
"execution_config": config.execution_config,
"triggers": config.triggers
}
}
return yaml.dump(dsl_info, default_flow_style=False, allow_unicode=True)
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
"""检查工作流配置的完整性