[fix] handoff limit error
This commit is contained in:
@@ -61,12 +61,12 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
def long_term_memory(question: str) -> str:
|
||||
"""
|
||||
从用户的历史记忆中检索相关信息。这是一个强大的工具,可以帮助你了解用户的背景、偏好和历史对话内容。
|
||||
|
||||
|
||||
以下场景不需要使用此工具:
|
||||
1. 情绪/社交问候场景(如"你好"、"谢谢"、"再见"等简单寒暄)
|
||||
2. 纯任务性场景(如"帮我写代码"、"翻译这段文字"等不需要历史上下文的任务)
|
||||
3. 处理外部内容时(如用户提供的文本、代码、RAG数据等,这些内容本身已经包含所需信息)
|
||||
|
||||
|
||||
除上述场景外的所有其他情况都应该使用此工具,特别是:
|
||||
- 用户询问个人信息或历史对话内容
|
||||
- 需要了解用户偏好、习惯或背景
|
||||
@@ -528,9 +528,9 @@ class DraftRunService:
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
for tool_config in agent_config.tools:
|
||||
print("+"*50)
|
||||
print(f"agent_config:{agent_config}")
|
||||
print(f"tool_config:{tool_config}")
|
||||
# print("+"*50)
|
||||
# print(f"agent_config:{agent_config}")
|
||||
# print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Handoffs 服务 - 基于 LangGraph 的多 Agent 协作"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
@@ -11,6 +11,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
import operator
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
@@ -20,19 +21,42 @@ from app.services.model_service import ModelApiKeyService
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
# ==================== Reducer 函数 ====================
|
||||
|
||||
def replace_value(current, new):
|
||||
"""替换值的 reducer - 总是使用新值"""
|
||||
return new
|
||||
|
||||
|
||||
# ==================== 状态定义 ====================
|
||||
|
||||
class HandoffState(TypedDict):
|
||||
"""Handoff 状态"""
|
||||
messages: List[BaseMessage]
|
||||
active_agent: Optional[str]
|
||||
messages: Annotated[List[BaseMessage], operator.add] # 消息列表追加
|
||||
active_agent: Annotated[Optional[str], replace_value]
|
||||
handoff_count: Annotated[int, replace_value]
|
||||
handoff_history: Annotated[List[str], replace_value]
|
||||
pending_question: Annotated[Optional[str], replace_value]
|
||||
previous_answer: Annotated[Optional[str], replace_value]
|
||||
|
||||
|
||||
# ==================== 常量 ====================
|
||||
|
||||
MAX_HANDOFFS = 5 # 最大 handoff 次数
|
||||
|
||||
|
||||
# ==================== 工具输入模型 ====================
|
||||
|
||||
class TransferInput(BaseModel):
|
||||
"""转移工具的输入参数"""
|
||||
reason: str = Field(description="转移原因")
|
||||
reason: str = Field(description="转移原因,说明为什么需要转交")
|
||||
unhandled_question: str = Field(
|
||||
description="需要转交给其他专家处理的具体问题。注意:只转交你无法回答的部分,不要转交整个原始问题"
|
||||
)
|
||||
your_answer: str = Field(
|
||||
default="",
|
||||
description="你已经回答的内容摘要(如果有的话)。如果你已经回答了部分问题,在这里简要说明"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工具创建 ====================
|
||||
@@ -50,11 +74,22 @@ def create_transfer_tool(target_agent: str, description: str):
|
||||
tool_name = f"transfer_to_{target_agent}"
|
||||
|
||||
@tool(tool_name, args_schema=TransferInput)
|
||||
def transfer_tool(reason: str) -> Command:
|
||||
"""动态生成的转移工具"""
|
||||
def transfer_tool(reason: str, unhandled_question: str, your_answer: str = "") -> Command:
|
||||
"""动态生成的转移工具
|
||||
|
||||
Args:
|
||||
reason: 转移原因
|
||||
unhandled_question: 需要转交的具体问题(只转交未处理的部分)
|
||||
your_answer: 你已经回答的内容摘要
|
||||
"""
|
||||
return Command(
|
||||
goto=target_agent,
|
||||
update={"active_agent": target_agent},
|
||||
update={
|
||||
"active_agent": target_agent,
|
||||
"pending_question": unhandled_question, # 存储要转交的具体问题
|
||||
"previous_answer": your_answer, # 存储之前的回答
|
||||
# handoff_count 和 handoff_history 在 agent_node 中更新
|
||||
},
|
||||
)
|
||||
|
||||
transfer_tool.__doc__ = description
|
||||
@@ -99,8 +134,42 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List,
|
||||
"""Agent 节点执行函数"""
|
||||
logger.debug(f"Agent {agent_name} 执行, active_agent: {state.get('active_agent')}")
|
||||
|
||||
# 获取当前 handoff 状态
|
||||
handoff_count = state.get("handoff_count", 0)
|
||||
handoff_history = state.get("handoff_history", [])
|
||||
pending_question = state.get("pending_question")
|
||||
previous_answer = state.get("previous_answer", "")
|
||||
|
||||
# 检查是否达到最大 handoff 次数
|
||||
if handoff_count >= MAX_HANDOFFS:
|
||||
logger.warning(f"Agent {agent_name}: 达到最大 handoff 次数,直接回复")
|
||||
return {
|
||||
"messages": [AIMessage(content="抱歉,我无法继续处理这个请求。请尝试重新提问。")],
|
||||
"handoff_count": handoff_count,
|
||||
"handoff_history": handoff_history,
|
||||
"pending_question": None,
|
||||
"previous_answer": ""
|
||||
}
|
||||
|
||||
messages = state.get("messages", [])
|
||||
full_messages = [{"role": "system", "content": system_prompt}] + messages
|
||||
|
||||
# 如果有 pending_question,构建新的消息上下文
|
||||
if pending_question and handoff_count > 0:
|
||||
# 构建包含上下文的消息
|
||||
context_msg = f"【来自其他专家的转交】\n"
|
||||
if previous_answer:
|
||||
context_msg += f"之前的专家已经回答了: {previous_answer}\n\n"
|
||||
context_msg += f"现在需要你回答的问题是: {pending_question}\n\n"
|
||||
if handoff_history:
|
||||
context_msg += f"【注意】以下专家已经处理过这个问题,不能再转交给他们: {', '.join(handoff_history)}"
|
||||
|
||||
# 使用转交的具体问题,而不是原始消息
|
||||
effective_messages = [HumanMessage(content=context_msg)]
|
||||
logger.info(f"Agent {agent_name} 收到转交问题(非流式): {pending_question[:100]}...")
|
||||
else:
|
||||
effective_messages = messages
|
||||
|
||||
full_messages = [{"role": "system", "content": system_prompt}] + effective_messages
|
||||
|
||||
response = await llm.ainvoke(full_messages)
|
||||
|
||||
@@ -116,17 +185,62 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List,
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
tool_args = {}
|
||||
|
||||
# 确保必要的参数存在
|
||||
if not tool_args.get("reason"):
|
||||
tool_args["reason"] = "用户请求转移"
|
||||
|
||||
# 获取 LLM 提供的 unhandled_question
|
||||
llm_unhandled_question = tool_args.get("unhandled_question", "")
|
||||
|
||||
# 提取目标 agent
|
||||
target_agent = tool_name.replace("transfer_to_", "")
|
||||
|
||||
# 检查是否会形成循环:目标 Agent 是否已经在 handoff_history 中
|
||||
if target_agent in handoff_history:
|
||||
logger.warning(f"Agent {agent_name} 尝试移交给已处理过的 {target_agent},强制直接回复")
|
||||
return {
|
||||
"messages": [AIMessage(content="抱歉,这个问题超出了我的专业范围,我无法回答。")],
|
||||
"handoff_count": handoff_count,
|
||||
"handoff_history": handoff_history,
|
||||
"pending_question": None,
|
||||
"previous_answer": ""
|
||||
}
|
||||
|
||||
# 第一次转交,检查是否提供了 unhandled_question
|
||||
if not llm_unhandled_question:
|
||||
# 使用原始消息
|
||||
last_human_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
|
||||
llm_unhandled_question = last_human_msg.content if last_human_msg else ""
|
||||
|
||||
tool_args["unhandled_question"] = llm_unhandled_question
|
||||
|
||||
for t in tools:
|
||||
if t.name == tool_name:
|
||||
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
|
||||
result = t.invoke(tool_args)
|
||||
if isinstance(result, Command):
|
||||
return result
|
||||
# 提取目标 agent
|
||||
target_agent = tool_name.replace("transfer_to_", "")
|
||||
new_history = handoff_history + [agent_name]
|
||||
|
||||
logger.info(f"Agent {agent_name} handoff 到 {target_agent} (count: {handoff_count + 1}), 转交问题: {tool_args.get('unhandled_question', '')[:50]}...")
|
||||
|
||||
# 返回 Command 并更新 handoff 状态
|
||||
return Command(
|
||||
goto=target_agent,
|
||||
update={
|
||||
"active_agent": target_agent,
|
||||
"handoff_count": handoff_count + 1,
|
||||
"handoff_history": new_history,
|
||||
"pending_question": tool_args.get("unhandled_question", ""),
|
||||
"previous_answer": tool_args.get("your_answer", "")
|
||||
}
|
||||
)
|
||||
|
||||
return {"messages": [response]}
|
||||
return {
|
||||
"messages": [response],
|
||||
"handoff_count": handoff_count,
|
||||
"handoff_history": handoff_history,
|
||||
"pending_question": None, # 清除 pending_question
|
||||
"previous_answer": ""
|
||||
}
|
||||
|
||||
return agent_node
|
||||
|
||||
@@ -144,8 +258,44 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List
|
||||
"""Agent 节点执行函数(流式)"""
|
||||
logger.debug(f"Agent {agent_name} 流式执行, active_agent: {state.get('active_agent')}")
|
||||
|
||||
# 获取当前 handoff 状态
|
||||
handoff_count = state.get("handoff_count", 0)
|
||||
handoff_history = state.get("handoff_history", [])
|
||||
pending_question = state.get("pending_question")
|
||||
previous_answer = state.get("previous_answer", "")
|
||||
|
||||
logger.info(f"Agent {agent_name} 状态: handoff_count={handoff_count}, pending_question={pending_question}, previous_answer={previous_answer[:50] if previous_answer else ''}")
|
||||
|
||||
# 检查是否达到最大 handoff 次数
|
||||
if handoff_count >= MAX_HANDOFFS:
|
||||
logger.warning(f"Agent {agent_name}: 达到最大 handoff 次数,直接回复")
|
||||
return {
|
||||
"messages": [AIMessage(content="抱歉,我无法继续处理这个请求。请尝试重新提问。")],
|
||||
"handoff_count": handoff_count,
|
||||
"handoff_history": handoff_history,
|
||||
"pending_question": None,
|
||||
"previous_answer": ""
|
||||
}
|
||||
|
||||
messages = state.get("messages", [])
|
||||
full_messages = [{"role": "system", "content": system_prompt}] + messages
|
||||
|
||||
# 如果有 pending_question,构建新的消息上下文
|
||||
if pending_question and handoff_count > 0:
|
||||
# 构建包含上下文的消息
|
||||
context_msg = f"【来自其他专家的转交】\n"
|
||||
if previous_answer:
|
||||
context_msg += f"之前的专家已经回答了: {previous_answer}\n\n"
|
||||
context_msg += f"现在需要你回答的问题是: {pending_question}\n\n"
|
||||
if handoff_history:
|
||||
context_msg += f"【注意】以下专家已经处理过这个问题,不能再转交给他们: {', '.join(handoff_history)}"
|
||||
|
||||
# 使用转交的具体问题,而不是原始消息
|
||||
effective_messages = [HumanMessage(content=context_msg)]
|
||||
logger.info(f"Agent {agent_name} 收到转交问题(流式): {pending_question[:100]}...")
|
||||
else:
|
||||
effective_messages = messages
|
||||
|
||||
full_messages = [{"role": "system", "content": system_prompt}] + effective_messages
|
||||
|
||||
full_content = ""
|
||||
collected_tool_calls = {}
|
||||
@@ -205,23 +355,80 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List
|
||||
elif not tc.get("args"):
|
||||
tc["args"] = {}
|
||||
|
||||
# 执行工具调用
|
||||
if tool_calls_list and tool_calls_list[0].get("name"):
|
||||
tool_call = tool_calls_list[0]
|
||||
tool_name = tool_call.get("name", "")
|
||||
tool_args = tool_call.get("args", {})
|
||||
# 执行工具调用 - 选择参数最完整的工具调用
|
||||
if tool_calls_list:
|
||||
# 找到参数最完整的 transfer 工具调用
|
||||
best_tool_call = None
|
||||
best_args_len = -1
|
||||
|
||||
if not tool_args.get("reason"):
|
||||
tool_args["reason"] = "用户请求转移"
|
||||
for tc in tool_calls_list:
|
||||
tc_name = tc.get("name", "")
|
||||
if tc_name.startswith("transfer_to_"):
|
||||
tc_args = tc.get("args", {})
|
||||
if isinstance(tc_args, str):
|
||||
try:
|
||||
tc_args = json.loads(tc_args)
|
||||
except:
|
||||
tc_args = {}
|
||||
# 计算参数完整度
|
||||
args_len = len(str(tc_args.get("unhandled_question", ""))) + len(str(tc_args.get("reason", "")))
|
||||
if args_len > best_args_len:
|
||||
best_args_len = args_len
|
||||
best_tool_call = (tc_name, tc_args)
|
||||
|
||||
if best_tool_call:
|
||||
tool_name, tool_args = best_tool_call
|
||||
|
||||
# 确保必要的参数存在
|
||||
if not tool_args.get("reason"):
|
||||
tool_args["reason"] = "用户请求转移"
|
||||
|
||||
# 获取 LLM 提供的 unhandled_question
|
||||
llm_unhandled_question = tool_args.get("unhandled_question", "")
|
||||
|
||||
# 提取目标 agent
|
||||
target_agent = tool_name.replace("transfer_to_", "")
|
||||
|
||||
# 检查是否会形成循环:目标 Agent 是否已经在 handoff_history 中
|
||||
if target_agent in handoff_history:
|
||||
logger.warning(f"Agent {agent_name} 尝试移交给已处理过的 {target_agent},强制直接回复")
|
||||
return {
|
||||
"messages": [AIMessage(content=full_content if full_content else "抱歉,这个问题超出了我的专业范围,我无法回答。")],
|
||||
"handoff_count": handoff_count,
|
||||
"handoff_history": handoff_history,
|
||||
"pending_question": None,
|
||||
"previous_answer": ""
|
||||
}
|
||||
|
||||
# 检查是否提供了 unhandled_question
|
||||
if not llm_unhandled_question:
|
||||
# 使用原始消息
|
||||
last_human_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
|
||||
llm_unhandled_question = last_human_msg.content if last_human_msg else ""
|
||||
|
||||
new_history = handoff_history + [agent_name]
|
||||
|
||||
logger.info(f"Agent {agent_name} handoff 到 {target_agent} (count: {handoff_count + 1}), 转交问题: {llm_unhandled_question[:100]}...")
|
||||
|
||||
# 返回 Command 并更新 handoff 状态
|
||||
return Command(
|
||||
goto=target_agent,
|
||||
update={
|
||||
"active_agent": target_agent,
|
||||
"handoff_count": handoff_count + 1,
|
||||
"handoff_history": new_history,
|
||||
"pending_question": llm_unhandled_question,
|
||||
"previous_answer": tool_args.get("your_answer", "")
|
||||
}
|
||||
)
|
||||
|
||||
for t in tools:
|
||||
if t.name == tool_name:
|
||||
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
|
||||
result = t.invoke(tool_args)
|
||||
if isinstance(result, Command):
|
||||
return result
|
||||
|
||||
return {"messages": [AIMessage(content=full_content)]}
|
||||
return {
|
||||
"messages": [AIMessage(content=full_content)],
|
||||
"handoff_count": handoff_count,
|
||||
"handoff_history": handoff_history,
|
||||
"pending_question": None, # 清除 pending_question
|
||||
"previous_answer": ""
|
||||
}
|
||||
|
||||
return agent_node
|
||||
|
||||
@@ -239,12 +446,36 @@ def create_route_initial(default_agent: str):
|
||||
|
||||
|
||||
def route_after_agent(state: HandoffState) -> str:
|
||||
"""Agent 执行后的路由"""
|
||||
"""Agent 执行后的路由
|
||||
|
||||
检查:
|
||||
1. 是否达到最大 handoff 次数
|
||||
2. 是否形成循环(连续在两个 Agent 之间切换)
|
||||
3. 最后一条消息是否有 tool_calls
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
handoff_count = state.get("handoff_count", 0)
|
||||
handoff_history = state.get("handoff_history", [])
|
||||
|
||||
# 检查是否达到最大 handoff 次数
|
||||
if handoff_count >= MAX_HANDOFFS:
|
||||
logger.warning(f"达到最大 handoff 次数 ({MAX_HANDOFFS}),强制结束")
|
||||
return END
|
||||
|
||||
# 检查是否形成循环(A -> B -> A -> B 模式)
|
||||
if len(handoff_history) >= 4:
|
||||
# 检查最近 4 次是否形成 A-B-A-B 循环
|
||||
recent = handoff_history[-4:]
|
||||
if recent[0] == recent[2] and recent[1] == recent[3] and recent[0] != recent[1]:
|
||||
logger.warning(f"检测到循环 handoff: {recent},强制结束")
|
||||
return END
|
||||
|
||||
# 检查最后一条消息
|
||||
if messages:
|
||||
last_msg = messages[-1]
|
||||
if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None):
|
||||
return END
|
||||
|
||||
return state.get("active_agent", END)
|
||||
|
||||
|
||||
@@ -352,11 +583,43 @@ def convert_multi_agent_config_to_handoffs(
|
||||
]
|
||||
# 更新系统提示词,添加转移说明
|
||||
other_agents = agent_configs[safe_name]["can_transfer_to"]
|
||||
current_capabilities = agent_configs[safe_name].get("capabilities", [])
|
||||
|
||||
if other_agents:
|
||||
transfer_instructions = "\n如果用户的问题不在你的专长范围内,可以使用以下工具转移到其他 Agent:"
|
||||
# 构建其他 Agent 的专长信息
|
||||
other_agents_info = []
|
||||
for other_name in other_agents:
|
||||
other_config = agent_configs[other_name]
|
||||
other_caps = other_config.get("capabilities", [])
|
||||
if other_caps:
|
||||
other_agents_info.append(f"- {other_config['name']}: 专长 {', '.join(other_caps)}")
|
||||
else:
|
||||
other_agents_info.append(f"- {other_config['name']}")
|
||||
|
||||
transfer_instructions = f"""
|
||||
|
||||
【重要工作原则】
|
||||
1. 你必须先输出你对专长范围({', '.join(current_capabilities) if current_capabilities else '你的领域'})问题的完整回答
|
||||
2. 回答完成后,如果还有其他部分需要其他专家处理,再调用转移工具
|
||||
3. 不能转移给已经处理过这个问题的专家
|
||||
|
||||
【回答流程】
|
||||
1. 先直接输出你的回答内容(不要放在工具参数里)
|
||||
2. 输出完成后,调用转移工具转交剩余问题
|
||||
|
||||
【其他可用的专家】
|
||||
{chr(10).join(other_agents_info)}
|
||||
|
||||
【转移工具参数】
|
||||
- reason: 转移原因
|
||||
- unhandled_question: 需要其他专家回答的具体问题
|
||||
- your_answer: 简要说明你回答了什么(摘要即可)
|
||||
|
||||
【转移工具】"""
|
||||
for other_name in other_agents:
|
||||
other_config = agent_configs[other_name]
|
||||
transfer_instructions += f"\n- transfer_to_{other_name}: {other_config['description']}"
|
||||
|
||||
agent_configs[safe_name]["system_prompt"] += transfer_instructions
|
||||
|
||||
return agent_configs
|
||||
@@ -455,7 +718,11 @@ class HandoffsService:
|
||||
logger.info(f"Handoffs chat: conversation_id={conversation_id}, message={message[:50]}...")
|
||||
|
||||
result = await self.graph.ainvoke({
|
||||
"messages": [HumanMessage(content=message)]
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"handoff_count": 0,
|
||||
"handoff_history": [],
|
||||
"pending_question": None,
|
||||
"previous_answer": ""
|
||||
}, config=config)
|
||||
|
||||
# 提取响应
|
||||
@@ -469,7 +736,8 @@ class HandoffsService:
|
||||
"conversation_id": str(conversation_id),
|
||||
"active_agent": result.get("active_agent"),
|
||||
"response": response_content,
|
||||
"message_count": len(result.get("messages", []))
|
||||
"message_count": len(result.get("messages", [])),
|
||||
"handoff_count": result.get("handoff_count", 0)
|
||||
}
|
||||
|
||||
async def chat_stream(
|
||||
@@ -487,10 +755,18 @@ class HandoffsService:
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
current_agent = None
|
||||
handoff_count = 0
|
||||
collected_tool_calls = {} # 收集工具调用信息
|
||||
|
||||
try:
|
||||
async for event in self.graph.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
{
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"handoff_count": 0,
|
||||
"handoff_history": [],
|
||||
"pending_question": None,
|
||||
"previous_answer": ""
|
||||
},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
@@ -507,20 +783,81 @@ class HandoffsService:
|
||||
|
||||
# 捕获 LLM 流式输出
|
||||
elif kind == "on_chat_model_stream":
|
||||
content = event["data"]["chunk"].content
|
||||
chunk = event["data"]["chunk"]
|
||||
content = chunk.content if hasattr(chunk, 'content') else ""
|
||||
if content:
|
||||
yield f"event: message\ndata: {json.dumps({'content': content}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: message\ndata: {json.dumps({'content': content, 'agent': current_agent}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 收集工具调用信息
|
||||
if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
|
||||
for tc in chunk.tool_calls:
|
||||
tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, 'id', "0")
|
||||
tc_id = tc_id or "0"
|
||||
if tc_id not in collected_tool_calls:
|
||||
collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""}
|
||||
|
||||
tc_name = tc.get("name") if isinstance(tc, dict) else getattr(tc, 'name', None)
|
||||
tc_args = tc.get("args") if isinstance(tc, dict) else getattr(tc, 'args', None)
|
||||
|
||||
if tc_name:
|
||||
collected_tool_calls[tc_id]["name"] = tc_name
|
||||
if tc_args:
|
||||
if isinstance(tc_args, dict):
|
||||
collected_tool_calls[tc_id]["args"] = tc_args
|
||||
elif isinstance(tc_args, str):
|
||||
if isinstance(collected_tool_calls[tc_id]["args"], str):
|
||||
collected_tool_calls[tc_id]["args"] += tc_args
|
||||
|
||||
# 处理 tool_call_chunks
|
||||
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
|
||||
for tc_chunk in chunk.tool_call_chunks:
|
||||
idx = str(tc_chunk.get("index", 0) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'index', 0))
|
||||
if idx not in collected_tool_calls:
|
||||
tc_id = tc_chunk.get("id", idx) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', idx)
|
||||
collected_tool_calls[idx] = {"id": tc_id, "name": "", "args": ""}
|
||||
|
||||
tc_id = tc_chunk.get("id") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', None)
|
||||
tc_name = tc_chunk.get("name") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'name', None)
|
||||
tc_args = tc_chunk.get("args") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'args', None)
|
||||
|
||||
if tc_id:
|
||||
collected_tool_calls[idx]["id"] = tc_id
|
||||
if tc_name:
|
||||
collected_tool_calls[idx]["name"] = tc_name
|
||||
if tc_args:
|
||||
if isinstance(collected_tool_calls[idx]["args"], str):
|
||||
collected_tool_calls[idx]["args"] += tc_args
|
||||
|
||||
# 捕获工具调用(Handoff)
|
||||
elif kind == "on_tool_start":
|
||||
tool_name = event.get("name", "")
|
||||
if tool_name.startswith("transfer_to_"):
|
||||
target_agent = tool_name.replace("transfer_to_", "")
|
||||
target_name = self.agent_configs.get(target_agent, {}).get("name", target_agent)
|
||||
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent, 'to_name': target_name}, ensure_ascii=False)}\n\n"
|
||||
# 捕获 LLM 结束事件,输出收集到的工具调用
|
||||
elif kind == "on_chat_model_end":
|
||||
if collected_tool_calls:
|
||||
# 找到参数最完整的 transfer 工具调用
|
||||
best_tc = None
|
||||
best_args_len = -1
|
||||
for tc_id, tc_info in collected_tool_calls.items():
|
||||
if tc_info.get("name", "").startswith("transfer_to_"):
|
||||
args = tc_info.get("args", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except:
|
||||
args = {}
|
||||
# 计算参数完整度(有 unhandled_question 的优先)
|
||||
args_len = len(str(args.get("unhandled_question", ""))) + len(str(args.get("reason", "")))
|
||||
if args_len > best_args_len:
|
||||
best_args_len = args_len
|
||||
best_tc = (tc_info, args)
|
||||
|
||||
if best_tc:
|
||||
tc_info, args = best_tc
|
||||
handoff_count += 1
|
||||
target_agent = tc_info["name"].replace("transfer_to_", "")
|
||||
target_name = self.agent_configs.get(target_agent, {}).get("name", target_agent)
|
||||
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent, 'to_name': target_name, 'handoff_count': handoff_count, 'reason': args.get('reason', ''), 'unhandled_question': args.get('unhandled_question', ''), 'your_answer': args.get('your_answer', '')}, ensure_ascii=False)}\n\n"
|
||||
collected_tool_calls = {} # 清空,准备收集下一个 Agent 的工具调用
|
||||
|
||||
# 发送结束事件
|
||||
yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent, 'total_handoffs': handoff_count}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Handoffs stream error: {str(e)}")
|
||||
|
||||
@@ -57,7 +57,7 @@ class MultiAgentOrchestrator:
|
||||
# 只有 supervisor 模式才需要 default_model_config_id 和 router
|
||||
self.master_model_config = None
|
||||
self.router = None
|
||||
|
||||
|
||||
if self._normalized_mode == OrchestrationMode.SUPERVISOR:
|
||||
# 获取 Master Agent 的模型配置
|
||||
if not self.default_model_config_id:
|
||||
@@ -89,10 +89,10 @@ class MultiAgentOrchestrator:
|
||||
|
||||
def _normalize_orchestration_mode(self, mode: str) -> str:
|
||||
"""标准化 orchestration_mode,兼容旧值
|
||||
|
||||
|
||||
Args:
|
||||
mode: 原始的 orchestration_mode 值
|
||||
|
||||
|
||||
Returns:
|
||||
标准化后的模式:collaboration 或 supervisor
|
||||
"""
|
||||
@@ -162,8 +162,8 @@ class MultiAgentOrchestrator:
|
||||
# 1. 主 Agent 分析任务
|
||||
task_analysis = await self._analyze_task(message, variables)
|
||||
task_analysis["use_llm_routing"] = use_llm_routing
|
||||
|
||||
async for event in self._execute_conditional_stream(
|
||||
|
||||
async for event in self._execute_supervisor_stream(
|
||||
task_analysis,
|
||||
conversation_id,
|
||||
user_id,
|
||||
@@ -247,7 +247,7 @@ class MultiAgentOrchestrator:
|
||||
user_id,
|
||||
variables
|
||||
)
|
||||
|
||||
|
||||
# Supervisor 模式:由 Master Agent 统一调度
|
||||
# 1. Master Agent 分析任务并做出决策
|
||||
task_analysis = await self._analyze_task(message, variables)
|
||||
@@ -920,7 +920,7 @@ class MultiAgentOrchestrator:
|
||||
"content": final_response
|
||||
})
|
||||
|
||||
async def _execute_conditional_stream(
|
||||
async def _execute_supervisor_stream(
|
||||
self,
|
||||
task_analysis: Dict[str, Any],
|
||||
conversation_id: Optional[uuid.UUID],
|
||||
@@ -945,6 +945,9 @@ class MultiAgentOrchestrator:
|
||||
|
||||
message = task_analysis.get("message", "")
|
||||
routing_decision = task_analysis.get("routing_decision")
|
||||
yield self._format_sse_event("routing_decision", {
|
||||
"routing_decision": routing_decision
|
||||
})
|
||||
|
||||
# 1. 检查是否需要协作
|
||||
if routing_decision and routing_decision.get("need_collaboration"):
|
||||
@@ -1426,9 +1429,9 @@ class MultiAgentOrchestrator:
|
||||
user_rag_memory_id: str = ''
|
||||
):
|
||||
"""Collaboration 模式流式执行 - Agent 之间可以相互 handoff
|
||||
|
||||
|
||||
使用 handoffs_service 实现 Agent 之间的动态切换
|
||||
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
@@ -1437,7 +1440,7 @@ class MultiAgentOrchestrator:
|
||||
memory: 是否启用记忆
|
||||
storage_type: 存储类型
|
||||
user_rag_memory_id: RAG 记忆 ID
|
||||
|
||||
|
||||
Yields:
|
||||
SSE 格式的事件流
|
||||
"""
|
||||
@@ -1445,39 +1448,39 @@ class MultiAgentOrchestrator:
|
||||
convert_multi_agent_config_to_handoffs,
|
||||
HandoffsService
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 1. 构建 multi_agent_config 字典
|
||||
multi_agent_config = {
|
||||
"sub_agents": self.config.sub_agents,
|
||||
"orchestration_mode": self.config.orchestration_mode
|
||||
}
|
||||
|
||||
|
||||
# 2. 转换配置(每个 Agent 包含自己的 model_config)
|
||||
agent_configs = convert_multi_agent_config_to_handoffs(
|
||||
multi_agent_config,
|
||||
multi_agent_config,
|
||||
self.db
|
||||
)
|
||||
|
||||
|
||||
if not agent_configs:
|
||||
raise BusinessException("没有可用的子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 3. 创建 HandoffsService
|
||||
handoffs_service = HandoffsService(
|
||||
agent_configs=agent_configs,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
||||
# 4. 使用 handoffs_service 的流式聊天
|
||||
conv_id = str(conversation_id) if conversation_id else None
|
||||
|
||||
|
||||
async for event in handoffs_service.chat_stream(
|
||||
message=message,
|
||||
conversation_id=conv_id
|
||||
):
|
||||
# handoffs_service 返回的已经是 SSE 格式,直接 yield
|
||||
yield event
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Collaboration 模式执行失败: {str(e)}", exc_info=True)
|
||||
yield self._format_sse_event("error", {
|
||||
@@ -1493,15 +1496,15 @@ class MultiAgentOrchestrator:
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Collaboration 模式非流式执行 - Agent 之间可以相互 handoff
|
||||
|
||||
|
||||
使用 handoffs_service 实现 Agent 之间的动态切换
|
||||
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
user_id: 用户 ID
|
||||
variables: 变量参数
|
||||
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
@@ -1509,41 +1512,41 @@ class MultiAgentOrchestrator:
|
||||
convert_multi_agent_config_to_handoffs,
|
||||
HandoffsService
|
||||
)
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# 1. 构建 multi_agent_config 字典
|
||||
multi_agent_config = {
|
||||
"sub_agents": self.config.sub_agents,
|
||||
"orchestration_mode": self.config.orchestration_mode
|
||||
}
|
||||
|
||||
|
||||
# 2. 转换配置(每个 Agent 包含自己的 model_config)
|
||||
agent_configs = convert_multi_agent_config_to_handoffs(
|
||||
multi_agent_config,
|
||||
multi_agent_config,
|
||||
self.db
|
||||
)
|
||||
|
||||
|
||||
if not agent_configs:
|
||||
raise BusinessException("没有可用的子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 3. 创建 HandoffsService
|
||||
handoffs_service = HandoffsService(
|
||||
agent_configs=agent_configs,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
|
||||
# 4. 使用 handoffs_service 的非流式聊天
|
||||
conv_id = str(conversation_id) if conversation_id else None
|
||||
|
||||
|
||||
result = await handoffs_service.chat(
|
||||
message=message,
|
||||
conversation_id=conv_id
|
||||
)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
return {
|
||||
"message": result.get("response", ""),
|
||||
"conversation_id": result.get("conversation_id"),
|
||||
@@ -1552,7 +1555,7 @@ class MultiAgentOrchestrator:
|
||||
"active_agent": result.get("active_agent"),
|
||||
"sub_results": result
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Collaboration 模式执行失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -2352,7 +2355,7 @@ class MultiAgentOrchestrator:
|
||||
merged_parts.append(f"**{sub_question}**\n{message}")
|
||||
else:
|
||||
merged_parts.append(message)
|
||||
|
||||
|
||||
if merged_parts:
|
||||
return "\n\n".join(merged_parts)
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user