feat: support model load balancing and add message_id to API responses
This commit is contained in:
@@ -618,7 +618,12 @@ class BaseNode(ABC):
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@staticmethod
|
||||
async def process_message(provider: str, content: str | dict | FileObject, enable_file=False) -> list | str | None:
|
||||
async def process_message(
|
||||
provider: str,
|
||||
is_omni: bool,
|
||||
content: str | dict | FileObject,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
if isinstance(content, dict):
|
||||
content = FileObject(
|
||||
type=content.get("type"),
|
||||
@@ -637,7 +642,7 @@ class BaseNode(ABC):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider)
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
@@ -647,7 +652,6 @@ class BaseNode(ABC):
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
)
|
||||
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
return message
|
||||
@@ -669,11 +673,11 @@ class BaseNode(ABC):
|
||||
return content
|
||||
return result
|
||||
|
||||
def model_balance(self, model_config: ModelConfig) -> ModelApiKey:
|
||||
@staticmethod
|
||||
def model_balance(model_config: ModelConfig) -> ModelApiKey:
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
raise ValueError("No active API keys available for model")
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
return api_keys[0]
|
||||
|
||||
@@ -153,30 +153,30 @@ class LLMNode(BaseNode):
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, content, self.typed_config.vision)
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, file.value, self.typed_config.vision)
|
||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -190,14 +190,14 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
{"role": message["role"], "content": file_content}
|
||||
)
|
||||
else:
|
||||
message["content"] = await self.process_message(provider, message["content"], self.typed_config.vision)
|
||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||
history_message.append(message)
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
|
||||
Reference in New Issue
Block a user