feat(model and app):
1. Increase support for visual models and multimodal models; 2. The application and workflow can input various multimodal files such as images, documents, audio, and videos.
This commit is contained in:
@@ -157,6 +157,7 @@ class AppChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -180,7 +181,7 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
@@ -343,6 +344,7 @@ class AppChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -366,7 +368,7 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ class AppService:
|
||||
# 检查主 Agent 的模型配置
|
||||
multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id
|
||||
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id)
|
||||
if not model_api_key:
|
||||
raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id))
|
||||
|
||||
|
||||
101
api/app/services/audio_transcription_service.py
Normal file
101
api/app/services/audio_transcription_service.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
音频转文本服务
|
||||
|
||||
支持的服务商:
|
||||
- DashScope (阿里云通义千问)
|
||||
- OpenAI Whisper
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AudioTranscriptionService:
|
||||
"""音频转文本服务"""
|
||||
|
||||
@staticmethod
|
||||
async def transcribe_dashscope(audio_url: str, api_key: str) -> str:
|
||||
"""
|
||||
使用阿里云通义千问语音识别服务转换音频为文本
|
||||
|
||||
Args:
|
||||
audio_url: 音频文件 URL
|
||||
api_key: DashScope API Key
|
||||
|
||||
Returns:
|
||||
str: 转录的文本
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
},
|
||||
json={
|
||||
"model": "paraformer-v2",
|
||||
"input": {
|
||||
"file_urls": [audio_url]
|
||||
},
|
||||
"parameters": {
|
||||
"language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"]
|
||||
}
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("output", {}).get("results"):
|
||||
text = result["output"]["results"][0].get("transcription_text", "")
|
||||
logger.info(f"音频转文本成功: {len(text)} 字符")
|
||||
return text
|
||||
|
||||
return "[音频转文本失败]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope 音频转文本失败: {e}")
|
||||
return f"[音频转文本失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def transcribe_openai(audio_url: str, api_key: str) -> str:
|
||||
"""
|
||||
使用 OpenAI Whisper 转换音频为文本
|
||||
|
||||
Args:
|
||||
audio_url: 音频文件 URL
|
||||
api_key: OpenAI API Key
|
||||
|
||||
Returns:
|
||||
str: 转录的文本
|
||||
"""
|
||||
try:
|
||||
# 下载音频文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
audio_response = await client.get(audio_url)
|
||||
audio_response.raise_for_status()
|
||||
audio_data = audio_response.content
|
||||
|
||||
# 调用 Whisper API
|
||||
files = {"file": ("audio.mp3", audio_data, "audio/mpeg")}
|
||||
data = {"model": "whisper-1"}
|
||||
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/audio/transcriptions",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
text = result.get("text", "")
|
||||
logger.info(f"音频转文本成功: {len(text)} 字符")
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI Whisper 音频转文本失败: {e}")
|
||||
return f"[音频转文本失败: {str(e)}]"
|
||||
@@ -445,6 +445,7 @@ class CollaborativeOrchestrator:
|
||||
"provider": api_key_config.provider,
|
||||
"api_key": api_key_config.api_key,
|
||||
"api_base": api_key_config.api_base,
|
||||
"is_omni": api_key_config.is_omni,
|
||||
"model_parameters": config_data.get("model_parameters", {}),
|
||||
"api_key_id": api_key_config.id
|
||||
}
|
||||
@@ -511,6 +512,7 @@ class CollaborativeOrchestrator:
|
||||
provider=agent_config["provider"],
|
||||
api_key=agent_config["api_key"],
|
||||
base_url=agent_config.get("api_base"),
|
||||
is_omni=agent_config.get("is_omni", False),
|
||||
extra_params=extra_params
|
||||
)
|
||||
|
||||
|
||||
@@ -415,6 +415,7 @@ class DraftRunService:
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -442,7 +443,7 @@ class DraftRunService:
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
@@ -683,6 +684,7 @@ class DraftRunService:
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -711,7 +713,7 @@ class DraftRunService:
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
@@ -809,7 +811,7 @@ class DraftRunService:
|
||||
"""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
|
||||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
|
||||
"""获取模型的 API Key
|
||||
|
||||
Args:
|
||||
@@ -846,7 +848,8 @@ class DraftRunService:
|
||||
"provider": api_key.provider,
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id
|
||||
"api_key_id": api_key.id,
|
||||
"is_omni": api_key.is_omni
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
|
||||
@@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs(
|
||||
provider=model_api_key.provider,
|
||||
api_key=model_api_key.api_key,
|
||||
base_url=model_api_key.api_base,
|
||||
is_omni=model_api_key.is_omni,
|
||||
extra_params={
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
|
||||
@@ -414,6 +414,7 @@ class LLMRouter:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
@@ -392,6 +392,7 @@ class MasterAgentRouter:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
extra_params = extra_params
|
||||
)
|
||||
|
||||
|
||||
@@ -90,7 +90,8 @@ class ModelConfigService:
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello"
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -102,6 +103,7 @@ class ModelConfigService:
|
||||
api_base: API基础URL
|
||||
model_type: 模型类型 (llm/chat/embedding/rerank)
|
||||
test_message: 测试消息
|
||||
is_omni: 是否为Omni模型
|
||||
|
||||
Returns:
|
||||
Dict: 验证结果
|
||||
@@ -114,14 +116,27 @@ class ModelConfigService:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
# dashscope 的 omni 模型需要使用 compatible-mode
|
||||
if provider.lower() == ModelProvider.DASHSCOPE and is_omni:
|
||||
if not api_base:
|
||||
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
else:
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
# 根据模型类型选择不同的验证方式
|
||||
model_type_lower = model_type.lower()
|
||||
@@ -257,8 +272,9 @@ class ModelConfigService:
|
||||
provider=model_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
model_type=model_data.type,
|
||||
test_message="Hello",
|
||||
is_omni=model_data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -279,6 +295,9 @@ class ModelConfigService:
|
||||
for api_key_data in api_key_datas:
|
||||
api_key_data.model_name = model_data.name
|
||||
api_key_data.provider = model_data.provider
|
||||
# 同步capability和is_omni
|
||||
api_key_data.capability = model_data.capability
|
||||
api_key_data.is_omni = model_data.is_omni
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_ids=[model.id],
|
||||
**api_key_data.model_dump()
|
||||
@@ -497,6 +516,8 @@ class ModelApiKeyService:
|
||||
existing_key.config = data.config
|
||||
existing_key.priority = data.priority
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
@@ -513,7 +534,8 @@ class ModelApiKeyService:
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
@@ -528,6 +550,8 @@ class ModelApiKeyService:
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
capability=data.capability if data.capability is not None else model_config.capability,
|
||||
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni,
|
||||
config=data.config,
|
||||
is_active=data.is_active,
|
||||
priority=data.priority
|
||||
@@ -572,6 +596,8 @@ class ModelApiKeyService:
|
||||
existing_key.config = api_key_data.config
|
||||
existing_key.priority = api_key_data.priority
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
@@ -589,7 +615,8 @@ class ModelApiKeyService:
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -620,7 +647,8 @@ class ModelApiKeyService:
|
||||
api_key=api_key_data.api_key or existing_api_key.api_key,
|
||||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -755,6 +783,8 @@ class ModelBaseService:
|
||||
"type": model_base.type,
|
||||
"logo": model_base.logo,
|
||||
"description": model_base.description,
|
||||
"capability": model_base.capability,
|
||||
"is_omni": model_base.is_omni,
|
||||
"is_composite": False
|
||||
}
|
||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
||||
|
||||
@@ -2593,6 +2593,7 @@ class MultiAgentOrchestrator:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.7, # 整合任务使用中等温度
|
||||
max_tokens=2000
|
||||
)
|
||||
@@ -2758,6 +2759,7 @@ class MultiAgentOrchestrator:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
extra_params={"streaming": True} # 启用流式输出
|
||||
|
||||
@@ -267,7 +267,7 @@ class MultiAgentService:
|
||||
|
||||
# 2. 验证模型配置(如果提供了)
|
||||
if data.default_model_config_id:
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id)
|
||||
if not model_api_key:
|
||||
raise ResourceNotFoundException("模型配置", str(data.default_model_config_id))
|
||||
|
||||
|
||||
@@ -9,47 +9,100 @@
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Protocol
|
||||
import httpx
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
from docx import Document
|
||||
import io
|
||||
import PyPDF2
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.generic_file_model import GenericFile
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ImageFormatStrategy(Protocol):
|
||||
"""图片格式策略接口"""
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""格式化图片"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""格式化文档"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
|
||||
"""格式化音频"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""格式化视频"""
|
||||
pass
|
||||
|
||||
|
||||
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"""通义千问策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""将图片 URL 转换为特定 provider 的格式"""
|
||||
...
|
||||
|
||||
|
||||
class DashScopeImageStrategy:
|
||||
"""通义千问图片格式策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""通义千问文档格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
class BedrockImageStrategy:
|
||||
"""Bedrock/Anthropic 图片格式策略"""
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
通义千问音频格式
|
||||
- 原生支持: qwen-audio 系列
|
||||
- 其他模型: 需要转录为文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
}
|
||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||
return {
|
||||
"type": "audio",
|
||||
"audio": url
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问视频格式(qwen-vl 系列原生支持)"""
|
||||
return {
|
||||
"type": "video",
|
||||
"video": url
|
||||
}
|
||||
|
||||
|
||||
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
"""Bedrock/Anthropic 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
@@ -84,9 +137,46 @@ class BedrockImageStrategy:
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||
# Bedrock 文档需要 base64 编码
|
||||
text_bytes = text.encode('utf-8')
|
||||
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
||||
|
||||
class OpenAIImageStrategy:
|
||||
"""OpenAI 图片格式策略"""
|
||||
return {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "text/plain",
|
||||
"data": base64_text
|
||||
}
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 音频格式
|
||||
不支持原生音频,必须转录为文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频转录]\n{transcription}"
|
||||
}
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]"
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 视频格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
|
||||
}
|
||||
|
||||
|
||||
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"""OpenAI 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
@@ -97,29 +187,97 @@ class OpenAIImageStrategy:
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""OpenAI 文档格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI 音频格式
|
||||
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
||||
- 其他模型使用转录文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
}
|
||||
|
||||
# OpenAI 音频需要 base64 编码
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
ct = response.headers.get("content-type", "")
|
||||
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频失败: {e}")
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 视频格式"""
|
||||
return {
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": url
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Provider 到策略的映射
|
||||
PROVIDER_STRATEGIES = {
|
||||
"dashscope": DashScopeImageStrategy,
|
||||
"bedrock": BedrockImageStrategy,
|
||||
"anthropic": BedrockImageStrategy,
|
||||
"openai": OpenAIImageStrategy,
|
||||
"dashscope": DashScopeFormatStrategy,
|
||||
"bedrock": BedrockFormatStrategy,
|
||||
"anthropic": BedrockFormatStrategy,
|
||||
"openai": OpenAIFormatStrategy,
|
||||
}
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic 等)
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
|
||||
api_key: API 密钥(用于音频转文本)
|
||||
enable_audio_transcription: 是否启用音频转文本
|
||||
is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key
|
||||
self.enable_audio_transcription = enable_audio_transcription
|
||||
self.is_omni = is_omni
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
@@ -137,20 +295,32 @@ class MultimodalService:
|
||||
if not files:
|
||||
return []
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
if self.provider == "dashscope" and self.is_omni:
|
||||
strategy_class = OpenAIFormatStrategy
|
||||
else:
|
||||
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
|
||||
if not strategy_class:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
strategy = strategy_class()
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file)
|
||||
content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file)
|
||||
content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
content = await self._process_audio(file)
|
||||
content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
content = await self._process_video(file)
|
||||
content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
@@ -172,55 +342,29 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
Args:
|
||||
file: 图片文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 根据 provider 返回不同格式
|
||||
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
- 通义千问: {"type": "image", "image": "url"}
|
||||
Dict: 根据 provider 返回不同格式的图片内容
|
||||
"""
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||
|
||||
# 根据 provider 返回不同格式
|
||||
if self.provider in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||
try:
|
||||
logger.info(f"开始下载并编码图片: {url}")
|
||||
base64_data, media_type = await self._download_and_encode_image(url)
|
||||
result = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data[:100] + "..." # 只记录前100个字符
|
||||
}
|
||||
}
|
||||
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
|
||||
# 返回完整数据
|
||||
result["source"]["data"] = base64_data
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
|
||||
# 返回错误提示
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[图片加载失败: {str(e)}]"
|
||||
}
|
||||
else:
|
||||
# 通义千问等其他格式支持 URL
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_image(url)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
"type": "text",
|
||||
"text": f"[图片处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||
@staticmethod
|
||||
async def _download_and_encode_image(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
@@ -230,8 +374,6 @@ class MultimodalService:
|
||||
Returns:
|
||||
tuple: (base64_data, media_type)
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
# 下载图片
|
||||
@@ -258,15 +400,16 @@ class MultimodalService:
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
|
||||
Args:
|
||||
file: 文档文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: text 格式的内容(包含提取的文本)
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
@@ -277,48 +420,68 @@ class MultimodalService:
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
file_name = generic_file.file_name if generic_file else "unknown"
|
||||
file_name = file_metadata.file_name if file_metadata else "unknown"
|
||||
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
# 使用策略格式化文档
|
||||
return await strategy.format_document(file_name, text)
|
||||
|
||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理音频文件
|
||||
|
||||
Args:
|
||||
file: 音频文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 音频内容(暂时返回占位符)
|
||||
Dict: 根据 provider 返回不同格式的音频内容
|
||||
"""
|
||||
# TODO: 实现音频转文字功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件,暂不支持处理]"
|
||||
}
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||
# 如果启用音频转文本且有 API Key
|
||||
transcription = None
|
||||
if self.enable_audio_transcription and self.api_key:
|
||||
logger.info(f"开始音频转文本: {url}")
|
||||
if self.provider == "dashscope":
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
|
||||
elif self.provider == "openai":
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
return await strategy.format_audio(file.file_type, url, transcription)
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理视频文件
|
||||
|
||||
Args:
|
||||
file: 视频文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 视频内容(暂时返回占位符)
|
||||
Dict: 根据 provider 返回不同格式的视频内容
|
||||
"""
|
||||
# TODO: 实现视频处理功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[视频文件,暂不支持处理]"
|
||||
}
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_video(url)
|
||||
except Exception as e:
|
||||
logger.error(f"处理视频失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[视频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def get_file_url(self, file: FileInput) -> str:
|
||||
"""
|
||||
@@ -336,26 +499,22 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return file.url
|
||||
else:
|
||||
# 本地文件,通过 file_storage 系统获取永久访问 URL
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
|
||||
file_id = file.upload_file_id
|
||||
print("="*50)
|
||||
print("file_id",file_id)
|
||||
|
||||
|
||||
# 查询 FileMetadata
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
# 返回永久URL
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
@@ -370,58 +529,79 @@ class MultimodalService:
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# TODO: 根据文件类型提取文本
|
||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||
# - Word: 使用 python-docx
|
||||
# - TXT/MD: 直接读取
|
||||
|
||||
file_ext = generic_file.file_ext.lower()
|
||||
file_ext = file_metadata.file_ext.lower()
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(generic_file.storage_path)
|
||||
return await self._read_text_file(file_url)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(generic_file.storage_path)
|
||||
return await self._extract_pdf_text(file_url)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(generic_file.storage_path)
|
||||
return await self._extract_word_text(file_url)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
async def _read_text_file(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _read_text_file(file_url: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
with open(storage_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
# 下载文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
|
||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _extract_pdf_text(file_url: str) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# TODO: 实现 PDF 文本提取
|
||||
# import PyPDF2 或 pdfplumber
|
||||
return "[PDF 文本提取功能待实现]"
|
||||
# 下载 PDF 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
pdf_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 PDF
|
||||
text_parts = []
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
||||
for page in pdf_reader.pages:
|
||||
text_parts.append(page.extract_text())
|
||||
return '\n'.join(text_parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 文本失败: {e}")
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
async def _extract_word_text(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_url: str) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# TODO: 实现 Word 文本提取
|
||||
# import docx
|
||||
return "[Word 文本提取功能待实现]"
|
||||
# 下载 Word 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
word_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 Word 文档
|
||||
word_file = io.BytesIO(word_data)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Word 文本失败: {e}")
|
||||
return f"[Word 提取失败: {str(e)}]"
|
||||
|
||||
@@ -184,7 +184,8 @@ class PromptOptimizerService:
|
||||
model_name=api_config.model_name,
|
||||
provider=api_config.provider,
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base
|
||||
base_url=api_config.api_base,
|
||||
is_omni=api_config.is_omni
|
||||
), type=ModelType(model_config.type))
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
|
||||
@@ -247,6 +247,7 @@ class SharedChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -454,6 +455,7 @@ class SharedChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
|
||||
Reference in New Issue
Block a user