From 3f42ea2c61f195d02ffd5aed6d04e44dee12a5bf Mon Sep 17 00:00:00 2001 From: Mark Date: Tue, 3 Feb 2026 12:05:39 +0800 Subject: [PATCH] [add] bedrock claude support --- api/app/core/agent/langchain_agent.py | 52 ++++--- api/app/core/models/base.py | 7 +- api/app/services/draft_run_service.py | 12 +- api/app/services/multimodal_service.py | 189 +++++++++++++++++++++++-- 4 files changed, 224 insertions(+), 36 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 40db9568..019fe4ce 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -248,28 +248,48 @@ class LangChainAgent: if context: user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}" - # 如果有文件,构建多模态消息(使用通义千问原生格式) + # 构建用户消息(支持多模态) if files and len(files) > 0: - # 通义千问多模态格式: [{"text": "..."}, {"image": "url"}] - # 注意:不使用 LangChain 的标准格式,因为它会转换为 OpenAI 格式 - content_parts = [{"text": user_content}] - - # 添加文件内容(已经是通义千问格式) - for file_item in files: - if file_item.get("type") == "image": - # 通义千问图片格式: {"image": "url"} - content_parts.append({"image": file_item["image"]}) - elif file_item.get("type") == "text": - # 文本内容 - content_parts.append({"text": file_item["text"]}) - - logger.debug(f"构建多模态消息,content_parts: {content_parts}") + content_parts = self._build_multimodal_content(user_content, files) messages.append(HumanMessage(content=content_parts)) else: - # 纯文本消息(向后兼容) + # 纯文本消息 messages.append(HumanMessage(content=user_content)) return messages + + def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 构建多模态消息内容 + + Args: + text: 文本内容 + files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式) + + Returns: + List[Dict]: 消息内容列表 + """ + # 根据 provider 使用不同的文本格式 + if self.provider.lower() in ["bedrock", "anthropic"]: + # Anthropic/Bedrock: {"type": "text", "text": "..."} + content_parts = [{"type": "text", "text": text}] + else: + # 通义千问等: {"text": "..."} + content_parts = [{"text": text}] + + # 添加文件内容 + # MultimodalService 已经根据 provider 返回了正确格式,直接使用 + content_parts.extend(files) + + logger.debug( + f"构建多模态消息: provider={self.provider}, " + f"parts={len(content_parts)}, " + f"files={len(files)}" + ) + + return content_parts + + return messages async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): db = next(get_db()) diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index f92a0cb3..f5f49af0 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -81,6 +81,8 @@ class RedBearModelFactory: # api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id # region 从 base_url 或 extra_params 获取 from botocore.config import Config as BotoConfig + from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id + max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50")) max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2")) # Configure with increased connection pool @@ -89,8 +91,11 @@ class RedBearModelFactory: retries={'max_attempts': max_retries, 'mode': 'adaptive'} ) + # 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID) + model_id = normalize_bedrock_model_id(config.model_name) + params = { - "model_id": config.model_name, + "model_id": model_id, "config": boto_config, **config.extra_params } diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 1f0f459d..17f9db85 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -413,9 +413,11 @@ class DraftRunService: # 6. 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + # 获取 provider 信息 + provider = api_key_config.get("provider", "openai") + multimodal_service = MultimodalService(self.db, provider=provider) processed_files = await multimodal_service.process_files(files) - logger.info(f"处理了 {len(processed_files)} 个文件") + logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 context = None @@ -659,9 +661,11 @@ class DraftRunService: # 6. 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + # 获取 provider 信息 + provider = api_key_config.get("provider", "openai") + multimodal_service = MultimodalService(self.db, provider=provider) processed_files = await multimodal_service.process_files(files) - logger.info(f"处理了 {len(processed_files)} 个文件") + logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 context = None diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 81735ef4..a460a7ba 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -3,12 +3,13 @@ 处理图片、文档等多模态文件,转换为 LLM 可用的格式 -格式说明: -- 当前使用通义千问格式 -- 通义千问格式: {"type": "image", "image": "url"} +支持的 Provider: +- DashScope (通义千问): 支持 URL 格式 +- Bedrock/Anthropic: 仅支持 base64 格式 +- OpenAI: 支持 URL 和 base64 格式 """ import uuid -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Protocol from sqlalchemy.orm import Session from app.core.logging_config import get_business_logger @@ -20,11 +21,105 @@ from app.models.generic_file_model import GenericFile logger = get_business_logger() +class ImageFormatStrategy(Protocol): + """图片格式策略接口""" + + async def format_image(self, url: str) -> Dict[str, Any]: + """将图片 URL 转换为特定 provider 的格式""" + ... + + +class DashScopeImageStrategy: + """通义千问图片格式策略""" + + async def format_image(self, url: str) -> Dict[str, Any]: + """通义千问格式: {"type": "image", "image": "url"}""" + return { + "type": "image", + "image": url + } + + +class BedrockImageStrategy: + """Bedrock/Anthropic 图片格式策略""" + + async def format_image(self, url: str) -> Dict[str, Any]: + """ + Bedrock/Anthropic 格式: base64 编码 + {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} + """ + import httpx + import base64 + from mimetypes import guess_type + + logger.info(f"下载并编码图片: {url}") + + # 下载图片 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + + # 获取图片数据 + image_data = response.content + + # 确定 media type + content_type = response.headers.get("content-type") + if content_type and content_type.startswith("image/"): + media_type = content_type + else: + guessed_type, _ = guess_type(url) + media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" + + # 转换为 base64 + base64_data = base64.b64encode(image_data).decode("utf-8") + + logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") + + return { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": base64_data + } + } + + +class OpenAIImageStrategy: + """OpenAI 图片格式策略""" + + async def format_image(self, url: str) -> Dict[str, Any]: + """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" + return { + "type": "image_url", + "image_url": { + "url": url + } + } + + +# Provider 到策略的映射 +PROVIDER_STRATEGIES = { + "dashscope": DashScopeImageStrategy, + "bedrock": BedrockImageStrategy, + "anthropic": BedrockImageStrategy, + "openai": OpenAIImageStrategy, +} + + class MultimodalService: """多模态文件处理服务""" - def __init__(self, db: Session): + def __init__(self, db: Session, provider: str = "dashscope"): + """ + 初始化多模态服务 + + Args: + db: 数据库会话 + provider: 模型提供商(dashscope, bedrock, anthropic 等) + """ self.db = db + self.provider = provider.lower() async def process_files( self, @@ -37,7 +132,7 @@ class MultimodalService: files: 文件输入列表 Returns: - List[Dict]: LLM 可用的内容格式列表 + List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式) """ if not files: return [] @@ -74,7 +169,7 @@ class MultimodalService: "text": f"[文件处理失败: {str(e)}]" }) - logger.info(f"成功处理 {len(result)}/{len(files)} 个文件") + logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result async def _process_image(self, file: FileInput) -> Dict[str, Any]: @@ -85,24 +180,88 @@ class MultimodalService: file: 图片文件输入 Returns: - Dict: 通义千问格式 {"type": "image", "image": "url"} + Dict: 根据 provider 返回不同格式 + - Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} + - 通义千问: {"type": "image", "image": "url"} """ if file.transfer_method == TransferMethod.REMOTE_URL: - # 远程 URL,使用通义千问格式 - logger.debug(f"处理远程图片: {file.url}") - return { - "type": "image", - "image": file.url - } + url = file.url else: # 本地文件,获取访问 URL url = await self._get_file_url(file.upload_file_id) - logger.debug(f"处理本地图片: {url}") + + logger.debug(f"处理图片: {url}, provider={self.provider}") + + # 根据 provider 返回不同格式 + if self.provider in ["bedrock", "anthropic"]: + # Anthropic/Bedrock 只支持 base64 格式,需要下载并转换 + try: + logger.info(f"开始下载并编码图片: {url}") + base64_data, media_type = await self._download_and_encode_image(url) + result = { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": base64_data[:100] + "..." # 只记录前100个字符 + } + } + logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}") + # 返回完整数据 + result["source"]["data"] = base64_data + return result + except Exception as e: + logger.error(f"下载并编码图片失败: {e}", exc_info=True) + # 返回错误提示 + return { + "type": "text", + "text": f"[图片加载失败: {str(e)}]" + } + else: + # 通义千问等其他格式支持 URL return { "type": "image", "image": url } + async def _download_and_encode_image(self, url: str) -> tuple[str, str]: + """ + 下载图片并转换为 base64 + + Args: + url: 图片 URL + + Returns: + tuple: (base64_data, media_type) + """ + import httpx + import base64 + from mimetypes import guess_type + + # 下载图片 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + + # 获取图片数据 + image_data = response.content + + # 确定 media type + content_type = response.headers.get("content-type") + if content_type and content_type.startswith("image/"): + media_type = content_type + else: + # 从 URL 推断 + guessed_type, _ = guess_type(url) + media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" + + # 转换为 base64 + base64_data = base64.b64encode(image_data).decode("utf-8") + + logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") + + return base64_data, media_type + async def _process_document(self, file: FileInput) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等)