[add ] handoffs service and test

This commit is contained in:
Mark
2026-01-07 14:58:23 +08:00
parent 957f8f83ff
commit ba2220d7c8
2 changed files with 635 additions and 25 deletions

View File

@@ -1,23 +1,22 @@
from fastapi import APIRouter, Depends, status, Query, HTTPException
from langchain_core.messages import HumanMessage, SystemMessage
from fastapi import APIRouter, Depends, status, HTTPException, Body, Path
from fastapi.responses import StreamingResponse
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
from app.models.user_model import User
from app.schemas import model_schema
from app.models.models_model import ModelApiKey
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import AppChatRequest
from app.services.model_service import ModelConfigService
from app.services.handoffs_service import get_handoffs_service, reset_default_service
from app.services.conversation_service import ConversationService
from app.core.logging_config import get_api_logger
from app.dependencies import get_current_user
# 获取API专用日志器
api_logger = get_api_logger()
@@ -28,6 +27,8 @@ router = APIRouter(
)
# ==================== 原有测试接口 ====================
@router.get("/llm/{model_id}", response_model=ApiResponse)
def test_llm(
model_id: uuid.UUID,
@@ -50,7 +51,6 @@ def test_llm(
template = """Question: {question}
Answer: Let's think step by step."""
# ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm
answer = chain.invoke({"question": "What is LangChain?"})
@@ -80,13 +80,13 @@ def test_embedding(
base_url=apiConfig.api_base
))
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
embeddings = model.embed_documents(data)
print(embeddings)
query = "我想找一个适合学习的地方。"
@@ -114,13 +114,109 @@ def test_rerank(
base_url=apiConfig.api_base
))
query = "最近哪家咖啡店评价最好?"
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
scores = model.rerank(query=query, documents=data, top_n=3)
print(scores)
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
# ==================== Handoffs 测试接口 ====================
@router.post("/handoffs/{app_id}")
async def test_handoffs(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: AppChatRequest = Body(...),
current_user=Depends(get_current_user),
db: Session = Depends(get_db)
):
"""测试 Agent Handoffs 功能
演示 LangGraph 实现的多 Agent 协作和动态切换
- 默认从 sales_agent 开始
- 根据用户问题自动切换到合适的 Agent
- 使用 conversation_id 保持会话状态
- 通过 stream 参数控制是否流式输出
事件类型(流式):
- start: 开始执行
- agent: 当前 Agent 信息
- message: 流式消息内容
- handoff: Agent 切换事件
- end: 执行结束
- error: 错误信息
"""
try:
workspace_id = current_user.current_workspace_id
# 获取或创建会话
conversation_service = ConversationService(db)
if request.conversation_id:
# 验证会话存在
conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id))
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
conversation_id = str(conversation.id)
else:
# 创建新会话
conversation = conversation_service.create_or_get_conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=request.user_id,
is_draft=True
)
conversation_id = str(conversation.id)
# 根据 stream 参数决定返回方式
if request.stream:
# 流式返回
service = get_handoffs_service(streaming=True)
return StreamingResponse(
service.chat_stream(
message=request.message,
conversation_id=conversation_id
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
else:
# 非流式返回
service = get_handoffs_service(streaming=False)
result = await service.chat(
message=request.message,
conversation_id=conversation_id
)
return success(data=result, msg="Handoffs 测试成功")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Handoffs 测试失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/handoffs/agents", response_model=ApiResponse)
def get_handoff_agents():
"""获取可用的 Handoff Agent 列表"""
service = get_handoffs_service()
agents = service.get_agents()
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
@router.delete("/handoffs/reset")
def reset_handoff_service():
"""重置 Handoff 服务(清除所有会话状态)"""
reset_default_service()
return success(msg="Handoff 服务已重置")

View File

@@ -0,0 +1,514 @@
"""Handoffs 服务 - 基于 LangGraph 的多 Agent 协作"""
import json
import uuid
from typing import List, Dict, Any, Optional, AsyncGenerator
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from app.core.logging_config import get_business_logger
logger = get_business_logger()
# ==================== 状态定义 ====================
class HandoffState(TypedDict):
"""Handoff 状态"""
messages: List[BaseMessage]
active_agent: Optional[str]
# ==================== 工具输入模型 ====================
class TransferInput(BaseModel):
"""转移工具的输入参数"""
reason: str = Field(description="转移原因")
# ==================== 默认配置 ====================
DEFAULT_AGENT_CONFIGS = {
"sales_agent": {
"description": "转移到销售 Agent。当用户询问价格、购买或销售相关问题时使用。",
"system_prompt": """你是一个销售 Agent。帮助用户解答销售相关问题。
如果用户询问技术问题或需要技术支持,使用 transfer_to_support_agent 工具转移到支持 Agent。""",
"can_transfer_to": ["support_agent"]
},
"support_agent": {
"description": "转移到支持 Agent。当用户询问技术问题或需要帮助时使用。",
"system_prompt": """你是一个技术支持 Agent。帮助用户解决技术问题。
如果用户询问价格或购买相关问题,使用 transfer_to_sales_agent 工具转移到销售 Agent。""",
"can_transfer_to": ["sales_agent"]
}
}
DEFAULT_LLM_CONFIG = {
"api_key": "sk-8e9e40cd171749858ce2d3722ea75669",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen-plus",
"temperature": 0.7
}
# ==================== 工具创建 ====================
def create_transfer_tool(target_agent: str, description: str):
"""动态创建转移工具
Args:
target_agent: 目标 Agent 名称
description: 工具描述
Returns:
转移工具函数
"""
tool_name = f"transfer_to_{target_agent}"
@tool(tool_name, args_schema=TransferInput)
def transfer_tool(reason: str) -> Command:
"""动态生成的转移工具"""
return Command(
goto=target_agent,
update={"active_agent": target_agent},
)
transfer_tool.__doc__ = description
transfer_tool.description = description
return transfer_tool
def create_tools_for_agent(agent_name: str, configs: Dict) -> List:
"""根据 Agent 配置动态创建其可用的转移工具
Args:
agent_name: 当前 Agent 名称
configs: Agent 配置字典
Returns:
该 Agent 可用的工具列表
"""
config = configs.get(agent_name, {})
can_transfer_to = config.get("can_transfer_to", [])
tools = []
for target_agent in can_transfer_to:
target_config = configs.get(target_agent, {})
description = target_config.get("description", f"转移到 {target_agent}")
tools.append(create_transfer_tool(target_agent, description))
return tools
# ==================== Agent 节点创建 ====================
def create_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7):
"""创建 Agent 节点(非流式)
Args:
agent_name: Agent 名称
system_prompt: 系统提示词
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url
).bind_tools(tools)
async def agent_node(state: HandoffState) -> Dict[str, Any]:
"""Agent 节点执行函数"""
logger.debug(f"Agent {agent_name} 执行, active_agent: {state.get('active_agent')}")
messages = state.get("messages", [])
full_messages = [{"role": "system", "content": system_prompt}] + messages
response = await llm.ainvoke(full_messages)
# 检查工具调用
if hasattr(response, 'tool_calls') and response.tool_calls:
tool_call = response.tool_calls[0]
tool_name = tool_call["name"]
tool_args = tool_call["args"]
if not tool_args.get("reason"):
tool_args["reason"] = "用户请求转移"
for t in tools:
if t.name == tool_name:
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
result = t.invoke(tool_args)
if isinstance(result, Command):
return result
return {"messages": [response]}
return agent_node
def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7):
"""创建支持流式输出的 Agent 节点
Args:
agent_name: Agent 名称
system_prompt: 系统提示词
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
streaming=True
).bind_tools(tools)
async def agent_node(state: HandoffState):
"""Agent 节点执行函数(流式)"""
logger.debug(f"Agent {agent_name} 流式执行, active_agent: {state.get('active_agent')}")
messages = state.get("messages", [])
full_messages = [{"role": "system", "content": system_prompt}] + messages
full_content = ""
collected_tool_calls = {}
async for chunk in llm.astream(full_messages):
if chunk.content:
full_content += chunk.content
# 收集工具调用
if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
for tc in chunk.tool_calls:
tc_id = tc.get("id") or "0"
if tc_id not in collected_tool_calls:
collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""}
if tc.get("name"):
collected_tool_calls[tc_id]["name"] = tc["name"]
if tc.get("args"):
if isinstance(tc["args"], dict):
collected_tool_calls[tc_id]["args"] = tc["args"]
elif isinstance(tc["args"], str):
if isinstance(collected_tool_calls[tc_id]["args"], str):
collected_tool_calls[tc_id]["args"] += tc["args"]
# 处理 tool_call_chunks
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = str(tc_chunk.get("index", 0))
if idx not in collected_tool_calls:
collected_tool_calls[idx] = {"id": tc_chunk.get("id", idx), "name": "", "args": ""}
if tc_chunk.get("id"):
collected_tool_calls[idx]["id"] = tc_chunk["id"]
if tc_chunk.get("name"):
collected_tool_calls[idx]["name"] = tc_chunk["name"]
if tc_chunk.get("args"):
if isinstance(collected_tool_calls[idx]["args"], str):
collected_tool_calls[idx]["args"] += tc_chunk["args"]
# 解析工具调用
tool_calls_list = list(collected_tool_calls.values())
for tc in tool_calls_list:
if isinstance(tc.get("args"), str) and tc["args"]:
try:
tc["args"] = json.loads(tc["args"])
except (json.JSONDecodeError, ValueError):
tc["args"] = {}
elif not tc.get("args"):
tc["args"] = {}
# 执行工具调用
if tool_calls_list and tool_calls_list[0].get("name"):
tool_call = tool_calls_list[0]
tool_name = tool_call.get("name", "")
tool_args = tool_call.get("args", {})
if not tool_args.get("reason"):
tool_args["reason"] = "用户请求转移"
for t in tools:
if t.name == tool_name:
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
result = t.invoke(tool_args)
if isinstance(result, Command):
return result
return {"messages": [AIMessage(content=full_content)]}
return agent_node
# ==================== 路由函数 ====================
def create_route_initial(default_agent: str = "sales_agent"):
"""创建初始路由函数"""
def route_initial(state: HandoffState) -> str:
active = state.get("active_agent")
if active:
return active
return default_agent
return route_initial
def route_after_agent(state: HandoffState) -> str:
"""Agent 执行后的路由"""
messages = state.get("messages", [])
if messages:
last_msg = messages[-1]
if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None):
return END
return state.get("active_agent", "sales_agent")
# ==================== Handoffs 服务类 ====================
class HandoffsService:
"""Handoffs 服务 - 管理多 Agent 协作"""
def __init__(
self,
agent_configs: Dict[str, Dict] = None,
llm_config: Dict[str, Any] = None,
streaming: bool = True
):
"""初始化 Handoffs 服务
Args:
agent_configs: Agent 配置字典
llm_config: LLM 配置
streaming: 是否启用流式输出
"""
self.agent_configs = agent_configs or DEFAULT_AGENT_CONFIGS
self.llm_config = llm_config or DEFAULT_LLM_CONFIG
self.streaming = streaming
self._graph = None
logger.info(f"HandoffsService 初始化, agents: {list(self.agent_configs.keys())}")
def _build_graph(self):
"""构建 LangGraph 图"""
builder = StateGraph(HandoffState)
agent_names = list(self.agent_configs.keys())
for agent_name in agent_names:
config = self.agent_configs[agent_name]
tools = create_tools_for_agent(agent_name, self.agent_configs)
if self.streaming:
agent_node = create_streaming_agent_node(
agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools,
api_key=self.llm_config.get("api_key"),
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
)
else:
agent_node = create_agent_node(
agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools,
api_key=self.llm_config.get("api_key"),
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
)
builder.add_node(agent_name, agent_node)
# 添加边
default_agent = agent_names[0] if agent_names else "sales_agent"
builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names)
for agent_name in agent_names:
builder.add_conditional_edges(agent_name, route_after_agent, agent_names + [END])
memory = MemorySaver()
return builder.compile(checkpointer=memory)
@property
def graph(self):
"""获取图实例(懒加载)"""
if self._graph is None:
self._graph = self._build_graph()
return self._graph
def reset(self):
"""重置图实例"""
self._graph = None
logger.info("HandoffsService 图已重置")
async def chat(
self,
message: str,
conversation_id: str = None
) -> Dict[str, Any]:
"""非流式聊天
Args:
message: 用户消息
conversation_id: 会话 ID
Returns:
聊天结果
"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}}
logger.info(f"Handoffs chat: conversation_id={conversation_id}, message={message[:50]}...")
result = await self.graph.ainvoke({
"messages": [HumanMessage(content=message)]
}, config=config)
# 提取响应
response_content = ""
for msg in result.get("messages", []):
if isinstance(msg, AIMessage):
response_content = msg.content
break
return {
"conversation_id": str(conversation_id),
"active_agent": result.get("active_agent"),
"response": response_content,
"message_count": len(result.get("messages", []))
}
async def chat_stream(
self,
message: str,
conversation_id: str = None
) -> AsyncGenerator[str, None]:
"""流式聊天
Args:
message: 用户消息
conversation_id: 会话 ID
Yields:
SSE 格式的事件
"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}}
logger.info(f"Handoffs stream chat: conversation_id={conversation_id}, message={message[:50]}...")
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
current_agent = None
try:
async for event in self.graph.astream_events(
{"messages": [HumanMessage(content=message)]},
config=config,
version="v2"
):
kind = event["event"]
# 捕获节点开始Agent 切换)
if kind == "on_chain_start":
node_name = event.get("name", "")
if node_name in self.agent_configs:
if current_agent != node_name:
current_agent = node_name
yield f"event: agent\ndata: {json.dumps({'agent': node_name}, ensure_ascii=False)}\n\n"
# 捕获 LLM 流式输出
elif kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"event: message\ndata: {json.dumps({'content': content}, ensure_ascii=False)}\n\n"
# 捕获工具调用Handoff
elif kind == "on_tool_start":
tool_name = event.get("name", "")
if tool_name.startswith("transfer_to_"):
target_agent = tool_name.replace("transfer_to_", "")
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent}, ensure_ascii=False)}\n\n"
# 发送结束事件
yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n"
except Exception as e:
logger.error(f"Handoffs stream error: {str(e)}")
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
def get_agents(self) -> List[Dict[str, Any]]:
"""获取可用的 Agent 列表
Returns:
Agent 列表
"""
agents = []
for name, config in self.agent_configs.items():
agents.append({
"name": name,
"description": config.get("description", ""),
"can_transfer_to": config.get("can_transfer_to", [])
})
return agents
# ==================== 全局实例 ====================
_default_service: Optional[HandoffsService] = None
def get_handoffs_service(
agent_configs: Dict[str, Dict] = None,
llm_config: Dict[str, Any] = None,
streaming: bool = True
) -> HandoffsService:
"""获取 Handoffs 服务实例
Args:
agent_configs: Agent 配置(可选)
llm_config: LLM 配置(可选)
streaming: 是否流式
Returns:
HandoffsService 实例
"""
global _default_service
# 如果有自定义配置,创建新实例
if agent_configs or llm_config:
return HandoffsService(agent_configs, llm_config, streaming)
# 否则使用默认实例
if _default_service is None:
_default_service = HandoffsService(streaming=streaming)
return _default_service
def reset_default_service():
"""重置默认服务实例"""
global _default_service
if _default_service:
_default_service.reset()
_default_service = None