diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index dacbef85..b84011d3 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,7 +1,7 @@ import asyncio import logging -import uuid from abc import ABC, abstractmethod +from datetime import datetime from functools import cached_property from typing import Any, AsyncGenerator @@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.variable.base_variable import VariableType, FileObject from app.db import get_db_read +from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy from app.schemas import FileInput from app.services.multimodal_service import MultimodalService @@ -629,7 +630,7 @@ class BaseNode(ABC): ) if isinstance(content, str): if enable_file: - return [{"text": content}] + return [{"type": "text", "text": content}] return content elif isinstance(content, FileObject): @@ -667,3 +668,12 @@ class BaseNode(ABC): elif isinstance(content, str): return content return result + + def model_balance(self, model_config: ModelConfig) -> ModelApiKey: + api_keys = [key for key in model_config.api_keys if key.is_active] + if not api_keys: + raise ValueError("No active API keys available for model") + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) + return api_keys[0] diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 4b63bc4e..92a0dff7 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -112,11 +112,12 @@ class LLMNode(BaseNode): raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) # 在 Session 关闭前提取所有需要的数据 - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) @@ -129,7 +130,8 @@ class LLMNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, - extra_params=extra_params + extra_params=extra_params, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 4811c118..700ed85f 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index e2fd97ae..5cebd886 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key base_url = api_config.api_base + is_omni = api_config.is_omni model_type = config.type return RedBearLLM( @@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode): provider=provider, api_key=api_key, base_url=base_url, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/schemas/conversation_schema.py b/api/app/schemas/conversation_schema.py index 0fcbc718..13766ef6 100644 --- a/api/app/schemas/conversation_schema.py +++ b/api/app/schemas/conversation_schema.py @@ -86,6 +86,7 @@ class ChatResponse(BaseModel): """聊天响应(非流式)""" conversation_id: uuid.UUID message: str + message_id: str usage: Optional[Dict[str, Any]] = None elapsed_time: Optional[float] = None