feat(model and app):

1. Increase support for visual models and multimodal models;
2. The application and workflow can input various multimodal files such as images, documents, audio, and videos.
This commit is contained in:
Timebomb2018
2026-03-05 09:55:54 +08:00
parent 23bfdcefef
commit 590ec3a446
26 changed files with 958 additions and 233 deletions

View File

@@ -157,6 +157,7 @@ class AppChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -180,7 +181,7 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db)
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
@@ -343,6 +344,7 @@ class AppChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -366,7 +368,7 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db)
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")

View File

@@ -232,7 +232,7 @@ class AppService:
# 检查主 Agent 的模型配置
multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id
model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id)
if not model_api_key:
raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id))

View File

@@ -0,0 +1,101 @@
"""
音频转文本服务
支持的服务商:
- DashScope (阿里云通义千问)
- OpenAI Whisper
"""
import httpx
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AudioTranscriptionService:
"""音频转文本服务"""
@staticmethod
async def transcribe_dashscope(audio_url: str, api_key: str) -> str:
"""
使用阿里云通义千问语音识别服务转换音频为文本
Args:
audio_url: 音频文件 URL
api_key: DashScope API Key
Returns:
str: 转录的文本
"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
},
json={
"model": "paraformer-v2",
"input": {
"file_urls": [audio_url]
},
"parameters": {
"language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"]
}
}
)
response.raise_for_status()
result = response.json()
if result.get("output", {}).get("results"):
text = result["output"]["results"][0].get("transcription_text", "")
logger.info(f"音频转文本成功: {len(text)} 字符")
return text
return "[音频转文本失败]"
except Exception as e:
logger.error(f"DashScope 音频转文本失败: {e}")
return f"[音频转文本失败: {str(e)}]"
@staticmethod
async def transcribe_openai(audio_url: str, api_key: str) -> str:
"""
使用 OpenAI Whisper 转换音频为文本
Args:
audio_url: 音频文件 URL
api_key: OpenAI API Key
Returns:
str: 转录的文本
"""
try:
# 下载音频文件
async with httpx.AsyncClient(timeout=60.0) as client:
audio_response = await client.get(audio_url)
audio_response.raise_for_status()
audio_data = audio_response.content
# 调用 Whisper API
files = {"file": ("audio.mp3", audio_data, "audio/mpeg")}
data = {"model": "whisper-1"}
response = await client.post(
"https://api.openai.com/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {api_key}"},
files=files,
data=data
)
response.raise_for_status()
result = response.json()
text = result.get("text", "")
logger.info(f"音频转文本成功: {len(text)} 字符")
return text
except Exception as e:
logger.error(f"OpenAI Whisper 音频转文本失败: {e}")
return f"[音频转文本失败: {str(e)}]"

View File

@@ -445,6 +445,7 @@ class CollaborativeOrchestrator:
"provider": api_key_config.provider,
"api_key": api_key_config.api_key,
"api_base": api_key_config.api_base,
"is_omni": api_key_config.is_omni,
"model_parameters": config_data.get("model_parameters", {}),
"api_key_id": api_key_config.id
}
@@ -511,6 +512,7 @@ class CollaborativeOrchestrator:
provider=agent_config["provider"],
api_key=agent_config["api_key"],
base_url=agent_config.get("api_base"),
is_omni=agent_config.get("is_omni", False),
extra_params=extra_params
)

View File

@@ -415,6 +415,7 @@ class DraftRunService:
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -442,7 +443,7 @@ class DraftRunService:
if files:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider)
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
@@ -683,6 +684,7 @@ class DraftRunService:
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -711,7 +713,7 @@ class DraftRunService:
if files:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider)
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
@@ -809,7 +811,7 @@ class DraftRunService:
"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
"""获取模型的 API Key
Args:
@@ -846,7 +848,8 @@ class DraftRunService:
"provider": api_key.provider,
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id
"api_key_id": api_key.id,
"is_omni": api_key.is_omni
}
async def _ensure_conversation(

View File

@@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs(
provider=model_api_key.provider,
api_key=model_api_key.api_key,
base_url=model_api_key.api_base,
is_omni=model_api_key.is_omni,
extra_params={
"temperature": 0.7,
"max_tokens": 2000,

View File

@@ -414,6 +414,7 @@ class LLMRouter:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.3,
max_tokens=500
)

View File

@@ -392,6 +392,7 @@ class MasterAgentRouter:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
extra_params = extra_params
)

View File

@@ -90,7 +90,8 @@ class ModelConfigService:
api_key: str,
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello"
test_message: str = "Hello",
is_omni: bool = False
) -> Dict[str, Any]:
"""验证模型配置是否有效
@@ -102,6 +103,7 @@ class ModelConfigService:
api_base: API基础URL
model_type: 模型类型 (llm/chat/embedding/rerank)
test_message: 测试消息
is_omni: 是否为Omni模型
Returns:
Dict: 验证结果
@@ -114,14 +116,27 @@ class ModelConfigService:
try:
start_time = time.time()
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
# dashscope 的 omni 模型需要使用 compatible-mode
if provider.lower() == ModelProvider.DASHSCOPE and is_omni:
if not api_base:
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
model_config = RedBearModelConfig(
model_name=model_name,
provider=ModelProvider.OPENAI,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
else:
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
# 根据模型类型选择不同的验证方式
model_type_lower = model_type.lower()
@@ -257,8 +272,9 @@ class ModelConfigService:
provider=model_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_data.type, # 传递模型类型
test_message="Hello"
model_type=model_data.type,
test_message="Hello",
is_omni=model_data.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -279,6 +295,9 @@ class ModelConfigService:
for api_key_data in api_key_datas:
api_key_data.model_name = model_data.name
api_key_data.provider = model_data.provider
# 同步capability和is_omni
api_key_data.capability = model_data.capability
api_key_data.is_omni = model_data.is_omni
api_key_create_schema = ModelApiKeyCreate(
model_config_ids=[model.id],
**api_key_data.model_dump()
@@ -497,6 +516,8 @@ class ModelApiKeyService:
existing_key.config = data.config
existing_key.priority = data.priority
existing_key.model_name = model_name
existing_key.capability = data.capability
existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
@@ -513,7 +534,8 @@ class ModelApiKeyService:
api_key=data.api_key,
api_base=data.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=data.is_omni
)
if not validation_result["valid"]:
# 记录验证失败的模型,但不抛出异常
@@ -528,6 +550,8 @@ class ModelApiKeyService:
provider=data.provider,
api_key=data.api_key,
api_base=data.api_base,
capability=data.capability if data.capability is not None else model_config.capability,
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni,
config=data.config,
is_active=data.is_active,
priority=data.priority
@@ -572,6 +596,8 @@ class ModelApiKeyService:
existing_key.config = api_key_data.config
existing_key.priority = api_key_data.priority
existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
@@ -589,7 +615,8 @@ class ModelApiKeyService:
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=model_config.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -620,7 +647,8 @@ class ModelApiKeyService:
api_key=api_key_data.api_key or existing_api_key.api_key,
api_base=api_key_data.api_base or existing_api_key.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=model_config.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -755,6 +783,8 @@ class ModelBaseService:
"type": model_base.type,
"logo": model_base.logo,
"description": model_base.description,
"capability": model_base.capability,
"is_omni": model_base.is_omni,
"is_composite": False
}
model_config = ModelConfigRepository.create(db, model_config_data)

View File

@@ -2593,6 +2593,7 @@ class MultiAgentOrchestrator:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.7, # 整合任务使用中等温度
max_tokens=2000
)
@@ -2758,6 +2759,7 @@ class MultiAgentOrchestrator:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.7,
max_tokens=2000,
extra_params={"streaming": True} # 启用流式输出

View File

@@ -267,7 +267,7 @@ class MultiAgentService:
# 2. 验证模型配置(如果提供了)
if data.default_model_config_id:
model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id)
if not model_api_key:
raise ResourceNotFoundException("模型配置", str(data.default_model_config_id))

View File

@@ -9,47 +9,100 @@
- OpenAI: 支持 URL 和 base64 格式
"""
import uuid
from typing import List, Dict, Any, Optional, Protocol
import httpx
import base64
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from sqlalchemy.orm import Session
from docx import Document
import io
import PyPDF2
from app.core.logging_config import get_business_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.models.generic_file_model import GenericFile
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
from app.services.audio_transcription_service import AudioTranscriptionService
logger = get_business_logger()
class ImageFormatStrategy(Protocol):
"""图片格式策略接口"""
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
@abstractmethod
async def format_image(self, url: str) -> Dict[str, Any]:
"""格式化图片"""
pass
@abstractmethod
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""格式化文档"""
pass
@abstractmethod
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
"""格式化音频"""
pass
@abstractmethod
async def format_video(self, url: str) -> Dict[str, Any]:
"""格式化视频"""
pass
class DashScopeFormatStrategy(MultimodalFormatStrategy):
"""通义千问策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""将图片 URL 转换为特定 provider 的格式"""
...
class DashScopeImageStrategy:
"""通义千问图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""通义千问格式: {"type": "image", "image": "url"}"""
"""通义千问图片格式:{"type": "image", "image": "url"}"""
return {
"type": "image",
"image": url
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""通义千问文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
class BedrockImageStrategy:
"""Bedrock/Anthropic 图片格式策略"""
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
通义千问音频格式
- 原生支持: qwen-audio 系列
- 其他模型: 需要转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
"type": "audio",
"audio": url
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""通义千问视频格式qwen-vl 系列原生支持)"""
return {
"type": "video",
"video": url
}
class BedrockFormatStrategy(MultimodalFormatStrategy):
"""Bedrock/Anthropic 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""
Bedrock/Anthropic 格式: base64 编码
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
"""
import httpx
import base64
from mimetypes import guess_type
logger.info(f"下载并编码图片: {url}")
@@ -84,9 +137,46 @@ class BedrockImageStrategy:
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
# Bedrock 文档需要 base64 编码
text_bytes = text.encode('utf-8')
base64_text = base64.b64encode(text_bytes).decode('utf-8')
class OpenAIImageStrategy:
"""OpenAI 图片格式策略"""
return {
"type": "document",
"source": {
"type": "base64",
"media_type": "text/plain",
"data": base64_text
}
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
Bedrock/Anthropic 音频格式
不支持原生音频,必须转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"[音频转录]\n{transcription}"
}
return {
"type": "text",
"text": "[音频文件Bedrock 不支持原生音频,请启用音频转文本功能]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 视频格式"""
return {
"type": "text",
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
}
class OpenAIFormatStrategy(MultimodalFormatStrategy):
"""OpenAI 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
@@ -97,29 +187,97 @@ class OpenAIImageStrategy:
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""OpenAI 文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
OpenAI 音频格式
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
- 其他模型使用转录文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# OpenAI 音频需要 base64 编码
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
audio_data = response.content
base64_audio = base64.b64encode(audio_data).decode('utf-8')
# 1. 优先从 file_type (MIME) 取扩展名
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
# 2. 从响应头 content-type 取
if not file_ext:
ct = response.headers.get("content-type", "")
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
# 3. 从 URL 路径取扩展名
if not file_ext:
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
# 4. 默认 wav
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
file_ext = "wav" if not file_ext else file_ext
return {
"type": "input_audio",
"input_audio": {
"data": f"data:;base64,{base64_audio}",
"format": file_ext
}
}
except Exception as e:
logger.error(f"下载音频失败: {e}")
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""OpenAI 视频格式"""
return {
"type": "video_url",
"video_url": {
"url": url
}
}
# Provider 到策略的映射
PROVIDER_STRATEGIES = {
"dashscope": DashScopeImageStrategy,
"bedrock": BedrockImageStrategy,
"anthropic": BedrockImageStrategy,
"openai": OpenAIImageStrategy,
"dashscope": DashScopeFormatStrategy,
"bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy,
}
class MultimodalService:
"""多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope"):
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
"""
初始化多模态服务
Args:
db: 数据库会话
provider: 模型提供商dashscope, bedrock, anthropic 等)
provider: 模型提供商dashscope, bedrock, anthropic, openai 等)
api_key: API 密钥(用于音频转文本)
enable_audio_transcription: 是否启用音频转文本
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
"""
self.db = db
self.provider = provider.lower()
self.api_key = api_key
self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files(
self,
@@ -137,20 +295,32 @@ class MultimodalService:
if not files:
return []
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
if self.provider == "dashscope" and self.is_omni:
strategy_class = OpenAIFormatStrategy
else:
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
if not strategy_class:
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
strategy_class = DashScopeFormatStrategy
strategy = strategy_class()
result = []
for idx, file in enumerate(files):
try:
if file.type == FileType.IMAGE:
content = await self._process_image(file)
content = await self._process_image(file, strategy)
result.append(content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file)
content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO:
content = await self._process_audio(file)
content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO:
content = await self._process_video(file)
content = await self._process_video(file, strategy)
result.append(content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
@@ -172,55 +342,29 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
Args:
file: 图片文件输入
strategy: 格式化策略
Returns:
Dict: 根据 provider 返回不同格式
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
- 通义千问: {"type": "image", "image": "url"}
Dict: 根据 provider 返回不同格式的图片内容
"""
url = await self.get_file_url(file)
logger.debug(f"处理图片: {url}, provider={self.provider}")
# 根据 provider 返回不同格式
if self.provider in ["bedrock", "anthropic"]:
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
try:
logger.info(f"开始下载并编码图片: {url}")
base64_data, media_type = await self._download_and_encode_image(url)
result = {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data[:100] + "..." # 只记录前100个字符
}
}
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
# 返回完整数据
result["source"]["data"] = base64_data
return result
except Exception as e:
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
# 返回错误提示
return {
"type": "text",
"text": f"[图片加载失败: {str(e)}]"
}
else:
# 通义千问等其他格式支持 URL
try:
url = await self.get_file_url(file)
return await strategy.format_image(url)
except Exception as e:
logger.error(f"处理图片失败: {e}", exc_info=True)
return {
"type": "image",
"image": url
"type": "text",
"text": f"[图片处理失败: {str(e)}]"
}
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
@@ -230,8 +374,6 @@ class MultimodalService:
Returns:
tuple: (base64_data, media_type)
"""
import httpx
import base64
from mimetypes import guess_type
# 下载图片
@@ -258,15 +400,16 @@ class MultimodalService:
return base64_data, media_type
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
Args:
file: 文档文件输入
strategy: 格式化策略
Returns:
Dict: text 格式的内容(包含提取的文本)
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
@@ -277,48 +420,68 @@ class MultimodalService:
else:
# 本地文件,提取文本内容
text = await self._extract_document_text(file.upload_file_id)
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
file_name = generic_file.file_name if generic_file else "unknown"
file_name = file_metadata.file_name if file_metadata else "unknown"
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
# 使用策略格式化文档
return await strategy.format_document(file_name, text)
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理音频文件
Args:
file: 音频文件输入
strategy: 格式化策略
Returns:
Dict: 音频内容(暂时返回占位符)
Dict: 根据 provider 返回不同格式的音频内容
"""
# TODO: 实现音频转文字功能
return {
"type": "text",
"text": "[音频文件,暂不支持处理]"
}
try:
url = await self.get_file_url(file)
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
# 如果启用音频转文本且有 API Key
transcription = None
if self.enable_audio_transcription and self.api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
return await strategy.format_audio(file.file_type, url, transcription)
except Exception as e:
logger.error(f"处理音频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理视频文件
Args:
file: 视频文件输入
strategy: 格式化策略
Returns:
Dict: 视频内容(暂时返回占位符)
Dict: 根据 provider 返回不同格式的视频内容
"""
# TODO: 实现视频处理功能
return {
"type": "text",
"text": "[视频文件,暂不支持处理]"
}
try:
url = await self.get_file_url(file)
return await strategy.format_video(url)
except Exception as e:
logger.error(f"处理视频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[视频处理失败: {str(e)}]"
}
async def get_file_url(self, file: FileInput) -> str:
"""
@@ -336,26 +499,22 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return file.url
else:
# 本地文件,通过 file_storage 系统获取永久访问 URL
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
file_id = file.upload_file_id
print("="*50)
print("file_id",file_id)
# 查询 FileMetadata
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# 返回永久URL
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
@@ -370,58 +529,79 @@ class MultimodalService:
Returns:
str: 提取的文本内容
"""
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file_id,
GenericFile.status == "active"
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not generic_file:
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# TODO: 根据文件类型提取文本
# - PDF: 使用 PyPDF2 或 pdfplumber
# - Word: 使用 python-docx
# - TXT/MD: 直接读取
file_ext = generic_file.file_ext.lower()
file_ext = file_metadata.file_ext.lower()
server_url = settings.FILE_LOCAL_SERVER_URL
file_url = f"{server_url}/storage/permanent/{file_id}"
if file_ext in ['.txt', '.md', '.markdown']:
return await self._read_text_file(generic_file.storage_path)
return await self._read_text_file(file_url)
elif file_ext == '.pdf':
return await self._extract_pdf_text(generic_file.storage_path)
return await self._extract_pdf_text(file_url)
elif file_ext in ['.doc', '.docx']:
return await self._extract_word_text(generic_file.storage_path)
return await self._extract_word_text(file_url)
else:
return f"[不支持的文档格式: {file_ext}]"
async def _read_text_file(self, storage_path: str) -> str:
@staticmethod
async def _read_text_file(file_url: str) -> str:
"""读取纯文本文件"""
try:
with open(storage_path, 'r', encoding='utf-8') as f:
return f.read()
# 下载文件
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
return response.text
except Exception as e:
logger.error(f"读取文本文件失败: {e}")
return f"[文件读取失败: {str(e)}]"
async def _extract_pdf_text(self, storage_path: str) -> str:
@staticmethod
async def _extract_pdf_text(file_url: str) -> str:
"""提取 PDF 文本"""
try:
# TODO: 实现 PDF 文本提取
# import PyPDF2 或 pdfplumber
return "[PDF 文本提取功能待实现]"
# 下载 PDF 文
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
pdf_data = response.content
# 使用 BytesIO 读取 PDF
text_parts = []
pdf_file = io.BytesIO(pdf_data)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page in pdf_reader.pages:
text_parts.append(page.extract_text())
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 PDF 文本失败: {e}")
return f"[PDF 提取失败: {str(e)}]"
async def _extract_word_text(self, storage_path: str) -> str:
@staticmethod
async def _extract_word_text(file_url: str) -> str:
"""提取 Word 文档文本"""
try:
# TODO: 实现 Word 文本提取
# import docx
return "[Word 文本提取功能待实现]"
# 下载 Word 文
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
word_data = response.content
# 使用 BytesIO 读取 Word 文档
word_file = io.BytesIO(word_data)
doc = Document(word_file)
text_parts = [paragraph.text for paragraph in doc.paragraphs]
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 Word 文本失败: {e}")
return f"[Word 提取失败: {str(e)}]"

View File

@@ -184,7 +184,8 @@ class PromptOptimizerService:
model_name=api_config.model_name,
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base
base_url=api_config.api_base,
is_omni=api_config.is_omni
), type=ModelType(model_config.type))
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')

View File

@@ -247,6 +247,7 @@ class SharedChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -454,6 +455,7 @@ class SharedChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,