[add ] handoffs service and test
This commit is contained in:
@@ -1,23 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, status, Query, HTTPException
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from fastapi import APIRouter, Depends, status, HTTPException, Body, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
|
||||
from app.models.user_model import User
|
||||
from app.schemas import model_schema
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import AppChatRequest
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.services.handoffs_service import get_handoffs_service, reset_default_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.dependencies import get_current_user
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -28,6 +27,8 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
# ==================== 原有测试接口 ====================
|
||||
|
||||
@router.get("/llm/{model_id}", response_model=ApiResponse)
|
||||
def test_llm(
|
||||
model_id: uuid.UUID,
|
||||
@@ -50,7 +51,6 @@ def test_llm(
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
# ChatPromptTemplate
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | llm
|
||||
answer = chain.invoke({"question": "What is LangChain?"})
|
||||
@@ -80,13 +80,13 @@ def test_embedding(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
embeddings = model.embed_documents(data)
|
||||
print(embeddings)
|
||||
query = "我想找一个适合学习的地方。"
|
||||
@@ -114,13 +114,109 @@ def test_rerank(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
query = "最近哪家咖啡店评价最好?"
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
scores = model.rerank(query=query, documents=data, top_n=3)
|
||||
print(scores)
|
||||
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
|
||||
|
||||
|
||||
# ==================== Handoffs 测试接口 ====================
|
||||
|
||||
@router.post("/handoffs/{app_id}")
|
||||
async def test_handoffs(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: AppChatRequest = Body(...),
|
||||
current_user=Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试 Agent Handoffs 功能
|
||||
|
||||
演示 LangGraph 实现的多 Agent 协作和动态切换
|
||||
|
||||
- 默认从 sales_agent 开始
|
||||
- 根据用户问题自动切换到合适的 Agent
|
||||
- 使用 conversation_id 保持会话状态
|
||||
- 通过 stream 参数控制是否流式输出
|
||||
|
||||
事件类型(流式):
|
||||
- start: 开始执行
|
||||
- agent: 当前 Agent 信息
|
||||
- message: 流式消息内容
|
||||
- handoff: Agent 切换事件
|
||||
- end: 执行结束
|
||||
- error: 错误信息
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 获取或创建会话
|
||||
conversation_service = ConversationService(db)
|
||||
|
||||
if request.conversation_id:
|
||||
# 验证会话存在
|
||||
conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id))
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
conversation_id = str(conversation.id)
|
||||
else:
|
||||
# 创建新会话
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=request.user_id,
|
||||
is_draft=True
|
||||
)
|
||||
conversation_id = str(conversation.id)
|
||||
|
||||
# 根据 stream 参数决定返回方式
|
||||
if request.stream:
|
||||
# 流式返回
|
||||
service = get_handoffs_service(streaming=True)
|
||||
return StreamingResponse(
|
||||
service.chat_stream(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式返回
|
||||
service = get_handoffs_service(streaming=False)
|
||||
result = await service.chat(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return success(data=result, msg="Handoffs 测试成功")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Handoffs 测试失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/handoffs/agents", response_model=ApiResponse)
|
||||
def get_handoff_agents():
|
||||
"""获取可用的 Handoff Agent 列表"""
|
||||
service = get_handoffs_service()
|
||||
agents = service.get_agents()
|
||||
|
||||
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
|
||||
|
||||
|
||||
@router.delete("/handoffs/reset")
|
||||
def reset_handoff_service():
|
||||
"""重置 Handoff 服务(清除所有会话状态)"""
|
||||
reset_default_service()
|
||||
return success(msg="Handoff 服务已重置")
|
||||
|
||||
514
api/app/services/handoffs_service.py
Normal file
514
api/app/services/handoffs_service.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""Handoffs 服务 - 基于 LangGraph 的多 Agent 协作"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.types import Command
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
# ==================== 状态定义 ====================
|
||||
|
||||
class HandoffState(TypedDict):
|
||||
"""Handoff 状态"""
|
||||
messages: List[BaseMessage]
|
||||
active_agent: Optional[str]
|
||||
|
||||
|
||||
# ==================== 工具输入模型 ====================
|
||||
|
||||
class TransferInput(BaseModel):
|
||||
"""转移工具的输入参数"""
|
||||
reason: str = Field(description="转移原因")
|
||||
|
||||
|
||||
# ==================== 默认配置 ====================
|
||||
|
||||
DEFAULT_AGENT_CONFIGS = {
|
||||
"sales_agent": {
|
||||
"description": "转移到销售 Agent。当用户询问价格、购买或销售相关问题时使用。",
|
||||
"system_prompt": """你是一个销售 Agent。帮助用户解答销售相关问题。
|
||||
如果用户询问技术问题或需要技术支持,使用 transfer_to_support_agent 工具转移到支持 Agent。""",
|
||||
"can_transfer_to": ["support_agent"]
|
||||
},
|
||||
"support_agent": {
|
||||
"description": "转移到支持 Agent。当用户询问技术问题或需要帮助时使用。",
|
||||
"system_prompt": """你是一个技术支持 Agent。帮助用户解决技术问题。
|
||||
如果用户询问价格或购买相关问题,使用 transfer_to_sales_agent 工具转移到销售 Agent。""",
|
||||
"can_transfer_to": ["sales_agent"]
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_LLM_CONFIG = {
|
||||
"api_key": "sk-8e9e40cd171749858ce2d3722ea75669",
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"model": "qwen-plus",
|
||||
"temperature": 0.7
|
||||
}
|
||||
|
||||
|
||||
# ==================== 工具创建 ====================
|
||||
|
||||
def create_transfer_tool(target_agent: str, description: str):
|
||||
"""动态创建转移工具
|
||||
|
||||
Args:
|
||||
target_agent: 目标 Agent 名称
|
||||
description: 工具描述
|
||||
|
||||
Returns:
|
||||
转移工具函数
|
||||
"""
|
||||
tool_name = f"transfer_to_{target_agent}"
|
||||
|
||||
@tool(tool_name, args_schema=TransferInput)
|
||||
def transfer_tool(reason: str) -> Command:
|
||||
"""动态生成的转移工具"""
|
||||
return Command(
|
||||
goto=target_agent,
|
||||
update={"active_agent": target_agent},
|
||||
)
|
||||
|
||||
transfer_tool.__doc__ = description
|
||||
transfer_tool.description = description
|
||||
return transfer_tool
|
||||
|
||||
|
||||
def create_tools_for_agent(agent_name: str, configs: Dict) -> List:
|
||||
"""根据 Agent 配置动态创建其可用的转移工具
|
||||
|
||||
Args:
|
||||
agent_name: 当前 Agent 名称
|
||||
configs: Agent 配置字典
|
||||
|
||||
Returns:
|
||||
该 Agent 可用的工具列表
|
||||
"""
|
||||
config = configs.get(agent_name, {})
|
||||
can_transfer_to = config.get("can_transfer_to", [])
|
||||
|
||||
tools = []
|
||||
for target_agent in can_transfer_to:
|
||||
target_config = configs.get(target_agent, {})
|
||||
description = target_config.get("description", f"转移到 {target_agent}")
|
||||
tools.append(create_transfer_tool(target_agent, description))
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
# ==================== Agent 节点创建 ====================
|
||||
|
||||
def create_agent_node(agent_name: str, system_prompt: str, tools: List,
|
||||
api_key: str, base_url: str, model: str, temperature: float = 0.7):
|
||||
"""创建 Agent 节点(非流式)
|
||||
|
||||
Args:
|
||||
agent_name: Agent 名称
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表
|
||||
api_key: API Key
|
||||
base_url: API Base URL
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
Agent 节点函数
|
||||
"""
|
||||
llm = ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
).bind_tools(tools)
|
||||
|
||||
async def agent_node(state: HandoffState) -> Dict[str, Any]:
|
||||
"""Agent 节点执行函数"""
|
||||
logger.debug(f"Agent {agent_name} 执行, active_agent: {state.get('active_agent')}")
|
||||
|
||||
messages = state.get("messages", [])
|
||||
full_messages = [{"role": "system", "content": system_prompt}] + messages
|
||||
|
||||
response = await llm.ainvoke(full_messages)
|
||||
|
||||
# 检查工具调用
|
||||
if hasattr(response, 'tool_calls') and response.tool_calls:
|
||||
tool_call = response.tool_calls[0]
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
|
||||
if not tool_args.get("reason"):
|
||||
tool_args["reason"] = "用户请求转移"
|
||||
|
||||
for t in tools:
|
||||
if t.name == tool_name:
|
||||
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
|
||||
result = t.invoke(tool_args)
|
||||
if isinstance(result, Command):
|
||||
return result
|
||||
|
||||
return {"messages": [response]}
|
||||
|
||||
return agent_node
|
||||
|
||||
|
||||
def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List,
|
||||
api_key: str, base_url: str, model: str, temperature: float = 0.7):
|
||||
"""创建支持流式输出的 Agent 节点
|
||||
|
||||
Args:
|
||||
agent_name: Agent 名称
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表
|
||||
api_key: API Key
|
||||
base_url: API Base URL
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
Agent 节点函数
|
||||
"""
|
||||
llm = ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
streaming=True
|
||||
).bind_tools(tools)
|
||||
|
||||
async def agent_node(state: HandoffState):
|
||||
"""Agent 节点执行函数(流式)"""
|
||||
logger.debug(f"Agent {agent_name} 流式执行, active_agent: {state.get('active_agent')}")
|
||||
|
||||
messages = state.get("messages", [])
|
||||
full_messages = [{"role": "system", "content": system_prompt}] + messages
|
||||
|
||||
full_content = ""
|
||||
collected_tool_calls = {}
|
||||
|
||||
async for chunk in llm.astream(full_messages):
|
||||
if chunk.content:
|
||||
full_content += chunk.content
|
||||
|
||||
# 收集工具调用
|
||||
if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
|
||||
for tc in chunk.tool_calls:
|
||||
tc_id = tc.get("id") or "0"
|
||||
if tc_id not in collected_tool_calls:
|
||||
collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""}
|
||||
if tc.get("name"):
|
||||
collected_tool_calls[tc_id]["name"] = tc["name"]
|
||||
if tc.get("args"):
|
||||
if isinstance(tc["args"], dict):
|
||||
collected_tool_calls[tc_id]["args"] = tc["args"]
|
||||
elif isinstance(tc["args"], str):
|
||||
if isinstance(collected_tool_calls[tc_id]["args"], str):
|
||||
collected_tool_calls[tc_id]["args"] += tc["args"]
|
||||
|
||||
# 处理 tool_call_chunks
|
||||
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
|
||||
for tc_chunk in chunk.tool_call_chunks:
|
||||
idx = str(tc_chunk.get("index", 0))
|
||||
if idx not in collected_tool_calls:
|
||||
collected_tool_calls[idx] = {"id": tc_chunk.get("id", idx), "name": "", "args": ""}
|
||||
if tc_chunk.get("id"):
|
||||
collected_tool_calls[idx]["id"] = tc_chunk["id"]
|
||||
if tc_chunk.get("name"):
|
||||
collected_tool_calls[idx]["name"] = tc_chunk["name"]
|
||||
if tc_chunk.get("args"):
|
||||
if isinstance(collected_tool_calls[idx]["args"], str):
|
||||
collected_tool_calls[idx]["args"] += tc_chunk["args"]
|
||||
|
||||
# 解析工具调用
|
||||
tool_calls_list = list(collected_tool_calls.values())
|
||||
for tc in tool_calls_list:
|
||||
if isinstance(tc.get("args"), str) and tc["args"]:
|
||||
try:
|
||||
tc["args"] = json.loads(tc["args"])
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
tc["args"] = {}
|
||||
elif not tc.get("args"):
|
||||
tc["args"] = {}
|
||||
|
||||
# 执行工具调用
|
||||
if tool_calls_list and tool_calls_list[0].get("name"):
|
||||
tool_call = tool_calls_list[0]
|
||||
tool_name = tool_call.get("name", "")
|
||||
tool_args = tool_call.get("args", {})
|
||||
|
||||
if not tool_args.get("reason"):
|
||||
tool_args["reason"] = "用户请求转移"
|
||||
|
||||
for t in tools:
|
||||
if t.name == tool_name:
|
||||
logger.info(f"Agent {agent_name} 调用工具: {tool_name}")
|
||||
result = t.invoke(tool_args)
|
||||
if isinstance(result, Command):
|
||||
return result
|
||||
|
||||
return {"messages": [AIMessage(content=full_content)]}
|
||||
|
||||
return agent_node
|
||||
|
||||
|
||||
# ==================== 路由函数 ====================
|
||||
|
||||
def create_route_initial(default_agent: str = "sales_agent"):
|
||||
"""创建初始路由函数"""
|
||||
def route_initial(state: HandoffState) -> str:
|
||||
active = state.get("active_agent")
|
||||
if active:
|
||||
return active
|
||||
return default_agent
|
||||
return route_initial
|
||||
|
||||
|
||||
def route_after_agent(state: HandoffState) -> str:
|
||||
"""Agent 执行后的路由"""
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_msg = messages[-1]
|
||||
if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None):
|
||||
return END
|
||||
return state.get("active_agent", "sales_agent")
|
||||
|
||||
|
||||
# ==================== Handoffs 服务类 ====================
|
||||
|
||||
class HandoffsService:
|
||||
"""Handoffs 服务 - 管理多 Agent 协作"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_configs: Dict[str, Dict] = None,
|
||||
llm_config: Dict[str, Any] = None,
|
||||
streaming: bool = True
|
||||
):
|
||||
"""初始化 Handoffs 服务
|
||||
|
||||
Args:
|
||||
agent_configs: Agent 配置字典
|
||||
llm_config: LLM 配置
|
||||
streaming: 是否启用流式输出
|
||||
"""
|
||||
self.agent_configs = agent_configs or DEFAULT_AGENT_CONFIGS
|
||||
self.llm_config = llm_config or DEFAULT_LLM_CONFIG
|
||||
self.streaming = streaming
|
||||
self._graph = None
|
||||
|
||||
logger.info(f"HandoffsService 初始化, agents: {list(self.agent_configs.keys())}")
|
||||
|
||||
def _build_graph(self):
|
||||
"""构建 LangGraph 图"""
|
||||
builder = StateGraph(HandoffState)
|
||||
agent_names = list(self.agent_configs.keys())
|
||||
|
||||
for agent_name in agent_names:
|
||||
config = self.agent_configs[agent_name]
|
||||
tools = create_tools_for_agent(agent_name, self.agent_configs)
|
||||
|
||||
if self.streaming:
|
||||
agent_node = create_streaming_agent_node(
|
||||
agent_name=agent_name,
|
||||
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
|
||||
tools=tools,
|
||||
api_key=self.llm_config.get("api_key"),
|
||||
base_url=self.llm_config.get("base_url"),
|
||||
model=self.llm_config.get("model"),
|
||||
temperature=self.llm_config.get("temperature", 0.7)
|
||||
)
|
||||
else:
|
||||
agent_node = create_agent_node(
|
||||
agent_name=agent_name,
|
||||
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
|
||||
tools=tools,
|
||||
api_key=self.llm_config.get("api_key"),
|
||||
base_url=self.llm_config.get("base_url"),
|
||||
model=self.llm_config.get("model"),
|
||||
temperature=self.llm_config.get("temperature", 0.7)
|
||||
)
|
||||
builder.add_node(agent_name, agent_node)
|
||||
|
||||
# 添加边
|
||||
default_agent = agent_names[0] if agent_names else "sales_agent"
|
||||
builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names)
|
||||
|
||||
for agent_name in agent_names:
|
||||
builder.add_conditional_edges(agent_name, route_after_agent, agent_names + [END])
|
||||
|
||||
memory = MemorySaver()
|
||||
return builder.compile(checkpointer=memory)
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
"""获取图实例(懒加载)"""
|
||||
if self._graph is None:
|
||||
self._graph = self._build_graph()
|
||||
return self._graph
|
||||
|
||||
def reset(self):
|
||||
"""重置图实例"""
|
||||
self._graph = None
|
||||
logger.info("HandoffsService 图已重置")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""非流式聊天
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
聊天结果
|
||||
"""
|
||||
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
|
||||
config = {"configurable": {"thread_id": str(conversation_id)}}
|
||||
|
||||
logger.info(f"Handoffs chat: conversation_id={conversation_id}, message={message[:50]}...")
|
||||
|
||||
result = await self.graph.ainvoke({
|
||||
"messages": [HumanMessage(content=message)]
|
||||
}, config=config)
|
||||
|
||||
# 提取响应
|
||||
response_content = ""
|
||||
for msg in result.get("messages", []):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_content = msg.content
|
||||
break
|
||||
|
||||
return {
|
||||
"conversation_id": str(conversation_id),
|
||||
"active_agent": result.get("active_agent"),
|
||||
"response": response_content,
|
||||
"message_count": len(result.get("messages", []))
|
||||
}
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: str = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Yields:
|
||||
SSE 格式的事件
|
||||
"""
|
||||
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
|
||||
config = {"configurable": {"thread_id": str(conversation_id)}}
|
||||
|
||||
logger.info(f"Handoffs stream chat: conversation_id={conversation_id}, message={message[:50]}...")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
current_agent = None
|
||||
|
||||
try:
|
||||
async for event in self.graph.astream_events(
|
||||
{"messages": [HumanMessage(content=message)]},
|
||||
config=config,
|
||||
version="v2"
|
||||
):
|
||||
kind = event["event"]
|
||||
|
||||
# 捕获节点开始(Agent 切换)
|
||||
if kind == "on_chain_start":
|
||||
node_name = event.get("name", "")
|
||||
if node_name in self.agent_configs:
|
||||
if current_agent != node_name:
|
||||
current_agent = node_name
|
||||
yield f"event: agent\ndata: {json.dumps({'agent': node_name}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 捕获 LLM 流式输出
|
||||
elif kind == "on_chat_model_stream":
|
||||
content = event["data"]["chunk"].content
|
||||
if content:
|
||||
yield f"event: message\ndata: {json.dumps({'content': content}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 捕获工具调用(Handoff)
|
||||
elif kind == "on_tool_start":
|
||||
tool_name = event.get("name", "")
|
||||
if tool_name.startswith("transfer_to_"):
|
||||
target_agent = tool_name.replace("transfer_to_", "")
|
||||
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送结束事件
|
||||
yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Handoffs stream error: {str(e)}")
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
def get_agents(self) -> List[Dict[str, Any]]:
|
||||
"""获取可用的 Agent 列表
|
||||
|
||||
Returns:
|
||||
Agent 列表
|
||||
"""
|
||||
agents = []
|
||||
for name, config in self.agent_configs.items():
|
||||
agents.append({
|
||||
"name": name,
|
||||
"description": config.get("description", ""),
|
||||
"can_transfer_to": config.get("can_transfer_to", [])
|
||||
})
|
||||
return agents
|
||||
|
||||
|
||||
# ==================== 全局实例 ====================
|
||||
|
||||
_default_service: Optional[HandoffsService] = None
|
||||
|
||||
|
||||
def get_handoffs_service(
|
||||
agent_configs: Dict[str, Dict] = None,
|
||||
llm_config: Dict[str, Any] = None,
|
||||
streaming: bool = True
|
||||
) -> HandoffsService:
|
||||
"""获取 Handoffs 服务实例
|
||||
|
||||
Args:
|
||||
agent_configs: Agent 配置(可选)
|
||||
llm_config: LLM 配置(可选)
|
||||
streaming: 是否流式
|
||||
|
||||
Returns:
|
||||
HandoffsService 实例
|
||||
"""
|
||||
global _default_service
|
||||
|
||||
# 如果有自定义配置,创建新实例
|
||||
if agent_configs or llm_config:
|
||||
return HandoffsService(agent_configs, llm_config, streaming)
|
||||
|
||||
# 否则使用默认实例
|
||||
if _default_service is None:
|
||||
_default_service = HandoffsService(streaming=streaming)
|
||||
|
||||
return _default_service
|
||||
|
||||
|
||||
def reset_default_service():
|
||||
"""重置默认服务实例"""
|
||||
global _default_service
|
||||
if _default_service:
|
||||
_default_service.reset()
|
||||
_default_service = None
|
||||
Reference in New Issue
Block a user