feat: support model load balancing and add message_id to API responses
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||
from app.schemas import FileInput
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
@@ -629,7 +630,7 @@ class BaseNode(ABC):
|
||||
)
|
||||
if isinstance(content, str):
|
||||
if enable_file:
|
||||
return [{"text": content}]
|
||||
return [{"type": "text", "text": content}]
|
||||
return content
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
@@ -667,3 +668,12 @@ class BaseNode(ABC):
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
return result
|
||||
|
||||
def model_balance(self, model_config: ModelConfig) -> ModelApiKey:
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
raise ValueError("No active API keys available for model")
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
return api_keys[0]
|
||||
|
||||
@@ -112,11 +112,12 @@ class LLMNode(BaseNode):
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
@@ -129,7 +130,8 @@ class LLMNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
extra_params=extra_params,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode):
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
base_url = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
return RedBearLLM(
|
||||
@@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -86,6 +86,7 @@ class ChatResponse(BaseModel):
|
||||
"""聊天响应(非流式)"""
|
||||
conversation_id: uuid.UUID
|
||||
message: str
|
||||
message_id: str
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user