[modify] handoffs test

This commit is contained in:
Mark
2026-01-07 15:51:12 +08:00
parent ba2220d7c8
commit cd76ccadc5
2 changed files with 270 additions and 165 deletions

View File

@@ -13,7 +13,7 @@ from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import AppChatRequest
from app.services.model_service import ModelConfigService
from app.services.handoffs_service import get_handoffs_service, reset_default_service
from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache
from app.services.conversation_service import ConversationService
from app.core.logging_config import get_api_logger
from app.dependencies import get_current_user
@@ -139,7 +139,7 @@ async def test_handoffs(
演示 LangGraph 实现的多 Agent 协作和动态切换
- 默认从 sales_agent 开始
- 从数据库 multi_agent_config 获取 Agent 配置
- 根据用户问题自动切换到合适的 Agent
- 使用 conversation_id 保持会话状态
- 通过 stream 参数控制是否流式输出
@@ -177,7 +177,7 @@ async def test_handoffs(
# 根据 stream 参数决定返回方式
if request.stream:
# 流式返回
service = get_handoffs_service(streaming=True)
service = get_handoffs_service_for_app(app_id, db, streaming=True)
return StreamingResponse(
service.chat_stream(
message=request.message,
@@ -192,13 +192,15 @@ async def test_handoffs(
)
else:
# 非流式返回
service = get_handoffs_service(streaming=False)
service = get_handoffs_service_for_app(app_id, db, streaming=False)
result = await service.chat(
message=request.message,
conversation_id=conversation_id
)
return success(data=result, msg="Handoffs 测试成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
@@ -206,17 +208,29 @@ async def test_handoffs(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/handoffs/agents", response_model=ApiResponse)
def get_handoff_agents():
"""获取可用的 Handoff Agent 列表"""
service = get_handoffs_service()
agents = service.get_agents()
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse)
def get_handoff_agents(
app_id: uuid.UUID = Path(..., description="应用 ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user)
):
"""获取应用的 Handoff Agent 列表"""
try:
service = get_handoffs_service_for_app(app_id, db, streaming=False)
agents = service.get_agents()
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
api_logger.error(f"获取 Agent 列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/handoffs/reset")
def reset_handoff_service():
"""重置 Handoff 服务(清除所有会话状态)"""
reset_default_service()
@router.delete("/handoffs/{app_id}/reset")
def reset_handoff_service(
app_id: uuid.UUID = Path(..., description="应用 ID"),
current_user=Depends(get_current_user)
):
"""重置指定应用的 Handoff 服务缓存"""
reset_handoffs_service_cache(app_id)
return success(msg="Handoff 服务已重置")

View File

@@ -5,14 +5,17 @@ from typing import List, Dict, Any, Optional, AsyncGenerator
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
@@ -32,31 +35,6 @@ class TransferInput(BaseModel):
reason: str = Field(description="转移原因")
# ==================== 默认配置 ====================
DEFAULT_AGENT_CONFIGS = {
"sales_agent": {
"description": "转移到销售 Agent。当用户询问价格、购买或销售相关问题时使用。",
"system_prompt": """你是一个销售 Agent。帮助用户解答销售相关问题。
如果用户询问技术问题或需要技术支持,使用 transfer_to_support_agent 工具转移到支持 Agent。""",
"can_transfer_to": ["support_agent"]
},
"support_agent": {
"description": "转移到支持 Agent。当用户询问技术问题或需要帮助时使用。",
"system_prompt": """你是一个技术支持 Agent。帮助用户解决技术问题。
如果用户询问价格或购买相关问题,使用 transfer_to_sales_agent 工具转移到销售 Agent。""",
"can_transfer_to": ["sales_agent"]
}
}
DEFAULT_LLM_CONFIG = {
"api_key": "sk-8e9e40cd171749858ce2d3722ea75669",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen-plus",
"temperature": 0.7
}
# ==================== 工具创建 ====================
def create_transfer_tool(target_agent: str, description: str):
@@ -109,27 +87,13 @@ def create_tools_for_agent(agent_name: str, configs: Dict) -> List:
# ==================== Agent 节点创建 ====================
def create_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7):
"""创建 Agent 节点(非流式)
model_config: RedBearModelConfig):
"""创建 Agent 节点(非流式)"""
llm = RedBearLLM(model_config, type=ModelType.CHAT)
Args:
agent_name: Agent 名称
system_prompt: 系统提示词
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url
).bind_tools(tools)
# 绑定工具
if tools:
llm = llm.bind_tools(tools)
async def agent_node(state: HandoffState) -> Dict[str, Any]:
"""Agent 节点执行函数"""
@@ -143,8 +107,14 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List,
# 检查工具调用
if hasattr(response, 'tool_calls') and response.tool_calls:
tool_call = response.tool_calls[0]
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_name = tool_call["name"] if isinstance(tool_call, dict) else tool_call.name
tool_args = tool_call["args"] if isinstance(tool_call, dict) else tool_call.args
if isinstance(tool_args, str):
try:
tool_args = json.loads(tool_args)
except (json.JSONDecodeError, ValueError):
tool_args = {}
if not tool_args.get("reason"):
tool_args["reason"] = "用户请求转移"
@@ -162,28 +132,13 @@ def create_agent_node(agent_name: str, system_prompt: str, tools: List,
def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List,
api_key: str, base_url: str, model: str, temperature: float = 0.7):
"""创建支持流式输出的 Agent 节点
model_config: RedBearModelConfig):
"""创建支持流式输出的 Agent 节点"""
llm = RedBearLLM(model_config, type=ModelType.CHAT)
Args:
agent_name: Agent 名称
system_prompt: 系统提示词
tools: 工具列表
api_key: API Key
base_url: API Base URL
model: 模型名称
temperature: 温度参数
Returns:
Agent 节点函数
"""
llm = ChatOpenAI(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
streaming=True
).bind_tools(tools)
# 绑定工具
if tools:
llm = llm.bind_tools(tools)
async def agent_node(state: HandoffState):
"""Agent 节点执行函数(流式)"""
@@ -196,37 +151,48 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List
collected_tool_calls = {}
async for chunk in llm.astream(full_messages):
if chunk.content:
if hasattr(chunk, 'content') and chunk.content:
full_content += chunk.content
# 收集工具调用
if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
for tc in chunk.tool_calls:
tc_id = tc.get("id") or "0"
tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, 'id', "0")
tc_id = tc_id or "0"
if tc_id not in collected_tool_calls:
collected_tool_calls[tc_id] = {"id": tc_id, "name": "", "args": ""}
if tc.get("name"):
collected_tool_calls[tc_id]["name"] = tc["name"]
if tc.get("args"):
if isinstance(tc["args"], dict):
collected_tool_calls[tc_id]["args"] = tc["args"]
elif isinstance(tc["args"], str):
tc_name = tc.get("name") if isinstance(tc, dict) else getattr(tc, 'name', None)
tc_args = tc.get("args") if isinstance(tc, dict) else getattr(tc, 'args', None)
if tc_name:
collected_tool_calls[tc_id]["name"] = tc_name
if tc_args:
if isinstance(tc_args, dict):
collected_tool_calls[tc_id]["args"] = tc_args
elif isinstance(tc_args, str):
if isinstance(collected_tool_calls[tc_id]["args"], str):
collected_tool_calls[tc_id]["args"] += tc["args"]
collected_tool_calls[tc_id]["args"] += tc_args
# 处理 tool_call_chunks
if hasattr(chunk, 'tool_call_chunks') and chunk.tool_call_chunks:
for tc_chunk in chunk.tool_call_chunks:
idx = str(tc_chunk.get("index", 0))
idx = str(tc_chunk.get("index", 0) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'index', 0))
if idx not in collected_tool_calls:
collected_tool_calls[idx] = {"id": tc_chunk.get("id", idx), "name": "", "args": ""}
if tc_chunk.get("id"):
collected_tool_calls[idx]["id"] = tc_chunk["id"]
if tc_chunk.get("name"):
collected_tool_calls[idx]["name"] = tc_chunk["name"]
if tc_chunk.get("args"):
tc_id = tc_chunk.get("id", idx) if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', idx)
collected_tool_calls[idx] = {"id": tc_id, "name": "", "args": ""}
tc_id = tc_chunk.get("id") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'id', None)
tc_name = tc_chunk.get("name") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'name', None)
tc_args = tc_chunk.get("args") if isinstance(tc_chunk, dict) else getattr(tc_chunk, 'args', None)
if tc_id:
collected_tool_calls[idx]["id"] = tc_id
if tc_name:
collected_tool_calls[idx]["name"] = tc_name
if tc_args:
if isinstance(collected_tool_calls[idx]["args"], str):
collected_tool_calls[idx]["args"] += tc_chunk["args"]
collected_tool_calls[idx]["args"] += tc_args
# 解析工具调用
tool_calls_list = list(collected_tool_calls.values())
@@ -262,7 +228,7 @@ def create_streaming_agent_node(agent_name: str, system_prompt: str, tools: List
# ==================== 路由函数 ====================
def create_route_initial(default_agent: str = "sales_agent"):
def create_route_initial(default_agent: str):
"""创建初始路由函数"""
def route_initial(state: HandoffState) -> str:
active = state.get("active_agent")
@@ -279,7 +245,120 @@ def route_after_agent(state: HandoffState) -> str:
last_msg = messages[-1]
if isinstance(last_msg, AIMessage) and not getattr(last_msg, 'tool_calls', None):
return END
return state.get("active_agent", "sales_agent")
return state.get("active_agent", END)
# ==================== 配置转换 ====================
def convert_multi_agent_config_to_handoffs(
multi_agent_config: Dict,
db: Session
) -> tuple[Dict[str, Dict], RedBearModelConfig]:
"""将 multi_agent_config 转换为 handoffs 配置格式
Args:
multi_agent_config: 数据库中的多 Agent 配置
db: 数据库会话
Returns:
(agent_configs, model_config) 元组
"""
from app.models import AppRelease, App
sub_agents = multi_agent_config.get("sub_agents", [])
agent_configs = {}
agent_names = []
# 遍历子 Agent构建配置
for sub_agent in sub_agents:
agent_app_id = sub_agent.get("agent_id")
agent_name = sub_agent.get("name", f"agent_{agent_app_id[:8] if agent_app_id else 'unknown'}")
# 使用安全的 agent name去除特殊字符
safe_name = agent_name.replace(" ", "_").replace("-", "_").lower()
agent_names.append(safe_name)
# 从 AppRelease 获取 Agent 的系统提示词
system_prompt = f"你是 {agent_name}"
capabilities = sub_agent.get("capabilities", [])
if agent_app_id:
try:
agent_app_id_uuid = uuid.UUID(agent_app_id) if isinstance(agent_app_id, str) else agent_app_id
# 获取应用的当前发布版本
app = db.get(App, agent_app_id_uuid)
if app and app.current_release_id:
release = db.get(AppRelease, app.current_release_id)
if release and release.config:
config_data = release.config
# 从 release.config 获取 system_prompt
release_system_prompt = config_data.get("system_prompt")
if release_system_prompt:
system_prompt = release_system_prompt
logger.debug(f"从 AppRelease 获取 Agent {agent_name} 的系统提示词")
except Exception as e:
logger.warning(f"获取 Agent {agent_name} 的系统提示词失败: {str(e)}")
# 如果有 capabilities添加到系统提示词
if capabilities and not system_prompt.endswith(""):
system_prompt += f" 你的专长是: {', '.join(capabilities)}"
elif capabilities:
system_prompt += f" 你的专长是: {', '.join(capabilities)}"
agent_configs[safe_name] = {
"agent_id": agent_app_id,
"name": agent_name,
"description": f"转移到 {agent_name}{sub_agent.get('role') or ''}",
"system_prompt": system_prompt,
"capabilities": capabilities,
"can_transfer_to": [] # 稍后填充
}
# 设置每个 Agent 可以转移到的其他 Agent
for safe_name in agent_names:
agent_configs[safe_name]["can_transfer_to"] = [
name for name in agent_names if name != safe_name
]
# 更新系统提示词,添加转移说明
other_agents = agent_configs[safe_name]["can_transfer_to"]
if other_agents:
transfer_instructions = "\n如果用户的问题不在你的专长范围内,可以使用以下工具转移到其他 Agent"
for other_name in other_agents:
other_config = agent_configs[other_name]
transfer_instructions += f"\n- transfer_to_{other_name}: {other_config['description']}"
agent_configs[safe_name]["system_prompt"] += transfer_instructions
# 获取 LLM 配置
model_config = None
default_model_config_id = multi_agent_config.get("default_model_config_id")
if default_model_config_id:
model_api_key = ModelApiKeyService.get_a_api_key(db, default_model_config_id)
if model_api_key:
# 获取模型参数
model_parameters = multi_agent_config.get("model_parameters")
temperature = 0.7
max_tokens = 2000
if model_parameters:
if hasattr(model_parameters, 'temperature'):
temperature = model_parameters.temperature
max_tokens = model_parameters.max_tokens or 2000
elif isinstance(model_parameters, dict):
temperature = model_parameters.get("temperature", 0.7)
max_tokens = model_parameters.get("max_tokens", 2000)
model_config = RedBearModelConfig(
model_name=model_api_key.model_name,
provider=model_api_key.provider,
api_key=model_api_key.api_key,
base_url=model_api_key.api_base,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": True
}
)
return agent_configs, model_config
# ==================== Handoffs 服务类 ====================
@@ -289,19 +368,19 @@ class HandoffsService:
def __init__(
self,
agent_configs: Dict[str, Dict] = None,
llm_config: Dict[str, Any] = None,
agent_configs: Dict[str, Dict],
model_config: RedBearModelConfig,
streaming: bool = True
):
"""初始化 Handoffs 服务
Args:
agent_configs: Agent 配置字典
llm_config: LLM 配置
model_config: RedBearModelConfig 模型配置
streaming: 是否启用流式输出
"""
self.agent_configs = agent_configs or DEFAULT_AGENT_CONFIGS
self.llm_config = llm_config or DEFAULT_LLM_CONFIG
self.agent_configs = agent_configs
self.model_config = model_config
self.streaming = streaming
self._graph = None
@@ -312,6 +391,9 @@ class HandoffsService:
builder = StateGraph(HandoffState)
agent_names = list(self.agent_configs.keys())
if not agent_names:
raise ValueError("至少需要一个 Agent 配置")
for agent_name in agent_names:
config = self.agent_configs[agent_name]
tools = create_tools_for_agent(agent_name, self.agent_configs)
@@ -321,25 +403,19 @@ class HandoffsService:
agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools,
api_key=self.llm_config.get("api_key"),
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
model_config=self.model_config
)
else:
agent_node = create_agent_node(
agent_name=agent_name,
system_prompt=config.get("system_prompt", f"你是 {agent_name}"),
tools=tools,
api_key=self.llm_config.get("api_key"),
base_url=self.llm_config.get("base_url"),
model=self.llm_config.get("model"),
temperature=self.llm_config.get("temperature", 0.7)
model_config=self.model_config
)
builder.add_node(agent_name, agent_node)
# 添加边
default_agent = agent_names[0] if agent_names else "sales_agent"
default_agent = agent_names[0]
builder.add_conditional_edges(START, create_route_initial(default_agent), agent_names)
for agent_name in agent_names:
@@ -365,15 +441,7 @@ class HandoffsService:
message: str,
conversation_id: str = None
) -> Dict[str, Any]:
"""非流式聊天
Args:
message: 用户消息
conversation_id: 会话 ID
Returns:
聊天结果
"""
"""非流式聊天"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}}
@@ -402,15 +470,7 @@ class HandoffsService:
message: str,
conversation_id: str = None
) -> AsyncGenerator[str, None]:
"""流式聊天
Args:
message: 用户消息
conversation_id: 会话 ID
Yields:
SSE 格式的事件
"""
"""流式聊天"""
conversation_id = conversation_id or f"conv-{uuid.uuid4().hex[:8]}"
config = {"configurable": {"thread_id": str(conversation_id)}}
@@ -435,7 +495,8 @@ class HandoffsService:
if node_name in self.agent_configs:
if current_agent != node_name:
current_agent = node_name
yield f"event: agent\ndata: {json.dumps({'agent': node_name}, ensure_ascii=False)}\n\n"
agent_display_name = self.agent_configs[node_name].get("name", node_name)
yield f"event: agent\ndata: {json.dumps({'agent': node_name, 'agent_name': agent_display_name}, ensure_ascii=False)}\n\n"
# 捕获 LLM 流式输出
elif kind == "on_chat_model_stream":
@@ -448,7 +509,8 @@ class HandoffsService:
tool_name = event.get("name", "")
if tool_name.startswith("transfer_to_"):
target_agent = tool_name.replace("transfer_to_", "")
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent}, ensure_ascii=False)}\n\n"
target_name = self.agent_configs.get(target_agent, {}).get("name", target_agent)
yield f"event: handoff\ndata: {json.dumps({'from': current_agent, 'to': target_agent, 'to_name': target_name}, ensure_ascii=False)}\n\n"
# 发送结束事件
yield f"event: end\ndata: {json.dumps({'conversation_id': str(conversation_id), 'final_agent': current_agent}, ensure_ascii=False)}\n\n"
@@ -458,57 +520,86 @@ class HandoffsService:
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
def get_agents(self) -> List[Dict[str, Any]]:
"""获取可用的 Agent 列表
Returns:
Agent 列表
"""
"""获取可用的 Agent 列表"""
agents = []
for name, config in self.agent_configs.items():
agents.append({
"name": name,
"id": name,
"name": config.get("name", name),
"description": config.get("description", ""),
"capabilities": config.get("capabilities", []),
"can_transfer_to": config.get("can_transfer_to", [])
})
return agents
# ==================== 全局实例 ====================
# ==================== 服务工厂 ====================
_default_service: Optional[HandoffsService] = None
# 缓存服务实例(按 app_id
_service_cache: Dict[str, HandoffsService] = {}
def get_handoffs_service(
agent_configs: Dict[str, Dict] = None,
llm_config: Dict[str, Any] = None,
def get_handoffs_service_for_app(
app_id: uuid.UUID,
db: Session,
streaming: bool = True
) -> HandoffsService:
"""获取 Handoffs 服务实例
"""根据 app_id 获取 Handoffs 服务实例
Args:
agent_configs: Agent 配置(可选)
llm_config: LLM 配置(可选)
app_id: 应用 ID
db: 数据库会话
streaming: 是否流式
Returns:
HandoffsService 实例
"""
global _default_service
from app.services.multi_agent_service import MultiAgentService
# 如果有自定义配置,创建新实例
if agent_configs or llm_config:
return HandoffsService(agent_configs, llm_config, streaming)
cache_key = f"{app_id}_{streaming}"
# 否则使用默认实例
if _default_service is None:
_default_service = HandoffsService(streaming=streaming)
# 检查缓存
if cache_key in _service_cache:
return _service_cache[cache_key]
return _default_service
# 获取多 Agent 配置
multi_agent_service = MultiAgentService(db)
multi_agent_config = multi_agent_service.get_multi_agent_configs(app_id)
if not multi_agent_config:
raise ValueError(f"应用 {app_id} 没有多 Agent 配置")
# 转换配置
agent_configs, model_config = convert_multi_agent_config_to_handoffs(multi_agent_config, db)
if not agent_configs:
raise ValueError(f"应用 {app_id} 没有配置子 Agent")
if not model_config:
raise ValueError(f"应用 {app_id} 没有配置模型")
# 创建服务
service = HandoffsService(agent_configs, model_config, streaming)
# 缓存
_service_cache[cache_key] = service
return service
def reset_default_service():
"""重置默认服务实例"""
global _default_service
if _default_service:
_default_service.reset()
_default_service = None
def reset_handoffs_service_cache(app_id: uuid.UUID = None):
"""重置服务缓存
Args:
app_id: 应用 ID如果为 None 则清除所有缓存
"""
global _service_cache
if app_id:
keys_to_remove = [k for k in _service_cache if k.startswith(str(app_id))]
for key in keys_to_remove:
del _service_cache[key]
else:
_service_cache = {}
logger.info(f"Handoffs 服务缓存已重置: app_id={app_id}")