[add] multimodal
This commit is contained in:
@@ -454,7 +454,8 @@ async def draft_run(
|
|||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -475,7 +476,8 @@ async def draft_run(
|
|||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"message_length": len(payload.message),
|
"message_length": len(payload.message),
|
||||||
"has_conversation_id": bool(payload.conversation_id),
|
"has_conversation_id": bool(payload.conversation_id),
|
||||||
"has_variables": bool(payload.variables)
|
"has_variables": bool(payload.variables),
|
||||||
|
"has_files": bool(payload.files)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -490,7 +492,8 @@ async def draft_run(
|
|||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@@ -438,7 +438,8 @@ async def chat(
|
|||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -475,7 +476,8 @@ async def chat(
|
|||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
|||||||
@@ -155,7 +155,8 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -180,7 +181,8 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ class LangChainAgent:
|
|||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tools: Optional[Sequence[BaseTool]] = None,
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
streaming: bool = False
|
streaming: bool = False,
|
||||||
|
max_iterations: int = 3 # 新增:最大迭代次数
|
||||||
):
|
):
|
||||||
"""初始化 LangChain Agent
|
"""初始化 LangChain Agent
|
||||||
|
|
||||||
@@ -59,12 +60,14 @@ class LangChainAgent:
|
|||||||
system_prompt: 系统提示词
|
system_prompt: 系统提示词
|
||||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||||
streaming: 是否启用流式输出(默认 True)
|
streaming: 是否启用流式输出(默认 True)
|
||||||
|
max_iterations: 最大迭代次数(默认 3 次,防止工具调用死循环)
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||||
self.tools = tools or []
|
self.tools = tools or []
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
|
||||||
# 创建 RedBearLLM(支持多提供商)
|
# 创建 RedBearLLM(支持多提供商)
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
@@ -104,9 +107,9 @@ class LangChainAgent:
|
|||||||
"has_api_base": bool(api_base),
|
"has_api_base": bool(api_base),
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
|
"max_iterations": max_iterations,
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
||||||
"tool_count": len(self.tools)
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -114,7 +117,8 @@ class LangChainAgent:
|
|||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None
|
context: Optional[str] = None,
|
||||||
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
"""准备消息列表
|
"""准备消息列表
|
||||||
|
|
||||||
@@ -122,6 +126,7 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表
|
history: 历史消息列表
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
|
files: 多模态文件内容列表(已处理)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
@@ -144,7 +149,15 @@ class LangChainAgent:
|
|||||||
if context:
|
if context:
|
||||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
messages.append(HumanMessage(content=user_content))
|
# 如果有文件,构建多模态消息
|
||||||
|
if files and len(files) > 0:
|
||||||
|
# 多模态消息格式: [{"type": "text", "text": "..."}, {"type": "image_url", ...}]
|
||||||
|
content_parts = [{"type": "text", "text": user_content}]
|
||||||
|
content_parts.extend(files)
|
||||||
|
messages.append(HumanMessage(content=content_parts))
|
||||||
|
else:
|
||||||
|
# 纯文本消息(向后兼容)
|
||||||
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
@@ -254,7 +267,8 @@ class LangChainAgent:
|
|||||||
config_id: Optional[str] = None, # 添加这个参数
|
config_id: Optional[str] = None, # 添加这个参数
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
memory_flag: Optional[bool] = True
|
memory_flag: Optional[bool] = True,
|
||||||
|
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行对话
|
"""执行对话
|
||||||
|
|
||||||
@@ -313,8 +327,8 @@ class LangChainAgent:
|
|||||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"准备调用 LangChain Agent",
|
"准备调用 LangChain Agent",
|
||||||
@@ -322,20 +336,79 @@ class LangChainAgent:
|
|||||||
"has_context": bool(context),
|
"has_context": bool(context),
|
||||||
"has_history": bool(history),
|
"has_history": bool(history),
|
||||||
"has_tools": bool(self.tools),
|
"has_tools": bool(self.tools),
|
||||||
"message_count": len(messages)
|
"has_files": bool(files),
|
||||||
|
"message_count": len(messages),
|
||||||
|
"max_iterations": self.max_iterations
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 统一使用 agent.invoke 调用
|
# 统一使用 agent.invoke 调用
|
||||||
result = await self.agent.ainvoke({"messages": messages})
|
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
||||||
|
try:
|
||||||
|
result = await self.agent.ainvoke(
|
||||||
|
{"messages": messages},
|
||||||
|
config={"recursion_limit": self.max_iterations}
|
||||||
|
)
|
||||||
|
except RecursionError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||||
|
extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
# 返回一个友好的错误提示
|
||||||
|
return {
|
||||||
|
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
||||||
|
"model": self.model_name,
|
||||||
|
"elapsed_time": time.time() - start_time,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# 获取最后的 AI 消息
|
# 获取最后的 AI 消息
|
||||||
output_messages = result.get("messages", [])
|
output_messages = result.get("messages", [])
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
|
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||||
|
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
content = msg.content
|
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||||
|
logger.debug(f"AI 消息内容: {msg.content}")
|
||||||
|
|
||||||
|
# 处理多模态响应:content 可能是字符串或列表
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
content = msg.content
|
||||||
|
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
||||||
|
elif isinstance(msg.content, list):
|
||||||
|
# 多模态响应:提取文本部分
|
||||||
|
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
||||||
|
text_parts = []
|
||||||
|
for item in msg.content:
|
||||||
|
logger.debug(f"处理项: {item}")
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# 通义千问格式: {"text": "..."}
|
||||||
|
if "text" in item:
|
||||||
|
text = item.get("text", "")
|
||||||
|
text_parts.append(text)
|
||||||
|
logger.debug(f"提取文本: {text[:100]}...")
|
||||||
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
|
elif item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
text_parts.append(text)
|
||||||
|
logger.debug(f"提取文本: {text[:100]}...")
|
||||||
|
elif isinstance(item, str):
|
||||||
|
text_parts.append(item)
|
||||||
|
logger.debug(f"提取字符串: {item[:100]}...")
|
||||||
|
content = "".join(text_parts)
|
||||||
|
logger.debug(f"合并后内容长度: {len(content)}")
|
||||||
|
else:
|
||||||
|
content = str(msg.content)
|
||||||
|
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
@@ -377,7 +450,8 @@ class LangChainAgent:
|
|||||||
config_id: Optional[str] = None,
|
config_id: Optional[str] = None,
|
||||||
storage_type:Optional[str] = None,
|
storage_type:Optional[str] = None,
|
||||||
user_rag_memory_id:Optional[str] = None,
|
user_rag_memory_id:Optional[str] = None,
|
||||||
memory_flag: Optional[bool] = True
|
memory_flag: Optional[bool] = True,
|
||||||
|
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""执行流式对话
|
"""执行流式对话
|
||||||
|
|
||||||
@@ -432,11 +506,11 @@ class LangChainAgent:
|
|||||||
|
|
||||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
@@ -448,7 +522,8 @@ class LangChainAgent:
|
|||||||
try:
|
try:
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
version="v2"
|
version="v2",
|
||||||
|
config={"recursion_limit": self.max_iterations}
|
||||||
):
|
):
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
@@ -457,20 +532,70 @@ class LangChainAgent:
|
|||||||
if kind == "on_chat_model_stream":
|
if kind == "on_chat_model_stream":
|
||||||
# LLM 流式输出
|
# LLM 流式输出
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
full_content+=chunk.content
|
if chunk and hasattr(chunk, "content"):
|
||||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
# 处理多模态响应:content 可能是字符串或列表
|
||||||
yield chunk.content
|
chunk_content = chunk.content
|
||||||
yielded_content = True
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
|
full_content += chunk_content
|
||||||
|
yield chunk_content
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(chunk_content, list):
|
||||||
|
# 多模态响应:提取文本部分
|
||||||
|
for item in chunk_content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# 通义千问格式: {"text": "..."}
|
||||||
|
if "text" in item:
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
|
elif item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(item, str):
|
||||||
|
full_content += item
|
||||||
|
yield item
|
||||||
|
yielded_content = True
|
||||||
|
|
||||||
elif kind == "on_llm_stream":
|
elif kind == "on_llm_stream":
|
||||||
# 另一种 LLM 流式事件
|
# 另一种 LLM 流式事件
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk:
|
if chunk:
|
||||||
if hasattr(chunk, "content") and chunk.content:
|
if hasattr(chunk, "content"):
|
||||||
full_content+=chunk.content
|
chunk_content = chunk.content
|
||||||
yield chunk.content
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
yielded_content = True
|
full_content += chunk_content
|
||||||
|
yield chunk_content
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(chunk_content, list):
|
||||||
|
# 多模态响应:提取文本部分
|
||||||
|
for item in chunk_content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# 通义千问格式: {"text": "..."}
|
||||||
|
if "text" in item:
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
|
elif item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(item, str):
|
||||||
|
full_content += item
|
||||||
|
yield item
|
||||||
|
yielded_content = True
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
|
full_content += chunk
|
||||||
yield chunk
|
yield chunk
|
||||||
yielded_content = True
|
yielded_content = True
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
provider: bedrock
|
provider: bedrock
|
||||||
enabled: true
|
enabled: false
|
||||||
models:
|
models:
|
||||||
- name: ai21
|
- name: ai21
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
provider: dashscope
|
provider: dashscope
|
||||||
enabled: true
|
enabled: false
|
||||||
models:
|
models:
|
||||||
- name: deepseek-r1-distill-qwen-14b
|
- name: deepseek-r1-distill-qwen-14b
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
provider: openai
|
provider: openai
|
||||||
enabled: true
|
enabled: false
|
||||||
models:
|
models:
|
||||||
- name: chatgpt-4o-latest
|
- name: chatgpt-4o-latest
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ from .app_schema import (
|
|||||||
MemoryConfig,
|
MemoryConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
VariableDefinition,
|
VariableDefinition,
|
||||||
|
FileInput,
|
||||||
|
FileType,
|
||||||
|
TransferMethod,
|
||||||
)
|
)
|
||||||
from .conversation_schema import (
|
from .conversation_schema import (
|
||||||
Conversation,
|
Conversation,
|
||||||
@@ -94,6 +97,9 @@ __all__ = [
|
|||||||
"MemoryConfig",
|
"MemoryConfig",
|
||||||
"ToolConfig",
|
"ToolConfig",
|
||||||
"VariableDefinition",
|
"VariableDefinition",
|
||||||
|
"FileInput",
|
||||||
|
"FileType",
|
||||||
|
"TransferMethod",
|
||||||
"Conversation",
|
"Conversation",
|
||||||
"ConversationCreate",
|
"ConversationCreate",
|
||||||
"ConversationWithMessages",
|
"ConversationWithMessages",
|
||||||
|
|||||||
@@ -1,10 +1,51 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Any, List, Dict, Union
|
from typing import Optional, Any, List, Dict, Union
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Multimodal File Support ----------
|
||||||
|
|
||||||
|
class FileType(str, Enum):
|
||||||
|
"""文件类型枚举"""
|
||||||
|
IMAGE = "image"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
|
class TransferMethod(str, Enum):
|
||||||
|
"""文件传输方式枚举"""
|
||||||
|
LOCAL_FILE = "local_file" # 已上传到系统的文件
|
||||||
|
REMOTE_URL = "remote_url" # 外部URL
|
||||||
|
|
||||||
|
|
||||||
|
class FileInput(BaseModel):
|
||||||
|
"""文件输入 Schema"""
|
||||||
|
type: FileType = Field(..., description="文件类型: image/document/audio/video")
|
||||||
|
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
|
||||||
|
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
||||||
|
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||||
|
|
||||||
|
@field_validator("upload_file_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_local_file(cls, v, info):
|
||||||
|
"""验证 local_file 时必须提供 upload_file_id"""
|
||||||
|
if info.data.get("transfer_method") == TransferMethod.LOCAL_FILE and not v:
|
||||||
|
raise ValueError("transfer_method 为 local_file 时,upload_file_id 不能为空")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("url")
|
||||||
|
@classmethod
|
||||||
|
def validate_remote_url(cls, v, info):
|
||||||
|
"""验证 remote_url 时必须提供 url"""
|
||||||
|
if info.data.get("transfer_method") == TransferMethod.REMOTE_URL and not v:
|
||||||
|
raise ValueError("transfer_method 为 remote_url 时,url 不能为空")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
# ---------- Input Schemas ----------
|
# ---------- Input Schemas ----------
|
||||||
|
|
||||||
class KnowledgeBaseConfig(BaseModel):
|
class KnowledgeBaseConfig(BaseModel):
|
||||||
@@ -360,6 +401,7 @@ class AppChatRequest(BaseModel):
|
|||||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
class DraftRunRequest(BaseModel):
|
class DraftRunRequest(BaseModel):
|
||||||
@@ -369,6 +411,7 @@ class DraftRunRequest(BaseModel):
|
|||||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
class DraftRunResponse(BaseModel):
|
class DraftRunResponse(BaseModel):
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ import datetime
|
|||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||||
|
|
||||||
|
# 导入 FileInput(用于体验运行)
|
||||||
|
from app.schemas.app_schema import FileInput
|
||||||
|
|
||||||
|
|
||||||
# ---------- Input Schemas ----------
|
# ---------- Input Schemas ----------
|
||||||
|
|
||||||
@@ -28,6 +31,7 @@ class ChatRequest(BaseModel):
|
|||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
||||||
memory: bool = Field(default=True, description="是否启用记忆功能")
|
memory: bool = Field(default=True, description="是否启用记忆功能")
|
||||||
|
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
# ---------- Output Schemas ----------
|
# ---------- Output Schemas ----------
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Dict, Any, AsyncGenerator, Annotated
|
from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -15,6 +15,7 @@ from app.core.logging_config import get_business_logger
|
|||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||||
from app.schemas import DraftRunRequest
|
from app.schemas import DraftRunRequest
|
||||||
|
from app.schemas.app_schema import FileInput
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -26,6 +27,7 @@ from app.services.draft_run_service import create_web_search_tool
|
|||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -48,7 +50,8 @@ class AppChatService:
|
|||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None
|
workspace_id: Optional[str] = None,
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
|
|
||||||
@@ -155,7 +158,14 @@ class AppChatService:
|
|||||||
for msg in messages
|
for msg in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
# 调用 Agent
|
# 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
multimodal_service = MultimodalService(self.db)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
|
# 调用 Agent(支持多模态)
|
||||||
result = await agent.chat(
|
result = await agent.chat(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -164,7 +174,8 @@ class AppChatService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
@@ -199,6 +210,7 @@ class AppChatService:
|
|||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
|
|
||||||
@@ -305,10 +317,17 @@ class AppChatService:
|
|||||||
for msg in messages
|
for msg in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
multimodal_service = MultimodalService(self.db)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 发送开始事件
|
# 发送开始事件
|
||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# 流式调用 Agent
|
# 流式调用 Agent(支持多模态)
|
||||||
full_content = ""
|
full_content = ""
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
@@ -318,7 +337,8 @@ class AppChatService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
):
|
):
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
# 发送消息块事件
|
# 发送消息块事件
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ from app.models import AgentConfig, ModelApiKey, ModelConfig
|
|||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
from app.repositories.model_repository import ModelApiKeyRepository
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||||
|
from app.schemas.app_schema import FileInput
|
||||||
from app.services import task_service
|
from app.services import task_service
|
||||||
from app.services.langchain_tool_server import Search
|
from app.services.langchain_tool_server import Search
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -246,7 +248,8 @@ class DraftRunService:
|
|||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True,
|
web_search: bool = True,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
sub_agent: bool = False
|
sub_agent: bool = False,
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行试运行(使用 LangChain Agent)
|
"""执行试运行(使用 LangChain Agent)
|
||||||
|
|
||||||
@@ -406,7 +409,14 @@ class DraftRunService:
|
|||||||
max_history=agent_config.memory.get("max_history", 10)
|
max_history=agent_config.memory.get("max_history", 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 知识库检索
|
# 6. 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
multimodal_service = MultimodalService(self.db)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
|
# 7. 知识库检索
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -414,14 +424,15 @@ class DraftRunService:
|
|||||||
extra={
|
extra={
|
||||||
"model": api_key_config["model_name"],
|
"model": api_key_config["model_name"],
|
||||||
"has_history": bool(history),
|
"has_history": bool(history),
|
||||||
"has_context": bool(context)
|
"has_context": bool(context),
|
||||||
|
"has_files": bool(processed_files)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_config_= agent_config.memory
|
memory_config_= agent_config.memory
|
||||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||||
|
|
||||||
# 7. 调用 Agent
|
# 8. 调用 Agent(支持多模态)
|
||||||
result = await agent.chat(
|
result = await agent.chat(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -430,12 +441,13 @@ class DraftRunService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# 8. 保存会话消息
|
# 9. 保存会话消息
|
||||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
@@ -486,7 +498,8 @@ class DraftRunService:
|
|||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True, # 布尔类型默认值
|
web_search: bool = True, # 布尔类型默认值
|
||||||
memory: bool = True, # 布尔类型默认值
|
memory: bool = True, # 布尔类型默认值
|
||||||
sub_agent: bool = False # 是否是作为子Agent运行
|
sub_agent: bool = False, # 是否是作为子Agent运行
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""执行试运行(流式返回,使用 LangChain Agent)
|
"""执行试运行(流式返回,使用 LangChain Agent)
|
||||||
@@ -635,6 +648,13 @@ class DraftRunService:
|
|||||||
max_history=agent_config.memory.get("max_history", 10)
|
max_history=agent_config.memory.get("max_history", 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 6. 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
multimodal_service = MultimodalService(self.db)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
@@ -647,7 +667,7 @@ class DraftRunService:
|
|||||||
memory_config_ = agent_config.memory
|
memory_config_ = agent_config.memory
|
||||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||||
|
|
||||||
# 9. 流式调用 Agent
|
# 9. 流式调用 Agent(支持多模态)
|
||||||
full_content = ""
|
full_content = ""
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
@@ -657,7 +677,8 @@ class DraftRunService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
):
|
):
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
# 发送消息块事件
|
# 发送消息块事件
|
||||||
|
|||||||
268
api/app/services/multimodal_service.py
Normal file
268
api/app/services/multimodal_service.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""
|
||||||
|
多模态文件处理服务
|
||||||
|
|
||||||
|
处理图片、文档等多模态文件,转换为 LLM 可用的格式
|
||||||
|
|
||||||
|
格式说明:
|
||||||
|
- 当前使用通义千问格式
|
||||||
|
- 通义千问格式: {"type": "image", "image": "url"}
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||||
|
from app.models.generic_file_model import GenericFile
|
||||||
|
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalService:
|
||||||
|
"""多模态文件处理服务"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
async def process_files(
|
||||||
|
self,
|
||||||
|
files: Optional[List[FileInput]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
处理文件列表,返回 LLM 可用的格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 文件输入列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: LLM 可用的内容格式列表
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for idx, file in enumerate(files):
|
||||||
|
try:
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
|
content = await self._process_image(file)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.DOCUMENT:
|
||||||
|
content = await self._process_document(file)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.AUDIO:
|
||||||
|
content = await self._process_audio(file)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.VIDEO:
|
||||||
|
content = await self._process_video(file)
|
||||||
|
result.append(content)
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的文件类型: {file.type}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"处理文件失败",
|
||||||
|
extra={
|
||||||
|
"file_index": idx,
|
||||||
|
"file_type": file.type,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# 继续处理其他文件,不中断整个流程
|
||||||
|
result.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[文件处理失败: {str(e)}]"
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理图片文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 图片文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 通义千问格式
|
||||||
|
"""
|
||||||
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
|
# 远程 URL,使用通义千问格式
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"image": file.url
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 本地文件,获取访问 URL
|
||||||
|
url = await self._get_file_url(file.upload_file_id)
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"image": url
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理文档文件(PDF、Word 等)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 文档文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: text 格式的内容(包含提取的文本)
|
||||||
|
"""
|
||||||
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
|
# 远程文档暂不支持提取
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 本地文件,提取文本内容
|
||||||
|
text = await self._extract_document_text(file.upload_file_id)
|
||||||
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
|
GenericFile.id == file.upload_file_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
file_name = generic_file.file_name if generic_file else "unknown"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理音频文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 音频文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 音频内容(暂时返回占位符)
|
||||||
|
"""
|
||||||
|
# TODO: 实现音频转文字功能
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": "[音频文件,暂不支持处理]"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理视频文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 视频文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 视频内容(暂时返回占位符)
|
||||||
|
"""
|
||||||
|
# TODO: 实现视频处理功能
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": "[视频文件,暂不支持处理]"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _get_file_url(self, file_id: uuid.UUID) -> str:
|
||||||
|
"""
|
||||||
|
获取文件的访问 URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: 文件ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 文件访问 URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 文件不存在
|
||||||
|
"""
|
||||||
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
|
GenericFile.id == file_id,
|
||||||
|
GenericFile.status == "active"
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not generic_file:
|
||||||
|
raise BusinessException(
|
||||||
|
f"文件不存在或已删除: {file_id}",
|
||||||
|
BizCode.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有 access_url,直接返回
|
||||||
|
if generic_file.access_url:
|
||||||
|
return generic_file.access_url
|
||||||
|
|
||||||
|
# 否则,根据 storage_path 生成 URL
|
||||||
|
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||||
|
# 这里暂时返回一个占位 URL
|
||||||
|
return f"/api/files/{file_id}/download"
|
||||||
|
|
||||||
|
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||||
|
"""
|
||||||
|
提取文档文本内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: 文件ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 提取的文本内容
|
||||||
|
"""
|
||||||
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
|
GenericFile.id == file_id,
|
||||||
|
GenericFile.status == "active"
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not generic_file:
|
||||||
|
raise BusinessException(
|
||||||
|
f"文件不存在或已删除: {file_id}",
|
||||||
|
BizCode.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: 根据文件类型提取文本
|
||||||
|
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||||
|
# - Word: 使用 python-docx
|
||||||
|
# - TXT/MD: 直接读取
|
||||||
|
|
||||||
|
file_ext = generic_file.file_ext.lower()
|
||||||
|
|
||||||
|
if file_ext in ['.txt', '.md', '.markdown']:
|
||||||
|
return await self._read_text_file(generic_file.storage_path)
|
||||||
|
elif file_ext == '.pdf':
|
||||||
|
return await self._extract_pdf_text(generic_file.storage_path)
|
||||||
|
elif file_ext in ['.doc', '.docx']:
|
||||||
|
return await self._extract_word_text(generic_file.storage_path)
|
||||||
|
else:
|
||||||
|
return f"[不支持的文档格式: {file_ext}]"
|
||||||
|
|
||||||
|
async def _read_text_file(self, storage_path: str) -> str:
|
||||||
|
"""读取纯文本文件"""
|
||||||
|
try:
|
||||||
|
with open(storage_path, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取文本文件失败: {e}")
|
||||||
|
return f"[文件读取失败: {str(e)}]"
|
||||||
|
|
||||||
|
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||||
|
"""提取 PDF 文本"""
|
||||||
|
try:
|
||||||
|
# TODO: 实现 PDF 文本提取
|
||||||
|
# import PyPDF2 或 pdfplumber
|
||||||
|
return "[PDF 文本提取功能待实现]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 PDF 文本失败: {e}")
|
||||||
|
return f"[PDF 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
async def _extract_word_text(self, storage_path: str) -> str:
|
||||||
|
"""提取 Word 文档文本"""
|
||||||
|
try:
|
||||||
|
# TODO: 实现 Word 文本提取
|
||||||
|
# import docx
|
||||||
|
return "[Word 文本提取功能待实现]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 Word 文本失败: {e}")
|
||||||
|
return f"[Word 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
|
||||||
|
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||||
|
"""获取多模态服务实例(依赖注入)"""
|
||||||
|
return MultimodalService(db)
|
||||||
3082
api/uv.lock
generated
3082
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user