Files
MemoryBear/api/app/services/handoffs_service.py
2026-01-07 14:58:23 +08:00

515 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Handoffs 服务 - 基于 LangGraph 的多 Agent 协作"""
import json
import uuid
from typing import List, Dict, Any, Optional, AsyncGenerator
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from app.core.logging_config import get_business_logger
logger = get_business_logger()
# ==================== 状态定义 ====================
class HandoffState(TypedDict):
"""Handoff 状态"""
messages: List[BaseMessage]
active_agent: Optional[str]
# ==================== 工具输入模型 ====================
class TransferInput(BaseModel):
"""转移工具的输入参数"""
reason: str = Field(description="转移原因")
# ==================== 默认配置 ====================
DEFAULT_AGENT_CONFIGS = {
"sales_agent": {
"description": "转移到销售 Agent。当用户询问价格、购买或销售相关问题时使用。",
"system_prompt": """你是一个销售 Agent。帮助用户解答销售相关问题。
如果用户询问技术问题或需要技术支持,使用 transfer_to_support_agent 工具转移到支持 Agent。""",
"can_transfer_to": ["support_agent"]
},
"support_agent": {
"description": "转移到支持 Agent。当用户询问技术问题或需要帮助时使用。",
"system_prompt": """你是一个技术支持 Agent。帮助用户解决技术问题。
如果用户询问价格或购买相关问题,使用 transfer_to_sales_agent 工具转移到销售 Agent。""",
"can_transfer_to": ["sales_agent"]
}
}
DEFAULT_LLM_CONFIG = {
"api_key": "sk-8e9e40cd171749858ce2d3722ea75669",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen-plus",
"temperature": 0.7
}
# ==================== 工具创建 ====================
def create_transfer_tool(target_agent: str, description: str):
"""动态创建转移工具
Args:
target_agent: 目标 Agent 名称
description: 工具描述
Returns:
转移工具函数
"""
tool_name = f"transfer_to_{target_agent}"
@tool(tool_name, args_schema=TransferInput)
def transfer_tool(reason: str) -> Command:
"""动态生成的转移工具"""
return Command(
goto=target_agent,
update={"active_agent": target_agent},
)
transfer_tool.__doc__ = description
transfer_tool.description = description
return transfer_tool
def create_tools_for_agent(agent_name: str, configs: Dict) -> List:
"""根据 Agent 配置动态创建其可用的转移工具
Args:
agent_name: 当前 Agent 名称
configs: Agent 配置字典
Returns:
该 Agent 可用的工具列表
"""
config = configs.get(agent_name, {})
can_transfer_to = config.get("can_transfer_to", [])
tools = []
for target_agent in can_transfer_to:
target_config = configs.get(target_agent, {})
description = target_config.get("description", f"转移到 {target_agent}")
tools.append(create_transfer_tool(target_agent, description))
return tools
# ==================== Agent 节点创建 ====================
def create_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7):
"""创建 Agent 节点(非流式)
Args:
agent_name: Agent 名称
system_prompt: 系统提示词
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url
).bind_tools(tools)
async def agent_node(state: HandoffState) -> Dict[str, Any]:
"""Agent 节点执行函数"""
logger.debug(f"Agent {agent_name} 执行, active_agent: {state.get('active_agent')}")
messages = state.get("messages", [])
full_messages = [{"role": "system", "content": system_prompt}] + messages
response = await llm.ainvoke(full_messages)
# 检查工具调用
if hasattr(response, 'tool_calls') and response.tool_calls:
tool_call = response.tool_calls[0]
tool_name = tool_call["name"]
tool_args = tool_call["args"]
if not tool_args.get("reason"):
tool_args["reason"] = "用户请求转移"
for t in tools:
if t.name == tool_name:
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
result = t.invoke(tool_args)
if isinstance(result, Command):
return result
return {"messages": [response]}
return agent_node
def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7):
"""创建支持流式输出的 Agent 节点
Args:
agent_name: Agent 名称
system_prompt: 系统提示词
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
streaming=True
).bind_tools(tools)
async def agent_node(state: HandoffState):
"""Agent 节点执行函数(流式)"""
logger.debug(f"Agent {agent_name} 流式执行, active_agent: {state.get('active_agent')}")
messages = state.get("messages", [])
full_messages = [{"role": "system", "content": system_prompt}] + messages
full_content = ""
collected_tool_calls = {}
async for chunk in llm.astream(full_messages):
if chunk.content:
full_content += chunk.content
# 收集工具调用
if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
for tc in chunk.tool_calls:
tc_id = tc.get("id") or "0"
if tc_id not in collected_tool_calls:
collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""}
if tc.get("name"):
collected_tool_calls[tc_id]["name"] = tc["name"]
if tc.get("args"):
if isinstance(tc["args"], dict):
collected_tool_calls[tc_id]["args"] = tc["args"]
elif isinstance(tc["args"], str):
if isinstance(collected_tool_calls[tc_id]["args"], str):
collected_tool_calls[tc_id]["args"] += tc["args"]
# 处理 tool_call_chunks
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = str(tc_chunk.get("index", 0))
if idx not in collected_tool_calls:
collected_tool_calls[idx] = {"id": tc_chunk.get("id", idx), "name": "", "args": ""}
if tc_chunk.get("id"):
collected_tool_calls[idx]["id"] = tc_chunk["id"]
if tc_chunk.get("name"):
collected_tool_calls[idx]["name"] = tc_chunk["name"]
if tc_chunk.get("args"):
if isinstance(collected_tool_calls[idx]["args"], str):
collected_tool_calls[idx]["args"] += tc_chunk["args"]
# 解析工具调用
tool_calls_list = list(collected_tool_calls.values())
for tc in tool_calls_list:
if isinstance(tc.get("args"), str) and tc["args"]:
try:
tc["args"] = json.loads(tc["args"])
except (json.JSONDecodeError, ValueError):
tc["args"] = {}
elif not tc.get("args"):
tc["args"] = {}
# 执行工具调用
if tool_calls_list and tool_calls_list[0].get("name"):
tool_call = tool_calls_list[0]
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
if not tool_args.get("reason"):
tool_args["reason"] = "用户请求转移"
for t in tools:
if t.name == tool_name:
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
result = t.invoke(tool_args)
if isinstance(result, Command):
return result
return {"messages": [AIMessage(content=full_content)]}
return agent_node
# ==================== 路由函数 ====================
def create_route_initial(default_agent: str = "sales_agent"):
"""创建初始路由函数"""
def route_initial(state: HandoffState) -> str:
active = state.get("active_agent")
if active:
return active
return default_agent
return route_initial
def route_after_agent(state: HandoffState) -> str:
"""Agent 执行后的路由"""
messages = state.get("messages", [])
if messages:
last_msg = messages[-1]
if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None):
return END
return state.get("active_agent", "sales_agent")
# ==================== Handoffs 服务类 ====================
class HandoffsService:
"""Handoffs 服务 - 管理多 Agent 协作"""
def __init__(
self,
agent_configs: Dict[str, Dict] = None,
llm_config: Dict[str, Any] = None,
streaming: bool = True
):
"""初始化 Handoffs 服务
Args:
agent_configs: Agent 配置字典
llm_config: LLM 配置
streaming: 是否启用流式输出
"""
self.agent_configs = agent_configs or DEFAULT_AGENT_CONFIGS
self.llm_config = llm_config or DEFAULT_LLM_CONFIG
self.streaming = streaming
self._graph = None
logger.info(f"HandoffsService 初始化, agents: {list(self.agent_configs.keys())}")
def _build_graph(self):
"""构建 LangGraph 图"""
builder = StateGraph(HandoffState)
agent_names = list(self.agent_configs.keys())
for agent_name in agent_names:
config = self.agent_configs[agent_name]
tools = create_tools_for_agent(agent_name, self.agent_configs)
if self.streaming:
agent_node = create_streaming_agent_node(
agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools,
api_key=self.llm_config.get("api_key"),
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
)
else:
agent_node = create_agent_node(
agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools,
api_key=self.llm_config.get("api_key"),
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
)
builder.add_node(agent_name, agent_node)
# 添加边
default_agent = agent_names[0] if agent_names else "sales_agent"
builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names)
for agent_name in agent_names:
builder.add_conditional_edges(agent_name, route_after_agent, agent_names + [END])
memory = MemorySaver()
return builder.compile(checkpointer=memory)
@property
def graph(self):
"""获取图实例(懒加载)"""
if self._graph is None:
self._graph = self._build_graph()
return self._graph
def reset(self):
"""重置图实例"""
self._graph = None
logger.info("HandoffsService 图已重置")
async def chat(
self,
message: str,
conversation_id: str = None
) -> Dict[str, Any]:
"""非流式聊天
Args:
message: 用户消息
conversation_id: 会话 ID
Returns:
聊天结果
"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}}
logger.info(f"Handoffs chat: conversation_id={conversation_id}, message={message[:50]}...")
result = await self.graph.ainvoke({
"messages": [HumanMessage(content=message)]
}, config=config)
# 提取响应
response_content = ""
for msg in result.get("messages", []):
if isinstance(msg, AIMessage):
response_content = msg.content
break
return {
"conversation_id": str(conversation_id),
"active_agent": result.get("active_agent"),
"response": response_content,
"message_count": len(result.get("messages", []))
}
async def chat_stream(
self,
message: str,
conversation_id: str = None
) -> AsyncGenerator[str, None]:
"""流式聊天
Args:
message: 用户消息
conversation_id: 会话 ID
Yields:
SSE 格式的事件
"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}}
logger.info(f"Handoffs stream chat: conversation_id={conversation_id}, message={message[:50]}...")
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
current_agent = None
try:
async for event in self.graph.astream_events(
{"messages": [HumanMessage(content=message)]},
config=config,
version="v2"
):
kind = event["event"]
# 捕获节点开始Agent 切换)
if kind == "on_chain_start":
node_name = event.get("name", "")
if node_name in self.agent_configs:
if current_agent != node_name:
current_agent = node_name
yield f"event: agent\ndata: {json.dumps({'agent': node_name}, ensure_ascii=False)}\n\n"
# 捕获 LLM 流式输出
elif kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"event: message\ndata: {json.dumps({'content': content}, ensure_ascii=False)}\n\n"
# 捕获工具调用Handoff
elif kind == "on_tool_start":
tool_name = event.get("name", "")
if tool_name.startswith("transfer_to_"):
target_agent = tool_name.replace("transfer_to_", "")
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent}, ensure_ascii=False)}\n\n"
# 发送结束事件
yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n"
except Exception as e:
logger.error(f"Handoffs stream error: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
def get_agents(self) -> List[Dict[str, Any]]:
"""获取可用的 Agent 列表
Returns:
Agent 列表
"""
agents = []
for name, config in self.agent_configs.items():
agents.append({
"name": name,
"description": config.get("description", ""),
"can_transfer_to": config.get("can_transfer_to", [])
})
return agents
# ==================== 全局实例 ====================
_default_service: Optional[HandoffsService] = None
def get_handoffs_service(
agent_configs: Dict[str, Dict] = None,
llm_config: Dict[str, Any] = None,
streaming: bool = True
) -> HandoffsService:
"""获取 Handoffs 服务实例
Args:
agent_configs: Agent 配置(可选)
llm_config: LLM 配置(可选)
streaming: 是否流式
Returns:
HandoffsService 实例
"""
global _default_service
# 如果有自定义配置,创建新实例
if agent_configs or llm_config:
return HandoffsService(agent_configs, llm_config, streaming)
# 否则使用默认实例
if _default_service is None:
_default_service = HandoffsService(streaming=streaming)
return _default_service
def reset_default_service():
"""重置默认服务实例"""
global _default_service
if _default_service:
_default_service.reset()
_default_service = None