From cd76ccadc5221eb35157c8565e2cc98f3961067b Mon Sep 17 00:00:00 2001 From: Mark Date: Wed, 7 Jan 2026 15:51:12 +0800 Subject: [PATCH] [modify] handoffs test --- api/app/controllers/test_controller.py | 44 ++- api/app/services/handoffs_service.py | 391 +++++++++++++++---------- 2 files changed, 270 insertions(+), 165 deletions(-) diff --git a/api/app/controllers/test_controller.py b/api/app/controllers/test_controller.py index a8a30b13..5746405a 100644 --- a/api/app/controllers/test_controller.py +++ b/api/app/controllers/test_controller.py @@ -13,7 +13,7 @@ from app.core.response_utils import success from app.schemas.response_schema import ApiResponse from app.schemas.app_schema import AppChatRequest from app.services.model_service import ModelConfigService -from app.services.handoffs_service import get_handoffs_service, reset_default_service +from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache from app.services.conversation_service import ConversationService from app.core.logging_config import get_api_logger from app.dependencies import get_current_user @@ -139,7 +139,7 @@ async def test_handoffs( 演示 LangGraph 实现的多 Agent 协作和动态切换 - - 默认从 sales_agent 开始 + - 从数据库 multi_agent_config 获取 Agent 配置 - 根据用户问题自动切换到合适的 Agent - 使用 conversation_id 保持会话状态 - 通过 stream 参数控制是否流式输出 @@ -177,7 +177,7 @@ async def test_handoffs( # 根据 stream 参数决定返回方式 if request.stream: # 流式返回 - service = get_handoffs_service(streaming=True) + service = get_handoffs_service_for_app(app_id, db, streaming=True) return StreamingResponse( service.chat_stream( message=request.message, @@ -192,13 +192,15 @@ async def test_handoffs( ) else: # 非流式返回 - service = get_handoffs_service(streaming=False) + service = get_handoffs_service_for_app(app_id, db, streaming=False) result = await service.chat( message=request.message, conversation_id=conversation_id ) return success(data=result, msg="Handoffs 测试成功") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except HTTPException: raise except Exception as e: @@ -206,17 +208,29 @@ async def test_handoffs( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/handoffs/agents", response_model=ApiResponse) -def get_handoff_agents(): - """获取可用的 Handoff Agent 列表""" - service = get_handoffs_service() - agents = service.get_agents() - - return success(data={"agents": agents}, msg="获取 Agent 列表成功") +@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse) +def get_handoff_agents( + app_id: uuid.UUID = Path(..., description="应用 ID"), + db: Session = Depends(get_db), + current_user=Depends(get_current_user) +): + """获取应用的 Handoff Agent 列表""" + try: + service = get_handoffs_service_for_app(app_id, db, streaming=False) + agents = service.get_agents() + return success(data={"agents": agents}, msg="获取 Agent 列表成功") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + api_logger.error(f"获取 Agent 列表失败: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) -@router.delete("/handoffs/reset") -def reset_handoff_service(): - """重置 Handoff 服务(清除所有会话状态)""" - reset_default_service() +@router.delete("/handoffs/{app_id}/reset") +def reset_handoff_service( + app_id: uuid.UUID = Path(..., description="应用 ID"), + current_user=Depends(get_current_user) +): + """重置指定应用的 Handoff 服务缓存""" + reset_handoffs_service_cache(app_id) return success(msg="Handoff 服务已重置") diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index 64432299..45232339 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -5,14 +5,17 @@ from typing import List, Dict, Any, Optional, AsyncGenerator from typing_extensions import TypedDict from langchain_core.messages import HumanMessage, AIMessage, BaseMessage -from langchain_openai import ChatOpenAI from langgraph.graph import StateGraph, START, END from langgraph.types import Command from langgraph.checkpoint.memory import MemorySaver from langchain_core.tools import tool from pydantic import BaseModel, Field +from sqlalchemy.orm import Session from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig +from app.models.models_model import ModelType +from app.services.model_service import ModelApiKeyService logger = get_business_logger() @@ -32,31 +35,6 @@ class TransferInput(BaseModel): reason: str = Field(description="转移原因") -# ==================== 默认配置 ==================== - -DEFAULT_AGENT_CONFIGS = { - "sales_agent": { - "description": "转移到销售 Agent。当用户询问价格、购买或销售相关问题时使用。", - "system_prompt": """你是一个销售 Agent。帮助用户解答销售相关问题。 -如果用户询问技术问题或需要技术支持,使用 transfer_to_support_agent 工具转移到支持 Agent。""", - "can_transfer_to": ["support_agent"] - }, - "support_agent": { - "description": "转移到支持 Agent。当用户询问技术问题或需要帮助时使用。", - "system_prompt": """你是一个技术支持 Agent。帮助用户解决技术问题。 -如果用户询问价格或购买相关问题,使用 transfer_to_sales_agent 工具转移到销售 Agent。""", - "can_transfer_to": ["sales_agent"] - } -} - -DEFAULT_LLM_CONFIG = { - "api_key": "sk-8e9e40cd171749858ce2d3722ea75669", - "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", - "model": "qwen-plus", - "temperature": 0.7 -} - - # ==================== 工具创建 ==================== def create_transfer_tool(target_agent: str, description: str): @@ -109,27 +87,13 @@ def create_tools_for_agent(agent_name: str, configs: Dict) -> List: # ==================== Agent 节点创建 ==================== def create_agent_node(agent_name: str, system_prompt: str, tools: List, - api_key: str, base_url: str, model: str, temperature: float = 0.7): - """创建 Agent 节点(非流式) + model_config: RedBearModelConfig): + """创建 Agent 节点(非流式)""" + llm = RedBearLLM(model_config, type=ModelType.CHAT) - Args: - agent_name: Agent 名称 - system_prompt: 系统提示词 - tools: 工具列表 - api_key: API Key - base_url: API Base URL - model: 模型名称 - temperature: 温度参数 - - Returns: - Agent 节点函数 - """ - llm = ChatOpenAI( - model=model, - temperature=temperature, - api_key=api_key, - base_url=base_url - ).bind_tools(tools) + # 绑定工具 + if tools: + llm = llm.bind_tools(tools) async def agent_node(state: HandoffState) -> Dict[str, Any]: """Agent 节点执行函数""" @@ -143,8 +107,14 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List, # 检查工具调用 if hasattr(response, 'tool_calls') and response.tool_calls: tool_call = response.tool_calls[0] - tool_name = tool_call["name"] - tool_args = tool_call["args"] + tool_name = tool_call["name"] if isinstance(tool_call, dict) else tool_call.name + tool_args = tool_call["args"] if isinstance(tool_call, dict) else tool_call.args + + if isinstance(tool_args, str): + try: + tool_args = json.loads(tool_args) + except (json.JSONDecodeError, ValueError): + tool_args = {} if not tool_args.get("reason"): tool_args["reason"] = "用户请求转移" @@ -162,28 +132,13 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List, def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List, - api_key: str, base_url: str, model: str, temperature: float = 0.7): - """创建支持流式输出的 Agent 节点 + model_config: RedBearModelConfig): + """创建支持流式输出的 Agent 节点""" + llm = RedBearLLM(model_config, type=ModelType.CHAT) - Args: - agent_name: Agent 名称 - system_prompt: 系统提示词 - tools: 工具列表 - api_key: API Key - base_url: API Base URL - model: 模型名称 - temperature: 温度参数 - - Returns: - Agent 节点函数 - """ - llm = ChatOpenAI( - model=model, - temperature=temperature, - api_key=api_key, - base_url=base_url, - streaming=True - ).bind_tools(tools) + # 绑定工具 + if tools: + llm = llm.bind_tools(tools) async def agent_node(state: HandoffState): """Agent 节点执行函数(流式)""" @@ -196,37 +151,48 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List collected_tool_calls = {} async for chunk in llm.astream(full_messages): - if chunk.content: + if hasattr(chunk, 'content') and chunk.content: full_content += chunk.content # 收集工具调用 if hasattr(chunk, 'tool_calls') and chunk.tool_calls: for tc in chunk.tool_calls: - tc_id = tc.get("id") or "0" + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, 'id', "0") + tc_id = tc_id or "0" if tc_id not in collected_tool_calls: collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""} - if tc.get("name"): - collected_tool_calls[tc_id]["name"] = tc["name"] - if tc.get("args"): - if isinstance(tc["args"], dict): - collected_tool_calls[tc_id]["args"] = tc["args"] - elif isinstance(tc["args"], str): + + tc_name = tc.get("name") if isinstance(tc, dict) else getattr(tc, 'name', None) + tc_args = tc.get("args") if isinstance(tc, dict) else getattr(tc, 'args', None) + + if tc_name: + collected_tool_calls[tc_id]["name"] = tc_name + if tc_args: + if isinstance(tc_args, dict): + collected_tool_calls[tc_id]["args"] = tc_args + elif isinstance(tc_args, str): if isinstance(collected_tool_calls[tc_id]["args"], str): - collected_tool_calls[tc_id]["args"] += tc["args"] + collected_tool_calls[tc_id]["args"] += tc_args # 处理 tool_call_chunks if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: for tc_chunk in chunk.tool_call_chunks: - idx = str(tc_chunk.get("index", 0)) + idx = str(tc_chunk.get("index", 0) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'index', 0)) if idx not in collected_tool_calls: - collected_tool_calls[idx] = {"id": tc_chunk.get("id", idx), "name": "", "args": ""} - if tc_chunk.get("id"): - collected_tool_calls[idx]["id"] = tc_chunk["id"] - if tc_chunk.get("name"): - collected_tool_calls[idx]["name"] = tc_chunk["name"] - if tc_chunk.get("args"): + tc_id = tc_chunk.get("id", idx) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', idx) + collected_tool_calls[idx] = {"id": tc_id, "name": "", "args": ""} + + tc_id = tc_chunk.get("id") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', None) + tc_name = tc_chunk.get("name") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'name', None) + tc_args = tc_chunk.get("args") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'args', None) + + if tc_id: + collected_tool_calls[idx]["id"] = tc_id + if tc_name: + collected_tool_calls[idx]["name"] = tc_name + if tc_args: if isinstance(collected_tool_calls[idx]["args"], str): - collected_tool_calls[idx]["args"] += tc_chunk["args"] + collected_tool_calls[idx]["args"] += tc_args # 解析工具调用 tool_calls_list = list(collected_tool_calls.values()) @@ -262,7 +228,7 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List # ==================== 路由函数 ==================== -def create_route_initial(default_agent: str = "sales_agent"): +def create_route_initial(default_agent: str): """创建初始路由函数""" def route_initial(state: HandoffState) -> str: active = state.get("active_agent") @@ -279,7 +245,120 @@ def route_after_agent(state: HandoffState) -> str: last_msg = messages[-1] if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None): return END - return state.get("active_agent", "sales_agent") + return state.get("active_agent", END) + + +# ==================== 配置转换 ==================== + +def convert_multi_agent_config_to_handoffs( + multi_agent_config: Dict, + db: Session +) -> tuple[Dict[str, Dict], RedBearModelConfig]: + """将 multi_agent_config 转换为 handoffs 配置格式 + + Args: + multi_agent_config: 数据库中的多 Agent 配置 + db: 数据库会话 + + Returns: + (agent_configs, model_config) 元组 + """ + from app.models import AppRelease, App + + sub_agents = multi_agent_config.get("sub_agents", []) + agent_configs = {} + agent_names = [] + + # 遍历子 Agent,构建配置 + for sub_agent in sub_agents: + agent_app_id = sub_agent.get("agent_id") + agent_name = sub_agent.get("name", f"agent_{agent_app_id[:8] if agent_app_id else 'unknown'}") + # 使用安全的 agent name(去除特殊字符) + safe_name = agent_name.replace(" ", "_").replace("-", "_").lower() + agent_names.append(safe_name) + + # 从 AppRelease 获取 Agent 的系统提示词 + system_prompt = f"你是 {agent_name}。" + capabilities = sub_agent.get("capabilities", []) + + if agent_app_id: + try: + agent_app_id_uuid = uuid.UUID(agent_app_id) if isinstance(agent_app_id, str) else agent_app_id + # 获取应用的当前发布版本 + app = db.get(App, agent_app_id_uuid) + if app and app.current_release_id: + release = db.get(AppRelease, app.current_release_id) + if release and release.config: + config_data = release.config + # 从 release.config 获取 system_prompt + release_system_prompt = config_data.get("system_prompt") + if release_system_prompt: + system_prompt = release_system_prompt + logger.debug(f"从 AppRelease 获取 Agent {agent_name} 的系统提示词") + except Exception as e: + logger.warning(f"获取 Agent {agent_name} 的系统提示词失败: {str(e)}") + + # 如果有 capabilities,添加到系统提示词 + if capabilities and not system_prompt.endswith("。"): + system_prompt += f" 你的专长是: {', '.join(capabilities)}。" + elif capabilities: + system_prompt += f" 你的专长是: {', '.join(capabilities)}。" + + agent_configs[safe_name] = { + "agent_id": agent_app_id, + "name": agent_name, + "description": f"转移到 {agent_name}。{sub_agent.get('role') or ''}", + "system_prompt": system_prompt, + "capabilities": capabilities, + "can_transfer_to": [] # 稍后填充 + } + + # 设置每个 Agent 可以转移到的其他 Agent + for safe_name in agent_names: + agent_configs[safe_name]["can_transfer_to"] = [ + name for name in agent_names if name != safe_name + ] + # 更新系统提示词,添加转移说明 + other_agents = agent_configs[safe_name]["can_transfer_to"] + if other_agents: + transfer_instructions = "\n如果用户的问题不在你的专长范围内,可以使用以下工具转移到其他 Agent:" + for other_name in other_agents: + other_config = agent_configs[other_name] + transfer_instructions += f"\n- transfer_to_{other_name}: {other_config['description']}" + agent_configs[safe_name]["system_prompt"] += transfer_instructions + + # 获取 LLM 配置 + model_config = None + default_model_config_id = multi_agent_config.get("default_model_config_id") + if default_model_config_id: + model_api_key = ModelApiKeyService.get_a_api_key(db, default_model_config_id) + if model_api_key: + # 获取模型参数 + model_parameters = multi_agent_config.get("model_parameters") + temperature = 0.7 + max_tokens = 2000 + + if model_parameters: + if hasattr(model_parameters, 'temperature'): + temperature = model_parameters.temperature + max_tokens = model_parameters.max_tokens or 2000 + elif isinstance(model_parameters, dict): + temperature = model_parameters.get("temperature", 0.7) + max_tokens = model_parameters.get("max_tokens", 2000) + + model_config = RedBearModelConfig( + model_name=model_api_key.model_name, + provider=model_api_key.provider, + api_key=model_api_key.api_key, + base_url=model_api_key.api_base, + extra_params={ + "temperature": temperature, + "max_tokens": max_tokens, + "streaming": True + } + ) + + return agent_configs, model_config # ==================== Handoffs 服务类 ==================== @@ -289,19 +368,19 @@ class HandoffsService: def __init__( self, - agent_configs: Dict[str, Dict] = None, - llm_config: Dict[str, Any] = None, + agent_configs: Dict[str, Dict], + model_config: RedBearModelConfig, streaming: bool = True ): """初始化 Handoffs 服务 Args: agent_configs: Agent 配置字典 - llm_config: LLM 配置 + model_config: RedBearModelConfig 模型配置 streaming: 是否启用流式输出 """ - self.agent_configs = agent_configs or DEFAULT_AGENT_CONFIGS - self.llm_config = llm_config or DEFAULT_LLM_CONFIG + self.agent_configs = agent_configs + self.model_config = model_config self.streaming = streaming self._graph = None @@ -312,6 +391,9 @@ class HandoffsService: builder = StateGraph(HandoffState) agent_names = list(self.agent_configs.keys()) + if not agent_names: + raise ValueError("至少需要一个 Agent 配置") + for agent_name in agent_names: config = self.agent_configs[agent_name] tools = create_tools_for_agent(agent_name, self.agent_configs) @@ -321,25 +403,19 @@ class HandoffsService: agent_name=agent_name, system_prompt=config.get("system_prompt", f"你是 {agent_name}"), tools=tools, - api_key=self.llm_config.get("api_key"), - base_url=self.llm_config.get("base_url"), - model=self.llm_config.get("model"), - temperature=self.llm_config.get("temperature", 0.7) + model_config=self.model_config ) else: agent_node = create_agent_node( agent_name=agent_name, system_prompt=config.get("system_prompt", f"你是 {agent_name}"), tools=tools, - api_key=self.llm_config.get("api_key"), - base_url=self.llm_config.get("base_url"), - model=self.llm_config.get("model"), - temperature=self.llm_config.get("temperature", 0.7) + model_config=self.model_config ) builder.add_node(agent_name, agent_node) # 添加边 - default_agent = agent_names[0] if agent_names else "sales_agent" + default_agent = agent_names[0] builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names) for agent_name in agent_names: @@ -365,15 +441,7 @@ class HandoffsService: message: str, conversation_id: str = None ) -> Dict[str, Any]: - """非流式聊天 - - Args: - message: 用户消息 - conversation_id: 会话 ID - - Returns: - 聊天结果 - """ + """非流式聊天""" conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" config = {"configurable": {"thread_id": str(conversation_id)}} @@ -402,15 +470,7 @@ class HandoffsService: message: str, conversation_id: str = None ) -> AsyncGenerator[str, None]: - """流式聊天 - - Args: - message: 用户消息 - conversation_id: 会话 ID - - Yields: - SSE 格式的事件 - """ + """流式聊天""" conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" config = {"configurable": {"thread_id": str(conversation_id)}} @@ -435,7 +495,8 @@ class HandoffsService: if node_name in self.agent_configs: if current_agent != node_name: current_agent = node_name - yield f"event: agent\ndata: {json.dumps({'agent': node_name}, ensure_ascii=False)}\n\n" + agent_display_name = self.agent_configs[node_name].get("name", node_name) + yield f"event: agent\ndata: {json.dumps({'agent': node_name, 'agent_name': agent_display_name}, ensure_ascii=False)}\n\n" # 捕获 LLM 流式输出 elif kind == "on_chat_model_stream": @@ -448,7 +509,8 @@ class HandoffsService: tool_name = event.get("name", "") if tool_name.startswith("transfer_to_"): target_agent = tool_name.replace("transfer_to_", "") - yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent}, ensure_ascii=False)}\n\n" + target_name = self.agent_configs.get(target_agent, {}).get("name", target_agent) + yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent, 'to_name': target_name}, ensure_ascii=False)}\n\n" # 发送结束事件 yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n" @@ -458,57 +520,86 @@ class HandoffsService: yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" def get_agents(self) -> List[Dict[str, Any]]: - """获取可用的 Agent 列表 - - Returns: - Agent 列表 - """ + """获取可用的 Agent 列表""" agents = [] for name, config in self.agent_configs.items(): agents.append({ - "name": name, + "id": name, + "name": config.get("name", name), "description": config.get("description", ""), + "capabilities": config.get("capabilities", []), "can_transfer_to": config.get("can_transfer_to", []) }) return agents -# ==================== 全局实例 ==================== +# ==================== 服务工厂 ==================== -_default_service: Optional[HandoffsService] = None +# 缓存服务实例(按 app_id) +_service_cache: Dict[str, HandoffsService] = {} -def get_handoffs_service( - agent_configs: Dict[str, Dict] = None, - llm_config: Dict[str, Any] = None, +def get_handoffs_service_for_app( + app_id: uuid.UUID, + db: Session, streaming: bool = True ) -> HandoffsService: - """获取 Handoffs 服务实例 + """根据 app_id 获取 Handoffs 服务实例 Args: - agent_configs: Agent 配置(可选) - llm_config: LLM 配置(可选) + app_id: 应用 ID + db: 数据库会话 streaming: 是否流式 Returns: HandoffsService 实例 """ - global _default_service + from app.services.multi_agent_service import MultiAgentService - # 如果有自定义配置,创建新实例 - if agent_configs or llm_config: - return HandoffsService(agent_configs, llm_config, streaming) + cache_key = f"{app_id}_{streaming}" - # 否则使用默认实例 - if _default_service is None: - _default_service = HandoffsService(streaming=streaming) + # 检查缓存 + if cache_key in _service_cache: + return _service_cache[cache_key] - return _default_service + # 获取多 Agent 配置 + multi_agent_service = MultiAgentService(db) + multi_agent_config = multi_agent_service.get_multi_agent_configs(app_id) + + if not multi_agent_config: + raise ValueError(f"应用 {app_id} 没有多 Agent 配置") + + # 转换配置 + agent_configs, model_config = convert_multi_agent_config_to_handoffs(multi_agent_config, db) + + if not agent_configs: + raise ValueError(f"应用 {app_id} 没有配置子 Agent") + + if not model_config: + raise ValueError(f"应用 {app_id} 没有配置模型") + + # 创建服务 + service = HandoffsService(agent_configs, model_config, streaming) + + # 缓存 + _service_cache[cache_key] = service + + return service -def reset_default_service(): - """重置默认服务实例""" - global _default_service - if _default_service: - _default_service.reset() - _default_service = None +def reset_handoffs_service_cache(app_id: uuid.UUID = None): + """重置服务缓存 + + Args: + app_id: 应用 ID,如果为 None 则清除所有缓存 + """ + global _service_cache + + if app_id: + keys_to_remove = [k for k in _service_cache if k.startswith(str(app_id))] + for key in keys_to_remove: + del _service_cache[key] + else: + _service_cache = {} + + logger.info(f"Handoffs 服务缓存已重置: app_id={app_id}")