feat: support model load balancing and add message_id to API responses

This commit is contained in:
Eternity
2026-03-06 19:29:31 +08:00
parent 3e5f6176af
commit b756f0c86c
2 changed files with 17 additions and 13 deletions

View File

@@ -618,7 +618,12 @@ class BaseNode(ABC):
return variable_pool.has(selector) return variable_pool.has(selector)
@staticmethod @staticmethod
async def process_message(provider: str, content: str | dict | FileObject, enable_file=False) -> list | str | None: async def process_message(
provider: str,
is_omni: bool,
content: str | dict | FileObject,
enable_file=False
) -> list | str | None:
if isinstance(content, dict): if isinstance(content, dict):
content = FileObject( content = FileObject(
type=content.get("type"), type=content.get("type"),
@@ -637,7 +642,7 @@ class BaseNode(ABC):
if content.content_cache.get(provider): if content.content_cache.get(provider):
return content.content_cache[provider] return content.content_cache[provider]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, provider) multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
message = await multimodel_service.process_files( message = await multimodel_service.process_files(
[FileInput.model_construct( [FileInput.model_construct(
type=content.type, type=content.type,
@@ -647,7 +652,6 @@ class BaseNode(ABC):
upload_file_id=content.file_id upload_file_id=content.file_id
)] )]
) )
if message: if message:
content.content_cache[provider] = message content.content_cache[provider] = message
return message return message
@@ -669,11 +673,11 @@ class BaseNode(ABC):
return content return content
return result return result
def model_balance(self, model_config: ModelConfig) -> ModelApiKey: @staticmethod
def model_balance(model_config: ModelConfig) -> ModelApiKey:
api_keys = [key for key in model_config.api_keys if key.is_active] api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys: if not api_keys:
raise ValueError("No active API keys available for model") raise ValueError("No active API keys available for model")
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
return api_keys[0] return api_keys[0]

View File

@@ -153,30 +153,30 @@ class LLMNode(BaseNode):
if role == "system": if role == "system":
messages.append({ messages.append({
"role": "system", "role": "system",
"content": await self.process_message(provider, content, self.typed_config.vision) "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(provider, content, self.typed_config.vision) "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": await self.process_message(provider, content, self.typed_config.vision) "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(provider, content, self.typed_config.vision) "content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(provider, file.value, self.typed_config.vision) content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
@@ -190,14 +190,14 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list): if isinstance(message["content"], list):
file_content = [] file_content = []
for file in message["content"]: for file in message["content"]:
content = await self.process_message(provider, file, self.typed_config.vision) content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
history_message.append( history_message.append(
{"role": message["role"], "content": file_content} {"role": message["role"], "content": file_content}
) )
else: else:
message["content"] = await self.process_message(provider, message["content"], self.typed_config.vision) message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
history_message.append(message) history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:] messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages self.messages = messages