[add] multimodal
This commit is contained in:
@@ -3,7 +3,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, AsyncGenerator, Annotated
|
||||
from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -15,6 +15,7 @@ from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.services.tool_service import ToolService
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.db import get_db
|
||||
@@ -26,6 +27,7 @@ from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -48,7 +50,8 @@ class AppChatService:
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
|
||||
@@ -155,7 +158,14 @@ class AppChatService:
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 调用 Agent
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
result = await agent.chat(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -164,7 +174,8 @@ class AppChatService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -199,6 +210,7 @@ class AppChatService:
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
@@ -305,10 +317,17 @@ class AppChatService:
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流式调用 Agent
|
||||
# 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
@@ -318,7 +337,8 @@ class AppChatService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
|
||||
@@ -19,11 +19,13 @@ from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.services import task_service
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@@ -246,7 +248,8 @@ class DraftRunService:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
web_search: bool = True,
|
||||
memory: bool = True,
|
||||
sub_agent: bool = False
|
||||
sub_agent: bool = False,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""执行试运行(使用 LangChain Agent)
|
||||
|
||||
@@ -406,7 +409,14 @@ class DraftRunService:
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 知识库检索
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 7. 知识库检索
|
||||
context = None
|
||||
|
||||
logger.debug(
|
||||
@@ -414,14 +424,15 @@ class DraftRunService:
|
||||
extra={
|
||||
"model": api_key_config["model_name"],
|
||||
"has_history": bool(history),
|
||||
"has_context": bool(context)
|
||||
"has_context": bool(context),
|
||||
"has_files": bool(processed_files)
|
||||
}
|
||||
)
|
||||
|
||||
memory_config_= agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
|
||||
# 7. 调用 Agent
|
||||
# 8. 调用 Agent(支持多模态)
|
||||
result = await agent.chat(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -430,12 +441,13 @@ class DraftRunService:
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 8. 保存会话消息
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
@@ -486,7 +498,8 @@ class DraftRunService:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
web_search: bool = True, # 布尔类型默认值
|
||||
memory: bool = True, # 布尔类型默认值
|
||||
sub_agent: bool = False # 是否是作为子Agent运行
|
||||
sub_agent: bool = False, # 是否是作为子Agent运行
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行试运行(流式返回,使用 LangChain Agent)
|
||||
@@ -635,6 +648,13 @@ class DraftRunService:
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 7. 知识库检索
|
||||
context = None
|
||||
|
||||
@@ -647,7 +667,7 @@ class DraftRunService:
|
||||
memory_config_ = agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
|
||||
# 9. 流式调用 Agent
|
||||
# 9. 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
@@ -657,7 +677,8 @@ class DraftRunService:
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
|
||||
268
api/app/services/multimodal_service.py
Normal file
268
api/app/services/multimodal_service.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
多模态文件处理服务
|
||||
|
||||
处理图片、文档等多模态文件,转换为 LLM 可用的格式
|
||||
|
||||
格式说明:
|
||||
- 当前使用通义千问格式
|
||||
- 通义千问格式: {"type": "image", "image": "url"}
|
||||
"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.generic_file_model import GenericFile
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
List[Dict]: LLM 可用的内容格式列表
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file)
|
||||
result.append(content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
content = await self._process_audio(file)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
content = await self._process_video(file)
|
||||
result.append(content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理文件失败",
|
||||
extra={
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
"type": "text",
|
||||
"text": f"[文件处理失败: {str(e)}]"
|
||||
})
|
||||
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件")
|
||||
return result
|
||||
|
||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
Args:
|
||||
file: 图片文件输入
|
||||
|
||||
Returns:
|
||||
Dict: 通义千问格式
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程 URL,使用通义千问格式
|
||||
return {
|
||||
"type": "image",
|
||||
"image": file.url
|
||||
}
|
||||
else:
|
||||
# 本地文件,获取访问 URL
|
||||
url = await self._get_file_url(file.upload_file_id)
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
}
|
||||
|
||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
|
||||
Args:
|
||||
file: 文档文件输入
|
||||
|
||||
Returns:
|
||||
Dict: text 格式的内容(包含提取的文本)
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
file_name = generic_file.file_name if generic_file else "unknown"
|
||||
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理音频文件
|
||||
|
||||
Args:
|
||||
file: 音频文件输入
|
||||
|
||||
Returns:
|
||||
Dict: 音频内容(暂时返回占位符)
|
||||
"""
|
||||
# TODO: 实现音频转文字功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理视频文件
|
||||
|
||||
Args:
|
||||
file: 视频文件输入
|
||||
|
||||
Returns:
|
||||
Dict: 视频内容(暂时返回占位符)
|
||||
"""
|
||||
# TODO: 实现视频处理功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[视频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
async def _get_file_url(self, file_id: uuid.UUID) -> str:
|
||||
"""
|
||||
获取文件的访问 URL
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
str: 文件访问 URL
|
||||
|
||||
Raises:
|
||||
BusinessException: 文件不存在
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# 如果有 access_url,直接返回
|
||||
if generic_file.access_url:
|
||||
return generic_file.access_url
|
||||
|
||||
# 否则,根据 storage_path 生成 URL
|
||||
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||
# 这里暂时返回一个占位 URL
|
||||
return f"/api/files/{file_id}/download"
|
||||
|
||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# TODO: 根据文件类型提取文本
|
||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||
# - Word: 使用 python-docx
|
||||
# - TXT/MD: 直接读取
|
||||
|
||||
file_ext = generic_file.file_ext.lower()
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(generic_file.storage_path)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(generic_file.storage_path)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(generic_file.storage_path)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
async def _read_text_file(self, storage_path: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
with open(storage_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
|
||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# TODO: 实现 PDF 文本提取
|
||||
# import PyPDF2 或 pdfplumber
|
||||
return "[PDF 文本提取功能待实现]"
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 文本失败: {e}")
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
async def _extract_word_text(self, storage_path: str) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# TODO: 实现 Word 文本提取
|
||||
# import docx
|
||||
return "[Word 文本提取功能待实现]"
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Word 文本失败: {e}")
|
||||
return f"[Word 提取失败: {str(e)}]"
|
||||
|
||||
|
||||
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||
"""获取多模态服务实例(依赖注入)"""
|
||||
return MultimodalService(db)
|
||||
Reference in New Issue
Block a user