diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index dacbef85..496454ba 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,7 +1,7 @@ import asyncio import logging -import uuid from abc import ABC, abstractmethod +from datetime import datetime from functools import cached_property from typing import Any, AsyncGenerator @@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.variable.base_variable import VariableType, FileObject from app.db import get_db_read +from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy from app.schemas import FileInput from app.services.multimodal_service import MultimodalService @@ -617,7 +618,12 @@ class BaseNode(ABC): return variable_pool.has(selector) @staticmethod - async def process_message(provider: str, content: str | dict | FileObject, enable_file=False) -> list | str | None: + async def process_message( + provider: str, + is_omni: bool, + content: str | dict | FileObject, + enable_file=False + ) -> list | str | None: if isinstance(content, dict): content = FileObject( type=content.get("type"), @@ -629,14 +635,14 @@ class BaseNode(ABC): ) if isinstance(content, str): if enable_file: - return [{"text": content}] + return [{"type": "text", "text": content}] return content elif isinstance(content, FileObject): if content.content_cache.get(provider): return content.content_cache[provider] with get_db_read() as db: - multimodel_service = MultimodalService(db, provider) + multimodel_service = MultimodalService(db, provider, is_omni=is_omni) message = await multimodel_service.process_files( [FileInput.model_construct( type=content.type, @@ -646,7 +652,6 @@ class BaseNode(ABC): upload_file_id=content.file_id )] ) - if message: content.content_cache[provider] = message return message @@ -667,3 +672,12 @@ class BaseNode(ABC): elif isinstance(content, str): return content return result + + @staticmethod + def model_balance(model_config: ModelConfig) -> ModelApiKey: + api_keys = [key for key in model_config.api_keys if key.is_active] + if not api_keys: + raise ValueError("No active API keys available for model") + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) + return api_keys[0] diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 4b63bc4e..186c204f 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -112,11 +112,12 @@ class LLMNode(BaseNode): raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) # 在 Session 关闭前提取所有需要的数据 - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) @@ -129,7 +130,8 @@ class LLMNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, - extra_params=extra_params + extra_params=extra_params, + is_omni=is_omni ), type=ModelType(model_type) ) @@ -151,30 +153,30 @@ class LLMNode(BaseNode): if role == "system": messages.append({ "role": "system", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(provider, file.value, self.typed_config.vision) + content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -188,14 +190,14 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(provider, file, self.typed_config.vision) + content = await self.process_message(provider, is_omni, file, self.typed_config.vision) if content: file_content.extend(content) history_message.append( {"role": message["role"], "content": file_content} ) else: - message["content"] = await self.process_message(provider, message["content"], self.typed_config.vision) + message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) history_message.append(message) messages = messages[:-1] + history_message + messages[-1:] self.messages = messages diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 4811c118..700ed85f 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index e2fd97ae..5cebd886 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key base_url = api_config.api_base + is_omni = api_config.is_omni model_type = config.type return RedBearLLM( @@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode): provider=provider, api_key=api_key, base_url=base_url, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/schemas/conversation_schema.py b/api/app/schemas/conversation_schema.py index 0fcbc718..13766ef6 100644 --- a/api/app/schemas/conversation_schema.py +++ b/api/app/schemas/conversation_schema.py @@ -86,6 +86,7 @@ class ChatResponse(BaseModel): """聊天响应(非流式)""" conversation_id: uuid.UUID message: str + message_id: str usage: Optional[Dict[str, Any]] = None elapsed_time: Optional[float] = None