From ba2220d7c85a68e854617fce499e0a525bcaea6c Mon Sep 17 00:00:00 2001 From: Mark Date: Wed, 7 Jan 2026 14:58:23 +0800 Subject: [PATCH] [add ] handoffs service and test --- api/app/controllers/test_controller.py | 146 +++++-- api/app/services/handoffs_service.py | 514 +++++++++++++++++++++++++ 2 files changed, 635 insertions(+), 25 deletions(-) create mode 100644 api/app/services/handoffs_service.py diff --git a/api/app/controllers/test_controller.py b/api/app/controllers/test_controller.py index 98cbe26e..a8a30b13 100644 --- a/api/app/controllers/test_controller.py +++ b/api/app/controllers/test_controller.py @@ -1,23 +1,22 @@ -from fastapi import APIRouter, Depends, status, Query, HTTPException -from langchain_core.messages import HumanMessage, SystemMessage +from fastapi import APIRouter, Depends, status, HTTPException, Body, Path +from fastapi.responses import StreamingResponse from langchain_core.prompts import ChatPromptTemplate from sqlalchemy.orm import Session -from typing import List, Optional import uuid - from app.core.models import RedBearLLM, RedBearRerank from app.core.models.base import RedBearModelConfig from app.core.models.embedding import RedBearEmbeddings from app.db import get_db -from app.dependencies import get_current_user -from app.models.models_model import ModelApiKey, ModelProvider, ModelType -from app.models.user_model import User -from app.schemas import model_schema +from app.models.models_model import ModelApiKey from app.core.response_utils import success -from app.schemas.response_schema import ApiResponse, PageData -from app.services.model_service import ModelConfigService, ModelApiKeyService +from app.schemas.response_schema import ApiResponse +from app.schemas.app_schema import AppChatRequest +from app.services.model_service import ModelConfigService +from app.services.handoffs_service import get_handoffs_service, reset_default_service +from app.services.conversation_service import ConversationService from app.core.logging_config import get_api_logger +from app.dependencies import get_current_user # 获取API专用日志器 api_logger = get_api_logger() @@ -28,6 +27,8 @@ router = APIRouter( ) +# ==================== 原有测试接口 ==================== + @router.get("/llm/{model_id}", response_model=ApiResponse) def test_llm( model_id: uuid.UUID, @@ -50,7 +51,6 @@ def test_llm( template = """Question: {question} Answer: Let's think step by step.""" - # ChatPromptTemplate prompt = ChatPromptTemplate.from_template(template) chain = prompt | llm answer = chain.invoke({"question": "What is LangChain?"}) @@ -80,13 +80,13 @@ def test_embedding( base_url=apiConfig.api_base )) - data = [ - "最近哪家咖啡店评价最好?", - "附近有没有推荐的咖啡厅?", - "明天天气预报说会下雨。", - "北京是中国的首都。", - "我想找一个适合学习的地方。" - ] + data = [ + "最近哪家咖啡店评价最好?", + "附近有没有推荐的咖啡厅?", + "明天天气预报说会下雨。", + "北京是中国的首都。", + "我想找一个适合学习的地方。" + ] embeddings = model.embed_documents(data) print(embeddings) query = "我想找一个适合学习的地方。" @@ -114,13 +114,109 @@ def test_rerank( base_url=apiConfig.api_base )) query = "最近哪家咖啡店评价最好?" - data = [ - "最近哪家咖啡店评价最好?", - "附近有没有推荐的咖啡厅?", - "明天天气预报说会下雨。", - "北京是中国的首都。", - "我想找一个适合学习的地方。" - ] + data = [ + "最近哪家咖啡店评价最好?", + "附近有没有推荐的咖啡厅?", + "明天天气预报说会下雨。", + "北京是中国的首都。", + "我想找一个适合学习的地方。" + ] scores = model.rerank(query=query, documents=data, top_n=3) print(scores) return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores}) + + +# ==================== Handoffs 测试接口 ==================== + +@router.post("/handoffs/{app_id}") +async def test_handoffs( + app_id: uuid.UUID = Path(..., description="应用 ID"), + request: AppChatRequest = Body(...), + current_user=Depends(get_current_user), + db: Session = Depends(get_db) +): + """测试 Agent Handoffs 功能 + + 演示 LangGraph 实现的多 Agent 协作和动态切换 + + - 默认从 sales_agent 开始 + - 根据用户问题自动切换到合适的 Agent + - 使用 conversation_id 保持会话状态 + - 通过 stream 参数控制是否流式输出 + + 事件类型(流式): + - start: 开始执行 + - agent: 当前 Agent 信息 + - message: 流式消息内容 + - handoff: Agent 切换事件 + - end: 执行结束 + - error: 错误信息 + """ + try: + workspace_id = current_user.current_workspace_id + + # 获取或创建会话 + conversation_service = ConversationService(db) + + if request.conversation_id: + # 验证会话存在 + conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id)) + if not conversation: + raise HTTPException(status_code=404, detail="会话不存在") + conversation_id = str(conversation.id) + else: + # 创建新会话 + conversation = conversation_service.create_or_get_conversation( + app_id=app_id, + workspace_id=workspace_id, + user_id=request.user_id, + is_draft=True + ) + conversation_id = str(conversation.id) + + # 根据 stream 参数决定返回方式 + if request.stream: + # 流式返回 + service = get_handoffs_service(streaming=True) + return StreamingResponse( + service.chat_stream( + message=request.message, + conversation_id=conversation_id + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + else: + # 非流式返回 + service = get_handoffs_service(streaming=False) + result = await service.chat( + message=request.message, + conversation_id=conversation_id + ) + return success(data=result, msg="Handoffs 测试成功") + + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Handoffs 测试失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/handoffs/agents", response_model=ApiResponse) +def get_handoff_agents(): + """获取可用的 Handoff Agent 列表""" + service = get_handoffs_service() + agents = service.get_agents() + + return success(data={"agents": agents}, msg="获取 Agent 列表成功") + + +@router.delete("/handoffs/reset") +def reset_handoff_service(): + """重置 Handoff 服务(清除所有会话状态)""" + reset_default_service() + return success(msg="Handoff 服务已重置") diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py new file mode 100644 index 00000000..64432299 --- /dev/null +++ b/api/app/services/handoffs_service.py @@ -0,0 +1,514 @@ +"""Handoffs 服务 - 基于 LangGraph 的多 Agent 协作""" +import json +import uuid +from typing import List, Dict, Any, Optional, AsyncGenerator +from typing_extensions import TypedDict + +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage +from langchain_openai import ChatOpenAI +from langgraph.graph import StateGraph, START, END +from langgraph.types import Command +from langgraph.checkpoint.memory import MemorySaver +from langchain_core.tools import tool +from pydantic import BaseModel, Field + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +# ==================== 状态定义 ==================== + +class HandoffState(TypedDict): + """Handoff 状态""" + messages: List[BaseMessage] + active_agent: Optional[str] + + +# ==================== 工具输入模型 ==================== + +class TransferInput(BaseModel): + """转移工具的输入参数""" + reason: str = Field(description="转移原因") + + +# ==================== 默认配置 ==================== + +DEFAULT_AGENT_CONFIGS = { + "sales_agent": { + "description": "转移到销售 Agent。当用户询问价格、购买或销售相关问题时使用。", + "system_prompt": """你是一个销售 Agent。帮助用户解答销售相关问题。 +如果用户询问技术问题或需要技术支持,使用 transfer_to_support_agent 工具转移到支持 Agent。""", + "can_transfer_to": ["support_agent"] + }, + "support_agent": { + "description": "转移到支持 Agent。当用户询问技术问题或需要帮助时使用。", + "system_prompt": """你是一个技术支持 Agent。帮助用户解决技术问题。 +如果用户询问价格或购买相关问题,使用 transfer_to_sales_agent 工具转移到销售 Agent。""", + "can_transfer_to": ["sales_agent"] + } +} + +DEFAULT_LLM_CONFIG = { + "api_key": "sk-8e9e40cd171749858ce2d3722ea75669", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "qwen-plus", + "temperature": 0.7 +} + + +# ==================== 工具创建 ==================== + +def create_transfer_tool(target_agent: str, description: str): + """动态创建转移工具 + + Args: + target_agent: 目标 Agent 名称 + description: 工具描述 + + Returns: + 转移工具函数 + """ + tool_name = f"transfer_to_{target_agent}" + + @tool(tool_name, args_schema=TransferInput) + def transfer_tool(reason: str) -> Command: + """动态生成的转移工具""" + return Command( + goto=target_agent, + update={"active_agent": target_agent}, + ) + + transfer_tool.__doc__ = description + transfer_tool.description = description + return transfer_tool + + +def create_tools_for_agent(agent_name: str, configs: Dict) -> List: + """根据 Agent 配置动态创建其可用的转移工具 + + Args: + agent_name: 当前 Agent 名称 + configs: Agent 配置字典 + + Returns: + 该 Agent 可用的工具列表 + """ + config = configs.get(agent_name, {}) + can_transfer_to = config.get("can_transfer_to", []) + + tools = [] + for target_agent in can_transfer_to: + target_config = configs.get(target_agent, {}) + description = target_config.get("description", f"转移到 {target_agent}") + tools.append(create_transfer_tool(target_agent, description)) + + return tools + + +# ==================== Agent 节点创建 ==================== + +def create_agent_node(agent_name: str, system_prompt: str, tools: List, + api_key: str, base_url: str, model: str, temperature: float = 0.7): + """创建 Agent 节点(非流式) + + Args: + agent_name: Agent 名称 + system_prompt: 系统提示词 + tools: 工具列表 + api_key: API Key + base_url: API Base URL + model: 模型名称 + temperature: 温度参数 + + Returns: + Agent 节点函数 + """ + llm = ChatOpenAI( + model=model, + temperature=temperature, + api_key=api_key, + base_url=base_url + ).bind_tools(tools) + + async def agent_node(state: HandoffState) -> Dict[str, Any]: + """Agent 节点执行函数""" + logger.debug(f"Agent {agent_name} 执行, active_agent: {state.get('active_agent')}") + + messages = state.get("messages", []) + full_messages = [{"role": "system", "content": system_prompt}] + messages + + response = await llm.ainvoke(full_messages) + + # 检查工具调用 + if hasattr(response, 'tool_calls') and response.tool_calls: + tool_call = response.tool_calls[0] + tool_name = tool_call["name"] + tool_args = tool_call["args"] + + if not tool_args.get("reason"): + tool_args["reason"] = "用户请求转移" + + for t in tools: + if t.name == tool_name: + logger.info(f"Agent {agent_name} 调用工具: {tool_name}") + result = t.invoke(tool_args) + if isinstance(result, Command): + return result + + return {"messages": [response]} + + return agent_node + + +def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List, + api_key: str, base_url: str, model: str, temperature: float = 0.7): + """创建支持流式输出的 Agent 节点 + + Args: + agent_name: Agent 名称 + system_prompt: 系统提示词 + tools: 工具列表 + api_key: API Key + base_url: API Base URL + model: 模型名称 + temperature: 温度参数 + + Returns: + Agent 节点函数 + """ + llm = ChatOpenAI( + model=model, + temperature=temperature, + api_key=api_key, + base_url=base_url, + streaming=True + ).bind_tools(tools) + + async def agent_node(state: HandoffState): + """Agent 节点执行函数(流式)""" + logger.debug(f"Agent {agent_name} 流式执行, active_agent: {state.get('active_agent')}") + + messages = state.get("messages", []) + full_messages = [{"role": "system", "content": system_prompt}] + messages + + full_content = "" + collected_tool_calls = {} + + async for chunk in llm.astream(full_messages): + if chunk.content: + full_content += chunk.content + + # 收集工具调用 + if hasattr(chunk, 'tool_calls') and chunk.tool_calls: + for tc in chunk.tool_calls: + tc_id = tc.get("id") or "0" + if tc_id not in collected_tool_calls: + collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""} + if tc.get("name"): + collected_tool_calls[tc_id]["name"] = tc["name"] + if tc.get("args"): + if isinstance(tc["args"], dict): + collected_tool_calls[tc_id]["args"] = tc["args"] + elif isinstance(tc["args"], str): + if isinstance(collected_tool_calls[tc_id]["args"], str): + collected_tool_calls[tc_id]["args"] += tc["args"] + + # 处理 tool_call_chunks + if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: + for tc_chunk in chunk.tool_call_chunks: + idx = str(tc_chunk.get("index", 0)) + if idx not in collected_tool_calls: + collected_tool_calls[idx] = {"id": tc_chunk.get("id", idx), "name": "", "args": ""} + if tc_chunk.get("id"): + collected_tool_calls[idx]["id"] = tc_chunk["id"] + if tc_chunk.get("name"): + collected_tool_calls[idx]["name"] = tc_chunk["name"] + if tc_chunk.get("args"): + if isinstance(collected_tool_calls[idx]["args"], str): + collected_tool_calls[idx]["args"] += tc_chunk["args"] + + # 解析工具调用 + tool_calls_list = list(collected_tool_calls.values()) + for tc in tool_calls_list: + if isinstance(tc.get("args"), str) and tc["args"]: + try: + tc["args"] = json.loads(tc["args"]) + except (json.JSONDecodeError, ValueError): + tc["args"] = {} + elif not tc.get("args"): + tc["args"] = {} + + # 执行工具调用 + if tool_calls_list and tool_calls_list[0].get("name"): + tool_call = tool_calls_list[0] + tool_name = tool_call.get("name", "") + tool_args = tool_call.get("args", {}) + + if not tool_args.get("reason"): + tool_args["reason"] = "用户请求转移" + + for t in tools: + if t.name == tool_name: + logger.info(f"Agent {agent_name} 调用工具: {tool_name}") + result = t.invoke(tool_args) + if isinstance(result, Command): + return result + + return {"messages": [AIMessage(content=full_content)]} + + return agent_node + + +# ==================== 路由函数 ==================== + +def create_route_initial(default_agent: str = "sales_agent"): + """创建初始路由函数""" + def route_initial(state: HandoffState) -> str: + active = state.get("active_agent") + if active: + return active + return default_agent + return route_initial + + +def route_after_agent(state: HandoffState) -> str: + """Agent 执行后的路由""" + messages = state.get("messages", []) + if messages: + last_msg = messages[-1] + if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None): + return END + return state.get("active_agent", "sales_agent") + + +# ==================== Handoffs 服务类 ==================== + +class HandoffsService: + """Handoffs 服务 - 管理多 Agent 协作""" + + def __init__( + self, + agent_configs: Dict[str, Dict] = None, + llm_config: Dict[str, Any] = None, + streaming: bool = True + ): + """初始化 Handoffs 服务 + + Args: + agent_configs: Agent 配置字典 + llm_config: LLM 配置 + streaming: 是否启用流式输出 + """ + self.agent_configs = agent_configs or DEFAULT_AGENT_CONFIGS + self.llm_config = llm_config or DEFAULT_LLM_CONFIG + self.streaming = streaming + self._graph = None + + logger.info(f"HandoffsService 初始化, agents: {list(self.agent_configs.keys())}") + + def _build_graph(self): + """构建 LangGraph 图""" + builder = StateGraph(HandoffState) + agent_names = list(self.agent_configs.keys()) + + for agent_name in agent_names: + config = self.agent_configs[agent_name] + tools = create_tools_for_agent(agent_name, self.agent_configs) + + if self.streaming: + agent_node = create_streaming_agent_node( + agent_name=agent_name, + system_prompt=config.get("system_prompt", f"你是 {agent_name}"), + tools=tools, + api_key=self.llm_config.get("api_key"), + base_url=self.llm_config.get("base_url"), + model=self.llm_config.get("model"), + temperature=self.llm_config.get("temperature", 0.7) + ) + else: + agent_node = create_agent_node( + agent_name=agent_name, + system_prompt=config.get("system_prompt", f"你是 {agent_name}"), + tools=tools, + api_key=self.llm_config.get("api_key"), + base_url=self.llm_config.get("base_url"), + model=self.llm_config.get("model"), + temperature=self.llm_config.get("temperature", 0.7) + ) + builder.add_node(agent_name, agent_node) + + # 添加边 + default_agent = agent_names[0] if agent_names else "sales_agent" + builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names) + + for agent_name in agent_names: + builder.add_conditional_edges(agent_name, route_after_agent, agent_names + [END]) + + memory = MemorySaver() + return builder.compile(checkpointer=memory) + + @property + def graph(self): + """获取图实例(懒加载)""" + if self._graph is None: + self._graph = self._build_graph() + return self._graph + + def reset(self): + """重置图实例""" + self._graph = None + logger.info("HandoffsService 图已重置") + + async def chat( + self, + message: str, + conversation_id: str = None + ) -> Dict[str, Any]: + """非流式聊天 + + Args: + message: 用户消息 + conversation_id: 会话 ID + + Returns: + 聊天结果 + """ + conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" + config = {"configurable": {"thread_id": str(conversation_id)}} + + logger.info(f"Handoffs chat: conversation_id={conversation_id}, message={message[:50]}...") + + result = await self.graph.ainvoke({ + "messages": [HumanMessage(content=message)] + }, config=config) + + # 提取响应 + response_content = "" + for msg in result.get("messages", []): + if isinstance(msg, AIMessage): + response_content = msg.content + break + + return { + "conversation_id": str(conversation_id), + "active_agent": result.get("active_agent"), + "response": response_content, + "message_count": len(result.get("messages", [])) + } + + async def chat_stream( + self, + message: str, + conversation_id: str = None + ) -> AsyncGenerator[str, None]: + """流式聊天 + + Args: + message: 用户消息 + conversation_id: 会话 ID + + Yields: + SSE 格式的事件 + """ + conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" + config = {"configurable": {"thread_id": str(conversation_id)}} + + logger.info(f"Handoffs stream chat: conversation_id={conversation_id}, message={message[:50]}...") + + # 发送开始事件 + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" + + current_agent = None + + try: + async for event in self.graph.astream_events( + {"messages": [HumanMessage(content=message)]}, + config=config, + version="v2" + ): + kind = event["event"] + + # 捕获节点开始(Agent 切换) + if kind == "on_chain_start": + node_name = event.get("name", "") + if node_name in self.agent_configs: + if current_agent != node_name: + current_agent = node_name + yield f"event: agent\ndata: {json.dumps({'agent': node_name}, ensure_ascii=False)}\n\n" + + # 捕获 LLM 流式输出 + elif kind == "on_chat_model_stream": + content = event["data"]["chunk"].content + if content: + yield f"event: message\ndata: {json.dumps({'content': content}, ensure_ascii=False)}\n\n" + + # 捕获工具调用(Handoff) + elif kind == "on_tool_start": + tool_name = event.get("name", "") + if tool_name.startswith("transfer_to_"): + target_agent = tool_name.replace("transfer_to_", "") + yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent}, ensure_ascii=False)}\n\n" + + # 发送结束事件 + yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n" + + except Exception as e: + logger.error(f"Handoffs stream error: {str(e)}") + yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + + def get_agents(self) -> List[Dict[str, Any]]: + """获取可用的 Agent 列表 + + Returns: + Agent 列表 + """ + agents = [] + for name, config in self.agent_configs.items(): + agents.append({ + "name": name, + "description": config.get("description", ""), + "can_transfer_to": config.get("can_transfer_to", []) + }) + return agents + + +# ==================== 全局实例 ==================== + +_default_service: Optional[HandoffsService] = None + + +def get_handoffs_service( + agent_configs: Dict[str, Dict] = None, + llm_config: Dict[str, Any] = None, + streaming: bool = True +) -> HandoffsService: + """获取 Handoffs 服务实例 + + Args: + agent_configs: Agent 配置(可选) + llm_config: LLM 配置(可选) + streaming: 是否流式 + + Returns: + HandoffsService 实例 + """ + global _default_service + + # 如果有自定义配置,创建新实例 + if agent_configs or llm_config: + return HandoffsService(agent_configs, llm_config, streaming) + + # 否则使用默认实例 + if _default_service is None: + _default_service = HandoffsService(streaming=streaming) + + return _default_service + + +def reset_default_service(): + """重置默认服务实例""" + global _default_service + if _default_service: + _default_service.reset() + _default_service = None