diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index 14962a72..4e1ba74c 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -574,8 +574,12 @@ async def get_file_url( # For local storage, generate signed URL with expiration url = generate_signed_url(str(file_id), expires) else: - # For remote storage (OSS/S3), get presigned URL - url = await storage_service.get_file_url(file_key, expires=expires) + # For remote storage (OSS/S3), get presigned URL with forced download + url = await storage_service.get_file_url( + file_key, + expires=expires, + file_name=file_metadata.file_name, + ) url = _match_scheme(request, url) api_logger.info(f"Generated file URL: file_id={file_id}") @@ -786,7 +790,7 @@ async def permanent_download_file( # For remote storage, redirect to presigned URL with long expiration try: # Use a very long expiration (7 days max for most cloud providers) - presigned_url = await storage_service.get_file_url(file_key, expires=604800) + presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name) presigned_url = _match_scheme(request, presigned_url) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) except Exception as e: diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py index 0f2da3b0..6f27d87a 100644 --- a/api/app/controllers/mcp_market_config_controller.py +++ b/api/app/controllers/mcp_market_config_controller.py @@ -91,9 +91,11 @@ async def get_mcp_servers( try: cookies = api.get_cookies(token) + headers=api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' r = api.session.put( url=api.mcp_base_url, - headers=api.builder_headers(api.headers), + headers=headers, json=body, cookies=cookies) raise_for_http_status(r) @@ -173,6 +175,7 @@ async def get_operational_mcp_servers( url = f'{api.mcp_base_url}/operational' headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' try: cookies = api.get_cookies(access_token=token, cookies_required=True) @@ -260,7 +263,9 @@ async def create_mcp_market_config( api.login(create_data.token) body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} cookies = api.get_cookies(create_data.token) - r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {create_data.token}' + r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies) raise_for_http_status(r) except Exception as e: api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") @@ -290,9 +295,11 @@ async def create_mcp_market_config( 'search': "" } cookies = api.get_cookies(token) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {token}' r = api.session.put( url=api.mcp_base_url, - headers=api.builder_headers(api.headers), + headers=headers, json=body, cookies=cookies) raise_for_http_status(r) @@ -393,7 +400,9 @@ async def update_mcp_market_config( api.login(update_data.token) body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} cookies = api.get_cookies(update_data.token) - r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) + headers = api.builder_headers(api.headers) + headers['Authorization'] = f'Bearer {update_data.token}' + r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies) raise_for_http_status(r) except Exception as e: api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 33d7b60c..f5284b46 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -669,6 +669,7 @@ async def config_query( content = { "app_type": release.app.type, "variables": release.config.get("variables"), + "memory": release.config.get("memory", {}).get("enabled"), "features": release.config.get("features") } elif release.app.type == AppType.MULTI_AGENT: diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 6176caf5..2074b6ca 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): logger.info('写入长期记忆NEO4J') - formatted_messages = (redis_messages) + formatted_messages = redis_messages # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id diff --git a/api/app/core/memory/llm_tools/openai_embedder.py b/api/app/core/memory/llm_tools/openai_embedder.py index 2d6fccbc..6ae87887 100644 --- a/api/app/core/memory/llm_tools/openai_embedder.py +++ b/api/app/core/memory/llm_tools/openai_embedder.py @@ -2,6 +2,7 @@ OpenAI Embedder 客户端实现 基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 +自动支持火山引擎的多模态 Embedding。 """ from typing import List @@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import ( ) from app.core.models.base import RedBearModelConfig from app.core.models.embedding import RedBearEmbeddings +from app.models.models_model import ModelProvider logger = logging.getLogger(__name__) @@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient): - 批量文本嵌入 - 自动重试机制 - 错误处理 + - 火山引擎多模态 Embedding(自动识别) """ def __init__(self, model_config: RedBearModelConfig): @@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient): """ super().__init__(model_config) - # 初始化 RedBearEmbeddings 模型 + # 初始化 RedBearEmbeddings(自动支持火山引擎多模态) self.model = RedBearEmbeddings( RedBearModelConfig( model_name=self.model_name, @@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient): timeout=self.timeout, ) ) + self.is_multimodal = self.model.is_multimodal_supported() - logger.info("OpenAI Embedder 客户端初始化完成") + logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})") async def response( self, @@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient): return [] # 生成嵌入向量 - embeddings = await self.model.aembed_documents(texts) + if self.is_multimodal: + # 火山引擎多模态 Embedding + embeddings = await self.model.aembed_multimodal( + [{"type": "text", "text": text} for text in texts] + ) + else: + # 普通 Embedding + embeddings = await self.model.aembed_documents(texts) logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") return embeddings diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index da10c497..e0b86d8c 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1099,7 +1099,6 @@ class ExtractionOrchestrator: metadata=chunk.metadata, ) chunk_nodes.append(chunk_node) - logger.error(f"chunk file: {chunk.files}") for p, file_type in chunk.files: diff --git a/api/app/core/models/__init__.py b/api/app/core/models/__init__.py index f54afc08..f98d073f 100644 --- a/api/app/core/models/__init__.py +++ b/api/app/core/models/__init__.py @@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto from .llm import RedBearLLM from .embedding import RedBearEmbeddings from .rerank import RedBearRerank +from .generation import RedBearImageGenerator, RedBearVideoGenerator __all__ = [ "RedBearModelConfig", @@ -9,5 +10,7 @@ __all__ = [ "RedBearEmbeddings", "RedBearRerank", "RedBearModelFactory", - "get_provider_llm_class" + "get_provider_llm_class", + "RedBearImageGenerator", + "RedBearVideoGenerator" ] \ No newline at end of file diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 4a453c6b..80117f27 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -67,7 +67,7 @@ class RedBearModelFactory: **config.extra_params } - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: # 使用 httpx.Timeout 对象来设置详细的超时配置 # 这样可以分别控制连接超时和读取超时 import httpx @@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: return ChatOpenAI - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]: if type == ModelType.LLM: return OpenAI elif type == ModelType.CHAT: return ChatOpenAI + else: + raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED) elif provider == ModelProvider.DASHSCOPE: return ChatTongyi elif provider == ModelProvider.OLLAMA: diff --git a/api/app/core/models/embedding.py b/api/app/core/models/embedding.py index 16af2567..3269e1d0 100644 --- a/api/app/core/models/embedding.py +++ b/api/app/core/models/embedding.py @@ -1,23 +1,190 @@ -from typing import Any, Dict, List, Optional, TypeVar, Callable +from typing import Any, Dict, List, Optional, Union from langchain_core.embeddings import Embeddings -from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory +from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory +from app.models.models_model import ModelProvider + class RedBearEmbeddings(Embeddings): - """Embedding → 完全符合 LangChain Embeddings""" + """统一的 Embedding 类,自动支持多模态(根据 provider 判断)""" + def __init__(self, config: RedBearModelConfig): - self._model = self._create_model(config) self._config = config + self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO + + if self._is_volcano: + # 火山引擎使用 Ark SDK + self._client = self._create_volcano_client(config) + self._model = None + else: + # 其他 provider 使用 LangChain + self._model = self._create_model(config) + self._client = None def _create_model(self, config: RedBearModelConfig) -> Embeddings: - """根据配置创建模型""" + """根据配置创建 LangChain 模型""" embedding_class = get_provider_embedding_class(config.provider) model_params = RedBearModelFactory.get_model_params(config) return embedding_class(**model_params) + + def _create_volcano_client(self, config: RedBearModelConfig): + """创建火山引擎客户端""" + from volcenginesdkarkruntime import Ark + return Ark(api_key=config.api_key, base_url=config.base_url) + # ==================== LangChain 标准接口 ==================== + def embed_documents(self, texts: list[str]) -> list[list[float]]: - return self._model.embed_documents(texts) + """批量文本向量化(LangChain 标准接口)""" + if self._is_volcano: + # 火山引擎多模态 Embedding + contents = [{"type": "text", "text": text} for text in texts] + response = self._client.multimodal_embeddings.create( + model=self._config.model_name, + input=contents, + encoding_format="float" + ) + return [response.data.embedding] + else: + # 其他 provider + return self._model.embed_documents(texts) def embed_query(self, text: str) -> List[float]: - return self._model.embed_query(text) + """单个文本向量化(LangChain 标准接口)""" + if self._is_volcano: + # 火山引擎多模态 Embedding + result = self.embed_documents([text]) + return result[0] if result else [] + else: + # 其他 provider + return self._model.embed_query(text) + + # ==================== 多模态扩展方法 ==================== + + def embed_multimodal( + self, + contents: List[Dict[str, Any]], + **kwargs + ) -> List[List[float]]: + """ + 多模态向量化(仅火山引擎支持) + + Args: + contents: 内容列表,格式: + - 文本: {"type": "text", "text": "..."} + - 图片: {"type": "image_url", "image_url": {"url": "..."}} + - 视频: {"type": "video_url", "video_url": {"url": "..."}} + **kwargs: 其他参数 + + Returns: + 向量列表 + """ + if not self._is_volcano: + raise NotImplementedError( + f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + response = self._client.multimodal_embeddings.create( + model=self._config.model_name, + input=contents, + **kwargs + ) + return [response.data.embedding] + + async def aembed_multimodal( + self, + contents: List[Dict[str, Any]], + **kwargs + ) -> List[List[float]]: + """异步多模态向量化""" + # 火山引擎 SDK 暂不支持异步,使用同步方法 + return self.embed_multimodal(contents, **kwargs) + + def embed_text(self, text: str, **kwargs) -> List[float]: + """文本向量化(便捷方法)""" + if self._is_volcano: + result = self.embed_multimodal( + [{"type": "text", "text": text}], + **kwargs + ) + return result[0] if result else [] + else: + return self.embed_query(text) + + def embed_image(self, image_url: str, **kwargs) -> List[float]: + """图片向量化(仅火山引擎支持)""" + if not self._is_volcano: + raise NotImplementedError( + f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + result = self.embed_multimodal( + [{"type": "image_url", "image_url": {"url": image_url}}], + **kwargs + ) + return result[0] if result else [] + + def embed_video(self, video_url: str, **kwargs) -> List[float]: + """视频向量化(仅火山引擎支持)""" + if not self._is_volcano: + raise NotImplementedError( + f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + result = self.embed_multimodal( + [{"type": "video_url", "video_url": {"url": video_url}}], + **kwargs + ) + return result[0] if result else [] + + def embed_batch( + self, + items: List[Union[str, Dict[str, Any]]], + **kwargs + ) -> List[List[float]]: + """ + 批量向量化(支持混合类型) + + Args: + items: 可以是字符串列表或内容字典列表 + **kwargs: 其他参数 + + Returns: + 向量列表 + """ + # 如果全是字符串,使用标准方法 + if all(isinstance(item, str) for item in items): + return self.embed_documents(items) + + # 如果包含字典,需要多模态支持 + if not self._is_volcano: + raise NotImplementedError( + f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}" + ) + + # 标准化输入格式 + contents = [] + for item in items: + if isinstance(item, str): + contents.append({"type": "text", "text": item}) + elif isinstance(item, dict): + contents.append(item) + else: + raise ValueError(f"不支持的输入类型: {type(item)}") + + return self.embed_multimodal(contents, **kwargs) + + # ==================== 工具方法 ==================== + + def is_multimodal_supported(self) -> bool: + """检查是否支持多模态""" + return self._is_volcano + + def get_provider(self) -> str: + """获取 provider""" + return self._config.provider + + +# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容 +RedBearMultimodalEmbeddings = RedBearEmbeddings diff --git a/api/app/core/models/generation.py b/api/app/core/models/generation.py new file mode 100644 index 00000000..b6388d3f --- /dev/null +++ b/api/app/core/models/generation.py @@ -0,0 +1,344 @@ +""" +图片和视频生成模型封装 + +支持的 Provider: +- Volcano (火山引擎): 使用 volcenginesdkarkruntime +- OpenAI: 使用 openai SDK +""" +from typing import Any, Dict, Optional + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.images.images import ( + SequentialImageGenerationOptions, + ContentGenerationTool, + OptimizePromptOptions +) + +from app.core.models.base import RedBearModelConfig +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.models.models_model import ModelProvider + + +class RedBearImageGenerator: + """图片生成模型封装""" + + def __init__(self, config: RedBearModelConfig): + self._config = config + self._client = self._create_client(config) + + def _create_client(self, config: RedBearModelConfig): + """根据 provider 创建客户端""" + provider = config.provider.lower() + + if provider == ModelProvider.VOLCANO: + return Ark(api_key=config.api_key, base_url=config.base_url) + # elif provider == ModelProvider.OPENAI: + # from openai import OpenAI + # return OpenAI(api_key=config.api_key, base_url=config.base_url) + else: + raise BusinessException( + f"不支持的图片生成提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def generate( + self, + prompt: str, + image: Optional[Any] = None, + size: Optional[str] = "2K", + output_format: str = "png", + response_format: str = "url", + watermark: bool = False, + sequential_image_generation: Optional[str] = None, + sequential_image_generation_options: Optional[Dict] = None, + tools: Optional[list] = None, + optimize_prompt_options: Optional[Dict] = None, + stream: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + 生成图片 + + Args: + prompt: 提示词 + image: 参考图片URL或URL列表(图文生图/多图融合) + size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素) + output_format: 输出格式,如 "png", "jpg" + response_format: 返回格式,"url" 或 "b64_json" + watermark: 是否添加水印 + sequential_image_generation: 组图生成模式,"auto" 或 "disabled" + sequential_image_generation_options: 组图生成选项,如 {"max_images": 4} + tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图 + optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"} + stream: 是否使用流式生成 + **kwargs: 其他参数 + + Returns: + 生成结果 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + params = { + "model": self._config.model_name, + "prompt": prompt, + "size": size, + "output_format": output_format, + "response_format": response_format, + "watermark": watermark, + } + + if image is not None: + params["image"] = image + + if sequential_image_generation: + params["sequential_image_generation"] = sequential_image_generation + if sequential_image_generation_options: + params["sequential_image_generation_options"] = SequentialImageGenerationOptions( + **sequential_image_generation_options + ) + + if tools: + params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools] + + if optimize_prompt_options: + params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options) + + if stream: + params["stream"] = True + + params.update(kwargs) + response = self._client.images.generate(**params) + + # elif provider == ModelProvider.OPENAI: + # response = self._client.images.generate( + # model=self._config.model_name, + # prompt=prompt, + # size=size, + # n=n, + # **kwargs + # ) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + return response.model_dump() if hasattr(response, 'model_dump') else response + + async def agenerate( + self, + prompt: str, + image: Optional[Any] = None, + size: Optional[str] = "2K", + output_format: str = "png", + response_format: str = "url", + watermark: bool = False, + **kwargs + ) -> Dict[str, Any]: + """异步生成图片""" + return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs) + + +class RedBearVideoGenerator: + """视频生成模型封装""" + + def __init__(self, config: RedBearModelConfig): + self._config = config + self._client = self._create_client(config) + + def _create_client(self, config: RedBearModelConfig): + """根据 provider 创建客户端""" + provider = config.provider.lower() + + if provider == ModelProvider.VOLCANO: + return Ark(api_key=config.api_key, base_url=config.base_url) + else: + raise BusinessException( + f"不支持的视频生成提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def generate( + self, + prompt: str, + image_url: Optional[str] = None, + first_frame_url: Optional[str] = None, + last_frame_url: Optional[str] = None, + reference_images: Optional[list] = None, + draft_task_id: Optional[str] = None, + duration: Optional[int] = None, + frames: Optional[int] = None, + ratio: Optional[str] = None, + resolution: Optional[str] = None, + generate_audio: bool = False, + watermark: bool = False, + camera_fixed: bool = False, + seed: Optional[int] = None, + return_last_frame: bool = False, + service_tier: str = "default", + execution_expires_after: Optional[int] = None, + draft: bool = False, + **kwargs + ) -> Dict[str, Any]: + """ + 生成视频 + + Args: + prompt: 提示词 + image_url: 首帧图片URL(图生视频-基于首帧) + first_frame_url: 首帧图片URL(图生视频-基于首尾帧) + last_frame_url: 尾帧图片URL(图生视频-基于首尾帧) + reference_images: 参考图片URL列表(图生视频-基于参考图) + draft_task_id: Draft任务ID(基于Draft生成正式视频) + duration: 视频时长(秒),与frames二选一 + frames: 视频帧数,与duration二选一 + ratio: 视频比例,如 "16:9", "9:16", "adaptive" + resolution: 视频分辨率,如 "720p", "1080p" + generate_audio: 是否生成音频 + watermark: 是否添加水印 + camera_fixed: 是否固定镜头 + seed: 随机种子 + return_last_frame: 是否返回最后一帧 + service_tier: 服务层级,"default" 或 "flex"(离线推理) + execution_expires_after: 任务过期时间(秒) + draft: 是否生成样片 + **kwargs: 其他参数 + + Returns: + 生成结果(包含任务ID,需要轮询获取结果) + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + content = [{"type": "text", "text": prompt}] + + if draft_task_id: + content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}] + else: + if image_url: + content.append({"type": "image_url", "image_url": {"url": image_url}}) + + if first_frame_url: + content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"}) + if last_frame_url: + content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"}) + + if reference_images: + for ref_url in reference_images: + content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"}) + + params = {"model": self._config.model_name, "content": content, "watermark": watermark} + + if duration: + params["duration"] = duration + if frames: + params["frames"] = frames + if ratio: + params["ratio"] = ratio + if resolution: + params["resolution"] = resolution + if generate_audio: + params["generate_audio"] = generate_audio + if camera_fixed: + params["camera_fixed"] = camera_fixed + if seed is not None: + params["seed"] = seed + if return_last_frame: + params["return_last_frame"] = return_last_frame + if service_tier != "default": + params["service_tier"] = service_tier + if execution_expires_after: + params["execution_expires_after"] = execution_expires_after + if draft: + params["draft"] = draft + + params.update(kwargs) + response = self._client.content_generation.tasks.create(**params) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + return response.model_dump() if hasattr(response, 'model_dump') else response + + async def agenerate( + self, + prompt: str, + image_url: Optional[str] = None, + duration: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """异步生成视频""" + return self.generate(prompt, image_url=image_url, duration=duration, **kwargs) + + def get_task_status(self, task_id: str) -> Dict[str, Any]: + """ + 查询视频生成任务状态 + + Args: + task_id: 任务ID + + Returns: + 任务状态信息 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + response = self._client.content_generation.tasks.get(task_id=task_id) + return response.model_dump() if hasattr(response, 'model_dump') else response + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + async def aget_task_status(self, task_id: str) -> Dict[str, Any]: + """异步查询任务状态""" + return self.get_task_status(task_id) + + def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]: + """ + 查询视频生成任务列表 + + Args: + page_size: 每页数量 + status: 任务状态筛选,如 "succeeded", "failed", "pending" + **kwargs: 其他参数 + + Returns: + 任务列表 + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + params = {"page_size": page_size} + if status: + params["status"] = status + params.update(kwargs) + response = self._client.content_generation.tasks.list(**params) + return response.model_dump() if hasattr(response, 'model_dump') else response + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) + + def delete_task(self, task_id: str) -> None: + """ + 删除或取消视频生成任务 + + Args: + task_id: 任务ID + """ + provider = self._config.provider.lower() + + if provider == ModelProvider.VOLCANO: + self._client.content_generation.tasks.delete(task_id=task_id) + else: + raise BusinessException( + f"不支持的提供商: {provider}", + code=BizCode.PROVIDER_NOT_SUPPORTED + ) diff --git a/api/app/core/models/scripts/volcano_models.yaml b/api/app/core/models/scripts/volcano_models.yaml new file mode 100644 index 00000000..24609f5a --- /dev/null +++ b/api/app/core/models/scripts/volcano_models.yaml @@ -0,0 +1,334 @@ +provider: volcano +models: +# Doubao-Seed 2.0 系列 +- name: doubao-seed-2-0-pro-260215 + type: chat + provider: volcano + description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-lite-260215 + type: chat + provider: volcano + description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-mini-260215 + type: chat + provider: volcano + description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-2-0-code-preview-260215 + type: chat + provider: volcano + description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 代码模型 + logo: volcano + +# Doubao-Seed 1.x 系列 +- name: doubao-seed-1-8-251228 + type: chat + provider: volcano + description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-251015 + type: chat + provider: volcano + description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-lite-251015 + type: chat + provider: volcano + description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-1-6-flash-250828 + type: chat + provider: volcano + description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-seed-code-preview-251028 + type: chat + provider: volcano + description: 面向Agentic编程任务进行了深度优化。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 代码模型 + logo: volcano + +- name: doubao-seed-1-6-vision-250815 + type: chat + provider: volcano + description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 大语言模型 + - 多模态模型 + logo: volcano + +# Doubao 1.5 系列 +- name: doubao-1-5-vision-pro-32k-250115 + type: chat + provider: volcano + description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 大语言模型 + - 多模态模型 + logo: volcano + +- name: doubao-1-5-pro-32k-250115 + type: chat + provider: volcano + description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 大语言模型 + logo: volcano + +- name: doubao-1-5-lite-32k-250115 + type: chat + provider: volcano + description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 大语言模型 + logo: volcano + +# Doubao-Seedance 视频生成系列 +- name: doubao-seedance-1-5-pro-251215 + type: video + provider: volcano + description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-pro-250528 + type: video + provider: volcano + description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-pro-fast-251015 + type: video + provider: volcano + description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + logo: volcano + +- name: doubao-seedance-1-0-lite-i2v-250428 + type: video + provider: volcano + description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 视频生成 + - 图生视频 + logo: volcano + +- name: doubao-seedance-1-0-lite-t2v-250428 + type: video + provider: volcano + description: 基于文本提示词生成视频 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 视频生成 + - 文生视频 + logo: volcano + +# Doubao-Seedream 图像生成系列 +- name: doubao-seedream-5-0-260128 + type: image + provider: volcano + description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-4-5-251128 + type: image + provider: volcano + description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-4-0-250828 + type: image + provider: volcano + description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。 + is_deprecated: false + is_official: true + capability: + - vision + is_omni: false + tags: + - 图像生成 + logo: volcano + +- name: doubao-seedream-3-0-t2i-250415 + type: image + provider: volcano + description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。 + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 图像生成 + - 文生图 + logo: volcano + +# Doubao 翻译系列 +- name: doubao-seed-translation-250915 + type: chat + provider: volcano + description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens + is_deprecated: false + is_official: true + capability: [] + is_omni: false + tags: + - 翻译模型 + logo: volcano + +# Doubao Embedding 系列 +- name: doubao-embedding-vision-251215 + type: embedding + provider: volcano + description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。 + is_deprecated: false + is_official: true + capability: + - vision + - video + is_omni: false + tags: + - 向量模型 + - 多模态模型 + logo: volcano diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 198d1473..386920e0 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel): class ElasticSearchVector(BaseVector): def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): super().__init__(index_name.lower()) - # self.embeddings = XinferenceEmbeddings( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port - # model_uid="bge-m3" # replace model_uid with the model UID return from launching the model - # ) - # Remove debug printing to avoid leaking sensitive information - # print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base) + + # 初始化 Embedding 模型(自动支持火山引擎多模态) self.embeddings = RedBearEmbeddings(RedBearModelConfig( model_name=embedding_config.model_name, provider=embedding_config.provider, api_key=embedding_config.api_key, base_url=embedding_config.api_base )) - # self.reranker = XinferenceRerank( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), - # model_uid="bge-reranker-large" - # ) - # Remove debug printing to avoid leaking sensitive information - # print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base) + self.is_multimodal_embedding = self.embeddings.is_multimodal_supported() + self.reranker = RedBearRerank(RedBearModelConfig( model_name=reranker_config.model_name, provider=reranker_config.provider, @@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector): def add_chunks(self, chunks: list[DocumentChunk], **kwargs): # 实现 Elasticsearch 保存向量 texts = [chunk.page_content for chunk in chunks] - embeddings = self.embeddings.embed_documents(list(texts)) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + embeddings = self.embeddings.embed_batch(texts) + else: + embeddings = self.embeddings.embed_documents(list(texts)) self.create(chunks, embeddings, **kwargs) def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): @@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector): updated count. """ indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - chunk.vector = self.embeddings.embed_query(chunk.page_content) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + chunk.vector = self.embeddings.embed_text(chunk.page_content) + else: + chunk.vector = self.embeddings.embed_query(chunk.page_content) body = { "script": { @@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector): def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: """Search the nearest neighbors to a vector.""" - query_vector = self.embeddings.embed_query(query) + if self.is_multimodal_embedding: + # 火山引擎多模态 Embedding + query_vector = self.embeddings.embed_text(query) + else: + query_vector = self.embeddings.embed_query(query) top_k = kwargs.get("top_k", 1024) score_threshold = float(kwargs.get("score_threshold") or 0.3) indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" diff --git a/api/app/core/storage/base.py b/api/app/core/storage/base.py index 8ab0fcde..09824c3f 100644 --- a/api/app/core/storage/base.py +++ b/api/app/core/storage/base.py @@ -109,17 +109,13 @@ class StorageBackend(ABC): pass @abstractmethod - async def get_url(self, file_key: str, expires: int = 3600) -> str: - """ - Get an access URL for the file. - - Args: - file_key: Unique identifier for the file in the storage system. - expires: URL validity period in seconds (default: 1 hour). - - Returns: - URL for accessing the file. - """ + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None + ) -> str: + """Get an access URL for the file.""" pass async def get_permanent_url(self, file_key: str) -> Optional[str]: diff --git a/api/app/core/storage/local.py b/api/app/core/storage/local.py index 4b8ae829..13adfc20 100644 --- a/api/app/core/storage/local.py +++ b/api/app/core/storage/local.py @@ -210,7 +210,12 @@ class LocalStorage(StorageBackend): cause=e, ) - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None + ) -> str: """ Get an access URL for the file. @@ -220,6 +225,7 @@ class LocalStorage(StorageBackend): Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (not used for local storage). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A relative URL path for accessing the file. diff --git a/api/app/core/storage/oss.py b/api/app/core/storage/oss.py index 27669ffa..1db86fef 100644 --- a/api/app/core/storage/oss.py +++ b/api/app/core/storage/oss.py @@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK. import io import logging +import urllib.parse from typing import AsyncIterator, Optional import oss2 @@ -242,24 +243,33 @@ class OSSStorage(StorageBackend): logger.error(f"Failed to check file existence in OSS {file_key}: {e}") return False - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get a presigned URL for accessing the file. Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A presigned URL for accessing the file. """ try: - url = self.bucket.sign_url("GET", file_key, expires) + params = {} + if file_name: + filename_encoded = urllib.parse.quote(file_name.encode("utf-8")) + params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}" + url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None) logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") return url except Exception as e: logger.error(f"Failed to generate presigned URL for {file_key}: {e}") - # Return a basic URL format as fallback return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" async def get_permanent_url(self, file_key: str) -> str: diff --git a/api/app/core/storage/s3.py b/api/app/core/storage/s3.py index c7b33ffe..f156f4a7 100644 --- a/api/app/core/storage/s3.py +++ b/api/app/core/storage/s3.py @@ -6,6 +6,7 @@ using the boto3 SDK. """ import io +import urllib.parse import logging from typing import AsyncIterator, Optional @@ -352,31 +353,37 @@ class S3Storage(StorageBackend): logger.error(f"Failed to check file existence in S3 {file_key}: {e}") return False - async def get_url(self, file_key: str, expires: int = 3600) -> str: + async def get_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get a presigned URL for accessing the file. Args: file_key: Unique identifier for the file in the storage system. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: A presigned URL for accessing the file. """ try: + params = {"Bucket": self.bucket_name, "Key": file_key} + if file_name: + filename_encoded = urllib.parse.quote(file_name.encode("utf-8")) + params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}" url = self.client.generate_presigned_url( "get_object", - Params={ - "Bucket": self.bucket_name, - "Key": file_key, - }, + Params=params, ExpiresIn=expires, ) logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") return url except Exception as e: logger.error(f"Failed to generate presigned URL for {file_key}: {e}") - # Return a basic URL format as fallback return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" async def get_permanent_url(self, file_key: str) -> str: diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py index 49321b89..2e24d085 100644 --- a/api/app/core/workflow/adapters/base_adapter.py +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -9,7 +9,7 @@ from typing import Any from pydantic import BaseModel, Field -from app.core.workflow.adapters.errors import ExceptionDefineition +from app.core.workflow.adapters.errors import ExceptionDefinition from app.schemas.workflow_schema import ( EdgeDefinition, NodeDefinition, @@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) - warnings: list[ExceptionDefineition] = Field(default_factory=list) - errors: list[ExceptionDefineition] = Field(default_factory=list) + warnings: list[ExceptionDefinition] = Field(default_factory=list) + errors: list[ExceptionDefinition] = Field(default_factory=list) class WorkflowImportResult(BaseModel): @@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel): edges: list[EdgeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list) - warnings: list[ExceptionDefineition] = Field(default_factory=list) - errors: list[ExceptionDefineition] = Field(default_factory=list) + warnings: list[ExceptionDefinition] = Field(default_factory=list) + errors: list[ExceptionDefinition] = Field(default_factory=list) class BasePlatformAdapter(ABC): diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 467beb07..4fa9508b 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -9,9 +9,9 @@ from urllib.parse import quote from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.errors import ( - UnsupportVariableType, - UnknowModelWarning, - ExceptionDefineition, + UnsupportedVariableType, + UnknownModelWarning, + ExceptionDefinition, ExceptionType ) from app.core.workflow.nodes.assigner.config import AssignmentItem @@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import ( HttpFormData, HttpTimeOutConfig, HttpRetryConfig, - HttpErrorDefaultTamplete, + HttpErrorDefaultTemplate, HttpErrorHandleConfig ) from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig @@ -108,7 +108,7 @@ class DifyConverter(BaseConverter): try: return config.model_validate(value) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node_id, node_name=node_name, @@ -138,7 +138,7 @@ class DifyConverter(BaseConverter): var_selector = mapping.get(var_selector, var_selector) return var_selector - def _process_list_variable_litearl(self, variable_selector: list) -> str | None: + def _process_list_variable_literal(self, variable_selector: list) -> str | None: if not self.process_var_selector(".".join(variable_selector)): return None return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" @@ -269,7 +269,7 @@ class DifyConverter(BaseConverter): var_type = self.variable_type_map(var["type"]) if not var_type: self.errors.append( - UnsupportVariableType( + UnsupportedVariableType( scope=node["id"], name=var["variable"], var_type=var["type"], @@ -281,7 +281,7 @@ class DifyConverter(BaseConverter): if var_type in ["file", "array[file]"]: self.errors.append( - ExceptionDefineition( + ExceptionDefinition( type=ExceptionType.VARIABLE, node_id=node["id"], node_name=node_data["title"], @@ -311,7 +311,7 @@ class DifyConverter(BaseConverter): def convert_question_classifier_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") @@ -327,7 +327,7 @@ class DifyConverter(BaseConverter): ) result = QuestionClassifierNodeConfig.model_construct( - input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")), + input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), categories=categories, ).model_dump() @@ -337,13 +337,13 @@ class DifyConverter(BaseConverter): def convert_llm_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") ) ) - context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) + context = self._process_list_variable_literal(node_data["context"]["variable_selector"]) memory = MemoryWindowSetting( enable=bool(node_data.get("memory")), enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), @@ -367,7 +367,7 @@ class DifyConverter(BaseConverter): ) ) vision = node_data["vision"]["enabled"] - vision_input = self._process_list_variable_litearl( + vision_input = self._process_list_variable_literal( node_data["vision"]["configs"]["variable_selector"] ) if vision else None result = LLMNodeConfig.model_construct( @@ -433,7 +433,7 @@ class DifyConverter(BaseConverter): conditions.append( LoopConditionDetail.model_construct( operator=self.convert_compare_operator(condition["comparison_operator"]), - left=self._process_list_variable_litearl(condition["variable_selector"]), + left=self._process_list_variable_literal(condition["variable_selector"]), right=self.trans_variable_format( right_value ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( @@ -453,7 +453,7 @@ class DifyConverter(BaseConverter): right_input_type = variable["value_type"] right_value_type = self.variable_type_map(variable["var_type"]) if right_input_type == ValueInputType.VARIABLE: - right_value = self._process_list_variable_litearl(variable.get("value", "")) + right_value = self._process_list_variable_literal(variable.get("value", "")) else: right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) loop_variables.append( @@ -475,10 +475,10 @@ class DifyConverter(BaseConverter): def convert_iteration_node_config(self, node: dict) -> dict: node_data = node["data"] result = IterationNodeConfig.model_construct( - input=self._process_list_variable_litearl(node_data["iterator_selector"]), + input=self._process_list_variable_literal(node_data["iterator_selector"]), parallel=node_data["is_parallel"], parallel_count=node_data["parallel_nums"], - output=self._process_list_variable_litearl(node_data["output_selector"]), + output=self._process_list_variable_literal(node_data["output_selector"]), output_type=self.variable_type_map(node_data.get("output_type")), flatten=node_data["flatten_output"], ).model_dump() @@ -494,8 +494,8 @@ class DifyConverter(BaseConverter): continue assignments.append( AssignmentItem( - variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), - value=self._process_list_variable_litearl( + variable_selector=self._process_list_variable_literal(assignment["variable_selector"]), + value=self._process_list_variable_literal( assignment["value"] ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], operation=self.convert_assignment_operator(assignment["operation"]) @@ -514,7 +514,7 @@ class DifyConverter(BaseConverter): input_variables.append( InputVariable.model_construct( name=input_variable["variable"], - variable=self._process_list_variable_litearl(input_variable["value_selector"]), + variable=self._process_list_variable_literal(input_variable["value_selector"]), ) ) @@ -570,7 +570,7 @@ class DifyConverter(BaseConverter): else: if node_data["body"]["data"]: body_content = (node_data["body"]["data"][0].get("value") or - self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) + self._process_list_variable_literal(node_data["body"]["data"][0].get("file"))) else: body_content = "" @@ -585,7 +585,7 @@ class DifyConverter(BaseConverter): self.trans_variable_format(key_value[0]) ] = self.trans_variable_format(key_value[1]) else: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node["id"], node_name=node_data["title"], @@ -603,7 +603,7 @@ class DifyConverter(BaseConverter): self.trans_variable_format(key_value[0]) ] = self.trans_variable_format(key_value[1]) else: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node["id"], node_name=node_data["title"], @@ -625,7 +625,7 @@ class DifyConverter(BaseConverter): default_header = var["value"] elif var["key"] == "status_code": default_status_code = var["value"] - default_value = HttpErrorDefaultTamplete( + default_value = HttpErrorDefaultTemplate( body=default_body, headers=default_header, status_code=default_status_code, @@ -668,7 +668,7 @@ class DifyConverter(BaseConverter): for variable in node_data["variables"]: mapping.append(VariablesMappingConfig.model_construct( name=variable["variable"], - value=self._process_list_variable_litearl(variable["value_selector"]) + value=self._process_list_variable_literal(variable["value_selector"]) )) result = JinjaRenderNodeConfig.model_construct( template=node_data["template"], @@ -679,14 +679,14 @@ class DifyConverter(BaseConverter): def convert_knowledge_node_config(self, node: dict) -> dict: node_data = node["data"] - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( node_id=node["id"], node_name=node_data["title"], type=ExceptionType.CONFIG, detail=f"Please reconfigure the Knowledge Retrieval node.", )) result = KnowledgeRetrievalNodeConfig.model_construct( - query=self._process_list_variable_litearl(node_data["query_variable_selector"]), + query=self._process_list_variable_literal(node_data["query_variable_selector"]), ).model_dump() self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) @@ -695,7 +695,7 @@ class DifyConverter(BaseConverter): def convert_parameter_extractor_node_config(self, node: dict) -> dict: node_data = node["data"] self.warnings.append( - UnknowModelWarning( + UnknownModelWarning( node_id=node["id"], node_name=node_data["title"], model_name=node_data["model"].get("name") @@ -712,7 +712,7 @@ class DifyConverter(BaseConverter): ) ) result = ParameterExtractorNodeConfig.model_construct( - text=self._process_list_variable_litearl(node_data["query"]), + text=self._process_list_variable_literal(node_data["query"]), params=params, prompt=node_data.get("instruction") ).model_dump() @@ -727,14 +727,14 @@ class DifyConverter(BaseConverter): group_type = {} if not advanced_settings or not advanced_settings["group_enabled"]: group_variables = [ - self._process_list_variable_litearl(variable) + self._process_list_variable_literal(variable) for variable in node_data["variables"] ] group_type["output"] = node_data["output_type"] else: for group in advanced_settings["groups"]: group_variables[group["group_name"]] = [ - self._process_list_variable_litearl(variable) + self._process_list_variable_literal(variable) for variable in group["variables"] ] group_type[group["group_name"]] = group["output_type"] @@ -751,7 +751,7 @@ class DifyConverter(BaseConverter): def convert_tool_node_config(self, node: dict) -> dict: node_data = node["data"] - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( node_id=node["id"], node_name=node_data["title"], type=ExceptionType.CONFIG, diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index 10397ad0..abd95408 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import ( WorkflowParserResult ) from app.core.workflow.adapters.dify.converter import DifyConverter -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.nodes.enums import NodeType from app.schemas.workflow_schema import ( NodeDefinition, @@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): if not all(field in self.config for field in require_fields): return False if self.config.get("app", {}).get("mode") == "workflow": - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.PLATFORM, detail="workflow mode is not supported" )) @@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): edge = self._convert_edge(edge) if edge: self.edges.append(edge) - # + for variable in self.config.get("workflow").get("conversation_variables"): con_var = self._convert_variable(variable) if variable: self.conv_variables.append(con_var) - # + # for variables in config.get("workflow").get("environment_variables"): # variable = self._convert_variable(variables) # conv_variables.append(variable) @@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): "y": node["position"]["y"] + position["y"] } self.errors.append( - ExceptionDefineition( + ExceptionDefinition( type=ExceptionType.NODE, node_id=node_id, detail="parent cycle node not found" @@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): node_data = node["data"] converter = self.get_node_convert(node_type) if node_type == NodeType.UNKNOWN: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node["id"], node_name=node["data"]["title"], @@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): )) return converter(node) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node["id"], node_name=node["data"]["title"], @@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: try: - source = edge["source"] target = edge["target"] label = None @@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): label=label, ) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"convert edge error - {e}", )) @@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): description=variable.get("description") ) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.VARIABLE, name=variable.get("name"), detail=f"convert variable error - {e}", diff --git a/api/app/core/workflow/adapters/errors.py b/api/app/core/workflow/adapters/errors.py index c0340a5e..cb743c68 100644 --- a/api/app/core/workflow/adapters/errors.py +++ b/api/app/core/workflow/adapters/errors.py @@ -18,7 +18,7 @@ class ExceptionType(StrEnum): UNKNOWN = "unknown" -class ExceptionDefineition(BaseModel): +class ExceptionDefinition(BaseModel): type: ExceptionType detail: str @@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel): name: str | None = None -class UnknowModelWarning(ExceptionDefineition): +class UnknownModelWarning(ExceptionDefinition): type: ExceptionType = ExceptionType.NODE def __init__(self, node_id, node_name, model_name): @@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition): ) -class UnknowError(ExceptionDefineition): +class UnknownError(ExceptionDefinition): type: ExceptionType = ExceptionType.UNKNOWN def __init__(self, detail: str, **kwargs): super().__init__(detail=detail, **kwargs) -class UnsupportPlatform(ExceptionDefineition): +class UnsupportedPlatform(ExceptionDefinition): type: ExceptionType = ExceptionType.PLATFORM def __init__(self, platform: str): - super().__init__(detail=f"Unsupport platform {platform}") + super().__init__(detail=f"Unsupported platform {platform}") -class UnsupportVariableType(ExceptionDefineition): +class UnsupportedVariableType(ExceptionDefinition): type: ExceptionType = ExceptionType.VARIABLE def __init__(self, scope, name, var_type: str, **kwargs): - super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs) + super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs) -class InvalidConfiguration(ExceptionDefineition): +class InvalidConfiguration(ExceptionDefinition): type: ExceptionType = ExceptionType.CONFIG def __init__(self): super().__init__(detail="Invalid workflow configuration format") -class UnsupportNodeType(ExceptionDefineition): +class UnsupportedNodeType(ExceptionDefinition): type: ExceptionType = ExceptionType.NODE def __init__(self, node_id: str, node_type: str): - super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") + super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}") diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py index 3516cb58..a2608a01 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py @@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import ( BasePlatformAdapter, WorkflowParserResult ) -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter from app.core.workflow.nodes.enums import NodeType from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition @@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): try: node_type = self.map_node_type(node["type"]) if node_type == NodeType.UNKNOWN: - self.errors.append(UnsupportNodeType( + self.errors.append(UnsupportedNodeType( node_id=node_id, node_type=node["type"] )) @@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): return NodeDefinition(**node) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.NODE, node_id=node_id, node_name=node_name, @@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: try: if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"edge {edge.get('id')} skipped: source or target node not found" )) return None return EdgeDefinition(**edge) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.EDGE, detail=f"convert edge error - {e}" )) @@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter): try: return VariableDefinition(**variable) except Exception as e: - self.warnings.append(ExceptionDefineition( + self.warnings.append(ExceptionDefinition( type=ExceptionType.VARIABLE, name=variable.get("name"), detail=f"convert variable error - {e}" diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py index 031c7025..e96e0bf2 100644 --- a/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_converter.py @@ -1,6 +1,6 @@ # -*- coding: UTF-8 -*- from app.core.workflow.adapters.base_converter import BaseConverter -from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.configs import ( StartNodeConfig, @@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter): try: return config_cls.model_validate(value) except Exception as e: - self.errors.append(ExceptionDefineition( + self.errors.append(ExceptionDefinition( type=ExceptionType.CONFIG, node_id=node_id, node_name=node_name, diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index c5cf3324..daef6e82 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -7,7 +7,7 @@ import re import uuid from collections import defaultdict from functools import lru_cache -from typing import Any, Iterable +from typing import Any, Iterable, Callable from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import START, END @@ -41,48 +41,31 @@ class GraphBuilder: self, workflow_config: dict[str, Any], stream: bool = False, - subgraph: bool = False, + cycle: str = '', variable_pool: VariablePool | None = None ): self.workflow_config = workflow_config self.stream = stream - self.subgraph = subgraph + self.cycle = cycle self.start_node_id: str | None = None - self.node_map = {node["id"]: node for node in self.nodes} + self.node_map: dict[str, dict] = {} self.end_node_map: dict[str, StreamOutputConfig] = {} - self._find_upstream_activation_dep = lru_cache( - maxsize=len(self.nodes) * 2 - )(self._find_upstream_activation_dep) + self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep if variable_pool: self.variable_pool = variable_pool else: self.variable_pool = VariablePool() - self.graph = StateGraph(WorkflowState) - self.add_nodes() - self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) - self.end_nodes = [ - node - for node in self.nodes - if node.get("type") == "end" and node.get("id") in self.reachable_nodes - ] - self.add_edges() - # EDGES MUST BE ADDED AFTER NODES ARE ADDED. - + self.graph: StateGraph | None = None + self.nodes: list = [] + self.edges: list = [] + self.reachable_nodes: set[str] | None = None + self.end_nodes: list[dict] = [] self._reverse_adj: dict[str, list[dict]] = defaultdict(list) - self._build_reverse_adj() - self._analyze_end_node_output() - - @property - def nodes(self) -> list[dict[str, Any]]: - return self.workflow_config.get("nodes", []) - - @property - def edges(self) -> list[dict[str, Any]]: - return self.workflow_config.get("edges", []) + self._adj: dict[str, list[str]] = defaultdict(list) def get_node_type(self, node_id: str) -> str: """Retrieve the type of node given its ID. @@ -108,13 +91,14 @@ class GraphBuilder: result[node[0]].append(node[1]) return result - def _build_reverse_adj(self): + def _build_adj(self): for edge in self.edges: if edge["source"] not in self.reachable_nodes: continue self._reverse_adj[edge.get("target")].append({ "id": edge["source"], "branch": edge.get("label") }) + self._adj[edge.get("source")].append(edge["target"]) def _find_upstream_activation_dep( self, @@ -302,22 +286,13 @@ class GraphBuilder: """ for node in self.nodes: node_type = node.get("type") - if node_type == NodeType.NOTES: - continue node_id = node.get("id") - cycle_node = node.get("cycle") - if cycle_node: - # Nodes within a loop subgraph are constructed by CycleGraphNode - if not self.subgraph: - continue - - # Record start and end node IDs - if node_type in [NodeType.START, NodeType.CYCLE_START]: - self.start_node_id = node_id + if node_id not in self.reachable_nodes: + continue # Create node instance (start and end nodes are also created) # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph - node_instance = NodeFactory.create_node(node, self.workflow_config) + node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id]) if node_type in BRANCH_NODES: @@ -413,11 +388,12 @@ class GraphBuilder: # Add conditional edges for source_node, branches in conditional_edges.items(): def make_router(src, branch_list): - """reate a router function for each source node that routes to a NOP node for later merging.""" + """Create a router function for each source node that routes to a NOP node for later merging.""" def make_branch_node(node_name, targets): def node(s): - # NOTE: NOP NODE MUST NOT MODIFY STATE + # NOTE: NOP NODE USED FOR ROUTING ONLY. + # MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS. return { "activate": { node_id: s["activate"][node_name] @@ -504,14 +480,52 @@ class GraphBuilder: logger.debug(f"Added waiting edge: {sources} -> {target}") # Connect End nodes to the global END node - for end_node in self.end_nodes: - end_node_id = end_node.get("id") - if end_node_id: - self.graph.add_edge(end_node_id, END) - logger.debug(f"Added edge: {end_node_id} -> END") + for node in self.reachable_nodes: + if not self._adj[node]: + self.graph.add_edge(node, END) return def build(self) -> CompiledStateGraph: + nodes = self.workflow_config.get("nodes", []) + edges = self.workflow_config.get("edges", []) + + for node in nodes: + if (node.get("cycle") or '') == self.cycle: + node_type = node.get("type") + if node_type in [NodeType.START, NodeType.CYCLE_START]: + self.start_node_id = node.get("id") + elif node_type == NodeType.NOTES: + continue + self.nodes.append(node) + self.node_map[node.get("id")] = node + + for edge in edges: + source_in = edge.get("source") in self.node_map + target_in = edge.get("target") in self.node_map + if source_in ^ target_in: + raise ValueError( + f"Cycle node is connected to external node, " + f"source: {edge.get('source')}, target: {edge.get('target')}" + ) + + if source_in and target_in: + self.edges.append(edge) + + self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) + self.end_nodes = [ + node + for node in self.nodes + if node.get("type") == "end" and node.get("id") in self.reachable_nodes + ] + self._build_adj() + self._find_upstream_activation_dep: Callable = lru_cache( + maxsize=len(self.nodes)*2 + )(self._find_upstream_activation_dep) + + self.graph = StateGraph(WorkflowState) + self.add_nodes() + self.add_edges() + + self._analyze_end_node_output() checkpointer = InMemorySaver() - self.graph = self.graph.compile(checkpointer=checkpointer) - return self.graph + return self.graph.compile(checkpointer=checkpointer) diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py index e5a03c1c..be0c957a 100644 --- a/api/app/core/workflow/engine/result_builder.py +++ b/api/app/core/workflow/engine/result_builder.py @@ -2,6 +2,7 @@ # Author: Eternity # @Email: 1533512157@qq.com # @Time : 2026/2/10 13:33 +from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.engine.variable_pool import VariablePool @@ -9,6 +10,7 @@ class WorkflowResultBuilder: def build_final_output( self, result: dict, + execution_context: ExecutionContext, variable_pool: VariablePool, elapsed_time: float, final_output: str, @@ -26,6 +28,8 @@ class WorkflowResultBuilder: - "node_outputs" (dict): Outputs of executed nodes. - "messages" (list): Conversation messages exchanged during execution. - "error" (str, optional): Error message if any node failed. + execution_context (ExecutionContext): The execution context containing metadata like + execution ID, workspace ID, and user ID.) variable_pool (VariablePool): Variable Pool elapsed_time (float): Total execution time in seconds. final_output (Any): The aggregated or final output content of the workflow @@ -48,18 +52,23 @@ class WorkflowResultBuilder: """ node_outputs = result.get("node_outputs", {}) token_usage = self.aggregate_token_usage(node_outputs) - conversation_id = variable_pool.get_value("sys.conversation_id") + conversation_vars = {} + sys_vars = {} + + if variable_pool: + conversation_vars = variable_pool.get_all_conversation_vars() + sys_vars = variable_pool.get_all_system_vars() return { "status": "completed" if success else "failed", "output": final_output, "variables": { - "conv": variable_pool.get_all_conversation_vars(), - "sys": variable_pool.get_all_system_vars() + "conv": conversation_vars, + "sys": sys_vars }, "node_outputs": node_outputs, "messages": result.get("messages", []), - "conversation_id": conversation_id, + "conversation_id": execution_context.conversation_id, "elapsed_time": elapsed_time, "token_usage": token_usage, "error": result.get("error"), diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py index 48eafaa9..036ce0e8 100644 --- a/api/app/core/workflow/engine/runtime_schema.py +++ b/api/app/core/workflow/engine/runtime_schema.py @@ -12,6 +12,7 @@ class ExecutionContext(BaseModel): execution_id: str workspace_id: str user_id: str + conversation_id: str memory_storage_type: str user_rag_memory_id: str checkpoint_config: RunnableConfig @@ -22,6 +23,7 @@ class ExecutionContext(BaseModel): execution_id: str, workspace_id: str, user_id: str, + conversation_id: str, memory_storage_type: str, user_rag_memory_id: str ): @@ -29,6 +31,7 @@ class ExecutionContext(BaseModel): execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + conversation_id=conversation_id, memory_storage_type=memory_storage_type, user_rag_memory_id=user_rag_memory_id, diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 6a127e96..0a820826 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -3,6 +3,7 @@ # @Email: 1533512157@qq.com # @Time : 2026/2/9 13:51 import datetime +import time import logging from typing import Any @@ -82,13 +83,15 @@ class WorkflowExecutor: CompiledStateGraph: The compiled and ready-to-run state graph. """ logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}") + start_time = time.time() builder = GraphBuilder( self.workflow_config, stream=stream, ) + + self.graph = builder.build() self.start_node_id = builder.start_node_id self.variable_pool = builder.variable_pool - self.graph = builder.build() self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.event_handler = EventStreamHandler( @@ -96,7 +99,8 @@ class WorkflowExecutor: variable_pool=self.variable_pool, execution_id=self.execution_context.execution_id ) - logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}") + logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, " + f"cost: {time.time() - start_time:.4f}s") return self.graph @@ -134,94 +138,12 @@ class WorkflowExecutor: return event.get("data") return self.result_builder.build_final_output( {"error": "Workflow execution did not end as expected"}, + self.execution_context, self.variable_pool, (datetime.datetime.now() - start).total_seconds(), "", success=False ) - # logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}") - # - # start_time = datetime.datetime.now() - # - # # Execute the workflow - # try: - # # Build the workflow graph - # graph = self.build_graph() - # - # # Initialize the variable pool with input data - # await self.variable_initializer.initialize( - # variable_pool=self.variable_pool, - # input_data=input_data, - # execution_context=self.execution_context - # ) - # initial_state = self.state_manager.create_initial_state( - # workflow_config=self.workflow_config, - # input_data=input_data, - # execution_context=self.execution_context, - # start_node_id=self.start_node_id - # ) - # - # result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) - # - # # Aggregate output from all End nodes - # full_content = '' - # for end_id in self.stream_coordinator.end_outputs.keys(): - # full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) - # - # # Append messages for user and assistant - # if input_data.get("files"): - # result["messages"].extend( - # [ - # { - # "role": "user", - # "content": input_data.get("message", '') - # }, - # { - # "role": "user", - # "content": input_data.get("files") - # }, - # { - # "role": "assistant", - # "content": full_content - # } - # ] - # ) - # else: - # result["messages"].extend( - # [ - # { - # "role": "user", - # "content": input_data.get("message", '') - # }, - # { - # "role": "assistant", - # "content": full_content - # } - # ] - # ) - # # Calculate elapsed time - # end_time = datetime.datetime.now() - # elapsed_time = (end_time - start_time).total_seconds() - # - # logger.info( - # f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") - # - # return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) - # - # except Exception as e: - # end_time = datetime.datetime.now() - # elapsed_time = (end_time - start_time).total_seconds() - # - # logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", - # exc_info=True) - # return { - # "status": "failed", - # "error": str(e), - # "output": None, - # "node_outputs": {}, - # "elapsed_time": elapsed_time, - # "token_usage": None - # } async def execute_stream( self, @@ -255,7 +177,7 @@ class WorkflowExecutor: "data": { "execution_id": self.execution_context.execution_id, "workspace_id": self.execution_context.workspace_id, - "conversation_id": input_data.get("conversation_id"), + "conversation_id": self.execution_context.conversation_id, "timestamp": int(start_time.timestamp() * 1000) } } @@ -376,6 +298,7 @@ class WorkflowExecutor: "event": "workflow_end", "data": self.result_builder.build_final_output( result, + self.execution_context, self.variable_pool, elapsed_time, full_content, @@ -396,6 +319,7 @@ class WorkflowExecutor: "event": "workflow_end", "data": self.result_builder.build_final_output( result, + self.execution_context, self.variable_pool, elapsed_time, full_content, @@ -432,6 +356,7 @@ async def execute_workflow( execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + conversation_id=input_data.get("conversation_id"), memory_storage_type=memory_storage_type, user_rag_memory_id=user_rag_memory_id ) @@ -471,6 +396,7 @@ async def execute_workflow_stream( workspace_id=workspace_id, user_id=user_id, memory_storage_type=memory_storage_type, + conversation_id=input_data.get("conversation_id"), user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 8959e27c..7b146a9c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -64,9 +64,7 @@ class AgentNode(BaseNode): if not release: raise ValueError(f"Agent 不存在: {agent_id}") - - return release, message async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 4c897d5a..f5bdf000 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) class AssignerNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.variable_updater = True self.typed_config: AssignerNodeConfig | None = None diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 7f2b8aa6..0b31c9e3 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -28,7 +28,7 @@ class BaseNode(ABC): All node types should inherit from this class and implement the `execute` method. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): """Initialize the node. Args: @@ -41,6 +41,7 @@ class BaseNode(ABC): self.node_type = node_config["type"] self.cycle = node_config.get("cycle") self.node_name = node_config.get("name", self.node_id) + self.down_stream_nodes = down_stream_nodes # 使用 or 运算符处理 None 值 self.config = node_config.get("config") or {} self.error_handling = node_config.get("error_handling") or {} @@ -93,18 +94,16 @@ class BaseNode(ABC): dict: A dict with a single key 'activate', mapping node IDs to their activation status (True/False). """ - edges = self.workflow_config.get("edges") - under_stream_nodes = [ - edge.get("target") - for edge in edges - if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES - ] - return { - "activate": { - node_id: self.check_activate(state) - for node_id in under_stream_nodes - } | {self.node_id: self.check_activate(state)} - } + activate_flag = self.check_activate(state) + + if self.node_type not in BRANCH_NODES: + activate = {node_id: activate_flag for node_id in self.down_stream_nodes} + else: + activate = {} + + activate[self.node_id] = activate_flag + + return {"activate": activate} @abstractmethod async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: @@ -315,8 +314,8 @@ class BaseNode(ABC): elapsed_time = (time.time() - start_time) * 1000 - logger.info(f"Node {self.node_id} streaming execution finished, " - f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") + logger.debug(f"Node {self.node_id} streaming execution finished, " + f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) @@ -428,8 +427,8 @@ class BaseNode(ABC): when an error edge exists. If no error edge exists, this method raises an exception to stop the workflow. """ - # Check if the node has an error edge defined - error_edge = self._find_error_edge() + # # Check if the node has an error edge defined + # error_edge = self._find_error_edge() # Extract input data (for logging or audit purposes) input_data = self._extract_input(state, variable_pool) @@ -447,27 +446,26 @@ class BaseNode(ABC): "error": error_message } - if error_edge: - # If an error edge exists, log a warning and continue to error node - logger.warning( - f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" - ) - return { - "node_outputs": { - self.node_id: node_output - }, - "error": error_message, - "error_node": self.node_id - } - else: - # If no error edge, send the error via stream writer and stop the workflow - writer = get_stream_writer() - writer({ - "type": "node_error", - **node_output - }) - logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") - raise Exception(f"Node {self.node_id} execution failed: {error_message}") + # if error_edge: + # # If an error edge exists, log a warning and continue to error node + # logger.warning( + # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" + # ) + # return { + # "node_outputs": { + # self.node_id: node_output + # }, + # "error": error_message, + # "error_node": self.node_id + # } + # else: + writer = get_stream_writer() + writer({ + "type": "node_error", + **node_output + }) + logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") + raise Exception(f"Node {self.node_id} execution failed: {error_message}") def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """Extracts the input data for this node (used for logging or audit). @@ -644,7 +642,7 @@ class BaseNode(ABC): if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] with get_db_read() as db: - multimodel_service = MultimodalService(db, api_config=api_config) + multimodal_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( type=content.type, url=content.url, @@ -653,7 +651,7 @@ class BaseNode(ABC): upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, ) file_obj.set_content(content.get_content()) - message = await multimodel_service.process_files( + message = await multimodal_service.process_files( [file_obj], ) content.set_content(file_obj.get_content()) @@ -661,7 +659,7 @@ class BaseNode(ABC): content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message return message return None - raise TypeError(f'Unexpect input value type - {type(content)}') + raise TypeError(f'Unexpected input value type - {type(content)}') @staticmethod def process_model_output(content) -> str: diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 1e055002..d89b208b 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -51,8 +51,8 @@ console.log(result) class CodeNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: CodeNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 71e0dbdb..fc80939f 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode): It acts as a container and execution controller for a subgraph. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) - - self.cycle_nodes = list() # Nodes belonging to this cycle - self.cycle_edges = list() # Edges connecting nodes within the cycle + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) + self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() self.start_node_id = None # ID of the start node within the cycle self.graph: StateGraph | CompiledStateGraph | None = None self.child_variable_pool: VariablePool | None = None - self.build_graph() - self.iteration_flag = True def _output_types(self) -> dict[str, VariableType]: outputs = {"__child_state": VariableType.ARRAY_OBJECT} @@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode): else: remain_edges.append(edge) - # Update workflow_config by removing cycle nodes and internal edges - self.workflow_config["nodes"] = [ - node for node in nodes if node.get("cycle") != self.node_id - ] - self.workflow_config["edges"] = remain_edges + # # Update workflow_config by removing cycle nodes and internal edges + # self.workflow_config["nodes"] = [ + # node for node in nodes if node.get("cycle") != self.node_id + # ] + # self.workflow_config["edges"] = remain_edges return cycle_nodes, cycle_edges @@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode): 3. Compile the graph for runtime execution """ from app.core.workflow.engine.graph_builder import GraphBuilder - self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.child_variable_pool = VariablePool() builder = GraphBuilder( { "nodes": self.cycle_nodes, "edges": self.cycle_edges, }, - subgraph=True, - variable_pool=self.child_variable_pool + variable_pool=self.child_variable_pool, + cycle=self.node_id ) - self.start_node_id = builder.start_node_id self.graph = builder.build() + self.start_node_id = builder.start_node_id self.child_variable_pool = builder.variable_pool async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: @@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode): Raises: RuntimeError: If the node type is unsupported. """ + self.build_graph() if self.node_type == NodeType.LOOP: return await LoopRuntime( start_id=self.start_node_id, @@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode): raise RuntimeError("Unknown cycle node type") async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): + self.build_graph() if self.node_type == NodeType.LOOP: yield { "__final__": True, diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index 5c2a6c2a..02df5091 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -1,9 +1,7 @@ """End 节点配置""" - from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition -from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig class EndNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 2799316a..770cf328 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -36,8 +36,6 @@ class EndNode(BaseNode): Returns: 最终输出字符串 """ - logger.info(f"节点 {self.node_id} (End) 开始执行") - # 获取配置的输出模板 output_template = self.config.get("output") @@ -46,11 +44,4 @@ class EndNode(BaseNode): output = self._render_template(output_template, variable_pool, strict=False) else: output = "" - - # 统计信息(用于日志) - node_outputs = state.get("node_outputs", {}) - total_nodes = len(node_outputs) - - logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - return output diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 43ab593b..5a603ac9 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -28,7 +28,7 @@ class NodeType(StrEnum): NOTES = "notes" -BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] +BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/http_request/config.py b/api/app/core/workflow/nodes/http_request/config.py index fe38fafb..e1b84f0c 100644 --- a/api/app/core/workflow/nodes/http_request/config.py +++ b/api/app/core/workflow/nodes/http_request/config.py @@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel): ) -class HttpErrorDefaultTamplete(BaseModel): +class HttpErrorDefaultTemplate(BaseModel): body: str = Field( default="", description="Default body returned on HTTP error", @@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel): description="Error handling strategy: 'none', 'default', or 'branch'", ) - default: HttpErrorDefaultTamplete | None = Field( + default: HttpErrorDefaultTemplate | None = Field( default=None, description="Default response template for error handling", ) diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 23378c83..086bee4a 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput -from app.core.workflow.utils.file_processer import mime_to_file_type +from app.core.workflow.utils.file_processor import mime_to_file_type from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.schemas import FileType, TransferMethod @@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode): or a branch identifier string when error branching is enabled. """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: HttpRequestNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 5d2bdf9a..ec46b20b 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) class IfElseNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: IfElseNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index e13709d4..abf21524 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) class JinjaRenderNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: JinjaRenderNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index d3e9efd9..92699cb4 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) class KnowledgeRetrievalNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.vector_service: ElasticSearchVector | None = None diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 66a0f1ac..a691001f 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -70,8 +70,8 @@ class LLMNode(BaseNode): - ai/assistant: AI 消息(AIMessage) """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: LLMNodeConfig | None = None self.messages = [] diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index a28247e4..73c52b79 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -14,8 +14,8 @@ from app.tasks import write_message_task class MemoryReadNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: MemoryReadNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: @@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode): class MemoryWriteNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: MemoryWriteNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 864e3251..9e5a7d24 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -104,13 +104,15 @@ class NodeFactory: def create_node( cls, node_config: dict[str, Any], - workflow_config: dict[str, Any] + workflow_config: dict[str, Any], + down_stream_nodes: list[str] ) -> WorkflowNode | None: """创建节点实例 Args: node_config: 节点配置 workflow_config: 工作流配置 + down_stream_nodes: 下游节点 Returns: 节点实例或 None(对于不支持的节点类型) @@ -127,7 +129,7 @@ class NodeFactory: # 创建节点实例 logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") - return node_class(node_config, workflow_config) + return node_class(node_config, workflow_config, down_stream_nodes) @classmethod def get_supported_types(cls) -> list[str]: diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index acac09e4..3dc5fcc3 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) class ParameterExtractorNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: ParameterExtractorNodeConfig | None = None self.response_metadata = {} diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 5cebd886..31fadaf6 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1" class QuestionClassifierNode(BaseNode): """问题分类器节点""" - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} self.response_metadata = {} diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index a9618f7b..7a324cc4 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -27,14 +27,8 @@ class StartNode(BaseNode): 注意:变量的验证和默认值处理由 Executor 在初始化时完成。 """ - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - """初始化 Start 节点 - - Args: - node_config: 节点配置 - workflow_config: 工作流配置 - """ - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) # 解析并验证配置 self.typed_config: StartNodeConfig | None = None @@ -62,7 +56,6 @@ class StartNode(BaseNode): 包含系统参数、会话变量和自定义变量的字典 """ self.typed_config = StartNodeConfig(**self.config) - logger.info(f"节点 {self.node_id} (Start) 开始执行") # 处理自定义变量(传入 pool 避免重复创建) custom_vars = self._process_custom_variables(variable_pool) @@ -77,9 +70,9 @@ class StartNode(BaseNode): **custom_vars # 自定义变量作为节点输出的一部分 } - logger.info( - f"节点 {self.node_id} (Start) 执行完成," - f"输出了 {len(custom_vars)} 个自定义变量" + logger.debug( + f"Node {self.node_id} (Start) execution completed, " + f"outputting {len(custom_vars)} custom variables" ) return result diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 0e9d3c62..72c5c6a8 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}") class ToolNode(BaseNode): """工具节点""" - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: ToolNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index de82f8ff..9a9c5566 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -12,8 +12,8 @@ logger = logging.getLogger(__name__) class VariableAggregatorNode(BaseNode): - def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - super().__init__(node_config, workflow_config) + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): + super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: VariableAggregatorNodeConfig | None = None def _output_types(self) -> dict[str, VariableType]: diff --git a/api/app/core/workflow/utils/file_processer.py b/api/app/core/workflow/utils/file_processor.py similarity index 100% rename from api/app/core/workflow/utils/file_processer.py rename to api/app/core/workflow/utils/file_processor.py diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 424fdf20..6a73efc4 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -153,7 +153,8 @@ class TemplateRenderer: # 全局渲染器实例(严格模式) -_default_renderer = TemplateRenderer(strict=True) +_strict_renderer = TemplateRenderer(strict=True) +_lenient_renderer = TemplateRenderer(strict=False) def render_template( @@ -184,7 +185,7 @@ def render_template( ... ) '请分析: 这是一段文本' """ - renderer = TemplateRenderer(strict=strict) + renderer = _strict_renderer if strict else _lenient_renderer return renderer.render(template, conv_vars, node_outputs, system_vars) @@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]: Returns: 错误列表 """ - return _default_renderer.validate(template) + return _strict_renderer.validate(template) diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index fe4aea19..0ad74865 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -6,6 +6,7 @@ import copy import logging +from collections import defaultdict, deque from typing import Any, Union, TYPE_CHECKING from app.core.workflow.nodes.enums import NodeType @@ -119,7 +120,6 @@ class WorkflowValidator: errors = [] graphs = cls.get_subgraph(workflow_config) - logger.info(graphs) for index, graph in enumerate(graphs): nodes = graph.get("nodes", []) edges = graph.get("edges", []) @@ -183,7 +183,7 @@ class WorkflowValidator: has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) if has_cycle: errors.append( - f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" + f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}" ) # 8. 验证变量名 @@ -204,18 +204,18 @@ class WorkflowValidator: Returns: 可达节点 ID 集合 """ + adj = defaultdict(list) + for edge in edges: + adj[edge["source"]].append(edge["target"]) + reachable = {start_id} - queue = [start_id] - + queue = deque([start_id]) while queue: - current = queue.pop(0) - for edge in edges: - if edge.get("source") == current: - target = edge.get("target") - if target and target not in reachable: - reachable.add(target) - queue.append(target) - + current = queue.popleft() + for target in adj[current]: + if target not in reachable: + reachable.add(target) + queue.append(target) return reachable @staticmethod @@ -229,10 +229,6 @@ class WorkflowValidator: Returns: (has_cycle, cycle_path): 是否有循环和循环路径 """ - # 排除 loop 类型的节点 - loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"} - - # 构建邻接表(排除 loop 节点的边和错误边) graph: dict[str, list[str]] = {} for edge in edges: source = edge.get("source") @@ -243,10 +239,6 @@ class WorkflowValidator: if edge_type == "error": continue - # 如果涉及 loop 节点,跳过 - if source in loop_nodes or target in loop_nodes: - continue - if source and target: if source not in graph: graph[source] = [] diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 5e8e3f1e..79e023c1 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -54,7 +54,7 @@ class DictVariable(BaseVariable): def valid_value(self, value) -> dict: if not isinstance(value, dict): - raise TypeError(f"Value must be a dict - {type(value)}:{value}") + raise TypeError(f"Value must be a dict - {type(value)}:{value}") return value def to_literal(self) -> str: diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 44a844d0..69bedc3d 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -27,9 +27,9 @@ class ModelType(StrEnum): RERANK = "rerank" # TTS = "tts" # SPEECH2TEXT = "speech2text" - # IMAGE = "image" + IMAGE = "image" # AUDIO = "audio" - # VISION = "vision" + VIDEO = "video" class ModelProvider(StrEnum): @@ -46,6 +46,7 @@ class ModelProvider(StrEnum): XINFERENCE = "xinference" GPUSTACK = "gpustack" BEDROCK = "bedrock" + VOLCANO = "volcano" COMPOSITE = "composite" diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index 044857d2..a92b5629 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -23,6 +23,17 @@ class Tenants(Base): # 国际化语言配置字段 default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言 supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表 + + # 租户联系信息 + contact_name = Column(String(100), nullable=True) # 联系人姓名 + contact_email = Column(String(255), nullable=True) # 联系人邮箱 + contact_phone = Column(String(50), nullable=True) # 联系人电话 + + # 租户套餐信息 + plan = Column(String(50), nullable=True) # 套餐类型 + plan_expired_at = Column(DateTime, nullable=True) # 套餐到期时间 + api_ops_rate_limit = Column(String(100), nullable=True) # API 调用频率限制 + status = Column(String(50), nullable=True, default='active') # 租户状态 # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index fd95c793..8c477d39 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -439,7 +439,6 @@ class ModelConfigRepository: ModelConfig.is_public ), ModelConfig.provider == provider, - ModelConfig.is_active, ~ModelConfig.is_composite ) ).all() diff --git a/api/app/services/file_storage_service.py b/api/app/services/file_storage_service.py index 2ebc5d9a..5897936b 100644 --- a/api/app/services/file_storage_service.py +++ b/api/app/services/file_storage_service.py @@ -325,27 +325,30 @@ class FileStorageService: ) raise - async def get_file_url(self, file_key: str, expires: int = 3600) -> str: + async def get_file_url( + self, + file_key: str, + expires: int = 3600, + file_name: Optional[str] = None, + ) -> str: """ Get an access URL for a file. Args: file_key: The file key. expires: URL validity period in seconds (default: 1 hour). + file_name: If set, adds Content-Disposition: attachment to force download. Returns: URL for accessing the file. """ logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s") - try: - url = await self.storage.get_url(file_key, expires) + url = await self.storage.get_url(file_key, expires, file_name=file_name) logger.debug(f"File URL generated: file_key={file_key}") return url except Exception as e: - logger.error( - f"Error getting file URL: file_key={file_key}, error={str(e)}" - ) + logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}") raise diff --git a/api/app/services/generation_service.py b/api/app/services/generation_service.py new file mode 100644 index 00000000..2505793c --- /dev/null +++ b/api/app/services/generation_service.py @@ -0,0 +1,162 @@ +""" +图片和视频生成服务 + +提供统一的生成接口,支持多种 Provider +""" +from typing import Dict, Any, Optional +from sqlalchemy.orm import Session +import uuid + +from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.models.models_model import ModelType +from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository +from app.services.model_service import ModelApiKeyService + + +class GenerationService: + """生成服务""" + + def __init__(self, db: Session): + self.db = db + + async def generate_image( + self, + model_config_id: str, + prompt: str, + size: Optional[str] = "2k", + **kwargs + ) -> Dict[str, Any]: + """ + 生成图片 + + Args: + model_config_id: 模型配置ID + prompt: 提示词 + size: 图片尺寸 + **kwargs: 其他参数 + + Returns: + 生成结果 + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + if model_config.type != ModelType.IMAGE: + raise BusinessException( + f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}", + code=BizCode.INVALID_PARAMETER + ) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 生成图片 + generator = RedBearImageGenerator(config) + result = await generator.agenerate(prompt, size, **kwargs) + + return result + + async def generate_video( + self, + model_config_id: str, + prompt: str, + duration: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """ + 生成视频 + + Args: + model_config_id: 模型配置ID + prompt: 提示词 + duration: 视频时长(秒) + **kwargs: 其他参数 + + Returns: + 生成结果(包含任务ID) + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + if model_config.type != ModelType.VIDEO: + raise BusinessException( + f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}", + code=BizCode.INVALID_PARAMETER + ) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 生成视频 + generator = RedBearVideoGenerator(config) + result = await generator.agenerate(prompt, duration, **kwargs) + + return result + + async def get_video_task_status( + self, + model_config_id: str, + task_id: str + ) -> Dict[str, Any]: + """ + 查询视频生成任务状态 + + Args: + model_config_id: 模型配置ID + task_id: 任务ID + + Returns: + 任务状态信息 + """ + # 获取模型配置 + model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id)) + if not model_config: + raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND) + + # 获取 API Key + api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id)) + if not api_key_info: + raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND) + + # 创建配置 + config = RedBearModelConfig( + model_name=api_key_info.model_name, + provider=api_key_info.provider, + api_key=api_key_info.api_key, + base_url=api_key_info.api_base, + extra_params=api_key_info.config or {} + ) + + # 查询任务状态 + generator = RedBearVideoGenerator(config) + result = await generator.aget_task_status(task_id) + + return result diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index e5c34492..289fd74c 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -357,6 +357,7 @@ class MemoryAgentService: if file_object is None: continue message["file_content"].append((file_object, file["type"])) + logger.info(messages) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: @@ -606,7 +607,7 @@ class MemoryAgentService: retrieved_content.append({query: statements}) # 如果 retrieved_content 为空,设置为空字符串 - if retrieved_content == []: + if not retrieved_content: retrieved_content = '' # 只有当回答不是"信息不足"且不是快速检索时才保存 diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index a7398504..b98674ba 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -154,10 +154,17 @@ class ModelConfigService: } elif model_type_lower == "embedding": - # Embedding 模型验证(在线程中运行同步方法) + # Embedding 模型验证 + # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] - vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) + + # 火山引擎使用 embed_batch,其他使用 embed_documents + if provider.lower() == "volcano": + vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) + else: + vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) + elapsed_time = time.time() - start_time return { @@ -193,6 +200,56 @@ class ModelConfigService: }, "error": None } + + elif model_type_lower == "image": + # 图片生成模型验证 + from app.core.models.generation import RedBearImageGenerator + + generator = RedBearImageGenerator(model_config) + result = await generator.agenerate( + prompt="a cute panda", + size="2K" + ) + elapsed_time = time.time() - start_time + logger.info(f"成功生成图片,结果: {result}") + + return { + "valid": True, + "message": "图片生成模型配置验证成功", + "response": f"成功生成图片,结果: {result}", + "elapsed_time": elapsed_time, + "usage": { + "prompt_length": len("a cute panda"), + "image_count": 1 + }, + "error": None + } + + elif model_type_lower == "video": + # 视频生成模型验证 + from app.core.models.generation import RedBearVideoGenerator + + generator = RedBearVideoGenerator(model_config) + result = await generator.agenerate( + prompt="a cute panda playing in bamboo forest", + duration=5 + ) + elapsed_time = time.time() - start_time + + # 视频生成是异步任务,返回任务ID + task_id = result.get("task_id") if isinstance(result, dict) else None + + return { + "valid": True, + "message": "视频生成模型配置验证成功", + "response": f"成功创建视频生成任务,任务ID: {task_id}", + "elapsed_time": elapsed_time, + "usage": { + "prompt_length": len("a cute panda playing in bamboo forest"), + "task_id": task_id + }, + "error": None + } else: return { diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index c4af12d5..3afd6206 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -294,6 +294,7 @@ PROVIDER_STRATEGIES = { "bedrock": BedrockFormatStrategy, "anthropic": BedrockFormatStrategy, "openai": OpenAIFormatStrategy, + "volcano": OpenAIFormatStrategy, } diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index b5522b74..3122d282 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -250,6 +250,20 @@ def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user: } ) + # 检查是否为租户联系人 + from app.models.tenant_model import Tenants + tenant = db.query(Tenants).filter(Tenants.id == db_user.tenant_id).first() + if tenant and tenant.contact_email and tenant.contact_email == db_user.email: + business_logger.warning(f"尝试停用租户联系人: {db_user.email}, tenant_id={db_user.tenant_id}") + raise BusinessException( + "该管理员是租户联系人,请先在租户信息中更换联系邮箱,再禁用此管理员", + code=BizCode.FORBIDDEN, + context={ + "user_id": str(user_id_to_deactivate), + "tenant_id": str(db_user.tenant_id) + } + ) + # 停用用户 business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})") db_user.is_active = False diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py index 2b36c5ea..fd8f25f3 100644 --- a/api/app/services/workflow_import_service.py +++ b/api/app/services/workflow_import_service.py @@ -12,7 +12,7 @@ from app.aioRedis import aio_redis_set, aio_redis_get from app.core.config import settings from app.core.exceptions import BusinessException from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult -from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration +from app.core.workflow.adapters.errors import UnsupportedPlatform, InvalidConfiguration from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.schemas import AppCreate from app.schemas.workflow_schema import WorkflowConfigCreate @@ -46,7 +46,7 @@ class WorkflowImportService: success=False, temp_id=None, workflow_id=None, - errors=[UnsupportPlatform(platform=platform)] + errors=[UnsupportedPlatform(platform=platform)] ) adapter = self.registry.get_adapter(platform, config) diff --git a/api/migrations/versions/1ea8fe97b5b7_202603252115.py b/api/migrations/versions/1ea8fe97b5b7_202603252115.py new file mode 100644 index 00000000..1f0df3e7 --- /dev/null +++ b/api/migrations/versions/1ea8fe97b5b7_202603252115.py @@ -0,0 +1,42 @@ +"""202603252115 + +Revision ID: 1ea8fe97b5b7 +Revises: e28bcc212da5 +Create Date: 2026-03-25 21:14:41.825048 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1ea8fe97b5b7' +down_revision: Union[str, None] = 'e28bcc212da5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tenants', sa.Column('contact_name', sa.String(length=100), nullable=True)) + op.add_column('tenants', sa.Column('contact_email', sa.String(length=255), nullable=True)) + op.add_column('tenants', sa.Column('contact_phone', sa.String(length=50), nullable=True)) + op.add_column('tenants', sa.Column('plan', sa.String(length=50), nullable=True)) + op.add_column('tenants', sa.Column('plan_expired_at', sa.DateTime(), nullable=True)) + op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.String(length=100), nullable=True)) + op.add_column('tenants', sa.Column('status', sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tenants', 'status') + op.drop_column('tenants', 'api_ops_rate_limit') + op.drop_column('tenants', 'plan_expired_at') + op.drop_column('tenants', 'plan') + op.drop_column('tenants', 'contact_phone') + op.drop_column('tenants', 'contact_email') + op.drop_column('tenants', 'contact_name') + # ### end Alembic commands ### diff --git a/api/pyproject.toml b/api/pyproject.toml index e6fddea8..8ced574c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -147,6 +147,7 @@ dependencies = [ "modelscope>=1.34.0", "python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'", "python-magic-bin>=0.4.14; sys_platform=='win32'", + "volcengine-python-sdk[ark]==5.0.19" ] [tool.pytest.ini_options]