[modify] agent call tools strategy
This commit is contained in:
@@ -46,7 +46,8 @@ class LangChainAgent:
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: int = 3 # 新增:最大迭代次数
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -59,15 +60,36 @@ class LangChainAgent:
|
||||
max_tokens: 最大 token 数
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||
streaming: 是否启用流式输出(默认 True)
|
||||
max_iterations: 最大迭代次数(默认 3 次,防止工具调用死循环)
|
||||
streaming: 是否启用流式输出
|
||||
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
||||
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.max_iterations = max_iterations
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
model_config = RedBearModelConfig(
|
||||
@@ -91,11 +113,14 @@ class LangChainAgent:
|
||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||
self._underlying_llm.streaming = True
|
||||
|
||||
# 包装工具以跟踪连续调用次数
|
||||
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
||||
|
||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||
# 无论是否有工具,都使用 agent 统一处理
|
||||
self.agent = create_agent(
|
||||
model=self.llm,
|
||||
tools=self.tools if self.tools else None,
|
||||
tools=wrapped_tools,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
|
||||
@@ -107,12 +132,84 @@ class LangChainAgent:
|
||||
"has_api_base": bool(api_base),
|
||||
"temperature": temperature,
|
||||
"streaming": streaming,
|
||||
"max_iterations": max_iterations,
|
||||
"max_iterations": self.max_iterations,
|
||||
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else []
|
||||
}
|
||||
)
|
||||
|
||||
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
||||
"""包装工具以跟踪连续调用次数
|
||||
|
||||
Args:
|
||||
tools: 原始工具列表
|
||||
|
||||
Returns:
|
||||
List[BaseTool]: 包装后的工具列表
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
# 检查是否是连续调用同一个工具
|
||||
if self.last_tool_called == tool_name:
|
||||
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
||||
else:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
||||
f"返回提示信息"
|
||||
)
|
||||
return (
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
description=original_tool.description,
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
@@ -149,11 +246,22 @@ class LangChainAgent:
|
||||
if context:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
# 如果有文件,构建多模态消息
|
||||
# 如果有文件,构建多模态消息(使用通义千问原生格式)
|
||||
if files and len(files) > 0:
|
||||
# 多模态消息格式: [{"type": "text", "text": "..."}, {"type": "image_url", ...}]
|
||||
content_parts = [{"type": "text", "text": user_content}]
|
||||
content_parts.extend(files)
|
||||
# 通义千问多模态格式: [{"text": "..."}, {"image": "url"}]
|
||||
# 注意:不使用 LangChain 的标准格式,因为它会转换为 OpenAI 格式
|
||||
content_parts = [{"text": user_content}]
|
||||
|
||||
# 添加文件内容(已经是通义千问格式)
|
||||
for file_item in files:
|
||||
if file_item.get("type") == "image":
|
||||
# 通义千问图片格式: {"image": "url"}
|
||||
content_parts.append({"image": file_item["image"]})
|
||||
elif file_item.get("type") == "text":
|
||||
# 文本内容
|
||||
content_parts.append({"text": file_item["text"]})
|
||||
|
||||
logger.debug(f"构建多模态消息,content_parts: {content_parts}")
|
||||
messages.append(HumanMessage(content=content_parts))
|
||||
else:
|
||||
# 纯文本消息(向后兼容)
|
||||
|
||||
@@ -64,26 +64,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
@tool(args_schema=LongTermMemoryInput)
|
||||
def long_term_memory(question: str) -> str:
|
||||
"""
|
||||
从用户的历史记忆中检索相关信息。这是一个强大的工具,可以帮助你了解用户的背景、偏好和历史对话内容。
|
||||
从用户的历史记忆中检索相关信息。用于了解用户的背景、偏好和历史对话内容。
|
||||
|
||||
以下场景不需要使用此工具:
|
||||
1. 情绪/社交问候场景(如"你好"、"谢谢"、"再见"等简单寒暄)
|
||||
2. 纯任务性场景(如"帮我写代码"、"翻译这段文字"等不需要历史上下文的任务)
|
||||
3. 处理外部内容时(如用户提供的文本、代码、RAG数据等,这些内容本身已经包含所需信息)
|
||||
**何时使用此工具:**
|
||||
- 用户明确询问历史信息(如"我之前说过什么"、"上次我们聊了什么")
|
||||
- 用户询问个人信息或偏好(如"我喜欢什么"、"我的习惯是什么")
|
||||
- 需要基于历史上下文提供个性化建议
|
||||
|
||||
除上述场景外的所有其他情况都应该使用此工具,特别是:
|
||||
- 用户询问个人信息或历史对话内容
|
||||
- 需要了解用户偏好、习惯或背景
|
||||
- 用户提到"之前"、"上次"、"记得"等涉及历史的词汇
|
||||
- 需要个性化回复或基于历史上下文的建议
|
||||
- 用户询问关于自己的任何信息
|
||||
**何时不使用此工具:**
|
||||
- 简单问候(如"你好"、"谢谢"、"再见")
|
||||
- 纯任务性请求(如"写代码"、"翻译文字"、"分析图片")
|
||||
- 用户已提供完整信息(如提供了文本、图片、文档等内容)
|
||||
- 创作性任务(如"写诗"、"编故事"、"创作谜语")
|
||||
|
||||
**重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。**
|
||||
|
||||
需要对question改写/优化:
|
||||
需要重点关注一以下几点
|
||||
- 相关的关键词,保持原问题的核心语义不变, 根据上下文,使问题更具体、更清晰,将模糊的表达转换为明确的搜索词
|
||||
- 使用同义词或相关术语扩展查询
|
||||
Args:
|
||||
question: question改写之后的内容
|
||||
question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词)
|
||||
|
||||
Returns:
|
||||
检索到的历史记忆内容
|
||||
@@ -126,6 +123,10 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
}
|
||||
)
|
||||
|
||||
# 检查是否有有效内容
|
||||
if not memory_content or str(memory_content).strip() == "" or "answer" in str(memory_content) and str(memory_content).count("''") > 0:
|
||||
return "未找到相关的历史记忆。请直接回答用户的问题,不要再次调用此工具。"
|
||||
|
||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||
except Exception as e:
|
||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||
|
||||
@@ -85,10 +85,11 @@ class MultimodalService:
|
||||
file: 图片文件输入
|
||||
|
||||
Returns:
|
||||
Dict: 通义千问格式
|
||||
Dict: 通义千问格式 {"type": "image", "image": "url"}
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程 URL,使用通义千问格式
|
||||
logger.debug(f"处理远程图片: {file.url}")
|
||||
return {
|
||||
"type": "image",
|
||||
"image": file.url
|
||||
@@ -96,6 +97,7 @@ class MultimodalService:
|
||||
else:
|
||||
# 本地文件,获取访问 URL
|
||||
url = await self._get_file_url(file.upload_file_id)
|
||||
logger.debug(f"处理本地图片: {url}")
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
|
||||
Reference in New Issue
Block a user