[modify] handoffs test

This commit is contained in:
Mark
2026-01-07 15:51:12 +08:00
parent ba2220d7c8
commit cd76ccadc5
2 changed files with 270 additions and 165 deletions

View File

@@ -13,7 +13,7 @@ from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import AppChatRequest from app.schemas.app_schema import AppChatRequest
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from app.services.handoffs_service import get_handoffs_service, reset_default_service from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.dependencies import get_current_user from app.dependencies import get_current_user
@@ -139,7 +139,7 @@ async def test_handoffs(
演示 LangGraph 实现的多 Agent 协作和动态切换 演示 LangGraph 实现的多 Agent 协作和动态切换
- 默认从 sales_agent 开始 - 从数据库 multi_agent_config 获取 Agent 配置
- 根据用户问题自动切换到合适的 Agent - 根据用户问题自动切换到合适的 Agent
- 使用 conversation_id 保持会话状态 - 使用 conversation_id 保持会话状态
- 通过 stream 参数控制是否流式输出 - 通过 stream 参数控制是否流式输出
@@ -177,7 +177,7 @@ async def test_handoffs(
# 根据 stream 参数决定返回方式 # 根据 stream 参数决定返回方式
if request.stream: if request.stream:
# 流式返回 # 流式返回
service = get_handoffs_service(streaming=True) service = get_handoffs_service_for_app(app_id, db, streaming=True)
return StreamingResponse( return StreamingResponse(
service.chat_stream( service.chat_stream(
message=request.message, message=request.message,
@@ -192,13 +192,15 @@ async def test_handoffs(
) )
else: else:
# 非流式返回 # 非流式返回
service = get_handoffs_service(streaming=False) service = get_handoffs_service_for_app(app_id, db, streaming=False)
result = await service.chat( result = await service.chat(
message=request.message, message=request.message,
conversation_id=conversation_id conversation_id=conversation_id
) )
return success(data=result, msg="Handoffs 测试成功") return success(data=result, msg="Handoffs 测试成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -206,17 +208,29 @@ async def test_handoffs(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/handoffs/agents", response_model=ApiResponse) @router.get("/handoffs/{app_id}/agents", response_model=ApiResponse)
def get_handoff_agents(): def get_handoff_agents(
"""获取可用的 Handoff Agent 列表""" app_id: uuid.UUID = Path(..., description="应用 ID"),
service = get_handoffs_service() db: Session = Depends(get_db),
agents = service.get_agents() current_user=Depends(get_current_user)
):
return success(data={"agents": agents}, msg="获取 Agent 列表成功") """获取应用的 Handoff Agent 列表"""
try:
service = get_handoffs_service_for_app(app_id, db, streaming=False)
agents = service.get_agents()
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
api_logger.error(f"获取 Agent 列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/handoffs/reset") @router.delete("/handoffs/{app_id}/reset")
def reset_handoff_service(): def reset_handoff_service(
"""重置 Handoff 服务(清除所有会话状态)""" app_id: uuid.UUID = Path(..., description="应用 ID"),
reset_default_service() current_user=Depends(get_current_user)
):
"""重置指定应用的 Handoff 服务缓存"""
reset_handoffs_service_cache(app_id)
return success(msg="Handoff 服务已重置") return success(msg="Handoff 服务已重置")

View File

@@ -5,14 +5,17 @@ from typing import List, Dict, Any, Optional, AsyncGenerator
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END from langgraph.graph import StateGraph, START, END
from langgraph.types import Command from langgraph.types import Command
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.services.model_service import ModelApiKeyService
logger = get_business_logger() logger = get_business_logger()
@@ -32,31 +35,6 @@ class TransferInput(BaseModel):
reason: str = Field(description="转移原因") reason: str = Field(description="转移原因")
# ==================== 默认配置 ====================
DEFAULT_AGENT_CONFIGS = {
"sales_agent": {
"description": "转移到销售 Agent。当用户询问价格、购买或销售相关问题时使用。",
"system_prompt": """你是一个销售 Agent。帮助用户解答销售相关问题。
如果用户询问技术问题或需要技术支持,使用 transfer_to_support_agent 工具转移到支持 Agent。""",
"can_transfer_to": ["support_agent"]
},
"support_agent": {
"description": "转移到支持 Agent。当用户询问技术问题或需要帮助时使用。",
"system_prompt": """你是一个技术支持 Agent。帮助用户解决技术问题。
如果用户询问价格或购买相关问题,使用 transfer_to_sales_agent 工具转移到销售 Agent。""",
"can_transfer_to": ["sales_agent"]
}
}
DEFAULT_LLM_CONFIG = {
"api_key": "sk-8e9e40cd171749858ce2d3722ea75669",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen-plus",
"temperature": 0.7
}
# ==================== 工具创建 ==================== # ==================== 工具创建 ====================
def create_transfer_tool(target_agent: str, description: str): def create_transfer_tool(target_agent: str, description: str):
@@ -109,27 +87,13 @@ def create_tools_for_agent(agent_name: str, configs: Dict) -> List:
# ==================== Agent 节点创建 ==================== # ==================== Agent 节点创建 ====================
def create_agent_node(agent_name: str, system_prompt: str, tools: List, def create_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7): model_config: RedBearModelConfig):
"""创建 Agent 节点(非流式) """创建 Agent 节点(非流式)"""
llm = RedBearLLM(model_config, type=ModelType.CHAT)
Args: # 绑定工具
agent_name: Agent 名称 if tools:
system_prompt: 系统提示词 llm = llm.bind_tools(tools)
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url
).bind_tools(tools)
async def agent_node(state: HandoffState) -> Dict[str, Any]: async def agent_node(state: HandoffState) -> Dict[str, Any]:
"""Agent 节点执行函数""" """Agent 节点执行函数"""
@@ -143,8 +107,14 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List,
# 检查工具调用 # 检查工具调用
if hasattr(response, 'tool_calls') and response.tool_calls: if hasattr(response, 'tool_calls') and response.tool_calls:
tool_call = response.tool_calls[0] tool_call = response.tool_calls[0]
tool_name = tool_call["name"] tool_name = tool_call["name"] if isinstance(tool_call, dict) else tool_call.name
tool_args = tool_call["args"] tool_args = tool_call["args"] if isinstance(tool_call, dict) else tool_call.args
if isinstance(tool_args, str):
try:
tool_args = json.loads(tool_args)
except (json.JSONDecodeError, ValueError):
tool_args = {}
if not tool_args.get("reason"): if not tool_args.get("reason"):
tool_args["reason"] = "用户请求转移" tool_args["reason"] = "用户请求转移"
@@ -162,28 +132,13 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List,
def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List, def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7): model_config: RedBearModelConfig):
"""创建支持流式输出的 Agent 节点 """创建支持流式输出的 Agent 节点"""
llm = RedBearLLM(model_config, type=ModelType.CHAT)
Args: # 绑定工具
agent_name: Agent 名称 if tools:
system_prompt: 系统提示词 llm = llm.bind_tools(tools)
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
streaming=True
).bind_tools(tools)
async def agent_node(state: HandoffState): async def agent_node(state: HandoffState):
"""Agent 节点执行函数(流式)""" """Agent 节点执行函数(流式)"""
@@ -196,37 +151,48 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List
collected_tool_calls = {} collected_tool_calls = {}
async for chunk in llm.astream(full_messages): async for chunk in llm.astream(full_messages):
if chunk.content: if hasattr(chunk, 'content') and chunk.content:
full_content += chunk.content full_content += chunk.content
# 收集工具调用 # 收集工具调用
if hasattr(chunk, 'tool_calls') and chunk.tool_calls: if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
for tc in chunk.tool_calls: for tc in chunk.tool_calls:
tc_id = tc.get("id") or "0" tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, 'id', "0")
tc_id = tc_id or "0"
if tc_id not in collected_tool_calls: if tc_id not in collected_tool_calls:
collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""} collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""}
if tc.get("name"):
collected_tool_calls[tc_id]["name"] = tc["name"] tc_name = tc.get("name") if isinstance(tc, dict) else getattr(tc, 'name', None)
if tc.get("args"): tc_args = tc.get("args") if isinstance(tc, dict) else getattr(tc, 'args', None)
if isinstance(tc["args"], dict):
collected_tool_calls[tc_id]["args"] = tc["args"] if tc_name:
elif isinstance(tc["args"], str): collected_tool_calls[tc_id]["name"] = tc_name
if tc_args:
if isinstance(tc_args, dict):
collected_tool_calls[tc_id]["args"] = tc_args
elif isinstance(tc_args, str):
if isinstance(collected_tool_calls[tc_id]["args"], str): if isinstance(collected_tool_calls[tc_id]["args"], str):
collected_tool_calls[tc_id]["args"] += tc["args"] collected_tool_calls[tc_id]["args"] += tc_args
# 处理 tool_call_chunks # 处理 tool_call_chunks
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks: if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks: for tc_chunk in chunk.tool_call_chunks:
idx = str(tc_chunk.get("index", 0)) idx = str(tc_chunk.get("index", 0) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'index', 0))
if idx not in collected_tool_calls: if idx not in collected_tool_calls:
collected_tool_calls[idx] = {"id": tc_chunk.get("id", idx), "name": "", "args": ""} tc_id = tc_chunk.get("id", idx) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', idx)
if tc_chunk.get("id"): collected_tool_calls[idx] = {"id": tc_id, "name": "", "args": ""}
collected_tool_calls[idx]["id"] = tc_chunk["id"]
if tc_chunk.get("name"): tc_id = tc_chunk.get("id") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', None)
collected_tool_calls[idx]["name"] = tc_chunk["name"] tc_name = tc_chunk.get("name") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'name', None)
if tc_chunk.get("args"): tc_args = tc_chunk.get("args") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'args', None)
if tc_id:
collected_tool_calls[idx]["id"] = tc_id
if tc_name:
collected_tool_calls[idx]["name"] = tc_name
if tc_args:
if isinstance(collected_tool_calls[idx]["args"], str): if isinstance(collected_tool_calls[idx]["args"], str):
collected_tool_calls[idx]["args"] += tc_chunk["args"] collected_tool_calls[idx]["args"] += tc_args
# 解析工具调用 # 解析工具调用
tool_calls_list = list(collected_tool_calls.values()) tool_calls_list = list(collected_tool_calls.values())
@@ -262,7 +228,7 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List
# ==================== 路由函数 ==================== # ==================== 路由函数 ====================
def create_route_initial(default_agent: str = "sales_agent"): def create_route_initial(default_agent: str):
"""创建初始路由函数""" """创建初始路由函数"""
def route_initial(state: HandoffState) -> str: def route_initial(state: HandoffState) -> str:
active = state.get("active_agent") active = state.get("active_agent")
@@ -279,7 +245,120 @@ def route_after_agent(state: HandoffState) -> str:
last_msg = messages[-1] last_msg = messages[-1]
if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None): if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None):
return END return END
return state.get("active_agent", "sales_agent") return state.get("active_agent", END)
# ==================== 配置转换 ====================
def convert_multi_agent_config_to_handoffs(
multi_agent_config: Dict,
db: Session
) -> tuple[Dict[str, Dict], RedBearModelConfig]:
"""将 multi_agent_config 转换为 handoffs 配置格式
Args:
multi_agent_config: 数据库中的多 Agent 配置
db: 数据库会话
Returns:
(agent_configs, model_config) 元组
"""
from app.models import AppRelease, App
sub_agents = multi_agent_config.get("sub_agents", [])
agent_configs = {}
agent_names = []
# 遍历子 Agent构建配置
for sub_agent in sub_agents:
agent_app_id = sub_agent.get("agent_id")
agent_name = sub_agent.get("name", f"agent_{agent_app_id[:8] if agent_app_id else 'unknown'}")
# 使用安全的 agent name去除特殊字符
safe_name = agent_name.replace(" ", "_").replace("-", "_").lower()
agent_names.append(safe_name)
# 从 AppRelease 获取 Agent 的系统提示词
system_prompt = f"你是 {agent_name}"
capabilities = sub_agent.get("capabilities", [])
if agent_app_id:
try:
agent_app_id_uuid = uuid.UUID(agent_app_id) if isinstance(agent_app_id, str) else agent_app_id
# 获取应用的当前发布版本
app = db.get(App, agent_app_id_uuid)
if app and app.current_release_id:
release = db.get(AppRelease, app.current_release_id)
if release and release.config:
config_data = release.config
# 从 release.config 获取 system_prompt
release_system_prompt = config_data.get("system_prompt")
if release_system_prompt:
system_prompt = release_system_prompt
logger.debug(f"从 AppRelease 获取 Agent {agent_name} 的系统提示词")
except Exception as e:
logger.warning(f"获取 Agent {agent_name} 的系统提示词失败: {str(e)}")
# 如果有 capabilities添加到系统提示词
if capabilities and not system_prompt.endswith(""):
system_prompt += f" 你的专长是: {', '.join(capabilities)}"
elif capabilities:
system_prompt += f" 你的专长是: {', '.join(capabilities)}"
agent_configs[safe_name] = {
"agent_id": agent_app_id,
"name": agent_name,
"description": f"转移到 {agent_name}{sub_agent.get('role') or ''}",
"system_prompt": system_prompt,
"capabilities": capabilities,
"can_transfer_to": [] # 稍后填充
}
# 设置每个 Agent 可以转移到的其他 Agent
for safe_name in agent_names:
agent_configs[safe_name]["can_transfer_to"] = [
name for name in agent_names if name != safe_name
]
# 更新系统提示词,添加转移说明
other_agents = agent_configs[safe_name]["can_transfer_to"]
if other_agents:
transfer_instructions = "\n如果用户的问题不在你的专长范围内,可以使用以下工具转移到其他 Agent"
for other_name in other_agents:
other_config = agent_configs[other_name]
transfer_instructions += f"\n- transfer_to_{other_name}: {other_config['description']}"
agent_configs[safe_name]["system_prompt"] += transfer_instructions
# 获取 LLM 配置
model_config = None
default_model_config_id = multi_agent_config.get("default_model_config_id")
if default_model_config_id:
model_api_key = ModelApiKeyService.get_a_api_key(db, default_model_config_id)
if model_api_key:
# 获取模型参数
model_parameters = multi_agent_config.get("model_parameters")
temperature = 0.7
max_tokens = 2000
if model_parameters:
if hasattr(model_parameters, 'temperature'):
temperature = model_parameters.temperature
max_tokens = model_parameters.max_tokens or 2000
elif isinstance(model_parameters, dict):
temperature = model_parameters.get("temperature", 0.7)
max_tokens = model_parameters.get("max_tokens", 2000)
model_config = RedBearModelConfig(
model_name=model_api_key.model_name,
provider=model_api_key.provider,
api_key=model_api_key.api_key,
base_url=model_api_key.api_base,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": True
}
)
return agent_configs, model_config
# ==================== Handoffs 服务类 ==================== # ==================== Handoffs 服务类 ====================
@@ -289,19 +368,19 @@ class HandoffsService:
def __init__( def __init__(
self, self,
agent_configs: Dict[str, Dict] = None, agent_configs: Dict[str, Dict],
llm_config: Dict[str, Any] = None, model_config: RedBearModelConfig,
streaming: bool = True streaming: bool = True
): ):
"""初始化 Handoffs 服务 """初始化 Handoffs 服务
Args: Args:
agent_configs: Agent 配置字典 agent_configs: Agent 配置字典
llm_config: LLM 配置 model_config: RedBearModelConfig 模型配置
streaming: 是否启用流式输出 streaming: 是否启用流式输出
""" """
self.agent_configs = agent_configs or DEFAULT_AGENT_CONFIGS self.agent_configs = agent_configs
self.llm_config = llm_config or DEFAULT_LLM_CONFIG self.model_config = model_config
self.streaming = streaming self.streaming = streaming
self._graph = None self._graph = None
@@ -312,6 +391,9 @@ class HandoffsService:
builder = StateGraph(HandoffState) builder = StateGraph(HandoffState)
agent_names = list(self.agent_configs.keys()) agent_names = list(self.agent_configs.keys())
if not agent_names:
raise ValueError("至少需要一个 Agent 配置")
for agent_name in agent_names: for agent_name in agent_names:
config = self.agent_configs[agent_name] config = self.agent_configs[agent_name]
tools = create_tools_for_agent(agent_name, self.agent_configs) tools = create_tools_for_agent(agent_name, self.agent_configs)
@@ -321,25 +403,19 @@ class HandoffsService:
agent_name=agent_name, agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"), system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools, tools=tools,
api_key=self.llm_config.get("api_key"), model_config=self.model_config
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
) )
else: else:
agent_node = create_agent_node( agent_node = create_agent_node(
agent_name=agent_name, agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"), system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools, tools=tools,
api_key=self.llm_config.get("api_key"), model_config=self.model_config
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
) )
builder.add_node(agent_name, agent_node) builder.add_node(agent_name, agent_node)
# 添加边 # 添加边
default_agent = agent_names[0] if agent_names else "sales_agent" default_agent = agent_names[0]
builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names) builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names)
for agent_name in agent_names: for agent_name in agent_names:
@@ -365,15 +441,7 @@ class HandoffsService:
message: str, message: str,
conversation_id: str = None conversation_id: str = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""非流式聊天 """非流式聊天"""
Args:
message: 用户消息
conversation_id: 会话 ID
Returns:
聊天结果
"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}} config = {"configurable": {"thread_id": str(conversation_id)}}
@@ -402,15 +470,7 @@ class HandoffsService:
message: str, message: str,
conversation_id: str = None conversation_id: str = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""流式聊天 """流式聊天"""
Args:
message: 用户消息
conversation_id: 会话 ID
Yields:
SSE 格式的事件
"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}" conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}} config = {"configurable": {"thread_id": str(conversation_id)}}
@@ -435,7 +495,8 @@ class HandoffsService:
if node_name in self.agent_configs: if node_name in self.agent_configs:
if current_agent != node_name: if current_agent != node_name:
current_agent = node_name current_agent = node_name
yield f"event: agent\ndata: {json.dumps({'agent': node_name}, ensure_ascii=False)}\n\n" agent_display_name = self.agent_configs[node_name].get("name", node_name)
yield f"event: agent\ndata: {json.dumps({'agent': node_name, 'agent_name': agent_display_name}, ensure_ascii=False)}\n\n"
# 捕获 LLM 流式输出 # 捕获 LLM 流式输出
elif kind == "on_chat_model_stream": elif kind == "on_chat_model_stream":
@@ -448,7 +509,8 @@ class HandoffsService:
tool_name = event.get("name", "") tool_name = event.get("name", "")
if tool_name.startswith("transfer_to_"): if tool_name.startswith("transfer_to_"):
target_agent = tool_name.replace("transfer_to_", "") target_agent = tool_name.replace("transfer_to_", "")
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent}, ensure_ascii=False)}\n\n" target_name = self.agent_configs.get(target_agent, {}).get("name", target_agent)
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent, 'to_name': target_name}, ensure_ascii=False)}\n\n"
# 发送结束事件 # 发送结束事件
yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n" yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n"
@@ -458,57 +520,86 @@ class HandoffsService:
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
def get_agents(self) -> List[Dict[str, Any]]: def get_agents(self) -> List[Dict[str, Any]]:
"""获取可用的 Agent 列表 """获取可用的 Agent 列表"""
Returns:
Agent 列表
"""
agents = [] agents = []
for name, config in self.agent_configs.items(): for name, config in self.agent_configs.items():
agents.append({ agents.append({
"name": name, "id": name,
"name": config.get("name", name),
"description": config.get("description", ""), "description": config.get("description", ""),
"capabilities": config.get("capabilities", []),
"can_transfer_to": config.get("can_transfer_to", []) "can_transfer_to": config.get("can_transfer_to", [])
}) })
return agents return agents
# ==================== 全局实例 ==================== # ==================== 服务工厂 ====================
_default_service: Optional[HandoffsService] = None # 缓存服务实例(按 app_id
_service_cache: Dict[str, HandoffsService] = {}
def get_handoffs_service( def get_handoffs_service_for_app(
agent_configs: Dict[str, Dict] = None, app_id: uuid.UUID,
llm_config: Dict[str, Any] = None, db: Session,
streaming: bool = True streaming: bool = True
) -> HandoffsService: ) -> HandoffsService:
"""获取 Handoffs 服务实例 """根据 app_id 获取 Handoffs 服务实例
Args: Args:
agent_configs: Agent 配置(可选) app_id: 应用 ID
llm_config: LLM 配置(可选) db: 数据库会话
streaming: 是否流式 streaming: 是否流式
Returns: Returns:
HandoffsService 实例 HandoffsService 实例
""" """
global _default_service from app.services.multi_agent_service import MultiAgentService
# 如果有自定义配置,创建新实例 cache_key = f"{app_id}_{streaming}"
if agent_configs or llm_config:
return HandoffsService(agent_configs, llm_config, streaming)
# 否则使用默认实例 # 检查缓存
if _default_service is None: if cache_key in _service_cache:
_default_service = HandoffsService(streaming=streaming) return _service_cache[cache_key]
return _default_service # 获取多 Agent 配置
multi_agent_service = MultiAgentService(db)
multi_agent_config = multi_agent_service.get_multi_agent_configs(app_id)
if not multi_agent_config:
raise ValueError(f"应用 {app_id} 没有多 Agent 配置")
# 转换配置
agent_configs, model_config = convert_multi_agent_config_to_handoffs(multi_agent_config, db)
if not agent_configs:
raise ValueError(f"应用 {app_id} 没有配置子 Agent")
if not model_config:
raise ValueError(f"应用 {app_id} 没有配置模型")
# 创建服务
service = HandoffsService(agent_configs, model_config, streaming)
# 缓存
_service_cache[cache_key] = service
return service
def reset_default_service(): def reset_handoffs_service_cache(app_id: uuid.UUID = None):
"""重置默认服务实例""" """重置服务缓存
global _default_service
if _default_service: Args:
_default_service.reset() app_id: 应用 ID如果为 None 则清除所有缓存
_default_service = None """
global _service_cache
if app_id:
keys_to_remove = [k for k in _service_cache if k.startswith(str(app_id))]
for key in keys_to_remove:
del _service_cache[key]
else:
_service_cache = {}
logger.info(f"Handoffs 服务缓存已重置: app_id={app_id}")