[modify] agent call tools strategy
This commit is contained in:
@@ -46,7 +46,8 @@ class LangChainAgent:
|
|||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
tools: Optional[Sequence[BaseTool]] = None,
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
max_iterations: int = 3 # 新增:最大迭代次数
|
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||||
|
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||||
):
|
):
|
||||||
"""初始化 LangChain Agent
|
"""初始化 LangChain Agent
|
||||||
|
|
||||||
@@ -59,15 +60,36 @@ class LangChainAgent:
|
|||||||
max_tokens: 最大 token 数
|
max_tokens: 最大 token 数
|
||||||
system_prompt: 系统提示词
|
system_prompt: 系统提示词
|
||||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||||
streaming: 是否启用流式输出(默认 True)
|
streaming: 是否启用流式输出
|
||||||
max_iterations: 最大迭代次数(默认 3 次,防止工具调用死循环)
|
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
||||||
|
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
|
||||||
self.tools = tools or []
|
self.tools = tools or []
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.max_iterations = max_iterations
|
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||||
|
|
||||||
|
# 工具调用计数器:记录每个工具的连续调用次数
|
||||||
|
self.tool_call_counter: Dict[str, int] = {}
|
||||||
|
self.last_tool_called: Optional[str] = None
|
||||||
|
|
||||||
|
# 根据工具数量动态调整最大迭代次数
|
||||||
|
# 基础值 + 每个工具额外的调用机会
|
||||||
|
if max_iterations is None:
|
||||||
|
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
||||||
|
self.max_iterations = 5 + len(self.tools) * 2
|
||||||
|
else:
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
|
||||||
|
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||||
|
f"tool_count={len(self.tools)}, "
|
||||||
|
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
||||||
|
f"auto_calculated={max_iterations is None}"
|
||||||
|
)
|
||||||
|
|
||||||
# 创建 RedBearLLM(支持多提供商)
|
# 创建 RedBearLLM(支持多提供商)
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
@@ -91,11 +113,14 @@ class LangChainAgent:
|
|||||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||||
self._underlying_llm.streaming = True
|
self._underlying_llm.streaming = True
|
||||||
|
|
||||||
|
# 包装工具以跟踪连续调用次数
|
||||||
|
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
||||||
|
|
||||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||||
# 无论是否有工具,都使用 agent 统一处理
|
# 无论是否有工具,都使用 agent 统一处理
|
||||||
self.agent = create_agent(
|
self.agent = create_agent(
|
||||||
model=self.llm,
|
model=self.llm,
|
||||||
tools=self.tools if self.tools else None,
|
tools=wrapped_tools,
|
||||||
system_prompt=self.system_prompt
|
system_prompt=self.system_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,12 +132,84 @@ class LangChainAgent:
|
|||||||
"has_api_base": bool(api_base),
|
"has_api_base": bool(api_base),
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"max_iterations": max_iterations,
|
"max_iterations": self.max_iterations,
|
||||||
|
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
||||||
|
"""包装工具以跟踪连续调用次数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: 原始工具列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseTool]: 包装后的工具列表
|
||||||
|
"""
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
wrapped_tools = []
|
||||||
|
|
||||||
|
for original_tool in tools:
|
||||||
|
tool_name = original_tool.name
|
||||||
|
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||||
|
|
||||||
|
if not original_func:
|
||||||
|
# 如果无法获取原始函数,直接使用原工具
|
||||||
|
wrapped_tools.append(original_tool)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建包装函数
|
||||||
|
def make_wrapped_func(tool_name, original_func):
|
||||||
|
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||||
|
@wraps(original_func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
"""包装后的工具函数,跟踪连续调用次数"""
|
||||||
|
# 检查是否是连续调用同一个工具
|
||||||
|
if self.last_tool_called == tool_name:
|
||||||
|
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
||||||
|
else:
|
||||||
|
# 切换到新工具,重置计数器
|
||||||
|
self.tool_call_counter[tool_name] = 1
|
||||||
|
self.last_tool_called = tool_name
|
||||||
|
|
||||||
|
current_count = self.tool_call_counter[tool_name]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否超过最大连续调用次数
|
||||||
|
if current_count > self.max_tool_consecutive_calls:
|
||||||
|
logger.warning(
|
||||||
|
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
||||||
|
f"返回提示信息"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||||
|
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用原始工具函数
|
||||||
|
return original_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped_func
|
||||||
|
|
||||||
|
# 使用 StructuredTool 创建新工具
|
||||||
|
wrapped_tool = StructuredTool(
|
||||||
|
name=original_tool.name,
|
||||||
|
description=original_tool.description,
|
||||||
|
func=make_wrapped_func(tool_name, original_func),
|
||||||
|
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapped_tools.append(wrapped_tool)
|
||||||
|
|
||||||
|
return wrapped_tools
|
||||||
|
|
||||||
def _prepare_messages(
|
def _prepare_messages(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@@ -149,11 +246,22 @@ class LangChainAgent:
|
|||||||
if context:
|
if context:
|
||||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
# 如果有文件,构建多模态消息
|
# 如果有文件,构建多模态消息(使用通义千问原生格式)
|
||||||
if files and len(files) > 0:
|
if files and len(files) > 0:
|
||||||
# 多模态消息格式: [{"type": "text", "text": "..."}, {"type": "image_url", ...}]
|
# 通义千问多模态格式: [{"text": "..."}, {"image": "url"}]
|
||||||
content_parts = [{"type": "text", "text": user_content}]
|
# 注意:不使用 LangChain 的标准格式,因为它会转换为 OpenAI 格式
|
||||||
content_parts.extend(files)
|
content_parts = [{"text": user_content}]
|
||||||
|
|
||||||
|
# 添加文件内容(已经是通义千问格式)
|
||||||
|
for file_item in files:
|
||||||
|
if file_item.get("type") == "image":
|
||||||
|
# 通义千问图片格式: {"image": "url"}
|
||||||
|
content_parts.append({"image": file_item["image"]})
|
||||||
|
elif file_item.get("type") == "text":
|
||||||
|
# 文本内容
|
||||||
|
content_parts.append({"text": file_item["text"]})
|
||||||
|
|
||||||
|
logger.debug(f"构建多模态消息,content_parts: {content_parts}")
|
||||||
messages.append(HumanMessage(content=content_parts))
|
messages.append(HumanMessage(content=content_parts))
|
||||||
else:
|
else:
|
||||||
# 纯文本消息(向后兼容)
|
# 纯文本消息(向后兼容)
|
||||||
|
|||||||
@@ -64,26 +64,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
@tool(args_schema=LongTermMemoryInput)
|
@tool(args_schema=LongTermMemoryInput)
|
||||||
def long_term_memory(question: str) -> str:
|
def long_term_memory(question: str) -> str:
|
||||||
"""
|
"""
|
||||||
从用户的历史记忆中检索相关信息。这是一个强大的工具,可以帮助你了解用户的背景、偏好和历史对话内容。
|
从用户的历史记忆中检索相关信息。用于了解用户的背景、偏好和历史对话内容。
|
||||||
|
|
||||||
以下场景不需要使用此工具:
|
**何时使用此工具:**
|
||||||
1. 情绪/社交问候场景(如"你好"、"谢谢"、"再见"等简单寒暄)
|
- 用户明确询问历史信息(如"我之前说过什么"、"上次我们聊了什么")
|
||||||
2. 纯任务性场景(如"帮我写代码"、"翻译这段文字"等不需要历史上下文的任务)
|
- 用户询问个人信息或偏好(如"我喜欢什么"、"我的习惯是什么")
|
||||||
3. 处理外部内容时(如用户提供的文本、代码、RAG数据等,这些内容本身已经包含所需信息)
|
- 需要基于历史上下文提供个性化建议
|
||||||
|
|
||||||
除上述场景外的所有其他情况都应该使用此工具,特别是:
|
**何时不使用此工具:**
|
||||||
- 用户询问个人信息或历史对话内容
|
- 简单问候(如"你好"、"谢谢"、"再见")
|
||||||
- 需要了解用户偏好、习惯或背景
|
- 纯任务性请求(如"写代码"、"翻译文字"、"分析图片")
|
||||||
- 用户提到"之前"、"上次"、"记得"等涉及历史的词汇
|
- 用户已提供完整信息(如提供了文本、图片、文档等内容)
|
||||||
- 需要个性化回复或基于历史上下文的建议
|
- 创作性任务(如"写诗"、"编故事"、"创作谜语")
|
||||||
- 用户询问关于自己的任何信息
|
|
||||||
|
**重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。**
|
||||||
|
|
||||||
需要对question改写/优化:
|
|
||||||
需要重点关注一以下几点
|
|
||||||
- 相关的关键词,保持原问题的核心语义不变, 根据上下文,使问题更具体、更清晰,将模糊的表达转换为明确的搜索词
|
|
||||||
- 使用同义词或相关术语扩展查询
|
|
||||||
Args:
|
Args:
|
||||||
question: question改写之后的内容
|
question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
检索到的历史记忆内容
|
检索到的历史记忆内容
|
||||||
@@ -126,6 +123,10 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 检查是否有有效内容
|
||||||
|
if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0:
|
||||||
|
return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。"
|
||||||
|
|
||||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||||
|
|||||||
@@ -85,10 +85,11 @@ class MultimodalService:
|
|||||||
file: 图片文件输入
|
file: 图片文件输入
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 通义千问格式
|
Dict: 通义千问格式 {"type": "image", "image": "url"}
|
||||||
"""
|
"""
|
||||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
# 远程 URL,使用通义千问格式
|
# 远程 URL,使用通义千问格式
|
||||||
|
logger.debug(f"处理远程图片: {file.url}")
|
||||||
return {
|
return {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": file.url
|
"image": file.url
|
||||||
@@ -96,6 +97,7 @@ class MultimodalService:
|
|||||||
else:
|
else:
|
||||||
# 本地文件,获取访问 URL
|
# 本地文件,获取访问 URL
|
||||||
url = await self._get_file_url(file.upload_file_id)
|
url = await self._get_file_url(file.upload_file_id)
|
||||||
|
logger.debug(f"处理本地图片: {url}")
|
||||||
return {
|
return {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": url
|
"image": url
|
||||||
|
|||||||
Reference in New Issue
Block a user