Merge branch 'develop' into fix/memoryconfig-update

This commit is contained in:
Ke Sun
2026-03-26 14:31:40 +08:00
63 changed files with 1527 additions and 416 deletions

View File

@@ -574,8 +574,12 @@ async def get_file_url(
# For local storage, generate signed URL with expiration # For local storage, generate signed URL with expiration
url = generate_signed_url(str(file_id), expires) url = generate_signed_url(str(file_id), expires)
else: else:
# For remote storage (OSS/S3), get presigned URL # For remote storage (OSS/S3), get presigned URL with forced download
url = await storage_service.get_file_url(file_key, expires=expires) url = await storage_service.get_file_url(
file_key,
expires=expires,
file_name=file_metadata.file_name,
)
url = _match_scheme(request, url) url = _match_scheme(request, url)
api_logger.info(f"Generated file URL: file_id={file_id}") api_logger.info(f"Generated file URL: file_id={file_id}")
@@ -786,7 +790,7 @@ async def permanent_download_file(
# For remote storage, redirect to presigned URL with long expiration # For remote storage, redirect to presigned URL with long expiration
try: try:
# Use a very long expiration (7 days max for most cloud providers) # Use a very long expiration (7 days max for most cloud providers)
presigned_url = await storage_service.get_file_url(file_key, expires=604800) presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
presigned_url = _match_scheme(request, presigned_url) presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e: except Exception as e:

View File

@@ -91,9 +91,11 @@ async def get_mcp_servers(
try: try:
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers=api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=api.builder_headers(api.headers), headers=headers,
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -173,6 +175,7 @@ async def get_operational_mcp_servers(
url = f'{api.mcp_base_url}/operational' url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
try: try:
cookies = api.get_cookies(access_token=token, cookies_required=True) cookies = api.get_cookies(access_token=token, cookies_required=True)
@@ -260,7 +263,9 @@ async def create_mcp_market_config(
api.login(create_data.token) api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token) cookies = api.get_cookies(create_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {create_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
@@ -290,9 +295,11 @@ async def create_mcp_market_config(
'search': "" 'search': ""
} }
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=api.builder_headers(api.headers), headers=headers,
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -393,7 +400,9 @@ async def update_mcp_market_config(
api.login(update_data.token) api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token) cookies = api.get_cookies(update_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {update_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")

View File

@@ -669,6 +669,7 @@ async def config_query(
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": release.config.get("variables"), "variables": release.config.get("variables"),
"memory": release.config.get("memory", {}).get("enabled"),
"features": release.config.get("features") "features": release.config.get("features")
} }
elif release.app.type == AppType.MULTI_AGENT: elif release.app.type == AppType.MULTI_AGENT:

View File

@@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope): elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J') logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages) formatted_messages = redis_messages
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'): if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id config_id = memory_config.config_id

View File

@@ -2,6 +2,7 @@
OpenAI Embedder 客户端实现 OpenAI Embedder 客户端实现
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
自动支持火山引擎的多模态 Embedding。
""" """
from typing import List from typing import List
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
) )
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings from app.core.models.embedding import RedBearEmbeddings
from app.models.models_model import ModelProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
- 批量文本嵌入 - 批量文本嵌入
- 自动重试机制 - 自动重试机制
- 错误处理 - 错误处理
- 火山引擎多模态 Embedding自动识别
""" """
def __init__(self, model_config: RedBearModelConfig): def __init__(self, model_config: RedBearModelConfig):
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
""" """
super().__init__(model_config) super().__init__(model_config)
# 初始化 RedBearEmbeddings 模型 # 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
self.model = RedBearEmbeddings( self.model = RedBearEmbeddings(
RedBearModelConfig( RedBearModelConfig(
model_name=self.model_name, model_name=self.model_name,
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
timeout=self.timeout, timeout=self.timeout,
) )
) )
self.is_multimodal = self.model.is_multimodal_supported()
logger.info("OpenAI Embedder 客户端初始化完成") logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
async def response( async def response(
self, self,
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
return [] return []
# 生成嵌入向量 # 生成嵌入向量
embeddings = await self.model.aembed_documents(texts) if self.is_multimodal:
# 火山引擎多模态 Embedding
embeddings = await self.model.aembed_multimodal(
[{"type": "text", "text": text} for text in texts]
)
else:
# 普通 Embedding
embeddings = await self.model.aembed_documents(texts)
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
return embeddings return embeddings

View File

@@ -1099,7 +1099,6 @@ class ExtractionOrchestrator:
metadata=chunk.metadata, metadata=chunk.metadata,
) )
chunk_nodes.append(chunk_node) chunk_nodes.append(chunk_node)
logger.error(f"chunk file: {chunk.files}")
for p, file_type in chunk.files: for p, file_type in chunk.files:

View File

@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
from .llm import RedBearLLM from .llm import RedBearLLM
from .embedding import RedBearEmbeddings from .embedding import RedBearEmbeddings
from .rerank import RedBearRerank from .rerank import RedBearRerank
from .generation import RedBearImageGenerator, RedBearVideoGenerator
__all__ = [ __all__ = [
"RedBearModelConfig", "RedBearModelConfig",
@@ -9,5 +10,7 @@ __all__ = [
"RedBearEmbeddings", "RedBearEmbeddings",
"RedBearRerank", "RedBearRerank",
"RedBearModelFactory", "RedBearModelFactory",
"get_provider_llm_class" "get_provider_llm_class",
"RedBearImageGenerator",
"RedBearVideoGenerator"
] ]

View File

@@ -67,7 +67,7 @@ class RedBearModelFactory:
**config.extra_params **config.extra_params
} }
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置 # 使用 httpx.Timeout 对象来设置详细的超时配置
# 这样可以分别控制连接超时和读取超时 # 这样可以分别控制连接超时和读取超时
import httpx import httpx
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式 # dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni: if provider == ModelProvider.DASHSCOPE and config.is_omni:
return ChatOpenAI return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
if type == ModelType.LLM: if type == ModelType.LLM:
return OpenAI return OpenAI
elif type == ModelType.CHAT: elif type == ModelType.CHAT:
return ChatOpenAI return ChatOpenAI
else:
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
return ChatTongyi return ChatTongyi
elif provider == ModelProvider.OLLAMA: elif provider == ModelProvider.OLLAMA:

View File

@@ -1,23 +1,190 @@
from typing import Any, Dict, List, Optional, TypeVar, Callable from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
from app.models.models_model import ModelProvider
class RedBearEmbeddings(Embeddings): class RedBearEmbeddings(Embeddings):
"""Embedding → 完全符合 LangChain Embeddings""" """统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
def __init__(self, config: RedBearModelConfig): def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config self._config = config
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
if self._is_volcano:
# 火山引擎使用 Ark SDK
self._client = self._create_volcano_client(config)
self._model = None
else:
# 其他 provider 使用 LangChain
self._model = self._create_model(config)
self._client = None
def _create_model(self, config: RedBearModelConfig) -> Embeddings: def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建模型""" """根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider) embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config) model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params) return embedding_class(**model_params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""
from volcenginesdkarkruntime import Ark
return Ark(api_key=config.api_key, base_url=config.base_url)
# ==================== LangChain 标准接口 ====================
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._model.embed_documents(texts) """批量文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
contents = [{"type": "text", "text": text} for text in texts]
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
encoding_format="float"
)
return [response.data.embedding]
else:
# 其他 provider
return self._model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
return self._model.embed_query(text) """单个文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
result = self.embed_documents([text])
return result[0] if result else []
else:
# 其他 provider
return self._model.embed_query(text)
# ==================== 多模态扩展方法 ====================
def embed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""
多模态向量化(仅火山引擎支持)
Args:
contents: 内容列表,格式:
- 文本: {"type": "text", "text": "..."}
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
**kwargs: 其他参数
Returns:
向量列表
"""
if not self._is_volcano:
raise NotImplementedError(
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
)
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
**kwargs
)
return [response.data.embedding]
async def aembed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""异步多模态向量化"""
# 火山引擎 SDK 暂不支持异步,使用同步方法
return self.embed_multimodal(contents, **kwargs)
def embed_text(self, text: str, **kwargs) -> List[float]:
"""文本向量化(便捷方法)"""
if self._is_volcano:
result = self.embed_multimodal(
[{"type": "text", "text": text}],
**kwargs
)
return result[0] if result else []
else:
return self.embed_query(text)
def embed_image(self, image_url: str, **kwargs) -> List[float]:
"""图片向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "image_url", "image_url": {"url": image_url}}],
**kwargs
)
return result[0] if result else []
def embed_video(self, video_url: str, **kwargs) -> List[float]:
"""视频向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "video_url", "video_url": {"url": video_url}}],
**kwargs
)
return result[0] if result else []
def embed_batch(
self,
items: List[Union[str, Dict[str, Any]]],
**kwargs
) -> List[List[float]]:
"""
批量向量化(支持混合类型)
Args:
items: 可以是字符串列表或内容字典列表
**kwargs: 其他参数
Returns:
向量列表
"""
# 如果全是字符串,使用标准方法
if all(isinstance(item, str) for item in items):
return self.embed_documents(items)
# 如果包含字典,需要多模态支持
if not self._is_volcano:
raise NotImplementedError(
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
# 标准化输入格式
contents = []
for item in items:
if isinstance(item, str):
contents.append({"type": "text", "text": item})
elif isinstance(item, dict):
contents.append(item)
else:
raise ValueError(f"不支持的输入类型: {type(item)}")
return self.embed_multimodal(contents, **kwargs)
# ==================== 工具方法 ====================
def is_multimodal_supported(self) -> bool:
"""检查是否支持多模态"""
return self._is_volcano
def get_provider(self) -> str:
"""获取 provider"""
return self._config.provider
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
RedBearMultimodalEmbeddings = RedBearEmbeddings

View File

@@ -0,0 +1,344 @@
"""
图片和视频生成模型封装
支持的 Provider:
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
- OpenAI: 使用 openai SDK
"""
from typing import Any, Dict, Optional
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.images.images import (
SequentialImageGenerationOptions,
ContentGenerationTool,
OptimizePromptOptions
)
from app.core.models.base import RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelProvider
class RedBearImageGenerator:
"""图片生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
# elif provider == ModelProvider.OPENAI:
# from openai import OpenAI
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的图片生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
sequential_image_generation: Optional[str] = None,
sequential_image_generation_options: Optional[Dict] = None,
tools: Optional[list] = None,
optimize_prompt_options: Optional[Dict] = None,
stream: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
prompt: 提示词
image: 参考图片URL或URL列表图文生图/多图融合)
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080"至少3686400像素
output_format: 输出格式,如 "png", "jpg"
response_format: 返回格式,"url""b64_json"
watermark: 是否添加水印
sequential_image_generation: 组图生成模式,"auto""disabled"
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
stream: 是否使用流式生成
**kwargs: 其他参数
Returns:
生成结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {
"model": self._config.model_name,
"prompt": prompt,
"size": size,
"output_format": output_format,
"response_format": response_format,
"watermark": watermark,
}
if image is not None:
params["image"] = image
if sequential_image_generation:
params["sequential_image_generation"] = sequential_image_generation
if sequential_image_generation_options:
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
**sequential_image_generation_options
)
if tools:
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
if optimize_prompt_options:
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
if stream:
params["stream"] = True
params.update(kwargs)
response = self._client.images.generate(**params)
# elif provider == ModelProvider.OPENAI:
# response = self._client.images.generate(
# model=self._config.model_name,
# prompt=prompt,
# size=size,
# n=n,
# **kwargs
# )
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
**kwargs
) -> Dict[str, Any]:
"""异步生成图片"""
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
class RedBearVideoGenerator:
"""视频生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的视频生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image_url: Optional[str] = None,
first_frame_url: Optional[str] = None,
last_frame_url: Optional[str] = None,
reference_images: Optional[list] = None,
draft_task_id: Optional[str] = None,
duration: Optional[int] = None,
frames: Optional[int] = None,
ratio: Optional[str] = None,
resolution: Optional[str] = None,
generate_audio: bool = False,
watermark: bool = False,
camera_fixed: bool = False,
seed: Optional[int] = None,
return_last_frame: bool = False,
service_tier: str = "default",
execution_expires_after: Optional[int] = None,
draft: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
prompt: 提示词
image_url: 首帧图片URL图生视频-基于首帧)
first_frame_url: 首帧图片URL图生视频-基于首尾帧)
last_frame_url: 尾帧图片URL图生视频-基于首尾帧)
reference_images: 参考图片URL列表图生视频-基于参考图)
draft_task_id: Draft任务ID基于Draft生成正式视频
duration: 视频时长与frames二选一
frames: 视频帧数与duration二选一
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
resolution: 视频分辨率,如 "720p", "1080p"
generate_audio: 是否生成音频
watermark: 是否添加水印
camera_fixed: 是否固定镜头
seed: 随机种子
return_last_frame: 是否返回最后一帧
service_tier: 服务层级,"default""flex"(离线推理)
execution_expires_after: 任务过期时间(秒)
draft: 是否生成样片
**kwargs: 其他参数
Returns:
生成结果包含任务ID需要轮询获取结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
content = [{"type": "text", "text": prompt}]
if draft_task_id:
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
else:
if image_url:
content.append({"type": "image_url", "image_url": {"url": image_url}})
if first_frame_url:
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
if last_frame_url:
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
if reference_images:
for ref_url in reference_images:
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
if duration:
params["duration"] = duration
if frames:
params["frames"] = frames
if ratio:
params["ratio"] = ratio
if resolution:
params["resolution"] = resolution
if generate_audio:
params["generate_audio"] = generate_audio
if camera_fixed:
params["camera_fixed"] = camera_fixed
if seed is not None:
params["seed"] = seed
if return_last_frame:
params["return_last_frame"] = return_last_frame
if service_tier != "default":
params["service_tier"] = service_tier
if execution_expires_after:
params["execution_expires_after"] = execution_expires_after
if draft:
params["draft"] = draft
params.update(kwargs)
response = self._client.content_generation.tasks.create(**params)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image_url: Optional[str] = None,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""异步生成视频"""
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
response = self._client.content_generation.tasks.get(task_id=task_id)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
"""异步查询任务状态"""
return self.get_task_status(task_id)
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
"""
查询视频生成任务列表
Args:
page_size: 每页数量
status: 任务状态筛选,如 "succeeded", "failed", "pending"
**kwargs: 其他参数
Returns:
任务列表
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {"page_size": page_size}
if status:
params["status"] = status
params.update(kwargs)
response = self._client.content_generation.tasks.list(**params)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def delete_task(self, task_id: str) -> None:
"""
删除或取消视频生成任务
Args:
task_id: 任务ID
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
self._client.content_generation.tasks.delete(task_id=task_id)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)

View File

@@ -0,0 +1,334 @@
provider: volcano
models:
# Doubao-Seed 2.0 系列
- name: doubao-seed-2-0-pro-260215
type: chat
provider: volcano
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-lite-260215
type: chat
provider: volcano
description: 面向高频企业场景兼顾性能与成本的均衡型模型综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-mini-260215
type: chat
provider: volcano
description: 面向低时延、高并发与成本敏感场景提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解适合成本和速度优先的轻量级任务。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-code-preview-260215
type: chat
provider: volcano
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
# Doubao-Seed 1.x 系列
- name: doubao-seed-1-8-251228
type: chat
provider: volcano
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面视觉基础能力显著提升可低帧率理解超长视频视频运动理解、复杂空间理解及文档结构化解析能力也有所优化还原生支持智能上下文管理用户可配置上下文策略。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-251015
type: chat
provider: volcano
description: Doubao-Seed-1.6全新多模态深度思考模型同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-lite-251015
type: chat
provider: volcano
description: 更高性价比常见任务的最佳选择支持minimal、low、medium、high 四种reasoning_effort思考深度
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-flash-250828
type: chat
provider: volcano
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型TPOT低至10ms 同时支持文本和视觉理解文本理解能力超过上一代lite视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-code-preview-251028
type: chat
provider: volcano
description: 面向Agentic编程任务进行了深度优化。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
- name: doubao-seed-1-6-vision-250815
type: chat
provider: volcano
description: 全新Doubao-Seed-1.6系列视觉深度思考模型视觉理解能力显著增强并支持image_process视觉工具
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
# Doubao 1.5 系列
- name: doubao-1-5-vision-pro-32k-250115
type: chat
provider: volcano
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
- name: doubao-1-5-pro-32k-250115
type: chat
provider: volcano
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-1-5-lite-32k-250115
type: chat
provider: volcano
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
# Doubao-Seedance 视频生成系列
- name: doubao-seedance-1-5-pro-251215
type: video
provider: volcano
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-250528
type: video
provider: volcano
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-fast-251015
type: video
provider: volcano
description: 一款价格触底、效能封顶的全面模型在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-lite-i2v-250428
type: video
provider: volcano
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
- 图生视频
logo: volcano
- name: doubao-seedance-1-0-lite-t2v-250428
type: video
provider: volcano
description: 基于文本提示词生成视频
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 视频生成
- 文生视频
logo: volcano
# Doubao-Seedream 图像生成系列
- name: doubao-seedream-5-0-260128
type: image
provider: volcano
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-5-251128
type: image
provider: volcano
description: 字节跳动最新推出的图像多模态模型整合了文生图、图生图、组图输出等能力融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-0-250828
type: image
provider: volcano
description: 基于领先架构的SOTA级多模态图像创作模型其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-3-0-t2i-250415
type: image
provider: volcano
description: 一款支持原生高分辨率的中英双语图像生成基础模型综合能力媲美GPT-4o处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 图像生成
- 文生图
logo: volcano
# Doubao 翻译系列
- name: doubao-seed-translation-250915
type: chat
provider: volcano
description: 通用多语言翻译模型支持30余种语言互译支持 4K 上下文窗口,输出长度支持最大 3K tokens
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 翻译模型
logo: volcano
# Doubao Embedding 系列
- name: doubao-embedding-vision-251215
type: embedding
provider: volcano
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 向量模型
- 多模态模型
logo: volcano

View File

@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
class ElasticSearchVector(BaseVector): class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
super().__init__(index_name.lower()) super().__init__(index_name.lower())
# self.embeddings = XinferenceEmbeddings(
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port # 初始化 Embedding 模型(自动支持火山引擎多模态)
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
# )
# Remove debug printing to avoid leaking sensitive information
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
self.embeddings = RedBearEmbeddings(RedBearModelConfig( self.embeddings = RedBearEmbeddings(RedBearModelConfig(
model_name=embedding_config.model_name, model_name=embedding_config.model_name,
provider=embedding_config.provider, provider=embedding_config.provider,
api_key=embedding_config.api_key, api_key=embedding_config.api_key,
base_url=embedding_config.api_base base_url=embedding_config.api_base
)) ))
# self.reranker = XinferenceRerank( self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
# model_uid="bge-reranker-large"
# )
# Remove debug printing to avoid leaking sensitive information
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
self.reranker = RedBearRerank(RedBearModelConfig( self.reranker = RedBearRerank(RedBearModelConfig(
model_name=reranker_config.model_name, model_name=reranker_config.model_name,
provider=reranker_config.provider, provider=reranker_config.provider,
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
def add_chunks(self, chunks: list[DocumentChunk], **kwargs): def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# 实现 Elasticsearch 保存向量 # 实现 Elasticsearch 保存向量
texts = [chunk.page_content for chunk in chunks] texts = [chunk.page_content for chunk in chunks]
embeddings = self.embeddings.embed_documents(list(texts)) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
embeddings = self.embeddings.embed_batch(texts)
else:
embeddings = self.embeddings.embed_documents(list(texts))
self.create(chunks, embeddings, **kwargs) self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
updated count. updated count.
""" """
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
chunk.vector = self.embeddings.embed_query(chunk.page_content) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
chunk.vector = self.embeddings.embed_text(chunk.page_content)
else:
chunk.vector = self.embeddings.embed_query(chunk.page_content)
body = { body = {
"script": { "script": {
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
"""Search the nearest neighbors to a vector.""" """Search the nearest neighbors to a vector."""
query_vector = self.embeddings.embed_query(query) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
query_vector = self.embeddings.embed_text(query)
else:
query_vector = self.embeddings.embed_query(query)
top_k = kwargs.get("top_k", 1024) top_k = kwargs.get("top_k", 1024)
score_threshold = float(kwargs.get("score_threshold") or 0.3) score_threshold = float(kwargs.get("score_threshold") or 0.3)
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"

View File

@@ -109,17 +109,13 @@ class StorageBackend(ABC):
pass pass
@abstractmethod @abstractmethod
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
""" self,
Get an access URL for the file. file_key: str,
expires: int = 3600,
Args: file_name: Optional[str] = None
file_key: Unique identifier for the file in the storage system. ) -> str:
expires: URL validity period in seconds (default: 1 hour). """Get an access URL for the file."""
Returns:
URL for accessing the file.
"""
pass pass
async def get_permanent_url(self, file_key: str) -> Optional[str]: async def get_permanent_url(self, file_key: str) -> Optional[str]:

View File

@@ -210,7 +210,12 @@ class LocalStorage(StorageBackend):
cause=e, cause=e,
) )
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None
) -> str:
""" """
Get an access URL for the file. Get an access URL for the file.
@@ -220,6 +225,7 @@ class LocalStorage(StorageBackend):
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (not used for local storage). expires: URL validity period in seconds (not used for local storage).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A relative URL path for accessing the file. A relative URL path for accessing the file.

View File

@@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK.
import io import io
import logging import logging
import urllib.parse
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
import oss2 import oss2
@@ -242,24 +243,33 @@ class OSSStorage(StorageBackend):
logger.error(f"Failed to check file existence in OSS {file_key}: {e}") logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
return False return False
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
url = self.bucket.sign_url("GET", file_key, expires) params = {}
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -6,6 +6,7 @@ using the boto3 SDK.
""" """
import io import io
import urllib.parse
import logging import logging
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
@@ -352,31 +353,37 @@ class S3Storage(StorageBackend):
logger.error(f"Failed to check file existence in S3 {file_key}: {e}") logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
return False return False
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
params = {"Bucket": self.bucket_name, "Key": file_key}
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.client.generate_presigned_url( url = self.client.generate_presigned_url(
"get_object", "get_object",
Params={ Params=params,
"Bucket": self.bucket_name,
"Key": file_key,
},
ExpiresIn=expires, ExpiresIn=expires,
) )
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -9,7 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.workflow.adapters.errors import ExceptionDefineition from app.core.workflow.adapters.errors import ExceptionDefinition
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
EdgeDefinition, EdgeDefinition,
NodeDefinition, NodeDefinition,
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list) warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list)
class WorkflowImportResult(BaseModel): class WorkflowImportResult(BaseModel):
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefineition] = Field(default_factory=list) warnings: list[ExceptionDefinition] = Field(default_factory=list)
errors: list[ExceptionDefineition] = Field(default_factory=list) errors: list[ExceptionDefinition] = Field(default_factory=list)
class BasePlatformAdapter(ABC): class BasePlatformAdapter(ABC):

View File

@@ -9,9 +9,9 @@ from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ( from app.core.workflow.adapters.errors import (
UnsupportVariableType, UnsupportedVariableType,
UnknowModelWarning, UnknownModelWarning,
ExceptionDefineition, ExceptionDefinition,
ExceptionType ExceptionType
) )
from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.assigner.config import AssignmentItem
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
HttpFormData, HttpFormData,
HttpTimeOutConfig, HttpTimeOutConfig,
HttpRetryConfig, HttpRetryConfig,
HttpErrorDefaultTamplete, HttpErrorDefaultTemplate,
HttpErrorHandleConfig HttpErrorHandleConfig
) )
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
try: try:
return config.model_validate(value) return config.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
var_selector = mapping.get(var_selector, var_selector) var_selector = mapping.get(var_selector, var_selector)
return var_selector return var_selector
def _process_list_variable_litearl(self, variable_selector: list) -> str | None: def _process_list_variable_literal(self, variable_selector: list) -> str | None:
if not self.process_var_selector(".".join(variable_selector)): if not self.process_var_selector(".".join(variable_selector)):
return None return None
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
var_type = self.variable_type_map(var["type"]) var_type = self.variable_type_map(var["type"])
if not var_type: if not var_type:
self.errors.append( self.errors.append(
UnsupportVariableType( UnsupportedVariableType(
scope=node["id"], scope=node["id"],
name=var["variable"], name=var["variable"],
var_type=var["type"], var_type=var["type"],
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
if var_type in ["file", "array[file]"]: if var_type in ["file", "array[file]"]:
self.errors.append( self.errors.append(
ExceptionDefineition( ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
def convert_question_classifier_node_config(self, node: dict) -> dict: def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
) )
result = QuestionClassifierNodeConfig.model_construct( result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")), input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories, categories=categories,
).model_dump() ).model_dump()
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
def convert_llm_node_config(self, node: dict) -> dict: def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
) )
) )
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
memory = MemoryWindowSetting( memory = MemoryWindowSetting(
enable=bool(node_data.get("memory")), enable=bool(node_data.get("memory")),
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
) )
) )
vision = node_data["vision"]["enabled"] vision = node_data["vision"]["enabled"]
vision_input = self._process_list_variable_litearl( vision_input = self._process_list_variable_literal(
node_data["vision"]["configs"]["variable_selector"] node_data["vision"]["configs"]["variable_selector"]
) if vision else None ) if vision else None
result = LLMNodeConfig.model_construct( result = LLMNodeConfig.model_construct(
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
conditions.append( conditions.append(
LoopConditionDetail.model_construct( LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]), operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]), left=self._process_list_variable_literal(condition["variable_selector"]),
right=self.trans_variable_format( right=self.trans_variable_format(
right_value right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"] right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"]) right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE: if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable.get("value", "")) right_value = self._process_list_variable_literal(variable.get("value", ""))
else: else:
right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append( loop_variables.append(
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
def convert_iteration_node_config(self, node: dict) -> dict: def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
result = IterationNodeConfig.model_construct( result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]), input=self._process_list_variable_literal(node_data["iterator_selector"]),
parallel=node_data["is_parallel"], parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"], parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]), output=self._process_list_variable_literal(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")), output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"], flatten=node_data["flatten_output"],
).model_dump() ).model_dump()
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
continue continue
assignments.append( assignments.append(
AssignmentItem( AssignmentItem(
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
value=self._process_list_variable_litearl( value=self._process_list_variable_literal(
assignment["value"] assignment["value"]
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
operation=self.convert_assignment_operator(assignment["operation"]) operation=self.convert_assignment_operator(assignment["operation"])
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
input_variables.append( input_variables.append(
InputVariable.model_construct( InputVariable.model_construct(
name=input_variable["variable"], name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]), variable=self._process_list_variable_literal(input_variable["value_selector"]),
) )
) )
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
else: else:
if node_data["body"]["data"]: if node_data["body"]["data"]:
body_content = (node_data["body"]["data"][0].get("value") or body_content = (node_data["body"]["data"][0].get("value") or
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
else: else:
body_content = "" body_content = ""
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
default_header = var["value"] default_header = var["value"]
elif var["key"] == "status_code": elif var["key"] == "status_code":
default_status_code = var["value"] default_status_code = var["value"]
default_value = HttpErrorDefaultTamplete( default_value = HttpErrorDefaultTemplate(
body=default_body, body=default_body,
headers=default_header, headers=default_header,
status_code=default_status_code, status_code=default_status_code,
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
for variable in node_data["variables"]: for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig.model_construct( mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"], name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"]) value=self._process_list_variable_literal(variable["value_selector"])
)) ))
result = JinjaRenderNodeConfig.model_construct( result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"], template=node_data["template"],
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
def convert_knowledge_node_config(self, node: dict) -> dict: def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.", detail=f"Please reconfigure the Knowledge Retrieval node.",
)) ))
result = KnowledgeRetrievalNodeConfig.model_construct( result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]), query=self._process_list_variable_literal(node_data["query_variable_selector"]),
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
def convert_parameter_extractor_node_config(self, node: dict) -> dict: def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknowModelWarning( UnknownModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
) )
) )
result = ParameterExtractorNodeConfig.model_construct( result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]), text=self._process_list_variable_literal(node_data["query"]),
params=params, params=params,
prompt=node_data.get("instruction") prompt=node_data.get("instruction")
).model_dump() ).model_dump()
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
group_type = {} group_type = {}
if not advanced_settings or not advanced_settings["group_enabled"]: if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables = [ group_variables = [
self._process_list_variable_litearl(variable) self._process_list_variable_literal(variable)
for variable in node_data["variables"] for variable in node_data["variables"]
] ]
group_type["output"] = node_data["output_type"] group_type["output"] = node_data["output_type"]
else: else:
for group in advanced_settings["groups"]: for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [ group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable) self._process_list_variable_literal(variable)
for variable in group["variables"] for variable in group["variables"]
] ]
group_type[group["group_name"]] = group["output_type"] group_type[group["group_name"]] = group["output_type"]
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
def convert_tool_node_config(self, node: dict) -> dict: def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.dify.converter import DifyConverter from app.core.workflow.adapters.dify.converter import DifyConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
NodeDefinition, NodeDefinition,
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False
if self.config.get("app", {}).get("mode") == "workflow": if self.config.get("app", {}).get("mode") == "workflow":
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.PLATFORM, type=ExceptionType.PLATFORM,
detail="workflow mode is not supported" detail="workflow mode is not supported"
)) ))
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
edge = self._convert_edge(edge) edge = self._convert_edge(edge)
if edge: if edge:
self.edges.append(edge) self.edges.append(edge)
#
for variable in self.config.get("workflow").get("conversation_variables"): for variable in self.config.get("workflow").get("conversation_variables"):
con_var = self._convert_variable(variable) con_var = self._convert_variable(variable)
if variable: if variable:
self.conv_variables.append(con_var) self.conv_variables.append(con_var)
#
# for variables in config.get("workflow").get("environment_variables"): # for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables) # variable = self._convert_variable(variables)
# conv_variables.append(variable) # conv_variables.append(variable)
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"y": node["position"]["y"] + position["y"] "y": node["position"]["y"] + position["y"]
} }
self.errors.append( self.errors.append(
ExceptionDefineition( ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
detail="parent cycle node not found" detail="parent cycle node not found"
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_data = node["data"] node_data = node["data"]
converter = self.get_node_convert(node_type) converter = self.get_node_convert(node_type)
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
)) ))
return converter(node) return converter(node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
try: try:
source = edge["source"] source = edge["source"]
target = edge["target"] target = edge["target"]
label = None label = None
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
label=label, label=label,
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}", detail=f"convert edge error - {e}",
)) ))
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
description=variable.get("description") description=variable.get("description")
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}", detail=f"convert variable error - {e}",

View File

@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
UNKNOWN = "unknown" UNKNOWN = "unknown"
class ExceptionDefineition(BaseModel): class ExceptionDefinition(BaseModel):
type: ExceptionType type: ExceptionType
detail: str detail: str
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
name: str | None = None name: str | None = None
class UnknowModelWarning(ExceptionDefineition): class UnknownModelWarning(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id, node_name, model_name): def __init__(self, node_id, node_name, model_name):
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
) )
class UnknowError(ExceptionDefineition): class UnknownError(ExceptionDefinition):
type: ExceptionType = ExceptionType.UNKNOWN type: ExceptionType = ExceptionType.UNKNOWN
def __init__(self, detail: str, **kwargs): def __init__(self, detail: str, **kwargs):
super().__init__(detail=detail, **kwargs) super().__init__(detail=detail, **kwargs)
class UnsupportPlatform(ExceptionDefineition): class UnsupportedPlatform(ExceptionDefinition):
type: ExceptionType = ExceptionType.PLATFORM type: ExceptionType = ExceptionType.PLATFORM
def __init__(self, platform: str): def __init__(self, platform: str):
super().__init__(detail=f"Unsupport platform {platform}") super().__init__(detail=f"Unsupported platform {platform}")
class UnsupportVariableType(ExceptionDefineition): class UnsupportedVariableType(ExceptionDefinition):
type: ExceptionType = ExceptionType.VARIABLE type: ExceptionType = ExceptionType.VARIABLE
def __init__(self, scope, name, var_type: str, **kwargs): def __init__(self, scope, name, var_type: str, **kwargs):
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type[{var_type}]", **kwargs) super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
class InvalidConfiguration(ExceptionDefineition): class InvalidConfiguration(ExceptionDefinition):
type: ExceptionType = ExceptionType.CONFIG type: ExceptionType = ExceptionType.CONFIG
def __init__(self): def __init__(self):
super().__init__(detail="Invalid workflow configuration format") super().__init__(detail="Invalid workflow configuration format")
class UnsupportNodeType(ExceptionDefineition): class UnsupportedNodeType(ExceptionDefinition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id: str, node_type: str): def __init__(self, node_id: str, node_type: str):
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")

View File

@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
BasePlatformAdapter, BasePlatformAdapter,
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
node_type = self.map_node_type(node["type"]) node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportNodeType( self.errors.append(UnsupportedNodeType(
node_id=node_id, node_id=node_id,
node_type=node["type"] node_type=node["type"]
)) ))
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
return NodeDefinition(**node) return NodeDefinition(**node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try: try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found" detail=f"edge {edge.get('id')} skipped: source or target node not found"
)) ))
return None return None
return EdgeDefinition(**edge) return EdgeDefinition(**edge)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}" detail=f"convert edge error - {e}"
)) ))
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
return VariableDefinition(**variable) return VariableDefinition(**variable)
except Exception as e: except Exception as e:
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefinition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}" detail=f"convert variable error - {e}"

View File

@@ -1,6 +1,6 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import ( from app.core.workflow.nodes.configs import (
StartNodeConfig, StartNodeConfig,
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
try: try:
return config_cls.model_validate(value) return config_cls.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefinition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,

View File

@@ -7,7 +7,7 @@ import re
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Any, Iterable from typing import Any, Iterable, Callable
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END from langgraph.graph import START, END
@@ -41,48 +41,31 @@ class GraphBuilder:
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
stream: bool = False, stream: bool = False,
subgraph: bool = False, cycle: str = '',
variable_pool: VariablePool | None = None variable_pool: VariablePool | None = None
): ):
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.stream = stream self.stream = stream
self.subgraph = subgraph self.cycle = cycle
self.start_node_id: str | None = None self.start_node_id: str | None = None
self.node_map = {node["id"]: node for node in self.nodes} self.node_map: dict[str, dict] = {}
self.end_node_map: dict[str, StreamOutputConfig] = {} self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep = lru_cache( self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
maxsize=len(self.nodes) * 2
)(self._find_upstream_activation_dep)
if variable_pool: if variable_pool:
self.variable_pool = variable_pool self.variable_pool = variable_pool
else: else:
self.variable_pool = VariablePool() self.variable_pool = VariablePool()
self.graph = StateGraph(WorkflowState) self.graph: StateGraph | None = None
self.add_nodes() self.nodes: list = []
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) self.edges: list = []
self.end_nodes = [ self.reachable_nodes: set[str] | None = None
node self.end_nodes: list[dict] = []
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list) self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj() self._adj: dict[str, list[str]] = defaultdict(list)
self._analyze_end_node_output()
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def get_node_type(self, node_id: str) -> str: def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID. """Retrieve the type of node given its ID.
@@ -108,13 +91,14 @@ class GraphBuilder:
result[node[0]].append(node[1]) result[node[0]].append(node[1])
return result return result
def _build_reverse_adj(self): def _build_adj(self):
for edge in self.edges: for edge in self.edges:
if edge["source"] not in self.reachable_nodes: if edge["source"] not in self.reachable_nodes:
continue continue
self._reverse_adj[edge.get("target")].append({ self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label") "id": edge["source"], "branch": edge.get("label")
}) })
self._adj[edge.get("source")].append(edge["target"])
def _find_upstream_activation_dep( def _find_upstream_activation_dep(
self, self,
@@ -302,22 +286,13 @@ class GraphBuilder:
""" """
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id") node_id = node.get("id")
cycle_node = node.get("cycle") if node_id not in self.reachable_nodes:
if cycle_node: continue
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
# Create node instance (start and end nodes are also created) # Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
if node_type in BRANCH_NODES: if node_type in BRANCH_NODES:
@@ -413,11 +388,12 @@ class GraphBuilder:
# Add conditional edges # Add conditional edges
for source_node, branches in conditional_edges.items(): for source_node, branches in conditional_edges.items():
def make_router(src, branch_list): def make_router(src, branch_list):
"""reate a router function for each source node that routes to a NOP node for later merging.""" """Create a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets): def make_branch_node(node_name, targets):
def node(s): def node(s):
# NOTE: NOP NODE MUST NOT MODIFY STATE # NOTE: NOP NODE USED FOR ROUTING ONLY.
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
return { return {
"activate": { "activate": {
node_id: s["activate"][node_name] node_id: s["activate"][node_name]
@@ -504,14 +480,52 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}") logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node # Connect End nodes to the global END node
for end_node in self.end_nodes: for node in self.reachable_nodes:
end_node_id = end_node.get("id") if not self._adj[node]:
if end_node_id: self.graph.add_edge(node, END)
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
for node in nodes:
if (node.get("cycle") or '') == self.cycle:
node_type = node.get("type")
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node.get("id")
elif node_type == NodeType.NOTES:
continue
self.nodes.append(node)
self.node_map[node.get("id")] = node
for edge in edges:
source_in = edge.get("source") in self.node_map
target_in = edge.get("target") in self.node_map
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
self.edges.append(edge)
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._build_adj()
self._find_upstream_activation_dep: Callable = lru_cache(
maxsize=len(self.nodes)*2
)(self._find_upstream_activation_dep)
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
self._analyze_end_node_output()
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
self.graph = self.graph.compile(checkpointer=checkpointer) return self.graph.compile(checkpointer=checkpointer)
return self.graph

View File

@@ -2,6 +2,7 @@
# Author: Eternity # Author: Eternity
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33 # @Time : 2026/2/10 13:33
from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
def build_final_output( def build_final_output(
self, self,
result: dict, result: dict,
execution_context: ExecutionContext,
variable_pool: VariablePool, variable_pool: VariablePool,
elapsed_time: float, elapsed_time: float,
final_output: str, final_output: str,
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
- "node_outputs" (dict): Outputs of executed nodes. - "node_outputs" (dict): Outputs of executed nodes.
- "messages" (list): Conversation messages exchanged during execution. - "messages" (list): Conversation messages exchanged during execution.
- "error" (str, optional): Error message if any node failed. - "error" (str, optional): Error message if any node failed.
execution_context (ExecutionContext): The execution context containing metadata like
execution ID, workspace ID, and user ID.)
variable_pool (VariablePool): Variable Pool variable_pool (VariablePool): Variable Pool
elapsed_time (float): Total execution time in seconds. elapsed_time (float): Total execution time in seconds.
final_output (Any): The aggregated or final output content of the workflow final_output (Any): The aggregated or final output content of the workflow
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
""" """
node_outputs = result.get("node_outputs", {}) node_outputs = result.get("node_outputs", {})
token_usage = self.aggregate_token_usage(node_outputs) token_usage = self.aggregate_token_usage(node_outputs)
conversation_id = variable_pool.get_value("sys.conversation_id") conversation_vars = {}
sys_vars = {}
if variable_pool:
conversation_vars = variable_pool.get_all_conversation_vars()
sys_vars = variable_pool.get_all_system_vars()
return { return {
"status": "completed" if success else "failed", "status": "completed" if success else "failed",
"output": final_output, "output": final_output,
"variables": { "variables": {
"conv": variable_pool.get_all_conversation_vars(), "conv": conversation_vars,
"sys": variable_pool.get_all_system_vars() "sys": sys_vars
}, },
"node_outputs": node_outputs, "node_outputs": node_outputs,
"messages": result.get("messages", []), "messages": result.get("messages", []),
"conversation_id": conversation_id, "conversation_id": execution_context.conversation_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": token_usage, "token_usage": token_usage,
"error": result.get("error"), "error": result.get("error"),

View File

@@ -12,6 +12,7 @@ class ExecutionContext(BaseModel):
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
conversation_id: str
memory_storage_type: str memory_storage_type: str
user_rag_memory_id: str user_rag_memory_id: str
checkpoint_config: RunnableConfig checkpoint_config: RunnableConfig
@@ -22,6 +23,7 @@ class ExecutionContext(BaseModel):
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str, user_id: str,
conversation_id: str,
memory_storage_type: str, memory_storage_type: str,
user_rag_memory_id: str user_rag_memory_id: str
): ):
@@ -29,6 +31,7 @@ class ExecutionContext(BaseModel):
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
conversation_id=conversation_id,
memory_storage_type=memory_storage_type, memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,

View File

@@ -3,6 +3,7 @@
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/9 13:51 # @Time : 2026/2/9 13:51
import datetime import datetime
import time
import logging import logging
from typing import Any from typing import Any
@@ -82,13 +83,15 @@ class WorkflowExecutor:
CompiledStateGraph: The compiled and ready-to-run state graph. CompiledStateGraph: The compiled and ready-to-run state graph.
""" """
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}") logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
start_time = time.time()
builder = GraphBuilder( builder = GraphBuilder(
self.workflow_config, self.workflow_config,
stream=stream, stream=stream,
) )
self.graph = builder.build()
self.start_node_id = builder.start_node_id self.start_node_id = builder.start_node_id
self.variable_pool = builder.variable_pool self.variable_pool = builder.variable_pool
self.graph = builder.build()
self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
self.event_handler = EventStreamHandler( self.event_handler = EventStreamHandler(
@@ -96,7 +99,8 @@ class WorkflowExecutor:
variable_pool=self.variable_pool, variable_pool=self.variable_pool,
execution_id=self.execution_context.execution_id execution_id=self.execution_context.execution_id
) )
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}") logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
f"cost: {time.time() - start_time:.4f}s")
return self.graph return self.graph
@@ -134,94 +138,12 @@ class WorkflowExecutor:
return event.get("data") return event.get("data")
return self.result_builder.build_final_output( return self.result_builder.build_final_output(
{"error": "Workflow execution did not end as expected"}, {"error": "Workflow execution did not end as expected"},
self.execution_context,
self.variable_pool, self.variable_pool,
(datetime.datetime.now() - start).total_seconds(), (datetime.datetime.now() - start).total_seconds(),
"", "",
success=False success=False
) )
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
#
# start_time = datetime.datetime.now()
#
# # Execute the workflow
# try:
# # Build the workflow graph
# graph = self.build_graph()
#
# # Initialize the variable pool with input data
# await self.variable_initializer.initialize(
# variable_pool=self.variable_pool,
# input_data=input_data,
# execution_context=self.execution_context
# )
# initial_state = self.state_manager.create_initial_state(
# workflow_config=self.workflow_config,
# input_data=input_data,
# execution_context=self.execution_context,
# start_node_id=self.start_node_id
# )
#
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
#
# # Aggregate output from all End nodes
# full_content = ''
# for end_id in self.stream_coordinator.end_outputs.keys():
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
#
# # Append messages for user and assistant
# if input_data.get("files"):
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "user",
# "content": input_data.get("files")
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# else:
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# # Calculate elapsed time
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.info(
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
#
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
#
# except Exception as e:
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
# exc_info=True)
# return {
# "status": "failed",
# "error": str(e),
# "output": None,
# "node_outputs": {},
# "elapsed_time": elapsed_time,
# "token_usage": None
# }
async def execute_stream( async def execute_stream(
self, self,
@@ -255,7 +177,7 @@ class WorkflowExecutor:
"data": { "data": {
"execution_id": self.execution_context.execution_id, "execution_id": self.execution_context.execution_id,
"workspace_id": self.execution_context.workspace_id, "workspace_id": self.execution_context.workspace_id,
"conversation_id": input_data.get("conversation_id"), "conversation_id": self.execution_context.conversation_id,
"timestamp": int(start_time.timestamp() * 1000) "timestamp": int(start_time.timestamp() * 1000)
} }
} }
@@ -376,6 +298,7 @@ class WorkflowExecutor:
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(
result, result,
self.execution_context,
self.variable_pool, self.variable_pool,
elapsed_time, elapsed_time,
full_content, full_content,
@@ -396,6 +319,7 @@ class WorkflowExecutor:
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(
result, result,
self.execution_context,
self.variable_pool, self.variable_pool,
elapsed_time, elapsed_time,
full_content, full_content,
@@ -432,6 +356,7 @@ async def execute_workflow(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
conversation_id=input_data.get("conversation_id"),
memory_storage_type=memory_storage_type, memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
) )
@@ -471,6 +396,7 @@ async def execute_workflow_stream(
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
memory_storage_type=memory_storage_type, memory_storage_type=memory_storage_type,
conversation_id=input_data.get("conversation_id"),
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(

View File

@@ -64,9 +64,7 @@ class AgentNode(BaseNode):
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
return release, message return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode): class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None

View File

@@ -28,7 +28,7 @@ class BaseNode(ABC):
All node types should inherit from this class and implement the `execute` method. All node types should inherit from this class and implement the `execute` method.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""Initialize the node. """Initialize the node.
Args: Args:
@@ -41,6 +41,7 @@ class BaseNode(ABC):
self.node_type = node_config["type"] self.node_type = node_config["type"]
self.cycle = node_config.get("cycle") self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id) self.node_name = node_config.get("name", self.node_id)
self.down_stream_nodes = down_stream_nodes
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {} self.error_handling = node_config.get("error_handling") or {}
@@ -93,18 +94,16 @@ class BaseNode(ABC):
dict: A dict with a single key 'activate', mapping node IDs to dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False). their activation status (True/False).
""" """
edges = self.workflow_config.get("edges") activate_flag = self.check_activate(state)
under_stream_nodes = [
edge.get("target") if self.node_type not in BRANCH_NODES:
for edge in edges activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES else:
] activate = {}
return {
"activate": { activate[self.node_id] = activate_flag
node_id: self.check_activate(state)
for node_id in under_stream_nodes return {"activate": activate}
} | {self.node_id: self.check_activate(state)}
}
@abstractmethod @abstractmethod
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -315,8 +314,8 @@ class BaseNode(ABC):
elapsed_time = (time.time() - start_time) * 1000 elapsed_time = (time.time() - start_time) * 1000
logger.info(f"Node {self.node_id} streaming execution finished, " logger.debug(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output) # Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result) extracted_output = self._extract_output(final_result)
@@ -428,8 +427,8 @@ class BaseNode(ABC):
when an error edge exists. If no error edge exists, this method when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow. raises an exception to stop the workflow.
""" """
# Check if the node has an error edge defined # # Check if the node has an error edge defined
error_edge = self._find_error_edge() # error_edge = self._find_error_edge()
# Extract input data (for logging or audit purposes) # Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool) input_data = self._extract_input(state, variable_pool)
@@ -447,27 +446,26 @@ class BaseNode(ABC):
"error": error_message "error": error_message
} }
if error_edge: # if error_edge:
# If an error edge exists, log a warning and continue to error node # # If an error edge exists, log a warning and continue to error node
logger.warning( # logger.warning(
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
) # )
return { # return {
"node_outputs": { # "node_outputs": {
self.node_id: node_output # self.node_id: node_output
}, # },
"error": error_message, # "error": error_message,
"error_node": self.node_id # "error_node": self.node_id
} # }
else: # else:
# If no error edge, send the error via stream writer and stop the workflow writer = get_stream_writer()
writer = get_stream_writer() writer({
writer({ "type": "node_error",
"type": "node_error", **node_output
**node_output })
}) logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") raise Exception(f"Node {self.node_id} execution failed: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit). """Extracts the input data for this node (used for logging or audit).
@@ -644,7 +642,7 @@ class BaseNode(ABC):
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, api_config=api_config) multimodal_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput( file_obj = FileInput(
type=content.type, type=content.type,
url=content.url, url=content.url,
@@ -653,7 +651,7 @@ class BaseNode(ABC):
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
) )
file_obj.set_content(content.get_content()) file_obj.set_content(content.get_content())
message = await multimodel_service.process_files( message = await multimodal_service.process_files(
[file_obj], [file_obj],
) )
content.set_content(file_obj.get_content()) content.set_content(file_obj.get_content())
@@ -661,7 +659,7 @@ class BaseNode(ABC):
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
return message return message
return None return None
raise TypeError(f'Unexpect input value type - {type(content)}') raise TypeError(f'Unexpected input value type - {type(content)}')
@staticmethod @staticmethod
def process_model_output(content) -> str: def process_model_output(content) -> str:

View File

@@ -51,8 +51,8 @@ console.log(result)
class CodeNode(BaseNode): class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: CodeNodeConfig | None = None self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
It acts as a container and execution controller for a subgraph. It acts as a container and execution controller for a subgraph.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT} outputs = {"__child_state": VariableType.ARRAY_OBJECT}
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
else: else:
remain_edges.append(edge) remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges # # Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [ # self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id # node for node in nodes if node.get("cycle") != self.node_id
] # ]
self.workflow_config["edges"] = remain_edges # self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges return cycle_nodes, cycle_edges
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
3. Compile the graph for runtime execution 3. Compile the graph for runtime execution
""" """
from app.core.workflow.engine.graph_builder import GraphBuilder from app.core.workflow.engine.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool() self.child_variable_pool = VariablePool()
builder = GraphBuilder( builder = GraphBuilder(
{ {
"nodes": self.cycle_nodes, "nodes": self.cycle_nodes,
"edges": self.cycle_edges, "edges": self.cycle_edges,
}, },
subgraph=True, variable_pool=self.child_variable_pool,
variable_pool=self.child_variable_pool cycle=self.node_id
) )
self.start_node_id = builder.start_node_id
self.graph = builder.build() self.graph = builder.build()
self.start_node_id = builder.start_node_id
self.child_variable_pool = builder.variable_pool self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
Raises: Raises:
RuntimeError: If the node type is unsupported. RuntimeError: If the node type is unsupported.
""" """
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
return await LoopRuntime( return await LoopRuntime(
start_id=self.start_node_id, start_id=self.start_node_id,
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
raise RuntimeError("Unknown cycle node type") raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
yield { yield {
"__final__": True, "__final__": True,

View File

@@ -1,9 +1,7 @@
"""End 节点配置""" """End 节点配置"""
from pydantic import Field from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
class EndNodeConfig(BaseNodeConfig): class EndNodeConfig(BaseNodeConfig):

View File

@@ -36,8 +36,6 @@ class EndNode(BaseNode):
Returns: Returns:
最终输出字符串 最终输出字符串
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
output = self._render_template(output_template, variable_pool, strict=False) output = self._render_template(output_template, variable_pool, strict=False)
else: else:
output = "" output = ""
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output return output

View File

@@ -28,7 +28,7 @@ class NodeType(StrEnum):
NOTES = "notes" NOTES = "notes"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
) )
class HttpErrorDefaultTamplete(BaseModel): class HttpErrorDefaultTemplate(BaseModel):
body: str = Field( body: str = Field(
default="", default="",
description="Default body returned on HTTP error", description="Default body returned on HTTP error",
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
description="Error handling strategy: 'none', 'default', or 'branch'", description="Error handling strategy: 'none', 'default', or 'branch'",
) )
default: HttpErrorDefaultTamplete | None = Field( default: HttpErrorDefaultTemplate | None = Field(
default=None, default=None,
description="Default response template for error handling", description="Default response template for error handling",
) )

View File

@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.utils.file_processer import mime_to_file_type from app.core.workflow.utils.file_processor import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod from app.schemas import FileType, TransferMethod
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
or a branch identifier string when error branching is enabled. or a branch identifier string when error branching is enabled.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: IfElseNodeConfig | None = None self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: JinjaRenderNodeConfig | None = None self.typed_config: JinjaRenderNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode): class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None self.vector_service: ElasticSearchVector | None = None

View File

@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: LLMNodeConfig | None = None self.typed_config: LLMNodeConfig | None = None
self.messages = [] self.messages = []

View File

@@ -14,8 +14,8 @@ from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryReadNodeConfig | None = None self.typed_config: MemoryReadNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
@@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode):
class MemoryWriteNode(BaseNode): class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryWriteNodeConfig | None = None self.typed_config: MemoryWriteNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -104,13 +104,15 @@ class NodeFactory:
def create_node( def create_node(
cls, cls,
node_config: dict[str, Any], node_config: dict[str, Any],
workflow_config: dict[str, Any] workflow_config: dict[str, Any],
down_stream_nodes: list[str]
) -> WorkflowNode | None: ) -> WorkflowNode | None:
"""创建节点实例 """创建节点实例
Args: Args:
node_config: 节点配置 node_config: 节点配置
workflow_config: 工作流配置 workflow_config: 工作流配置
down_stream_nodes: 下游节点
Returns: Returns:
节点实例或 None对于不支持的节点类型 节点实例或 None对于不支持的节点类型
@@ -127,7 +129,7 @@ class NodeFactory:
# 创建节点实例 # 创建节点实例
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config) return node_class(node_config, workflow_config, down_stream_nodes)
@classmethod @classmethod
def get_supported_types(cls) -> list[str]: def get_supported_types(cls) -> list[str]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode): class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ParameterExtractorNodeConfig | None = None self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {} self.response_metadata = {}

View File

@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: QuestionClassifierNodeConfig | None = None self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {} self.category_to_case_map = {}
self.response_metadata = {} self.response_metadata = {}

View File

@@ -27,14 +27,8 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。 注意:变量的验证和默认值处理由 Executor 在初始化时完成。
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""初始化 Start 节点 super().__init__(node_config, workflow_config, down_stream_nodes)
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置 # 解析并验证配置
self.typed_config: StartNodeConfig | None = None self.typed_config: StartNodeConfig | None = None
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
包含系统参数、会话变量和自定义变量的字典 包含系统参数、会话变量和自定义变量的字典
""" """
self.typed_config = StartNodeConfig(**self.config) self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 处理自定义变量(传入 pool 避免重复创建) # 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(variable_pool) custom_vars = self._process_custom_variables(variable_pool)
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
**custom_vars # 自定义变量作为节点输出的一部分 **custom_vars # 自定义变量作为节点输出的一部分
} }
logger.info( logger.debug(
f"节点 {self.node_id} (Start) 执行完成," f"Node {self.node_id} (Start) execution completed, "
f"输出了 {len(custom_vars)} 个自定义变量" f"outputting {len(custom_vars)} custom variables"
) )
return result return result

View File

@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
class ToolNode(BaseNode): class ToolNode(BaseNode):
"""工具节点""" """工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ToolNodeConfig | None = None self.typed_config: ToolNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode): class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: VariableAggregatorNodeConfig | None = None self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -153,7 +153,8 @@ class TemplateRenderer:
# 全局渲染器实例(严格模式) # 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True) _strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template( def render_template(
@@ -184,7 +185,7 @@ def render_template(
... ) ... )
'请分析: 这是一段文本' '请分析: 这是一段文本'
""" """
renderer = TemplateRenderer(strict=strict) renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars) return renderer.render(template, conv_vars, node_outputs, system_vars)
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
Returns: Returns:
错误列表 错误列表
""" """
return _default_renderer.validate(template) return _strict_renderer.validate(template)

View File

@@ -6,6 +6,7 @@
import copy import copy
import logging import logging
from collections import defaultdict, deque
from typing import Any, Union, TYPE_CHECKING from typing import Any, Union, TYPE_CHECKING
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
@@ -119,7 +120,6 @@ class WorkflowValidator:
errors = [] errors = []
graphs = cls.get_subgraph(workflow_config) graphs = cls.get_subgraph(workflow_config)
logger.info(graphs)
for index, graph in enumerate(graphs): for index, graph in enumerate(graphs):
nodes = graph.get("nodes", []) nodes = graph.get("nodes", [])
edges = graph.get("edges", []) edges = graph.get("edges", [])
@@ -183,7 +183,7 @@ class WorkflowValidator:
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle: if has_cycle:
errors.append( errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
) )
# 8. 验证变量名 # 8. 验证变量名
@@ -204,18 +204,18 @@ class WorkflowValidator:
Returns: Returns:
可达节点 ID 集合 可达节点 ID 集合
""" """
adj = defaultdict(list)
for edge in edges:
adj[edge["source"]].append(edge["target"])
reachable = {start_id} reachable = {start_id}
queue = [start_id] queue = deque([start_id])
while queue: while queue:
current = queue.pop(0) current = queue.popleft()
for edge in edges: for target in adj[current]:
if edge.get("source") == current: if target not in reachable:
target = edge.get("target") reachable.add(target)
if target and target not in reachable: queue.append(target)
reachable.add(target)
queue.append(target)
return reachable return reachable
@staticmethod @staticmethod
@@ -229,10 +229,6 @@ class WorkflowValidator:
Returns: Returns:
(has_cycle, cycle_path): 是否有循环和循环路径 (has_cycle, cycle_path): 是否有循环和循环路径
""" """
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {} graph: dict[str, list[str]] = {}
for edge in edges: for edge in edges:
source = edge.get("source") source = edge.get("source")
@@ -243,10 +239,6 @@ class WorkflowValidator:
if edge_type == "error": if edge_type == "error":
continue continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target: if source and target:
if source not in graph: if source not in graph:
graph[source] = [] graph[source] = []

View File

@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
def valid_value(self, value) -> dict: def valid_value(self, value) -> dict:
if not isinstance(value, dict): if not isinstance(value, dict):
raise TypeError(f"Value must be a dict - {type(value)}:{value}") raise TypeError(f"Value must be a dict - {type(value)}:{value}")
return value return value
def to_literal(self) -> str: def to_literal(self) -> str:

View File

@@ -27,9 +27,9 @@ class ModelType(StrEnum):
RERANK = "rerank" RERANK = "rerank"
# TTS = "tts" # TTS = "tts"
# SPEECH2TEXT = "speech2text" # SPEECH2TEXT = "speech2text"
# IMAGE = "image" IMAGE = "image"
# AUDIO = "audio" # AUDIO = "audio"
# VISION = "vision" VIDEO = "video"
class ModelProvider(StrEnum): class ModelProvider(StrEnum):
@@ -46,6 +46,7 @@ class ModelProvider(StrEnum):
XINFERENCE = "xinference" XINFERENCE = "xinference"
GPUSTACK = "gpustack" GPUSTACK = "gpustack"
BEDROCK = "bedrock" BEDROCK = "bedrock"
VOLCANO = "volcano"
COMPOSITE = "composite" COMPOSITE = "composite"

View File

@@ -23,6 +23,17 @@ class Tenants(Base):
# 国际化语言配置字段 # 国际化语言配置字段
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言 default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表 supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
# 租户联系信息
contact_name = Column(String(100), nullable=True) # 联系人姓名
contact_email = Column(String(255), nullable=True) # 联系人邮箱
contact_phone = Column(String(50), nullable=True) # 联系人电话
# 租户套餐信息
plan = Column(String(50), nullable=True) # 套餐类型
plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间
api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制
status = Column(String(50), nullable=True, default='active') # 租户状态
# Relationship to users - one tenant has many users # Relationship to users - one tenant has many users
users = relationship("User", back_populates="tenant") users = relationship("User", back_populates="tenant")

View File

@@ -439,7 +439,6 @@ class ModelConfigRepository:
ModelConfig.is_public ModelConfig.is_public
), ),
ModelConfig.provider == provider, ModelConfig.provider == provider,
ModelConfig.is_active,
~ModelConfig.is_composite ~ModelConfig.is_composite
) )
).all() ).all()

View File

@@ -325,27 +325,30 @@ class FileStorageService:
) )
raise raise
async def get_file_url(self, file_key: str, expires: int = 3600) -> str: async def get_file_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get an access URL for a file. Get an access URL for a file.
Args: Args:
file_key: The file key. file_key: The file key.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
URL for accessing the file. URL for accessing the file.
""" """
logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s") logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s")
try: try:
url = await self.storage.get_url(file_key, expires) url = await self.storage.get_url(file_key, expires, file_name=file_name)
logger.debug(f"File URL generated: file_key={file_key}") logger.debug(f"File URL generated: file_key={file_key}")
return url return url
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}")
f"Error getting file URL: file_key={file_key}, error={str(e)}"
)
raise raise

View File

@@ -0,0 +1,162 @@
"""
图片和视频生成服务
提供统一的生成接口,支持多种 Provider
"""
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
import uuid
from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelType
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.services.model_service import ModelApiKeyService
class GenerationService:
"""生成服务"""
def __init__(self, db: Session):
self.db = db
async def generate_image(
self,
model_config_id: str,
prompt: str,
size: Optional[str] = "2k",
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
model_config_id: 模型配置ID
prompt: 提示词
size: 图片尺寸
**kwargs: 其他参数
Returns:
生成结果
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.IMAGE:
raise BusinessException(
f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成图片
generator = RedBearImageGenerator(config)
result = await generator.agenerate(prompt, size, **kwargs)
return result
async def generate_video(
self,
model_config_id: str,
prompt: str,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
model_config_id: 模型配置ID
prompt: 提示词
duration: 视频时长(秒)
**kwargs: 其他参数
Returns:
生成结果包含任务ID
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.VIDEO:
raise BusinessException(
f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成视频
generator = RedBearVideoGenerator(config)
result = await generator.agenerate(prompt, duration, **kwargs)
return result
async def get_video_task_status(
self,
model_config_id: str,
task_id: str
) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
model_config_id: 模型配置ID
task_id: 任务ID
Returns:
任务状态信息
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 查询任务状态
generator = RedBearVideoGenerator(config)
result = await generator.aget_task_status(task_id)
return result

View File

@@ -357,6 +357,7 @@ class MemoryAgentService:
if file_object is None: if file_object is None:
continue continue
message["file_content"].append((file_object, file["type"])) message["file_content"].append((file_object, file["type"]))
logger.info(messages)
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try: try:
@@ -606,7 +607,7 @@ class MemoryAgentService:
retrieved_content.append({query: statements}) retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串 # 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []: if not retrieved_content:
retrieved_content = '' retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存 # 只有当回答不是"信息不足"且不是快速检索时才保存

View File

@@ -154,10 +154,17 @@ class ModelConfigService:
} }
elif model_type_lower == "embedding": elif model_type_lower == "embedding":
# Embedding 模型验证(在线程中运行同步方法) # Embedding 模型验证
# 统一使用 RedBearEmbeddings自动支持火山引擎多模态
embedding = RedBearEmbeddings(model_config) embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"] test_texts = [test_message, "测试文本"]
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
# 火山引擎使用 embed_batch其他使用 embed_documents
if provider.lower() == "volcano":
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
else:
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
return { return {
@@ -193,6 +200,56 @@ class ModelConfigService:
}, },
"error": None "error": None
} }
elif model_type_lower == "image":
# 图片生成模型验证
from app.core.models.generation import RedBearImageGenerator
generator = RedBearImageGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda",
size="2K"
)
elapsed_time = time.time() - start_time
logger.info(f"成功生成图片,结果: {result}")
return {
"valid": True,
"message": "图片生成模型配置验证成功",
"response": f"成功生成图片,结果: {result}",
"elapsed_time": elapsed_time,
"usage": {
"prompt_length": len("a cute panda"),
"image_count": 1
},
"error": None
}
elif model_type_lower == "video":
# 视频生成模型验证
from app.core.models.generation import RedBearVideoGenerator
generator = RedBearVideoGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda playing in bamboo forest",
duration=5
)
elapsed_time = time.time() - start_time
# 视频生成是异步任务返回任务ID
task_id = result.get("task_id") if isinstance(result, dict) else None
return {
"valid": True,
"message": "视频生成模型配置验证成功",
"response": f"成功创建视频生成任务任务ID: {task_id}",
"elapsed_time": elapsed_time,
"usage": {
"prompt_length": len("a cute panda playing in bamboo forest"),
"task_id": task_id
},
"error": None
}
else: else:
return { return {

View File

@@ -294,6 +294,7 @@ PROVIDER_STRATEGIES = {
"bedrock": BedrockFormatStrategy, "bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy, "anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy, "openai": OpenAIFormatStrategy,
"volcano": OpenAIFormatStrategy,
} }

View File

@@ -250,6 +250,20 @@ def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user:
} }
) )
# 检查是否为租户联系人
from app.models.tenant_model import Tenants
tenant = db.query(Tenants).filter(Tenants.id == db_user.tenant_id).first()
if tenant and tenant.contact_email and tenant.contact_email == db_user.email:
business_logger.warning(f"尝试停用租户联系人: {db_user.email}, tenant_id={db_user.tenant_id}")
raise BusinessException(
"该管理员是租户联系人,请先在租户信息中更换联系邮箱,再禁用此管理员",
code=BizCode.FORBIDDEN,
context={
"user_id": str(user_id_to_deactivate),
"tenant_id": str(db_user.tenant_id)
}
)
# 停用用户 # 停用用户
business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})") business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})")
db_user.is_active = False db_user.is_active = False

View File

@@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult
from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration
from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.schemas import AppCreate from app.schemas import AppCreate
from app.schemas.workflow_schema import WorkflowConfigCreate from app.schemas.workflow_schema import WorkflowConfigCreate
@@ -46,7 +46,7 @@ class WorkflowImportService:
success=False, success=False,
temp_id=None, temp_id=None,
workflow_id=None, workflow_id=None,
errors=[UnsupportPlatform(platform=platform)] errors=[UnsupportedPlatform(platform=platform)]
) )
adapter = self.registry.get_adapter(platform, config) adapter = self.registry.get_adapter(platform, config)

View File

@@ -0,0 +1,42 @@
"""202603252115
Revision ID: 1ea8fe97b5b7
Revises: e28bcc212da5
Create Date: 2026-03-25 21:14:41.825048
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1ea8fe97b5b7'
down_revision: Union[str, None] = 'e28bcc212da5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('tenants', sa.Column('contact_name', sa.String(length=100), nullable=True))
op.add_column('tenants', sa.Column('contact_email', sa.String(length=255), nullable=True))
op.add_column('tenants', sa.Column('contact_phone', sa.String(length=50), nullable=True))
op.add_column('tenants', sa.Column('plan', sa.String(length=50), nullable=True))
op.add_column('tenants', sa.Column('plan_expired_at', sa.DateTime(), nullable=True))
op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.String(length=100), nullable=True))
op.add_column('tenants', sa.Column('status', sa.String(length=50), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('tenants', 'status')
op.drop_column('tenants', 'api_ops_rate_limit')
op.drop_column('tenants', 'plan_expired_at')
op.drop_column('tenants', 'plan')
op.drop_column('tenants', 'contact_phone')
op.drop_column('tenants', 'contact_email')
op.drop_column('tenants', 'contact_name')
# ### end Alembic commands ###

View File

@@ -147,6 +147,7 @@ dependencies = [
"modelscope>=1.34.0", "modelscope>=1.34.0",
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'", "python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
"python-magic-bin>=0.4.14; sys_platform=='win32'", "python-magic-bin>=0.4.14; sys_platform=='win32'",
"volcengine-python-sdk[ark]==5.0.19"
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]