[add] bedrock claude support
This commit is contained in:
@@ -248,28 +248,48 @@ class LangChainAgent:
|
|||||||
if context:
|
if context:
|
||||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
# 如果有文件,构建多模态消息(使用通义千问原生格式)
|
# 构建用户消息(支持多模态)
|
||||||
if files and len(files) > 0:
|
if files and len(files) > 0:
|
||||||
# 通义千问多模态格式: [{"text": "..."}, {"image": "url"}]
|
content_parts = self._build_multimodal_content(user_content, files)
|
||||||
# 注意:不使用 LangChain 的标准格式,因为它会转换为 OpenAI 格式
|
|
||||||
content_parts = [{"text": user_content}]
|
|
||||||
|
|
||||||
# 添加文件内容(已经是通义千问格式)
|
|
||||||
for file_item in files:
|
|
||||||
if file_item.get("type") == "image":
|
|
||||||
# 通义千问图片格式: {"image": "url"}
|
|
||||||
content_parts.append({"image": file_item["image"]})
|
|
||||||
elif file_item.get("type") == "text":
|
|
||||||
# 文本内容
|
|
||||||
content_parts.append({"text": file_item["text"]})
|
|
||||||
|
|
||||||
logger.debug(f"构建多模态消息,content_parts: {content_parts}")
|
|
||||||
messages.append(HumanMessage(content=content_parts))
|
messages.append(HumanMessage(content=content_parts))
|
||||||
else:
|
else:
|
||||||
# 纯文本消息(向后兼容)
|
# 纯文本消息
|
||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
构建多模态消息内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文本内容
|
||||||
|
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 消息内容列表
|
||||||
|
"""
|
||||||
|
# 根据 provider 使用不同的文本格式
|
||||||
|
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||||
|
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||||
|
content_parts = [{"type": "text", "text": text}]
|
||||||
|
else:
|
||||||
|
# 通义千问等: {"text": "..."}
|
||||||
|
content_parts = [{"text": text}]
|
||||||
|
|
||||||
|
# 添加文件内容
|
||||||
|
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||||
|
content_parts.extend(files)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"构建多模态消息: provider={self.provider}, "
|
||||||
|
f"parts={len(content_parts)}, "
|
||||||
|
f"files={len(files)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return content_parts
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ class RedBearModelFactory:
|
|||||||
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
||||||
# region 从 base_url 或 extra_params 获取
|
# region 从 base_url 或 extra_params 获取
|
||||||
from botocore.config import Config as BotoConfig
|
from botocore.config import Config as BotoConfig
|
||||||
|
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||||
|
|
||||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||||
# Configure with increased connection pool
|
# Configure with increased connection pool
|
||||||
@@ -89,8 +91,11 @@ class RedBearModelFactory:
|
|||||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||||
|
model_id = normalize_bedrock_model_id(config.model_name)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model_id": config.model_name,
|
"model_id": model_id,
|
||||||
"config": boto_config,
|
"config": boto_config,
|
||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -413,9 +413,11 @@ class DraftRunService:
|
|||||||
# 6. 处理多模态文件
|
# 6. 处理多模态文件
|
||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
multimodal_service = MultimodalService(self.db)
|
# 获取 provider 信息
|
||||||
|
provider = api_key_config.get("provider", "openai")
|
||||||
|
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||||
processed_files = await multimodal_service.process_files(files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
context = None
|
context = None
|
||||||
@@ -659,9 +661,11 @@ class DraftRunService:
|
|||||||
# 6. 处理多模态文件
|
# 6. 处理多模态文件
|
||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
multimodal_service = MultimodalService(self.db)
|
# 获取 provider 信息
|
||||||
|
provider = api_key_config.get("provider", "openai")
|
||||||
|
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||||
processed_files = await multimodal_service.process_files(files)
|
processed_files = await multimodal_service.process_files(files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
context = None
|
context = None
|
||||||
|
|||||||
@@ -3,12 +3,13 @@
|
|||||||
|
|
||||||
处理图片、文档等多模态文件,转换为 LLM 可用的格式
|
处理图片、文档等多模态文件,转换为 LLM 可用的格式
|
||||||
|
|
||||||
格式说明:
|
支持的 Provider:
|
||||||
- 当前使用通义千问格式
|
- DashScope (通义千问): 支持 URL 格式
|
||||||
- 通义千问格式: {"type": "image", "image": "url"}
|
- Bedrock/Anthropic: 仅支持 base64 格式
|
||||||
|
- OpenAI: 支持 URL 和 base64 格式
|
||||||
"""
|
"""
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional, Protocol
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
@@ -20,11 +21,105 @@ from app.models.generic_file_model import GenericFile
|
|||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFormatStrategy(Protocol):
|
||||||
|
"""图片格式策略接口"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""将图片 URL 转换为特定 provider 的格式"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DashScopeImageStrategy:
|
||||||
|
"""通义千问图片格式策略"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"image": url
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockImageStrategy:
|
||||||
|
"""Bedrock/Anthropic 图片格式策略"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Bedrock/Anthropic 格式: base64 编码
|
||||||
|
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import base64
|
||||||
|
from mimetypes import guess_type
|
||||||
|
|
||||||
|
logger.info(f"下载并编码图片: {url}")
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 获取图片数据
|
||||||
|
image_data = response.content
|
||||||
|
|
||||||
|
# 确定 media type
|
||||||
|
content_type = response.headers.get("content-type")
|
||||||
|
if content_type and content_type.startswith("image/"):
|
||||||
|
media_type = content_type
|
||||||
|
else:
|
||||||
|
guessed_type, _ = guess_type(url)
|
||||||
|
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||||
|
|
||||||
|
# 转换为 base64
|
||||||
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": base64_data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageStrategy:
|
||||||
|
"""OpenAI 图片格式策略"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Provider 到策略的映射
|
||||||
|
PROVIDER_STRATEGIES = {
|
||||||
|
"dashscope": DashScopeImageStrategy,
|
||||||
|
"bedrock": BedrockImageStrategy,
|
||||||
|
"anthropic": BedrockImageStrategy,
|
||||||
|
"openai": OpenAIImageStrategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MultimodalService:
|
class MultimodalService:
|
||||||
"""多模态文件处理服务"""
|
"""多模态文件处理服务"""
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||||
|
"""
|
||||||
|
初始化多模态服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
provider: 模型提供商(dashscope, bedrock, anthropic 等)
|
||||||
|
"""
|
||||||
self.db = db
|
self.db = db
|
||||||
|
self.provider = provider.lower()
|
||||||
|
|
||||||
async def process_files(
|
async def process_files(
|
||||||
self,
|
self,
|
||||||
@@ -37,7 +132,7 @@ class MultimodalService:
|
|||||||
files: 文件输入列表
|
files: 文件输入列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict]: LLM 可用的内容格式列表
|
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
|
||||||
"""
|
"""
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
@@ -74,7 +169,7 @@ class MultimodalService:
|
|||||||
"text": f"[文件处理失败: {str(e)}]"
|
"text": f"[文件处理失败: {str(e)}]"
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件")
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||||
@@ -85,24 +180,88 @@ class MultimodalService:
|
|||||||
file: 图片文件输入
|
file: 图片文件输入
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 通义千问格式 {"type": "image", "image": "url"}
|
Dict: 根据 provider 返回不同格式
|
||||||
|
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||||
|
- 通义千问: {"type": "image", "image": "url"}
|
||||||
"""
|
"""
|
||||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
# 远程 URL,使用通义千问格式
|
url = file.url
|
||||||
logger.debug(f"处理远程图片: {file.url}")
|
|
||||||
return {
|
|
||||||
"type": "image",
|
|
||||||
"image": file.url
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# 本地文件,获取访问 URL
|
# 本地文件,获取访问 URL
|
||||||
url = await self._get_file_url(file.upload_file_id)
|
url = await self._get_file_url(file.upload_file_id)
|
||||||
logger.debug(f"处理本地图片: {url}")
|
|
||||||
|
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||||
|
|
||||||
|
# 根据 provider 返回不同格式
|
||||||
|
if self.provider in ["bedrock", "anthropic"]:
|
||||||
|
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||||
|
try:
|
||||||
|
logger.info(f"开始下载并编码图片: {url}")
|
||||||
|
base64_data, media_type = await self._download_and_encode_image(url)
|
||||||
|
result = {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": base64_data[:100] + "..." # 只记录前100个字符
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
|
||||||
|
# 返回完整数据
|
||||||
|
result["source"]["data"] = base64_data
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
|
||||||
|
# 返回错误提示
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[图片加载失败: {str(e)}]"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 通义千问等其他格式支持 URL
|
||||||
return {
|
return {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": url
|
"image": url
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
下载图片并转换为 base64
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 图片 URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (base64_data, media_type)
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import base64
|
||||||
|
from mimetypes import guess_type
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 获取图片数据
|
||||||
|
image_data = response.content
|
||||||
|
|
||||||
|
# 确定 media type
|
||||||
|
content_type = response.headers.get("content-type")
|
||||||
|
if content_type and content_type.startswith("image/"):
|
||||||
|
media_type = content_type
|
||||||
|
else:
|
||||||
|
# 从 URL 推断
|
||||||
|
guessed_type, _ = guess_type(url)
|
||||||
|
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||||
|
|
||||||
|
# 转换为 base64
|
||||||
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||||
|
|
||||||
|
return base64_data, media_type
|
||||||
|
|
||||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
处理文档文件(PDF、Word 等)
|
处理文档文件(PDF、Word 等)
|
||||||
|
|||||||
Reference in New Issue
Block a user