Files
MemoryBear/api/app/services/multimodal_service.py

609 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
多模态文件处理服务
处理图片、文档等多模态文件,转换为 LLM 可用的格式
支持的 Provider:
- DashScope (通义千问): 支持 URL 格式
- Bedrock/Anthropic: 仅支持 base64 格式
- OpenAI: 支持 URL 和 base64 格式
"""
import base64
import io
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import PyPDF2
import httpx
import magic
from docx import Document
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.services.audio_transcription_service import AudioTranscriptionService
logger = get_business_logger()
TEXT_MIME = ['text/plain', 'text/x-markdown']
PDF_MIME = ['application/pdf']
DOC_MIME = [
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
]
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
def __init__(self, file: FileInput):
self.file = file
@abstractmethod
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""格式化图片"""
pass
@abstractmethod
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""格式化文档"""
pass
@abstractmethod
async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""格式化音频"""
pass
@abstractmethod
async def format_video(self, url: str) -> Dict[str, Any]:
"""格式化视频"""
pass
class DashScopeFormatStrategy(MultimodalFormatStrategy):
"""通义千问策略"""
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""通义千问图片格式:{"type": "image", "image": "url"}"""
return {
"type": "image",
"image": url
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""通义千问文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(
self,
file_type: str,
url: str,
content: bytes | None = None,
transcription: Optional[str] = None
) -> Dict[str, Any]:
"""
通义千问音频格式
- 原生支持: qwen-audio 系列
- 其他模型: 需要转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
"type": "audio",
"audio": url
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""通义千问视频格式qwen-vl 系列原生支持)"""
return {
"type": "video",
"video": url
}
class BedrockFormatStrategy(MultimodalFormatStrategy):
"""Bedrock/Anthropic 策略"""
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""
Bedrock/Anthropic 格式: base64 编码
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
"""
logger.info(f"下载并编码图片: {url}")
# 下载图片
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
content = response.content
self.file.set_content(content)
# 确定 media type
content_type = magic.from_buffer(content, mime=True)
media_type = content_type if content_type.startswith("image/") else "image/jpeg"
base64_data = base64.b64encode(content).decode("utf-8")
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
# Bedrock 文档需要 base64 编码
text_bytes = text.encode('utf-8')
base64_text = base64.b64encode(text_bytes).decode('utf-8')
return {
"type": "document",
"source": {
"type": "base64",
"media_type": "text/plain",
"data": base64_text
}
}
async def format_audio(
self, file_type: str,
url: str,
content: bytes | None = None,
transcription: Optional[str] = None
) -> Dict[str, Any]:
"""
Bedrock/Anthropic 音频格式
不支持原生音频,必须转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"[音频转录]\n{transcription}"
}
return {
"type": "text",
"text": "[音频文件Bedrock 不支持原生音频,请启用音频转文本功能]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 视频格式"""
return {
"type": "text",
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
}
class OpenAIFormatStrategy(MultimodalFormatStrategy):
"""OpenAI 策略"""
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
return {
"type": "image_url",
"image_url": {
"url": url
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""OpenAI 文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(
self,
file_type: str,
url: str,
content: bytes | None = None,
transcription: Optional[str] = None
) -> Dict[str, Any]:
"""
OpenAI 音频格式
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
- 其他模型使用转录文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# OpenAI 音频需要 base64 编码
try:
audio_data = content
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
audio_data = response.content
self.file.set_content(audio_data)
base64_audio = base64.b64encode(audio_data).decode('utf-8')
# 1. 优先从 file_type (MIME) 取扩展名
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
# 2. 从响应头 content-type 取
if not file_ext:
content_type = magic.from_buffer(audio_data, mime=True)
file_ext = content_type.split('/')[-1].split(';')[0].strip() if '/' in content_type else None
# 3. 从 URL 路径取扩展名
if not file_ext:
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
# 4. 默认 wav
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
file_ext = "wav" if not file_ext else file_ext
return {
"type": "input_audio",
"input_audio": {
"data": f"data:;base64,{base64_audio}",
"format": file_ext
}
}
except Exception as e:
logger.error(f"下载音频失败: {e}")
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""OpenAI 视频格式"""
return {
"type": "video_url",
"video_url": {
"url": url
}
}
# Provider 到策略的映射
PROVIDER_STRATEGIES = {
"dashscope": DashScopeFormatStrategy,
"bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy,
}
class MultimodalService:
"""多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
enable_audio_transcription: bool = False, is_omni: bool = False):
"""
初始化多模态服务
Args:
db: 数据库会话
provider: 模型提供商dashscope, bedrock, anthropic, openai 等)
api_key: API 密钥(用于音频转文本)
enable_audio_transcription: 是否启用音频转文本
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
"""
self.db = db
self.provider = provider.lower()
self.api_key = api_key
self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files(
self,
files: Optional[List[FileInput]]
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
files: 文件输入列表
Returns:
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
"""
if not files:
return []
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
if self.provider == "dashscope" and self.is_omni:
strategy_class = OpenAIFormatStrategy
else:
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
if not strategy_class:
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
strategy_class = DashScopeFormatStrategy
result = []
for idx, file in enumerate(files):
strategy = strategy_class(file)
try:
if file.type == FileType.IMAGE:
content = await self._process_image(file, strategy)
result.append(content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO:
content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO:
content = await self._process_video(file, strategy)
result.append(content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e:
logger.error(
f"处理文件失败",
extra={
"file_index": idx,
"file_type": file.type,
"error": str(e)
}
)
# 继续处理其他文件,不中断整个流程
result.append({
"type": "text",
"text": f"[文件处理失败: {str(e)}]"
})
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
Args:
file: 图片文件输入
strategy: 格式化策略
Returns:
Dict: 根据 provider 返回不同格式的图片内容
"""
try:
url = await self.get_file_url(file)
return await strategy.format_image(url, content=file.get_content())
except Exception as e:
logger.error(f"处理图片失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[图片处理失败: {str(e)}]"
}
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
Args:
url: 图片 URL
Returns:
tuple: (base64_data, media_type)
"""
from mimetypes import guess_type
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
media_type = content_type
else:
# 从 URL 推断
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
Args:
file: 文档文件输入
strategy: 格式化策略
Returns:
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
return {
"type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
}
else:
# 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
file_name = file_metadata.file_name if file_metadata else "unknown"
# 使用策略格式化文档
return await strategy.format_document(file_name, text)
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理音频文件
Args:
file: 音频文件输入
strategy: 格式化策略
Returns:
Dict: 根据 provider 返回不同格式的音频内容
"""
try:
url = await self.get_file_url(file)
# 如果启用音频转文本且有 API Key
transcription = None
if self.enable_audio_transcription and self.api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
return await strategy.format_audio(file.file_type, url, file.get_content(), transcription)
except Exception as e:
logger.error(f"处理音频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理视频文件
Args:
file: 视频文件输入
strategy: 格式化策略
Returns:
Dict: 根据 provider 返回不同格式的视频内容
"""
try:
url = await self.get_file_url(file)
return await strategy.format_video(url)
except Exception as e:
logger.error(f"处理视频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[视频处理失败: {str(e)}]"
}
async def get_file_url(self, file: FileInput) -> str:
"""
获取文件的访问 URL
Args:
file: File Input Struct
Returns:
str: 文件访问 URL永久URL
Raises:
BusinessException: 文件不存在
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
return file.url
else:
file_id = file.upload_file_id
# 查询 FileMetadata
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# 返回永久URL
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file: FileInput) -> str:
"""
提取文档文本内容
Args:
file: 文件输入
Returns:
str: 提取的文本内容
"""
try:
file_content = file.get_content()
if not file_content:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file.url)
response.raise_for_status()
file_content = response.content
file.set_content(file_content)
file_mime_type = magic.from_buffer(file_content, mime=True)
if file_mime_type in TEXT_MIME:
return file_content.decode("utf-8")
elif file_mime_type in PDF_MIME:
return await self._extract_pdf_text(file_content)
elif file_mime_type in DOC_MIME:
return await self._extract_word_text(file_content)
else:
return f"[Unsupported file type: {file_mime_type}]"
except Exception as e:
logger.error(f"Failed to load file. - {e}")
return "[Failed to load file.]"
@staticmethod
async def _extract_pdf_text(file_content: bytes) -> str:
"""提取 PDF 文本"""
try:
# 使用 BytesIO 读取 PDF
text_parts = []
pdf_file = io.BytesIO(file_content)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page in pdf_reader.pages:
text_parts.append(page.extract_text())
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 PDF 文本失败: {e}")
return f"[PDF 提取失败: {str(e)}]"
@staticmethod
async def _extract_word_text(file_content: bytes) -> str:
"""提取 Word 文档文本"""
try:
# 使用 BytesIO 读取 Word 文档
word_file = io.BytesIO(file_content)
doc = Document(word_file)
text_parts = [paragraph.text for paragraph in doc.paragraphs]
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 Word 文本失败: {e}")
return f"[Word 提取失败: {str(e)}]"
def get_multimodal_service(db: Session) -> MultimodalService:
"""获取多模态服务实例(依赖注入)"""
return MultimodalService(db)