diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index b84011d3..496454ba 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -618,7 +618,12 @@ class BaseNode(ABC): return variable_pool.has(selector) @staticmethod - async def process_message(provider: str, content: str | dict | FileObject, enable_file=False) -> list | str | None: + async def process_message( + provider: str, + is_omni: bool, + content: str | dict | FileObject, + enable_file=False + ) -> list | str | None: if isinstance(content, dict): content = FileObject( type=content.get("type"), @@ -637,7 +642,7 @@ class BaseNode(ABC): if content.content_cache.get(provider): return content.content_cache[provider] with get_db_read() as db: - multimodel_service = MultimodalService(db, provider) + multimodel_service = MultimodalService(db, provider, is_omni=is_omni) message = await multimodel_service.process_files( [FileInput.model_construct( type=content.type, @@ -647,7 +652,6 @@ class BaseNode(ABC): upload_file_id=content.file_id )] ) - if message: content.content_cache[provider] = message return message @@ -669,11 +673,11 @@ class BaseNode(ABC): return content return result - def model_balance(self, model_config: ModelConfig) -> ModelApiKey: + @staticmethod + def model_balance(model_config: ModelConfig) -> ModelApiKey: api_keys = [key for key in model_config.api_keys if key.is_active] if not api_keys: raise ValueError("No active API keys available for model") if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: - if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: - return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) + return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) return api_keys[0] diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 92a0dff7..186c204f 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -153,30 +153,30 @@ class LLMNode(BaseNode): if role == "system": messages.append({ "role": "system", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(provider, content, self.typed_config.vision) + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(provider, file.value, self.typed_config.vision) + content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -190,14 +190,14 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(provider, file, self.typed_config.vision) + content = await self.process_message(provider, is_omni, file, self.typed_config.vision) if content: file_content.extend(content) history_message.append( {"role": message["role"], "content": file_content} ) else: - message["content"] = await self.process_message(provider, message["content"], self.typed_config.vision) + message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) history_message.append(message) messages = messages[:-1] + history_message + messages[-1:] self.messages = messages