Merge branch 'feature/multimodal' into develop
This commit is contained in:
@@ -454,7 +454,8 @@ async def draft_run(
|
|||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -475,7 +476,8 @@ async def draft_run(
|
|||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"message_length": len(payload.message),
|
"message_length": len(payload.message),
|
||||||
"has_conversation_id": bool(payload.conversation_id),
|
"has_conversation_id": bool(payload.conversation_id),
|
||||||
"has_variables": bool(payload.variables)
|
"has_variables": bool(payload.variables),
|
||||||
|
"has_files": bool(payload.files)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -490,7 +492,8 @@ async def draft_run(
|
|||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id or str(current_user.id),
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@@ -438,7 +438,8 @@ async def chat(
|
|||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -475,7 +476,8 @@ async def chat(
|
|||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
|||||||
@@ -155,7 +155,8 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -180,7 +181,8 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
files=payload.files # 传递多模态文件
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ class LangChainAgent:
|
|||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tools: Optional[Sequence[BaseTool]] = None,
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
streaming: bool = False
|
streaming: bool = False,
|
||||||
|
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||||
|
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||||
):
|
):
|
||||||
"""初始化 LangChain Agent
|
"""初始化 LangChain Agent
|
||||||
|
|
||||||
@@ -59,13 +61,36 @@ class LangChainAgent:
|
|||||||
max_tokens: 最大 token 数
|
max_tokens: 最大 token 数
|
||||||
system_prompt: 系统提示词
|
system_prompt: 系统提示词
|
||||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||||
streaming: 是否启用流式输出(默认 True)
|
streaming: 是否启用流式输出
|
||||||
|
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
||||||
|
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
|
||||||
self.tools = tools or []
|
self.tools = tools or []
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||||
|
|
||||||
|
# 工具调用计数器:记录每个工具的连续调用次数
|
||||||
|
self.tool_call_counter: Dict[str, int] = {}
|
||||||
|
self.last_tool_called: Optional[str] = None
|
||||||
|
|
||||||
|
# 根据工具数量动态调整最大迭代次数
|
||||||
|
# 基础值 + 每个工具额外的调用机会
|
||||||
|
if max_iterations is None:
|
||||||
|
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
||||||
|
self.max_iterations = 5 + len(self.tools) * 2
|
||||||
|
else:
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
|
||||||
|
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||||
|
f"tool_count={len(self.tools)}, "
|
||||||
|
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
||||||
|
f"auto_calculated={max_iterations is None}"
|
||||||
|
)
|
||||||
|
|
||||||
# 创建 RedBearLLM(支持多提供商)
|
# 创建 RedBearLLM(支持多提供商)
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
@@ -89,11 +114,14 @@ class LangChainAgent:
|
|||||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||||
self._underlying_llm.streaming = True
|
self._underlying_llm.streaming = True
|
||||||
|
|
||||||
|
# 包装工具以跟踪连续调用次数
|
||||||
|
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
||||||
|
|
||||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||||
# 无论是否有工具,都使用 agent 统一处理
|
# 无论是否有工具,都使用 agent 统一处理
|
||||||
self.agent = create_agent(
|
self.agent = create_agent(
|
||||||
model=self.llm,
|
model=self.llm,
|
||||||
tools=self.tools if self.tools else None,
|
tools=wrapped_tools,
|
||||||
system_prompt=self.system_prompt
|
system_prompt=self.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,17 +133,91 @@ class LangChainAgent:
|
|||||||
"has_api_base": bool(api_base),
|
"has_api_base": bool(api_base),
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
|
"max_iterations": self.max_iterations,
|
||||||
|
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||||
# "tool_count": len(self.tools)
|
# "tool_count": len(self.tools)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
||||||
|
"""包装工具以跟踪连续调用次数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: 原始工具列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseTool]: 包装后的工具列表
|
||||||
|
"""
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
wrapped_tools = []
|
||||||
|
|
||||||
|
for original_tool in tools:
|
||||||
|
tool_name = original_tool.name
|
||||||
|
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||||
|
|
||||||
|
if not original_func:
|
||||||
|
# 如果无法获取原始函数,直接使用原工具
|
||||||
|
wrapped_tools.append(original_tool)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建包装函数
|
||||||
|
def make_wrapped_func(tool_name, original_func):
|
||||||
|
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||||
|
@wraps(original_func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
"""包装后的工具函数,跟踪连续调用次数"""
|
||||||
|
# 检查是否是连续调用同一个工具
|
||||||
|
if self.last_tool_called == tool_name:
|
||||||
|
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
||||||
|
else:
|
||||||
|
# 切换到新工具,重置计数器
|
||||||
|
self.tool_call_counter[tool_name] = 1
|
||||||
|
self.last_tool_called = tool_name
|
||||||
|
|
||||||
|
current_count = self.tool_call_counter[tool_name]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否超过最大连续调用次数
|
||||||
|
if current_count > self.max_tool_consecutive_calls:
|
||||||
|
logger.warning(
|
||||||
|
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
||||||
|
f"返回提示信息"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||||
|
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用原始工具函数
|
||||||
|
return original_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped_func
|
||||||
|
|
||||||
|
# 使用 StructuredTool 创建新工具
|
||||||
|
wrapped_tool = StructuredTool(
|
||||||
|
name=original_tool.name,
|
||||||
|
description=original_tool.description,
|
||||||
|
func=make_wrapped_func(tool_name, original_func),
|
||||||
|
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapped_tools.append(wrapped_tool)
|
||||||
|
|
||||||
|
return wrapped_tools
|
||||||
|
|
||||||
def _prepare_messages(
|
def _prepare_messages(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None
|
context: Optional[str] = None,
|
||||||
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
) -> List[BaseMessage]:
|
) -> List[BaseMessage]:
|
||||||
"""准备消息列表
|
"""准备消息列表
|
||||||
|
|
||||||
@@ -123,6 +225,7 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表
|
history: 历史消息列表
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
|
files: 多模态文件内容列表(已处理)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
@@ -145,7 +248,47 @@ class LangChainAgent:
|
|||||||
if context:
|
if context:
|
||||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
messages.append(HumanMessage(content=user_content))
|
# 构建用户消息(支持多模态)
|
||||||
|
if files and len(files) > 0:
|
||||||
|
content_parts = self._build_multimodal_content(user_content, files)
|
||||||
|
messages.append(HumanMessage(content=content_parts))
|
||||||
|
else:
|
||||||
|
# 纯文本消息
|
||||||
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
构建多模态消息内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文本内容
|
||||||
|
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 消息内容列表
|
||||||
|
"""
|
||||||
|
# 根据 provider 使用不同的文本格式
|
||||||
|
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||||
|
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||||
|
content_parts = [{"type": "text", "text": text}]
|
||||||
|
else:
|
||||||
|
# 通义千问等: {"text": "..."}
|
||||||
|
content_parts = [{"text": text}]
|
||||||
|
|
||||||
|
# 添加文件内容
|
||||||
|
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||||
|
content_parts.extend(files)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"构建多模态消息: provider={self.provider}, "
|
||||||
|
f"parts={len(content_parts)}, "
|
||||||
|
f"files={len(files)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return content_parts
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
||||||
@@ -242,7 +385,8 @@ class LangChainAgent:
|
|||||||
config_id: Optional[str] = None, # 添加这个参数
|
config_id: Optional[str] = None, # 添加这个参数
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
memory_flag: Optional[bool] = True
|
memory_flag: Optional[bool] = True,
|
||||||
|
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行对话
|
"""执行对话
|
||||||
|
|
||||||
@@ -277,8 +421,8 @@ class LangChainAgent:
|
|||||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"准备调用 LangChain Agent",
|
"准备调用 LangChain Agent",
|
||||||
@@ -286,23 +430,81 @@ class LangChainAgent:
|
|||||||
"has_context": bool(context),
|
"has_context": bool(context),
|
||||||
"has_history": bool(history),
|
"has_history": bool(history),
|
||||||
"has_tools": bool(self.tools),
|
"has_tools": bool(self.tools),
|
||||||
"message_count": len(messages)
|
"has_files": bool(files),
|
||||||
|
"message_count": len(messages),
|
||||||
|
"max_iterations": self.max_iterations
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 统一使用 agent.invoke 调用
|
# 统一使用 agent.invoke 调用
|
||||||
result = await self.agent.ainvoke({"messages": messages})
|
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
||||||
|
try:
|
||||||
|
result = await self.agent.ainvoke(
|
||||||
|
{"messages": messages},
|
||||||
|
config={"recursion_limit": self.max_iterations}
|
||||||
|
)
|
||||||
|
except RecursionError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||||
|
extra={"error": str(e)}
|
||||||
|
)
|
||||||
|
# 返回一个友好的错误提示
|
||||||
|
return {
|
||||||
|
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
||||||
|
"model": self.model_name,
|
||||||
|
"elapsed_time": time.time() - start_time,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# 获取最后的 AI 消息
|
# 获取最后的 AI 消息
|
||||||
output_messages = result.get("messages", [])
|
output_messages = result.get("messages", [])
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
|
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
content = msg.content
|
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||||
|
logger.debug(f"AI 消息内容: {msg.content}")
|
||||||
|
|
||||||
|
# 处理多模态响应:content 可能是字符串或列表
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
content = msg.content
|
||||||
|
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
||||||
|
elif isinstance(msg.content, list):
|
||||||
|
# 多模态响应:提取文本部分
|
||||||
|
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
||||||
|
text_parts = []
|
||||||
|
for item in msg.content:
|
||||||
|
logger.debug(f"处理项: {item}")
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# 通义千问格式: {"text": "..."}
|
||||||
|
if "text" in item:
|
||||||
|
text = item.get("text", "")
|
||||||
|
text_parts.append(text)
|
||||||
|
logger.debug(f"提取文本: {text[:100]}...")
|
||||||
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
|
elif item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
text_parts.append(text)
|
||||||
|
logger.debug(f"提取文本: {text[:100]}...")
|
||||||
|
elif isinstance(item, str):
|
||||||
|
text_parts.append(item)
|
||||||
|
logger.debug(f"提取字符串: {item[:100]}...")
|
||||||
|
content = "".join(text_parts)
|
||||||
|
logger.debug(f"合并后内容长度: {len(content)}")
|
||||||
|
else:
|
||||||
|
content = str(msg.content)
|
||||||
|
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||||
break
|
break
|
||||||
|
|
||||||
|
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
@@ -345,7 +547,8 @@ class LangChainAgent:
|
|||||||
config_id: Optional[str] = None,
|
config_id: Optional[str] = None,
|
||||||
storage_type:Optional[str] = None,
|
storage_type:Optional[str] = None,
|
||||||
user_rag_memory_id:Optional[str] = None,
|
user_rag_memory_id:Optional[str] = None,
|
||||||
memory_flag: Optional[bool] = True
|
memory_flag: Optional[bool] = True,
|
||||||
|
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""执行流式对话
|
"""执行流式对话
|
||||||
|
|
||||||
@@ -382,11 +585,11 @@ class LangChainAgent:
|
|||||||
|
|
||||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
@@ -398,7 +601,8 @@ class LangChainAgent:
|
|||||||
try:
|
try:
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
version="v2"
|
version="v2",
|
||||||
|
config={"recursion_limit": self.max_iterations}
|
||||||
):
|
):
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
@@ -407,20 +611,70 @@ class LangChainAgent:
|
|||||||
if kind == "on_chat_model_stream":
|
if kind == "on_chat_model_stream":
|
||||||
# LLM 流式输出
|
# LLM 流式输出
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
full_content+=chunk.content
|
if chunk and hasattr(chunk, "content"):
|
||||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
# 处理多模态响应:content 可能是字符串或列表
|
||||||
yield chunk.content
|
chunk_content = chunk.content
|
||||||
yielded_content = True
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
|
full_content += chunk_content
|
||||||
|
yield chunk_content
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(chunk_content, list):
|
||||||
|
# 多模态响应:提取文本部分
|
||||||
|
for item in chunk_content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# 通义千问格式: {"text": "..."}
|
||||||
|
if "text" in item:
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
|
elif item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(item, str):
|
||||||
|
full_content += item
|
||||||
|
yield item
|
||||||
|
yielded_content = True
|
||||||
|
|
||||||
elif kind == "on_llm_stream":
|
elif kind == "on_llm_stream":
|
||||||
# 另一种 LLM 流式事件
|
# 另一种 LLM 流式事件
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk:
|
if chunk:
|
||||||
if hasattr(chunk, "content") and chunk.content:
|
if hasattr(chunk, "content"):
|
||||||
full_content+=chunk.content
|
chunk_content = chunk.content
|
||||||
yield chunk.content
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
yielded_content = True
|
full_content += chunk_content
|
||||||
|
yield chunk_content
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(chunk_content, list):
|
||||||
|
# 多模态响应:提取文本部分
|
||||||
|
for item in chunk_content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# 通义千问格式: {"text": "..."}
|
||||||
|
if "text" in item:
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
|
elif item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
if text:
|
||||||
|
full_content += text
|
||||||
|
yield text
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(item, str):
|
||||||
|
full_content += item
|
||||||
|
yield item
|
||||||
|
yielded_content = True
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
|
full_content += chunk
|
||||||
yield chunk
|
yield chunk
|
||||||
yielded_content = True
|
yielded_content = True
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ class RedBearModelFactory:
|
|||||||
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
||||||
# region 从 base_url 或 extra_params 获取
|
# region 从 base_url 或 extra_params 获取
|
||||||
from botocore.config import Config as BotoConfig
|
from botocore.config import Config as BotoConfig
|
||||||
|
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||||
|
|
||||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||||
# Configure with increased connection pool
|
# Configure with increased connection pool
|
||||||
@@ -89,8 +91,11 @@ class RedBearModelFactory:
|
|||||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||||
|
model_id = normalize_bedrock_model_id(config.model_name)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model_id": config.model_name,
|
"model_id": model_id,
|
||||||
"config": boto_config,
|
"config": boto_config,
|
||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
|
|||||||
188
api/app/core/models/bedrock_model_mapper.py
Normal file
188
api/app/core/models/bedrock_model_mapper.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
AWS Bedrock 模型名称映射器
|
||||||
|
|
||||||
|
将简化的模型名称自动转换为正确的 Bedrock Model ID
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
# Bedrock 模型名称映射表
|
||||||
|
BEDROCK_MODEL_MAPPING = {
|
||||||
|
# Claude 3.5 系列
|
||||||
|
"claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"claude-3-5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"claude-sonnet-3-5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
|
||||||
|
# Claude 3 系列
|
||||||
|
"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"claude-3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"claude-opus": "anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
|
||||||
|
# Claude 2 系列
|
||||||
|
"claude-2": "anthropic.claude-v2",
|
||||||
|
"claude-2.1": "anthropic.claude-v2:1",
|
||||||
|
"claude-instant": "anthropic.claude-instant-v1",
|
||||||
|
|
||||||
|
# Amazon Titan 系列
|
||||||
|
"titan-text-express": "amazon.titan-text-express-v1",
|
||||||
|
"titan-text-lite": "amazon.titan-text-lite-v1",
|
||||||
|
"titan-embed-text": "amazon.titan-embed-text-v1",
|
||||||
|
"titan-embed-image": "amazon.titan-embed-image-v1",
|
||||||
|
|
||||||
|
# Meta Llama 系列
|
||||||
|
"llama3-70b": "meta.llama3-70b-instruct-v1:0",
|
||||||
|
"llama3-8b": "meta.llama3-8b-instruct-v1:0",
|
||||||
|
"llama2-70b": "meta.llama2-70b-chat-v1",
|
||||||
|
"llama2-13b": "meta.llama2-13b-chat-v1",
|
||||||
|
|
||||||
|
# Mistral 系列
|
||||||
|
"mistral-7b": "mistral.mistral-7b-instruct-v0:2",
|
||||||
|
"mixtral-8x7b": "mistral.mixtral-8x7b-instruct-v0:1",
|
||||||
|
"mistral-large": "mistral.mistral-large-2402-v1:0",
|
||||||
|
|
||||||
|
# 常见错误格式的映射
|
||||||
|
"claude-sonnet-4-5": "anthropic.claude-3-5-sonnet-20240620-v1:0", # 常见错误
|
||||||
|
"claude-4-5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", # 常见错误
|
||||||
|
"claude-sonnet-4.5": "anthropic.claude-3-5-sonnet-20240620-v1:0", # 常见错误
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_bedrock_model_id(model_name: str, region: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
标准化 Bedrock 模型 ID
|
||||||
|
|
||||||
|
将简化的模型名称转换为正确的 Bedrock Model ID 格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称(可能是简化格式或完整格式)
|
||||||
|
region: AWS 区域(可选,如 "us", "eu", "apac")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 标准化的 Bedrock Model ID
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> normalize_bedrock_model_id("claude-sonnet-4-5")
|
||||||
|
'anthropic.claude-3-5-sonnet-20240620-v1:0'
|
||||||
|
|
||||||
|
>>> normalize_bedrock_model_id("claude-3.5-sonnet", region="eu")
|
||||||
|
'eu.anthropic.claude-3-5-sonnet-20240620-v1:0'
|
||||||
|
|
||||||
|
>>> normalize_bedrock_model_id("anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
'anthropic.claude-3-5-sonnet-20240620-v1:0'
|
||||||
|
"""
|
||||||
|
# 如果已经是正确的格式(包含 provider),直接返回
|
||||||
|
if "." in model_name and not model_name.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||||
|
# 检查是否是有效的 provider
|
||||||
|
provider = model_name.split(".", 1)[0]
|
||||||
|
valid_providers = ["anthropic", "amazon", "meta", "mistral", "deepseek", "openai", "ai21", "cohere", "stability"]
|
||||||
|
if provider in valid_providers:
|
||||||
|
logger.debug(f"Model ID 已经是正确格式: {model_name}")
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
# 移除区域前缀(如果存在)
|
||||||
|
original_model_name = model_name
|
||||||
|
region_prefix = None
|
||||||
|
if model_name.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||||
|
parts = model_name.split(".", 1)
|
||||||
|
region_prefix = parts[0]
|
||||||
|
model_name = parts[1] if len(parts) > 1 else model_name
|
||||||
|
|
||||||
|
# 转换为小写进行匹配
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
|
||||||
|
# 尝试从映射表中查找
|
||||||
|
if model_name_lower in BEDROCK_MODEL_MAPPING:
|
||||||
|
mapped_id = BEDROCK_MODEL_MAPPING[model_name_lower]
|
||||||
|
logger.info(f"映射模型名称: {original_model_name} -> {mapped_id}")
|
||||||
|
|
||||||
|
# 如果指定了区域或原始名称包含区域前缀,添加区域前缀
|
||||||
|
if region:
|
||||||
|
mapped_id = f"{region}.{mapped_id}"
|
||||||
|
elif region_prefix:
|
||||||
|
mapped_id = f"{region_prefix}.{mapped_id}"
|
||||||
|
|
||||||
|
return mapped_id
|
||||||
|
|
||||||
|
# 如果没有找到映射,返回原始名称并记录警告
|
||||||
|
logger.warning(
|
||||||
|
f"未找到模型名称映射: {original_model_name}。"
|
||||||
|
f"请确保使用正确的 Bedrock Model ID 格式,如 'anthropic.claude-3-5-sonnet-20240620-v1:0'"
|
||||||
|
)
|
||||||
|
return original_model_name
|
||||||
|
|
||||||
|
|
||||||
|
def is_bedrock_model_id(model_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查是否是 Bedrock Model ID 格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否是 Bedrock Model ID 格式
|
||||||
|
"""
|
||||||
|
# 移除区域前缀
|
||||||
|
if model_name.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||||
|
model_name = model_name.split(".", 1)[1]
|
||||||
|
|
||||||
|
# 检查是否包含 provider
|
||||||
|
if "." not in model_name:
|
||||||
|
return False
|
||||||
|
|
||||||
|
provider = model_name.split(".", 1)[0]
|
||||||
|
valid_providers = ["anthropic", "amazon", "meta", "mistral", "deepseek", "openai", "ai21", "cohere", "stability"]
|
||||||
|
return provider in valid_providers
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_from_model_id(model_id: str) -> str:
|
||||||
|
"""
|
||||||
|
从 Bedrock Model ID 中提取 provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Bedrock Model ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Provider 名称
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> get_provider_from_model_id("anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
'anthropic'
|
||||||
|
|
||||||
|
>>> get_provider_from_model_id("eu.anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||||
|
'anthropic'
|
||||||
|
"""
|
||||||
|
# 移除区域前缀
|
||||||
|
if model_id.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||||
|
parts = model_id.split(".", 2)
|
||||||
|
return parts[1] if len(parts) > 1 else model_id.split(".", 1)[0]
|
||||||
|
|
||||||
|
return model_id.split(".", 1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
# 添加更多映射的辅助函数
|
||||||
|
def add_model_mapping(short_name: str, full_model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
添加自定义模型名称映射
|
||||||
|
|
||||||
|
Args:
|
||||||
|
short_name: 简化的模型名称
|
||||||
|
full_model_id: 完整的 Bedrock Model ID
|
||||||
|
"""
|
||||||
|
BEDROCK_MODEL_MAPPING[short_name.lower()] = full_model_id
|
||||||
|
logger.info(f"添加模型映射: {short_name} -> {full_model_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_mappings() -> dict:
|
||||||
|
"""
|
||||||
|
获取所有模型名称映射
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 模型名称映射字典
|
||||||
|
"""
|
||||||
|
return BEDROCK_MODEL_MAPPING.copy()
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
provider: bedrock
|
provider: bedrock
|
||||||
enabled: true
|
enabled: false
|
||||||
models:
|
models:
|
||||||
- name: ai21
|
- name: ai21
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
provider: dashscope
|
provider: dashscope
|
||||||
enabled: true
|
enabled: false
|
||||||
models:
|
models:
|
||||||
- name: deepseek-r1-distill-qwen-14b
|
- name: deepseek-r1-distill-qwen-14b
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
provider: openai
|
provider: openai
|
||||||
enabled: true
|
enabled: false
|
||||||
models:
|
models:
|
||||||
- name: chatgpt-4o-latest
|
- name: chatgpt-4o-latest
|
||||||
type: llm
|
type: llm
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ from .app_schema import (
|
|||||||
MemoryConfig,
|
MemoryConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
VariableDefinition,
|
VariableDefinition,
|
||||||
|
FileInput,
|
||||||
|
FileType,
|
||||||
|
TransferMethod,
|
||||||
)
|
)
|
||||||
from .conversation_schema import (
|
from .conversation_schema import (
|
||||||
Conversation,
|
Conversation,
|
||||||
@@ -94,6 +97,9 @@ __all__ = [
|
|||||||
"MemoryConfig",
|
"MemoryConfig",
|
||||||
"ToolConfig",
|
"ToolConfig",
|
||||||
"VariableDefinition",
|
"VariableDefinition",
|
||||||
|
"FileInput",
|
||||||
|
"FileType",
|
||||||
|
"TransferMethod",
|
||||||
"Conversation",
|
"Conversation",
|
||||||
"ConversationCreate",
|
"ConversationCreate",
|
||||||
"ConversationWithMessages",
|
"ConversationWithMessages",
|
||||||
|
|||||||
@@ -1,10 +1,51 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Any, List, Dict, Union
|
from typing import Optional, Any, List, Dict, Union
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Multimodal File Support ----------
|
||||||
|
|
||||||
|
class FileType(str, Enum):
|
||||||
|
"""文件类型枚举"""
|
||||||
|
IMAGE = "image"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
|
class TransferMethod(str, Enum):
|
||||||
|
"""文件传输方式枚举"""
|
||||||
|
LOCAL_FILE = "local_file" # 已上传到系统的文件
|
||||||
|
REMOTE_URL = "remote_url" # 外部URL
|
||||||
|
|
||||||
|
|
||||||
|
class FileInput(BaseModel):
|
||||||
|
"""文件输入 Schema"""
|
||||||
|
type: FileType = Field(..., description="文件类型: image/document/audio/video")
|
||||||
|
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
|
||||||
|
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
||||||
|
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||||
|
|
||||||
|
@field_validator("upload_file_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_local_file(cls, v, info):
|
||||||
|
"""验证 local_file 时必须提供 upload_file_id"""
|
||||||
|
if info.data.get("transfer_method") == TransferMethod.LOCAL_FILE and not v:
|
||||||
|
raise ValueError("transfer_method 为 local_file 时,upload_file_id 不能为空")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("url")
|
||||||
|
@classmethod
|
||||||
|
def validate_remote_url(cls, v, info):
|
||||||
|
"""验证 remote_url 时必须提供 url"""
|
||||||
|
if info.data.get("transfer_method") == TransferMethod.REMOTE_URL and not v:
|
||||||
|
raise ValueError("transfer_method 为 remote_url 时,url 不能为空")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
# ---------- Input Schemas ----------
|
# ---------- Input Schemas ----------
|
||||||
|
|
||||||
class KnowledgeBaseConfig(BaseModel):
|
class KnowledgeBaseConfig(BaseModel):
|
||||||
@@ -360,6 +401,7 @@ class AppChatRequest(BaseModel):
|
|||||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
class DraftRunRequest(BaseModel):
|
class DraftRunRequest(BaseModel):
|
||||||
@@ -369,6 +411,7 @@ class DraftRunRequest(BaseModel):
|
|||||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
class DraftRunResponse(BaseModel):
|
class DraftRunResponse(BaseModel):
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ import datetime
|
|||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||||
|
|
||||||
|
# 导入 FileInput(用于体验运行)
|
||||||
|
from app.schemas.app_schema import FileInput
|
||||||
|
|
||||||
|
|
||||||
# ---------- Input Schemas ----------
|
# ---------- Input Schemas ----------
|
||||||
|
|
||||||
@@ -28,6 +31,7 @@ class ChatRequest(BaseModel):
|
|||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
||||||
memory: bool = Field(default=True, description="是否启用记忆功能")
|
memory: bool = Field(default=True, description="是否启用记忆功能")
|
||||||
|
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
# ---------- Output Schemas ----------
|
# ---------- Output Schemas ----------
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Dict, Any, AsyncGenerator, Annotated
|
from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -15,6 +15,7 @@ from app.core.logging_config import get_business_logger
|
|||||||
from app.db import get_db, get_db_context
|
from app.db import get_db, get_db_context
|
||||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||||
from app.schemas import DraftRunRequest
|
from app.schemas import DraftRunRequest
|
||||||
|
from app.schemas.app_schema import FileInput
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -26,6 +27,7 @@ from app.services.draft_run_service import create_web_search_tool
|
|||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -48,7 +50,8 @@ class AppChatService:
|
|||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None
|
workspace_id: Optional[str] = None,
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
|
|
||||||
@@ -155,7 +158,14 @@ class AppChatService:
|
|||||||
for msg in messages
|
for msg in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
# 调用 Agent
|
# 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
multimodal_service = MultimodalService(self.db)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
|
# 调用 Agent(支持多模态)
|
||||||
result = await agent.chat(
|
result = await agent.chat(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -164,7 +174,8 @@ class AppChatService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
@@ -206,6 +217,7 @@ class AppChatService:
|
|||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
|
|
||||||
@@ -312,10 +324,17 @@ class AppChatService:
|
|||||||
for msg in messages
|
for msg in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
multimodal_service = MultimodalService(self.db)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 发送开始事件
|
# 发送开始事件
|
||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# 流式调用 Agent
|
# 流式调用 Agent(支持多模态)
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
@@ -326,7 +345,8 @@ class AppChatService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
total_tokens = chunk
|
total_tokens = chunk
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ from app.models import AgentConfig, ModelApiKey, ModelConfig
|
|||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
from app.repositories.model_repository import ModelApiKeyRepository
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||||
|
from app.schemas.app_schema import FileInput
|
||||||
from app.services import task_service
|
from app.services import task_service
|
||||||
from app.services.langchain_tool_server import Search
|
from app.services.langchain_tool_server import Search
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
from app.services.model_parameter_merger import ModelParameterMerger
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -62,26 +64,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
@tool(args_schema=LongTermMemoryInput)
|
@tool(args_schema=LongTermMemoryInput)
|
||||||
def long_term_memory(question: str) -> str:
|
def long_term_memory(question: str) -> str:
|
||||||
"""
|
"""
|
||||||
从用户的历史记忆中检索相关信息。这是一个强大的工具,可以帮助你了解用户的背景、偏好和历史对话内容。
|
从用户的历史记忆中检索相关信息。用于了解用户的背景、偏好和历史对话内容。
|
||||||
|
|
||||||
以下场景不需要使用此工具:
|
**何时使用此工具:**
|
||||||
1. 情绪/社交问候场景(如"你好"、"谢谢"、"再见"等简单寒暄)
|
- 用户明确询问历史信息(如"我之前说过什么"、"上次我们聊了什么")
|
||||||
2. 纯任务性场景(如"帮我写代码"、"翻译这段文字"等不需要历史上下文的任务)
|
- 用户询问个人信息或偏好(如"我喜欢什么"、"我的习惯是什么")
|
||||||
3. 处理外部内容时(如用户提供的文本、代码、RAG数据等,这些内容本身已经包含所需信息)
|
- 需要基于历史上下文提供个性化建议
|
||||||
|
|
||||||
除上述场景外的所有其他情况都应该使用此工具,特别是:
|
**何时不使用此工具:**
|
||||||
- 用户询问个人信息或历史对话内容
|
- 简单问候(如"你好"、"谢谢"、"再见")
|
||||||
- 需要了解用户偏好、习惯或背景
|
- 纯任务性请求(如"写代码"、"翻译文字"、"分析图片")
|
||||||
- 用户提到"之前"、"上次"、"记得"等涉及历史的词汇
|
- 用户已提供完整信息(如提供了文本、图片、文档等内容)
|
||||||
- 需要个性化回复或基于历史上下文的建议
|
- 创作性任务(如"写诗"、"编故事"、"创作谜语")
|
||||||
- 用户询问关于自己的任何信息
|
|
||||||
|
**重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。**
|
||||||
|
|
||||||
需要对question改写/优化:
|
|
||||||
需要重点关注一以下几点
|
|
||||||
- 相关的关键词,保持原问题的核心语义不变, 根据上下文,使问题更具体、更清晰,将模糊的表达转换为明确的搜索词
|
|
||||||
- 使用同义词或相关术语扩展查询
|
|
||||||
Args:
|
Args:
|
||||||
question: question改写之后的内容
|
question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
检索到的历史记忆内容
|
检索到的历史记忆内容
|
||||||
@@ -124,6 +123,10 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 检查是否有有效内容
|
||||||
|
if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0:
|
||||||
|
return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。"
|
||||||
|
|
||||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||||
@@ -246,7 +249,8 @@ class DraftRunService:
|
|||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True,
|
web_search: bool = True,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
sub_agent: bool = False
|
sub_agent: bool = False,
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行试运行(使用 LangChain Agent)
|
"""执行试运行(使用 LangChain Agent)
|
||||||
|
|
||||||
@@ -406,7 +410,16 @@ class DraftRunService:
|
|||||||
max_history=agent_config.memory.get("max_history", 10)
|
max_history=agent_config.memory.get("max_history", 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 知识库检索
|
# 6. 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
# 获取 provider 信息
|
||||||
|
provider = api_key_config.get("provider", "openai")
|
||||||
|
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
|
# 7. 知识库检索
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -414,14 +427,15 @@ class DraftRunService:
|
|||||||
extra={
|
extra={
|
||||||
"model": api_key_config["model_name"],
|
"model": api_key_config["model_name"],
|
||||||
"has_history": bool(history),
|
"has_history": bool(history),
|
||||||
"has_context": bool(context)
|
"has_context": bool(context),
|
||||||
|
"has_files": bool(processed_files)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_config_= agent_config.memory
|
memory_config_= agent_config.memory
|
||||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||||
|
|
||||||
# 7. 调用 Agent
|
# 8. 调用 Agent(支持多模态)
|
||||||
result = await agent.chat(
|
result = await agent.chat(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -430,12 +444,13 @@ class DraftRunService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# 8. 保存会话消息
|
# 9. 保存会话消息
|
||||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
@@ -493,7 +508,8 @@ class DraftRunService:
|
|||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
web_search: bool = True, # 布尔类型默认值
|
web_search: bool = True, # 布尔类型默认值
|
||||||
memory: bool = True, # 布尔类型默认值
|
memory: bool = True, # 布尔类型默认值
|
||||||
sub_agent: bool = False # 是否是作为子Agent运行
|
sub_agent: bool = False, # 是否是作为子Agent运行
|
||||||
|
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||||
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""执行试运行(流式返回,使用 LangChain Agent)
|
"""执行试运行(流式返回,使用 LangChain Agent)
|
||||||
@@ -642,6 +658,15 @@ class DraftRunService:
|
|||||||
max_history=agent_config.memory.get("max_history", 10)
|
max_history=agent_config.memory.get("max_history", 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 6. 处理多模态文件
|
||||||
|
processed_files = None
|
||||||
|
if files:
|
||||||
|
# 获取 provider 信息
|
||||||
|
provider = api_key_config.get("provider", "openai")
|
||||||
|
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||||
|
processed_files = await multimodal_service.process_files(files)
|
||||||
|
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||||
|
|
||||||
# 7. 知识库检索
|
# 7. 知识库检索
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
@@ -654,7 +679,7 @@ class DraftRunService:
|
|||||||
memory_config_ = agent_config.memory
|
memory_config_ = agent_config.memory
|
||||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||||
|
|
||||||
# 9. 流式调用 Agent
|
# 9. 流式调用 Agent(支持多模态)
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
@@ -665,7 +690,8 @@ class DraftRunService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag,
|
||||||
|
files=processed_files # 传递处理后的文件
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
total_tokens = chunk
|
total_tokens = chunk
|
||||||
|
|||||||
429
api/app/services/multimodal_service.py
Normal file
429
api/app/services/multimodal_service.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
"""
|
||||||
|
多模态文件处理服务
|
||||||
|
|
||||||
|
处理图片、文档等多模态文件,转换为 LLM 可用的格式
|
||||||
|
|
||||||
|
支持的 Provider:
|
||||||
|
- DashScope (通义千问): 支持 URL 格式
|
||||||
|
- Bedrock/Anthropic: 仅支持 base64 格式
|
||||||
|
- OpenAI: 支持 URL 和 base64 格式
|
||||||
|
"""
|
||||||
|
import uuid
|
||||||
|
from typing import List, Dict, Any, Optional, Protocol
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||||
|
from app.models.generic_file_model import GenericFile
|
||||||
|
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFormatStrategy(Protocol):
|
||||||
|
"""图片格式策略接口"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""将图片 URL 转换为特定 provider 的格式"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DashScopeImageStrategy:
|
||||||
|
"""通义千问图片格式策略"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"image": url
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockImageStrategy:
|
||||||
|
"""Bedrock/Anthropic 图片格式策略"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Bedrock/Anthropic 格式: base64 编码
|
||||||
|
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import base64
|
||||||
|
from mimetypes import guess_type
|
||||||
|
|
||||||
|
logger.info(f"下载并编码图片: {url}")
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 获取图片数据
|
||||||
|
image_data = response.content
|
||||||
|
|
||||||
|
# 确定 media type
|
||||||
|
content_type = response.headers.get("content-type")
|
||||||
|
if content_type and content_type.startswith("image/"):
|
||||||
|
media_type = content_type
|
||||||
|
else:
|
||||||
|
guessed_type, _ = guess_type(url)
|
||||||
|
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||||
|
|
||||||
|
# 转换为 base64
|
||||||
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": base64_data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageStrategy:
|
||||||
|
"""OpenAI 图片格式策略"""
|
||||||
|
|
||||||
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
|
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Provider 到策略的映射
|
||||||
|
PROVIDER_STRATEGIES = {
|
||||||
|
"dashscope": DashScopeImageStrategy,
|
||||||
|
"bedrock": BedrockImageStrategy,
|
||||||
|
"anthropic": BedrockImageStrategy,
|
||||||
|
"openai": OpenAIImageStrategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalService:
|
||||||
|
"""多模态文件处理服务"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||||
|
"""
|
||||||
|
初始化多模态服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
provider: 模型提供商(dashscope, bedrock, anthropic 等)
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
self.provider = provider.lower()
|
||||||
|
|
||||||
|
async def process_files(
|
||||||
|
self,
|
||||||
|
files: Optional[List[FileInput]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
处理文件列表,返回 LLM 可用的格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 文件输入列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for idx, file in enumerate(files):
|
||||||
|
try:
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
|
content = await self._process_image(file)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.DOCUMENT:
|
||||||
|
content = await self._process_document(file)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.AUDIO:
|
||||||
|
content = await self._process_audio(file)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.VIDEO:
|
||||||
|
content = await self._process_video(file)
|
||||||
|
result.append(content)
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的文件类型: {file.type}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"处理文件失败",
|
||||||
|
extra={
|
||||||
|
"file_index": idx,
|
||||||
|
"file_type": file.type,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# 继续处理其他文件,不中断整个流程
|
||||||
|
result.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[文件处理失败: {str(e)}]"
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理图片文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 图片文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 根据 provider 返回不同格式
|
||||||
|
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||||
|
- 通义千问: {"type": "image", "image": "url"}
|
||||||
|
"""
|
||||||
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
|
url = file.url
|
||||||
|
else:
|
||||||
|
# 本地文件,获取访问 URL
|
||||||
|
url = await self._get_file_url(file.upload_file_id)
|
||||||
|
|
||||||
|
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||||
|
|
||||||
|
# 根据 provider 返回不同格式
|
||||||
|
if self.provider in ["bedrock", "anthropic"]:
|
||||||
|
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||||
|
try:
|
||||||
|
logger.info(f"开始下载并编码图片: {url}")
|
||||||
|
base64_data, media_type = await self._download_and_encode_image(url)
|
||||||
|
result = {
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": media_type,
|
||||||
|
"data": base64_data[:100] + "..." # 只记录前100个字符
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
|
||||||
|
# 返回完整数据
|
||||||
|
result["source"]["data"] = base64_data
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
|
||||||
|
# 返回错误提示
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[图片加载失败: {str(e)}]"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 通义千问等其他格式支持 URL
|
||||||
|
return {
|
||||||
|
"type": "image",
|
||||||
|
"image": url
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
下载图片并转换为 base64
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 图片 URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (base64_data, media_type)
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import base64
|
||||||
|
from mimetypes import guess_type
|
||||||
|
|
||||||
|
# 下载图片
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 获取图片数据
|
||||||
|
image_data = response.content
|
||||||
|
|
||||||
|
# 确定 media type
|
||||||
|
content_type = response.headers.get("content-type")
|
||||||
|
if content_type and content_type.startswith("image/"):
|
||||||
|
media_type = content_type
|
||||||
|
else:
|
||||||
|
# 从 URL 推断
|
||||||
|
guessed_type, _ = guess_type(url)
|
||||||
|
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||||
|
|
||||||
|
# 转换为 base64
|
||||||
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||||
|
|
||||||
|
return base64_data, media_type
|
||||||
|
|
||||||
|
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理文档文件(PDF、Word 等)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 文档文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: text 格式的内容(包含提取的文本)
|
||||||
|
"""
|
||||||
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
|
# 远程文档暂不支持提取
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"<document url=\"{file.url}\">\n[远程文档,暂不支持内容提取]\n</document>"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 本地文件,提取文本内容
|
||||||
|
text = await self._extract_document_text(file.upload_file_id)
|
||||||
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
|
GenericFile.id == file.upload_file_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
file_name = generic_file.file_name if generic_file else "unknown"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理音频文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 音频文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 音频内容(暂时返回占位符)
|
||||||
|
"""
|
||||||
|
# TODO: 实现音频转文字功能
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": "[音频文件,暂不支持处理]"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理视频文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 视频文件输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 视频内容(暂时返回占位符)
|
||||||
|
"""
|
||||||
|
# TODO: 实现视频处理功能
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": "[视频文件,暂不支持处理]"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _get_file_url(self, file_id: uuid.UUID) -> str:
|
||||||
|
"""
|
||||||
|
获取文件的访问 URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: 文件ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 文件访问 URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 文件不存在
|
||||||
|
"""
|
||||||
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
|
GenericFile.id == file_id,
|
||||||
|
GenericFile.status == "active"
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not generic_file:
|
||||||
|
raise BusinessException(
|
||||||
|
f"文件不存在或已删除: {file_id}",
|
||||||
|
BizCode.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有 access_url,直接返回
|
||||||
|
if generic_file.access_url:
|
||||||
|
return generic_file.access_url
|
||||||
|
|
||||||
|
# 否则,根据 storage_path 生成 URL
|
||||||
|
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||||
|
# 这里暂时返回一个占位 URL
|
||||||
|
return f"/api/files/{file_id}/download"
|
||||||
|
|
||||||
|
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||||
|
"""
|
||||||
|
提取文档文本内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: 文件ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 提取的文本内容
|
||||||
|
"""
|
||||||
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
|
GenericFile.id == file_id,
|
||||||
|
GenericFile.status == "active"
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not generic_file:
|
||||||
|
raise BusinessException(
|
||||||
|
f"文件不存在或已删除: {file_id}",
|
||||||
|
BizCode.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: 根据文件类型提取文本
|
||||||
|
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||||
|
# - Word: 使用 python-docx
|
||||||
|
# - TXT/MD: 直接读取
|
||||||
|
|
||||||
|
file_ext = generic_file.file_ext.lower()
|
||||||
|
|
||||||
|
if file_ext in ['.txt', '.md', '.markdown']:
|
||||||
|
return await self._read_text_file(generic_file.storage_path)
|
||||||
|
elif file_ext == '.pdf':
|
||||||
|
return await self._extract_pdf_text(generic_file.storage_path)
|
||||||
|
elif file_ext in ['.doc', '.docx']:
|
||||||
|
return await self._extract_word_text(generic_file.storage_path)
|
||||||
|
else:
|
||||||
|
return f"[不支持的文档格式: {file_ext}]"
|
||||||
|
|
||||||
|
async def _read_text_file(self, storage_path: str) -> str:
|
||||||
|
"""读取纯文本文件"""
|
||||||
|
try:
|
||||||
|
with open(storage_path, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取文本文件失败: {e}")
|
||||||
|
return f"[文件读取失败: {str(e)}]"
|
||||||
|
|
||||||
|
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||||
|
"""提取 PDF 文本"""
|
||||||
|
try:
|
||||||
|
# TODO: 实现 PDF 文本提取
|
||||||
|
# import PyPDF2 或 pdfplumber
|
||||||
|
return "[PDF 文本提取功能待实现]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 PDF 文本失败: {e}")
|
||||||
|
return f"[PDF 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
async def _extract_word_text(self, storage_path: str) -> str:
|
||||||
|
"""提取 Word 文档文本"""
|
||||||
|
try:
|
||||||
|
# TODO: 实现 Word 文本提取
|
||||||
|
# import docx
|
||||||
|
return "[Word 文本提取功能待实现]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 Word 文本失败: {e}")
|
||||||
|
return f"[Word 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
|
||||||
|
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||||
|
"""获取多模态服务实例(依赖注入)"""
|
||||||
|
return MultimodalService(db)
|
||||||
3082
api/uv.lock
generated
3082
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user