Merge pull request #505 from SuanmoSuanyangTechnology/fix/bug-patch
feat: support model load balancing and add message_id to API responses
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
|
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||||
from app.schemas import FileInput
|
from app.schemas import FileInput
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
@@ -617,7 +618,12 @@ class BaseNode(ABC):
|
|||||||
return variable_pool.has(selector)
|
return variable_pool.has(selector)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def process_message(provider: str, content: str | dict | FileObject, enable_file=False) -> list | str | None:
|
async def process_message(
|
||||||
|
provider: str,
|
||||||
|
is_omni: bool,
|
||||||
|
content: str | dict | FileObject,
|
||||||
|
enable_file=False
|
||||||
|
) -> list | str | None:
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
content = FileObject(
|
content = FileObject(
|
||||||
type=content.get("type"),
|
type=content.get("type"),
|
||||||
@@ -629,14 +635,14 @@ class BaseNode(ABC):
|
|||||||
)
|
)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
if enable_file:
|
if enable_file:
|
||||||
return [{"text": content}]
|
return [{"type": "text", "text": content}]
|
||||||
return content
|
return content
|
||||||
|
|
||||||
elif isinstance(content, FileObject):
|
elif isinstance(content, FileObject):
|
||||||
if content.content_cache.get(provider):
|
if content.content_cache.get(provider):
|
||||||
return content.content_cache[provider]
|
return content.content_cache[provider]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, provider)
|
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||||
message = await multimodel_service.process_files(
|
message = await multimodel_service.process_files(
|
||||||
[FileInput.model_construct(
|
[FileInput.model_construct(
|
||||||
type=content.type,
|
type=content.type,
|
||||||
@@ -646,7 +652,6 @@ class BaseNode(ABC):
|
|||||||
upload_file_id=content.file_id
|
upload_file_id=content.file_id
|
||||||
)]
|
)]
|
||||||
)
|
)
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
content.content_cache[provider] = message
|
content.content_cache[provider] = message
|
||||||
return message
|
return message
|
||||||
@@ -667,3 +672,12 @@ class BaseNode(ABC):
|
|||||||
elif isinstance(content, str):
|
elif isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def model_balance(model_config: ModelConfig) -> ModelApiKey:
|
||||||
|
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||||
|
if not api_keys:
|
||||||
|
raise ValueError("No active API keys available for model")
|
||||||
|
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||||
|
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||||
|
return api_keys[0]
|
||||||
|
|||||||
@@ -112,11 +112,12 @@ class LLMNode(BaseNode):
|
|||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
# 在 Session 关闭前提取所有需要的数据
|
# 在 Session 关闭前提取所有需要的数据
|
||||||
api_config = config.api_keys[0]
|
api_config = self.model_balance(config)
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
api_base = api_config.api_base
|
api_base = api_config.api_base
|
||||||
|
is_omni = api_config.is_omni
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
@@ -129,7 +130,8 @@ class LLMNode(BaseNode):
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
extra_params=extra_params
|
extra_params=extra_params,
|
||||||
|
is_omni=is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
@@ -151,30 +153,30 @@ class LLMNode(BaseNode):
|
|||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files.value:
|
for file in files.value:
|
||||||
content = await self.process_message(provider, file.value, self.typed_config.vision)
|
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
@@ -188,14 +190,14 @@ class LLMNode(BaseNode):
|
|||||||
if isinstance(message["content"], list):
|
if isinstance(message["content"], list):
|
||||||
file_content = []
|
file_content = []
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
history_message.append(
|
history_message.append(
|
||||||
{"role": message["role"], "content": file_content}
|
{"role": message["role"], "content": file_content}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
message["content"] = await self.process_message(provider, message["content"], self.typed_config.vision)
|
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||||
history_message.append(message)
|
history_message.append(message)
|
||||||
messages = messages[:-1] + history_message + messages[-1:]
|
messages = messages[:-1] + history_message + messages[-1:]
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
|
|||||||
@@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
api_config = config.api_keys[0]
|
api_config = self.model_balance(config)
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
api_base = api_config.api_base
|
api_base = api_config.api_base
|
||||||
|
is_omni = api_config.is_omni
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
@@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
|
is_omni=is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
api_config = config.api_keys[0]
|
api_config = self.model_balance(config)
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
base_url = api_config.api_base
|
base_url = api_config.api_base
|
||||||
|
is_omni = api_config.is_omni
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
return RedBearLLM(
|
return RedBearLLM(
|
||||||
@@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
|
is_omni=is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ class ChatResponse(BaseModel):
|
|||||||
"""聊天响应(非流式)"""
|
"""聊天响应(非流式)"""
|
||||||
conversation_id: uuid.UUID
|
conversation_id: uuid.UUID
|
||||||
message: str
|
message: str
|
||||||
|
message_id: str
|
||||||
usage: Optional[Dict[str, Any]] = None
|
usage: Optional[Dict[str, Any]] = None
|
||||||
elapsed_time: Optional[float] = None
|
elapsed_time: Optional[float] = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user