[add] multimodal
This commit is contained in:
@@ -454,7 +454,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -475,7 +476,8 @@ async def draft_run(
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id),
|
||||
"has_variables": bool(payload.variables)
|
||||
"has_variables": bool(payload.variables),
|
||||
"has_files": bool(payload.files)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -490,7 +492,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
|
||||
@@ -438,7 +438,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -475,7 +476,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
|
||||
@@ -155,7 +155,8 @@ async def chat(
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -180,7 +181,8 @@ async def chat(
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
|
||||
@@ -45,7 +45,8 @@ class LangChainAgent:
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False
|
||||
streaming: bool = False,
|
||||
max_iterations: int = 3 # 新增:最大迭代次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -59,12 +60,14 @@ class LangChainAgent:
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||
streaming: 是否启用流式输出(默认 True)
|
||||
max_iterations: 最大迭代次数(默认 3 次,防止工具调用死循环)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
model_config = RedBearModelConfig(
|
||||
@@ -104,9 +107,9 @@ class LangChainAgent:
|
||||
"has_api_base": bool(api_base),
|
||||
"temperature": temperature,
|
||||
"streaming": streaming,
|
||||
"max_iterations": max_iterations,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
||||
}
|
||||
)
|
||||
|
||||
@@ -114,7 +117,8 @@ class LangChainAgent:
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -122,6 +126,7 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件内容列表(已处理)
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
@@ -144,7 +149,15 @@ class LangChainAgent:
|
||||
if context:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
# 如果有文件,构建多模态消息
|
||||
if files and len(files) > 0:
|
||||
# 多模态消息格式: [{"type": "text", "text": "..."}, {"type": "image_url", ...}]
|
||||
content_parts = [{"type": "text", "text": user_content}]
|
||||
content_parts.extend(files)
|
||||
messages.append(HumanMessage(content=content_parts))
|
||||
else:
|
||||
# 纯文本消息(向后兼容)
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
@@ -254,7 +267,8 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -313,8 +327,8 @@ class LangChainAgent:
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
"准备调用 LangChain Agent",
|
||||
@@ -322,20 +336,79 @@ class LangChainAgent:
|
||||
"has_context": bool(context),
|
||||
"has_history": bool(history),
|
||||
"has_tools": bool(self.tools),
|
||||
"message_count": len(messages)
|
||||
"has_files": bool(files),
|
||||
"message_count": len(messages),
|
||||
"max_iterations": self.max_iterations
|
||||
}
|
||||
)
|
||||
|
||||
# 统一使用 agent.invoke 调用
|
||||
result = await self.agent.ainvoke({"messages": messages})
|
||||
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
||||
try:
|
||||
result = await self.agent.ainvoke(
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except RecursionError as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
)
|
||||
# 返回一个友好的错误提示
|
||||
return {
|
||||
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
||||
"model": self.model_name,
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
content = msg.content
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
||||
elif isinstance(msg.content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
||||
text_parts = []
|
||||
for item in msg.content:
|
||||
logger.debug(f"处理项: {item}")
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
logger.debug(f"提取字符串: {item[:100]}...")
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"合并后内容长度: {len(content)}")
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
@@ -377,7 +450,8 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -432,11 +506,11 @@ class LangChainAgent:
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
@@ -448,7 +522,8 @@ class LangChainAgent:
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2"
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
@@ -457,20 +532,70 @@ class LangChainAgent:
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
full_content+=chunk.content
|
||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
full_content+=chunk.content
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if hasattr(chunk, "content"):
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
provider: bedrock
|
||||
enabled: true
|
||||
enabled: false
|
||||
models:
|
||||
- name: ai21
|
||||
type: llm
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
provider: dashscope
|
||||
enabled: true
|
||||
enabled: false
|
||||
models:
|
||||
- name: deepseek-r1-distill-qwen-14b
|
||||
type: llm
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
provider: openai
|
||||
enabled: true
|
||||
enabled: false
|
||||
models:
|
||||
- name: chatgpt-4o-latest
|
||||
type: llm
|
||||
|
||||
@@ -26,6 +26,9 @@ from .app_schema import (
|
||||
MemoryConfig,
|
||||
ToolConfig,
|
||||
VariableDefinition,
|
||||
FileInput,
|
||||
FileType,
|
||||
TransferMethod,
|
||||
)
|
||||
from .conversation_schema import (
|
||||
Conversation,
|
||||
@@ -94,6 +97,9 @@ __all__ = [
|
||||
"MemoryConfig",
|
||||
"ToolConfig",
|
||||
"VariableDefinition",
|
||||
"FileInput",
|
||||
"FileType",
|
||||
"TransferMethod",
|
||||
"Conversation",
|
||||
"ConversationCreate",
|
||||
"ConversationWithMessages",
|
||||
|
||||
@@ -1,10 +1,51 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Optional, Any, List, Dict, Union
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
# ---------- Multimodal File Support ----------
|
||||
|
||||
class FileType(str, Enum):
|
||||
"""文件类型枚举"""
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
|
||||
|
||||
class TransferMethod(str, Enum):
|
||||
"""文件传输方式枚举"""
|
||||
LOCAL_FILE = "local_file" # 已上传到系统的文件
|
||||
REMOTE_URL = "remote_url" # 外部URL
|
||||
|
||||
|
||||
class FileInput(BaseModel):
|
||||
"""文件输入 Schema"""
|
||||
type: FileType = Field(..., description="文件类型: image/document/audio/video")
|
||||
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
|
||||
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||
|
||||
@field_validator("upload_file_id")
|
||||
@classmethod
|
||||
def validate_local_file(cls, v, info):
|
||||
"""验证 local_file 时必须提供 upload_file_id"""
|
||||
if info.data.get("transfer_method") == TransferMethod.LOCAL_FILE and not v:
|
||||
raise ValueError("transfer_method 为 local_file 时,upload_file_id 不能为空")
|
||||
return v
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_remote_url(cls, v, info):
|
||||
"""验证 remote_url 时必须提供 url"""
|
||||
if info.data.get("transfer_method") == TransferMethod.REMOTE_URL and not v:
|
||||
raise ValueError("transfer_method 为 remote_url 时,url 不能为空")
|
||||
return v
|
||||
|
||||
|
||||
# ---------- Input Schemas ----------
|
||||
|
||||
class KnowledgeBaseConfig(BaseModel):
|
||||
@@ -360,6 +401,7 @@ class AppChatRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class DraftRunRequest(BaseModel):
|
||||
@@ -369,6 +411,7 @@ class DraftRunRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class DraftRunResponse(BaseModel):
|
||||
|
||||
@@ -4,6 +4,9 @@ import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
# 导入 FileInput(用于体验运行)
|
||||
from app.schemas.app_schema import FileInput
|
||||
|
||||
|
||||
# ---------- Input Schemas ----------
|
||||
|
||||
@@ -28,6 +31,7 @@ class ChatRequest(BaseModel):
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
||||
memory: bool = Field(default=True, description="是否启用记忆功能")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, AsyncGenerator, Annotated
|
||||
from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -15,6 +15,7 @@ from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.services.tool_service import ToolService
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.db import get_db
|
||||
@@ -26,6 +27,7 @@ from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -48,7 +50,8 @@ class AppChatService:
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
|
||||
@@ -155,7 +158,14 @@ class AppChatService:
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 调用 Agent
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
result = await agent.chat(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -164,7 +174,8 @@ class AppChatService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -199,6 +210,7 @@ class AppChatService:
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
@@ -305,10 +317,17 @@ class AppChatService:
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流式调用 Agent
|
||||
# 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
@@ -318,7 +337,8 @@ class AppChatService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
|
||||
@@ -19,11 +19,13 @@ from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.services import task_service
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@@ -246,7 +248,8 @@ class DraftRunService:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
web_search: bool = True,
|
||||
memory: bool = True,
|
||||
sub_agent: bool = False
|
||||
sub_agent: bool = False,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""执行试运行(使用 LangChain Agent)
|
||||
|
||||
@@ -406,7 +409,14 @@ class DraftRunService:
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 知识库检索
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 7. 知识库检索
|
||||
context = None
|
||||
|
||||
logger.debug(
|
||||
@@ -414,14 +424,15 @@ class DraftRunService:
|
||||
extra={
|
||||
"model": api_key_config["model_name"],
|
||||
"has_history": bool(history),
|
||||
"has_context": bool(context)
|
||||
"has_context": bool(context),
|
||||
"has_files": bool(processed_files)
|
||||
}
|
||||
)
|
||||
|
||||
memory_config_= agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
|
||||
# 7. 调用 Agent
|
||||
# 8. 调用 Agent(支持多模态)
|
||||
result = await agent.chat(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -430,12 +441,13 @@ class DraftRunService:
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 8. 保存会话消息
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
@@ -486,7 +498,8 @@ class DraftRunService:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
web_search: bool = True, # 布尔类型默认值
|
||||
memory: bool = True, # 布尔类型默认值
|
||||
sub_agent: bool = False # 是否是作为子Agent运行
|
||||
sub_agent: bool = False, # 是否是作为子Agent运行
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行试运行(流式返回,使用 LangChain Agent)
|
||||
@@ -635,6 +648,13 @@ class DraftRunService:
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 7. 知识库检索
|
||||
context = None
|
||||
|
||||
@@ -647,7 +667,7 @@ class DraftRunService:
|
||||
memory_config_ = agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
|
||||
# 9. 流式调用 Agent
|
||||
# 9. 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
@@ -657,7 +677,8 @@ class DraftRunService:
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
|
||||
268
api/app/services/multimodal_service.py
Normal file
268
api/app/services/multimodal_service.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
多模态文件处理服务
|
||||
|
||||
处理图片、文档等多模态文件,转换为 LLM 可用的格式
|
||||
|
||||
格式说明:
|
||||
- 当前使用通义千问格式
|
||||
- 通义千问格式: {"type": "image", "image": "url"}
|
||||
"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.generic_file_model import GenericFile
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
List[Dict]: LLM 可用的内容格式列表
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file)
|
||||
result.append(content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
content = await self._process_audio(file)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
content = await self._process_video(file)
|
||||
result.append(content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理文件失败",
|
||||
extra={
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
"type": "text",
|
||||
"text": f"[文件处理失败: {str(e)}]"
|
||||
})
|
||||
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件")
|
||||
return result
|
||||
|
||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
Args:
|
||||
file: 图片文件输入
|
||||
|
||||
Returns:
|
||||
Dict: 通义千问格式
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程 URL,使用通义千问格式
|
||||
return {
|
||||
"type": "image",
|
||||
"image": file.url
|
||||
}
|
||||
else:
|
||||
# 本地文件,获取访问 URL
|
||||
url = await self._get_file_url(file.upload_file_id)
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
}
|
||||
|
||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
|
||||
Args:
|
||||
file: 文档文件输入
|
||||
|
||||
Returns:
|
||||
Dict: text 格式的内容(包含提取的文本)
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
file_name = generic_file.file_name if generic_file else "unknown"
|
||||
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理音频文件
|
||||
|
||||
Args:
|
||||
file: 音频文件输入
|
||||
|
||||
Returns:
|
||||
Dict: 音频内容(暂时返回占位符)
|
||||
"""
|
||||
# TODO: 实现音频转文字功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理视频文件
|
||||
|
||||
Args:
|
||||
file: 视频文件输入
|
||||
|
||||
Returns:
|
||||
Dict: 视频内容(暂时返回占位符)
|
||||
"""
|
||||
# TODO: 实现视频处理功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[视频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
async def _get_file_url(self, file_id: uuid.UUID) -> str:
|
||||
"""
|
||||
获取文件的访问 URL
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
str: 文件访问 URL
|
||||
|
||||
Raises:
|
||||
BusinessException: 文件不存在
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# 如果有 access_url,直接返回
|
||||
if generic_file.access_url:
|
||||
return generic_file.access_url
|
||||
|
||||
# 否则,根据 storage_path 生成 URL
|
||||
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||
# 这里暂时返回一个占位 URL
|
||||
return f"/api/files/{file_id}/download"
|
||||
|
||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# TODO: 根据文件类型提取文本
|
||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||
# - Word: 使用 python-docx
|
||||
# - TXT/MD: 直接读取
|
||||
|
||||
file_ext = generic_file.file_ext.lower()
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(generic_file.storage_path)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(generic_file.storage_path)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(generic_file.storage_path)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
async def _read_text_file(self, storage_path: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
with open(storage_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
|
||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# TODO: 实现 PDF 文本提取
|
||||
# import PyPDF2 或 pdfplumber
|
||||
return "[PDF 文本提取功能待实现]"
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 文本失败: {e}")
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
async def _extract_word_text(self, storage_path: str) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# TODO: 实现 Word 文本提取
|
||||
# import docx
|
||||
return "[Word 文本提取功能待实现]"
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Word 文本失败: {e}")
|
||||
return f"[Word 提取失败: {str(e)}]"
|
||||
|
||||
|
||||
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||
"""获取多模态服务实例(依赖注入)"""
|
||||
return MultimodalService(db)
|
||||
3082
api/uv.lock
generated
3082
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user