[modify] manage multi agent logic
This commit is contained in:
@@ -5,63 +5,63 @@ import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.schemas import ModelParameters
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.models import ModelConfig, AgentConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MasterAgentRouter:
|
||||
"""Master Agent 路由器
|
||||
|
||||
|
||||
让 Master Agent 作为"大脑",负责:
|
||||
1. 分析用户意图
|
||||
2. 选择最合适的 Sub Agent
|
||||
3. 决定是否需要多 Agent 协作
|
||||
4. 管理会话上下文
|
||||
|
||||
|
||||
优势:
|
||||
- 更智能的决策(基于完整上下文)
|
||||
- 减少 LLM 调用次数
|
||||
- 架构更清晰(Master Agent 真正起作用)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
master_agent_config: AgentConfig,
|
||||
master_model_config: ModelConfig,
|
||||
model_parameters: ModelParameters,
|
||||
sub_agents: Dict[str, Any],
|
||||
state_manager: ConversationStateManager,
|
||||
enable_rule_fast_path: bool = True
|
||||
):
|
||||
"""初始化 Master Agent 路由器
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
master_agent_config: Master Agent 配置
|
||||
master_model_config: Master Agent 使用的模型配置
|
||||
sub_agents: 子 Agent 配置字典
|
||||
state_manager: 会话状态管理器
|
||||
enable_rule_fast_path: 是否启用规则快速路径(性能优化)
|
||||
"""
|
||||
self.db = db
|
||||
self.master_agent_config = master_agent_config
|
||||
self.master_model_config = master_model_config
|
||||
self.model_parameters = model_parameters
|
||||
self.sub_agents = sub_agents
|
||||
self.state_manager = state_manager
|
||||
self.enable_rule_fast_path = enable_rule_fast_path
|
||||
|
||||
|
||||
logger.info(
|
||||
"Master Agent 路由器初始化",
|
||||
extra={
|
||||
"master_agent": master_agent_config.name,
|
||||
"sub_agent_count": len(sub_agents),
|
||||
"enable_rule_fast_path": enable_rule_fast_path
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def route(
|
||||
self,
|
||||
message: str,
|
||||
@@ -69,12 +69,12 @@ class MasterAgentRouter:
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""智能路由决策
|
||||
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
variables: 变量参数
|
||||
|
||||
|
||||
Returns:
|
||||
路由决策结果
|
||||
"""
|
||||
@@ -85,12 +85,12 @@ class MasterAgentRouter:
|
||||
"conversation_id": conversation_id
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 1. 获取会话状态
|
||||
state = None
|
||||
if conversation_id:
|
||||
state = self.state_manager.get_state(conversation_id)
|
||||
|
||||
|
||||
# 2. 尝试规则快速路径(可选的性能优化)
|
||||
if self.enable_rule_fast_path:
|
||||
rule_result = self._try_rule_fast_path(message, state)
|
||||
@@ -102,7 +102,7 @@ class MasterAgentRouter:
|
||||
"confidence": rule_result["confidence"]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 更新会话状态
|
||||
if conversation_id:
|
||||
self.state_manager.update_state(
|
||||
@@ -112,12 +112,12 @@ class MasterAgentRouter:
|
||||
rule_result.get("topic"),
|
||||
rule_result["confidence"]
|
||||
)
|
||||
|
||||
|
||||
return rule_result
|
||||
|
||||
|
||||
# 3. 调用 Master Agent 做决策
|
||||
decision = await self._master_agent_decide(message, state, variables)
|
||||
|
||||
|
||||
# 4. 更新会话状态
|
||||
if conversation_id:
|
||||
self.state_manager.update_state(
|
||||
@@ -127,7 +127,7 @@ class MasterAgentRouter:
|
||||
decision.get("topic"),
|
||||
decision["confidence"]
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
"Master Agent 路由完成",
|
||||
extra={
|
||||
@@ -136,22 +136,22 @@ class MasterAgentRouter:
|
||||
"confidence": decision["confidence"]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return decision
|
||||
|
||||
|
||||
def _try_rule_fast_path(
|
||||
self,
|
||||
message: str,
|
||||
state: Optional[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""尝试规则快速路径(性能优化)
|
||||
|
||||
|
||||
对于明确的关键词匹配,直接返回结果,不调用 Master Agent
|
||||
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
state: 会话状态
|
||||
|
||||
|
||||
Returns:
|
||||
如果命中规则返回决策结果,否则返回 None
|
||||
"""
|
||||
@@ -178,15 +178,15 @@ class MasterAgentRouter:
|
||||
"confidence_threshold": 0.9
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
message_lower = message.lower()
|
||||
|
||||
|
||||
for rule in high_confidence_rules:
|
||||
matched_keywords = [kw for kw in rule["keywords"] if kw in message_lower]
|
||||
|
||||
|
||||
if matched_keywords:
|
||||
confidence = len(matched_keywords) / len(rule["keywords"])
|
||||
|
||||
|
||||
if confidence >= rule["confidence_threshold"]:
|
||||
# 查找对应的 agent
|
||||
for agent_id, agent_data in self.sub_agents.items():
|
||||
@@ -201,9 +201,9 @@ class MasterAgentRouter:
|
||||
"need_collaboration": False,
|
||||
"routing_method": "rule"
|
||||
}
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _master_agent_decide(
|
||||
self,
|
||||
message: str,
|
||||
@@ -211,35 +211,35 @@ class MasterAgentRouter:
|
||||
variables: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""让 Master Agent 做路由决策
|
||||
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
state: 会话状态
|
||||
variables: 变量参数
|
||||
|
||||
|
||||
Returns:
|
||||
决策结果
|
||||
"""
|
||||
# 1. 构建决策 prompt
|
||||
prompt = self._build_decision_prompt(message, state, variables)
|
||||
|
||||
|
||||
# 2. 调用 Master Agent 的 LLM
|
||||
try:
|
||||
response = await self._call_master_agent_llm(prompt)
|
||||
|
||||
|
||||
# 3. 解析决策
|
||||
decision = self._parse_decision(response)
|
||||
|
||||
|
||||
# 4. 验证决策
|
||||
decision = self._validate_decision(decision)
|
||||
|
||||
|
||||
return decision
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Master Agent 决策失败: {str(e)}")
|
||||
# 降级到默认 agent
|
||||
return self._get_fallback_decision(message)
|
||||
|
||||
|
||||
def _build_decision_prompt(
|
||||
self,
|
||||
message: str,
|
||||
@@ -247,12 +247,12 @@ class MasterAgentRouter:
|
||||
variables: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建 Master Agent 的决策 prompt
|
||||
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
state: 会话状态
|
||||
variables: 变量参数
|
||||
|
||||
|
||||
Returns:
|
||||
prompt 字符串
|
||||
"""
|
||||
@@ -260,29 +260,29 @@ class MasterAgentRouter:
|
||||
agent_descriptions = []
|
||||
for agent_id, agent_data in self.sub_agents.items():
|
||||
agent_info = agent_data.get("info", {})
|
||||
|
||||
|
||||
name = agent_info.get("name", "未命名")
|
||||
role = agent_info.get("role", "")
|
||||
capabilities = agent_info.get("capabilities", [])
|
||||
|
||||
|
||||
# 简化格式:一行描述
|
||||
desc = f"- {agent_id}: {name}"
|
||||
if role:
|
||||
desc += f" ({role})"
|
||||
if capabilities:
|
||||
desc += f" - {', '.join(capabilities[:3])}" # 只取前3个能力
|
||||
|
||||
|
||||
agent_descriptions.append(desc)
|
||||
|
||||
|
||||
agents_text = "\n".join(agent_descriptions)
|
||||
|
||||
|
||||
# 2. 构建会话上下文
|
||||
context_text = ""
|
||||
if state:
|
||||
current_agent = state.get("current_agent_id")
|
||||
last_topic = state.get("last_topic")
|
||||
same_turns = state.get("same_agent_turns", 0)
|
||||
|
||||
|
||||
if current_agent:
|
||||
context_text = f"""
|
||||
当前会话上下文:
|
||||
@@ -290,10 +290,10 @@ class MasterAgentRouter:
|
||||
- 上一个主题: {last_topic}
|
||||
- 连续使用轮数: {same_turns}
|
||||
"""
|
||||
|
||||
|
||||
# 获取第一个可用的 agent_id 作为示例
|
||||
example_agent_id = next(iter(self.sub_agents.keys())) if self.sub_agents else "agent_id"
|
||||
|
||||
|
||||
# 3. 构建完整 prompt(简化版,提升性能)
|
||||
prompt = f"""路由任务:分析问题并选择合适的 Agent。
|
||||
|
||||
@@ -331,15 +331,15 @@ class MasterAgentRouter:
|
||||
6. 只做路由决策,不要回答问题内容
|
||||
|
||||
请返回 JSON:"""
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _call_master_agent_llm(self, prompt: str) -> str:
|
||||
"""调用 Master Agent 的 LLM
|
||||
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
|
||||
|
||||
Returns:
|
||||
LLM 响应
|
||||
"""
|
||||
@@ -347,16 +347,13 @@ class MasterAgentRouter:
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelApiKey, ModelType
|
||||
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == self.master_model_config.id,
|
||||
ModelApiKey.is_active == True
|
||||
).first()
|
||||
|
||||
api_key_config = ModelApiKeyService.get_a_api_key(self.db, self.master_model_config.id)
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("Master Agent 模型没有可用的 API Key")
|
||||
|
||||
|
||||
logger.info(
|
||||
"调用 Master Agent LLM",
|
||||
extra={
|
||||
@@ -364,39 +361,44 @@ class MasterAgentRouter:
|
||||
"model_name": api_key_config.model_name
|
||||
}
|
||||
)
|
||||
|
||||
temperature = 0.3 # 决策任务使用较低温度
|
||||
max_tokens = 1000
|
||||
extra_params = None
|
||||
if self.model_parameters:
|
||||
extra_params = {"temperature": temperature,
|
||||
"max_tokens":max_tokens
|
||||
}
|
||||
# 创建 RedBearModelConfig
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=api_key_config.model_name,
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
temperature=0.3, # 决策任务使用较低温度
|
||||
max_tokens=1000
|
||||
extra_params = extra_params
|
||||
)
|
||||
|
||||
|
||||
# 创建 LLM 实例
|
||||
llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||
|
||||
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
return response.content
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Master Agent LLM 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _parse_decision(self, response: str) -> Dict[str, Any]:
|
||||
"""解析 Master Agent 的决策
|
||||
|
||||
|
||||
Args:
|
||||
response: LLM 响应
|
||||
|
||||
|
||||
Returns:
|
||||
决策字典
|
||||
"""
|
||||
@@ -405,29 +407,29 @@ class MasterAgentRouter:
|
||||
json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
decision = json.loads(json_match.group())
|
||||
|
||||
|
||||
# 添加默认值
|
||||
decision.setdefault("confidence", 0.8)
|
||||
decision.setdefault("strategy", "master_agent")
|
||||
decision.setdefault("routing_method", "master_agent")
|
||||
decision.setdefault("need_collaboration", False)
|
||||
decision.setdefault("collaboration_agents", [])
|
||||
|
||||
|
||||
return decision
|
||||
else:
|
||||
raise ValueError("无法从响应中提取 JSON")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Master Agent 决策失败: {str(e)}")
|
||||
logger.debug(f"原始响应: {response}")
|
||||
raise
|
||||
|
||||
|
||||
def _validate_decision(self, decision: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""验证决策的有效性
|
||||
|
||||
|
||||
Args:
|
||||
decision: 决策字典
|
||||
|
||||
|
||||
Returns:
|
||||
验证后的决策
|
||||
"""
|
||||
@@ -439,26 +441,26 @@ class MasterAgentRouter:
|
||||
decision["selected_agent_id"] = self._get_default_agent_id()
|
||||
decision["confidence"] = 0.5
|
||||
decision["reasoning"] = "原始选择无效,使用默认 Agent"
|
||||
|
||||
|
||||
# 验证 confidence
|
||||
confidence = decision.get("confidence", 0.8)
|
||||
if not isinstance(confidence, (int, float)) or confidence < 0 or confidence > 1:
|
||||
decision["confidence"] = 0.8
|
||||
|
||||
|
||||
# 验证协作 agents
|
||||
if decision.get("need_collaboration"):
|
||||
# 检查是否是问题拆分模式
|
||||
if decision.get("need_decomposition") or decision.get("sub_questions"):
|
||||
# 问题拆分模式
|
||||
sub_questions = decision.get("sub_questions", [])
|
||||
|
||||
|
||||
# 验证每个子问题
|
||||
valid_sub_questions = []
|
||||
for sub_q in sub_questions:
|
||||
if isinstance(sub_q, dict):
|
||||
agent_id = sub_q.get("agent_id")
|
||||
question = sub_q.get("question")
|
||||
|
||||
|
||||
if agent_id in self.sub_agents and question:
|
||||
# 确保有必要的字段
|
||||
sub_q.setdefault("order", len(valid_sub_questions) + 1)
|
||||
@@ -475,9 +477,9 @@ class MasterAgentRouter:
|
||||
"available_agents": list(self.sub_agents.keys())
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
decision["sub_questions"] = valid_sub_questions
|
||||
|
||||
|
||||
# 如果所有子问题都验证失败,降级处理
|
||||
if not valid_sub_questions and sub_questions:
|
||||
logger.warning(
|
||||
@@ -496,10 +498,10 @@ class MasterAgentRouter:
|
||||
first_agent_id = next(iter(self.sub_agents.keys()))
|
||||
decision["selected_agent_id"] = first_agent_id
|
||||
logger.info(f"降级使用默认 Agent: {first_agent_id}")
|
||||
|
||||
|
||||
# 设置协作策略为 decomposition
|
||||
decision["collaboration_strategy"] = "decomposition"
|
||||
|
||||
|
||||
logger.info(
|
||||
"问题拆分决策验证完成",
|
||||
extra={
|
||||
@@ -510,7 +512,7 @@ class MasterAgentRouter:
|
||||
else:
|
||||
# 普通协作模式
|
||||
collaboration_agents = decision.get("collaboration_agents", [])
|
||||
|
||||
|
||||
# 如果是简单列表格式,转换为详细格式
|
||||
if collaboration_agents and isinstance(collaboration_agents[0], str):
|
||||
collaboration_agents = [
|
||||
@@ -522,7 +524,7 @@ class MasterAgentRouter:
|
||||
}
|
||||
for i, agent_id in enumerate(collaboration_agents)
|
||||
]
|
||||
|
||||
|
||||
# 验证每个协作 agent
|
||||
valid_agents = []
|
||||
for agent_info in collaboration_agents:
|
||||
@@ -541,13 +543,13 @@ class MasterAgentRouter:
|
||||
"task": "协作处理",
|
||||
"order": len(valid_agents) + 1
|
||||
})
|
||||
|
||||
|
||||
decision["collaboration_agents"] = valid_agents
|
||||
|
||||
|
||||
# 设置默认协作策略
|
||||
if not decision.get("collaboration_strategy"):
|
||||
decision["collaboration_strategy"] = "sequential"
|
||||
|
||||
|
||||
logger.info(
|
||||
"协作决策验证完成",
|
||||
extra={
|
||||
@@ -555,20 +557,20 @@ class MasterAgentRouter:
|
||||
"strategy": decision.get("collaboration_strategy")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return decision
|
||||
|
||||
|
||||
def _get_fallback_decision(self, message: str) -> Dict[str, Any]:
|
||||
"""获取降级决策(当 Master Agent 失败时)
|
||||
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
|
||||
Returns:
|
||||
降级决策
|
||||
"""
|
||||
default_agent_id = self._get_default_agent_id()
|
||||
|
||||
|
||||
return {
|
||||
"selected_agent_id": default_agent_id,
|
||||
"confidence": 0.5,
|
||||
@@ -579,15 +581,15 @@ class MasterAgentRouter:
|
||||
"collaboration_agents": [],
|
||||
"routing_method": "fallback"
|
||||
}
|
||||
|
||||
|
||||
def _get_default_agent_id(self) -> str:
|
||||
"""获取默认 Agent ID
|
||||
|
||||
|
||||
Returns:
|
||||
默认 Agent ID
|
||||
"""
|
||||
if self.sub_agents:
|
||||
# 返回第一个 agent
|
||||
return next(iter(self.sub_agents.keys()))
|
||||
|
||||
|
||||
return "default-agent"
|
||||
|
||||
Reference in New Issue
Block a user