427 lines
14 KiB
Python
427 lines
14 KiB
Python
"""智能路由器 - 解决多轮对话路由错乱"""
|
|
import re
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
from app.services.conversation_state_manager import ConversationStateManager
|
|
from app.core.logging_config import get_business_logger
|
|
|
|
logger = get_business_logger()
|
|
|
|
|
|
class SmartRouter:
|
|
"""智能路由器
|
|
|
|
核心功能:
|
|
1. 检测主题切换
|
|
2. 判断是否应该继续使用当前 Agent
|
|
3. 智能选择最合适的 Agent
|
|
4. 支持强制重新路由
|
|
"""
|
|
|
|
# 主题切换信号
|
|
SWITCH_SIGNALS = [
|
|
"换个话题", "另外", "还有", "对了",
|
|
"那这个呢", "再问一个", "顺便问下",
|
|
"我想问", "帮我", "请问", "换一个"
|
|
]
|
|
|
|
# 延续信号
|
|
CONTINUATION_SIGNALS = [
|
|
"继续", "还是", "也", "同样", "类似",
|
|
"这个", "那个", "它", "他", "她", "呢"
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
state_manager: ConversationStateManager,
|
|
routing_rules: List[Dict[str, Any]],
|
|
sub_agents: Dict[str, Any]
|
|
):
|
|
"""初始化智能路由器
|
|
|
|
Args:
|
|
state_manager: 会话状态管理器
|
|
routing_rules: 路由规则列表
|
|
sub_agents: 子 Agent 配置字典
|
|
"""
|
|
self.state_manager = state_manager
|
|
self.routing_rules = routing_rules
|
|
self.sub_agents = sub_agents
|
|
|
|
# 配置参数
|
|
self.min_confidence_for_switch = 0.7 # 切换 Agent 的最小置信度
|
|
self.max_same_agent_turns = 10 # 同一 Agent 最大连续轮数
|
|
|
|
async def route(
|
|
self,
|
|
message: str,
|
|
conversation_id: Optional[str] = None,
|
|
force_new: bool = False
|
|
) -> Dict[str, Any]:
|
|
"""智能路由
|
|
|
|
Args:
|
|
message: 用户消息
|
|
conversation_id: 会话 ID
|
|
force_new: 是否强制重新路由(忽略历史)
|
|
|
|
Returns:
|
|
路由结果 {
|
|
"agent_id": str,
|
|
"confidence": float,
|
|
"strategy": str,
|
|
"topic": str,
|
|
"topic_changed": bool,
|
|
"reason": str
|
|
}
|
|
"""
|
|
logger.info(
|
|
"开始智能路由",
|
|
extra={
|
|
"message_length": len(message),
|
|
"conversation_id": conversation_id,
|
|
"force_new": force_new
|
|
}
|
|
)
|
|
|
|
# 1. 获取会话状态
|
|
state = None
|
|
if conversation_id and not force_new:
|
|
state = self.state_manager.get_state(conversation_id)
|
|
|
|
# 2. 检测主题切换
|
|
topic_changed = self._detect_topic_change(message, state)
|
|
|
|
# 3. 提取当前主题
|
|
topic = self._extract_topic(message)
|
|
|
|
# 4. 选择路由策略
|
|
if force_new:
|
|
# 强制重新路由
|
|
agent_id, confidence = self._route_from_scratch(message)
|
|
strategy = "force_new"
|
|
reason = "用户强制重新路由"
|
|
|
|
elif not state or not state.get("current_agent_id"):
|
|
# 新会话,从头路由
|
|
agent_id, confidence = self._route_from_scratch(message)
|
|
strategy = "new_conversation"
|
|
reason = "新会话,首次路由"
|
|
|
|
elif topic_changed:
|
|
# 主题切换,重新路由
|
|
agent_id, confidence = self._route_from_scratch(message)
|
|
strategy = "topic_changed"
|
|
reason = f"检测到主题切换: {state.get('last_topic')} -> {topic}"
|
|
|
|
elif state.get("same_agent_turns", 0) >= self.max_same_agent_turns:
|
|
# 同一 Agent 使用太久,强制重新评估
|
|
agent_id, confidence = self._route_from_scratch(message)
|
|
strategy = "max_turns_reached"
|
|
reason = f"同一 Agent 已使用 {state['same_agent_turns']} 轮"
|
|
|
|
else:
|
|
# 检查是否应该继续使用当前 Agent
|
|
current_agent_id = state["current_agent_id"]
|
|
should_continue, continue_confidence = self._should_continue_current_agent(
|
|
message,
|
|
current_agent_id
|
|
)
|
|
|
|
if should_continue:
|
|
# 继续使用当前 Agent
|
|
agent_id = current_agent_id
|
|
confidence = continue_confidence
|
|
strategy = "continue_current"
|
|
reason = "消息在当前 Agent 能力范围内"
|
|
else:
|
|
# 重新路由
|
|
new_agent_id, new_confidence = self._route_from_scratch(message)
|
|
|
|
# 只有新 Agent 的置信度明显更高时才切换
|
|
if new_confidence > continue_confidence + self.min_confidence_for_switch:
|
|
agent_id = new_agent_id
|
|
confidence = new_confidence
|
|
strategy = "switch_agent"
|
|
reason = f"新 Agent 置信度更高: {new_confidence:.2f} vs {continue_confidence:.2f}"
|
|
else:
|
|
# 置信度差距不大,继续使用当前 Agent
|
|
agent_id = current_agent_id
|
|
confidence = continue_confidence
|
|
strategy = "keep_current"
|
|
reason = "置信度差距不足以切换 Agent"
|
|
|
|
# 5. 更新会话状态
|
|
if conversation_id:
|
|
self.state_manager.update_state(
|
|
conversation_id,
|
|
agent_id,
|
|
message,
|
|
topic,
|
|
confidence
|
|
)
|
|
|
|
result = {
|
|
"agent_id": agent_id,
|
|
"confidence": confidence,
|
|
"strategy": strategy,
|
|
"topic": topic,
|
|
"topic_changed": topic_changed,
|
|
"reason": reason
|
|
}
|
|
|
|
logger.info(
|
|
"路由完成",
|
|
extra={
|
|
"agent_id": agent_id,
|
|
"strategy": strategy,
|
|
"confidence": confidence,
|
|
"topic": topic
|
|
}
|
|
)
|
|
|
|
return result
|
|
|
|
def _detect_topic_change(
|
|
self,
|
|
message: str,
|
|
state: Optional[Dict[str, Any]]
|
|
) -> bool:
|
|
"""检测主题是否切换
|
|
|
|
Args:
|
|
message: 用户消息
|
|
state: 会话状态
|
|
|
|
Returns:
|
|
是否切换主题
|
|
"""
|
|
if not state or not state.get("last_topic"):
|
|
return False
|
|
|
|
# 检查明确的切换信号
|
|
for signal in self.SWITCH_SIGNALS:
|
|
if signal in message:
|
|
logger.info(f"检测到主题切换信号: {signal}")
|
|
return True
|
|
|
|
# 比较主题
|
|
current_topic = self._extract_topic(message)
|
|
last_topic = state.get("last_topic")
|
|
|
|
if current_topic != last_topic and current_topic != "其他":
|
|
logger.info(f"主题变化: {last_topic} -> {current_topic}")
|
|
return True
|
|
|
|
return False
|
|
|
|
def _should_continue_current_agent(
|
|
self,
|
|
message: str,
|
|
current_agent_id: str
|
|
) -> Tuple[bool, float]:
|
|
"""判断是否应该继续使用当前 Agent
|
|
|
|
Args:
|
|
message: 用户消息
|
|
current_agent_id: 当前 Agent ID
|
|
|
|
Returns:
|
|
(是否继续, 置信度)
|
|
"""
|
|
# 检查延续信号
|
|
has_continuation_signal = any(
|
|
signal in message
|
|
for signal in self.CONTINUATION_SIGNALS
|
|
)
|
|
|
|
# 计算当前 Agent 对消息的匹配度
|
|
current_score = self._calculate_agent_score(message, current_agent_id)
|
|
|
|
# 如果有延续信号且匹配度不太低,继续使用
|
|
if has_continuation_signal and current_score > 0.3:
|
|
return True, min(current_score + 0.2, 1.0)
|
|
|
|
# 如果匹配度高,继续使用
|
|
if current_score > 0.6:
|
|
return True, current_score
|
|
|
|
return False, current_score
|
|
|
|
def _route_from_scratch(self, message: str) -> Tuple[str, float]:
|
|
"""从头开始路由(不考虑历史)
|
|
|
|
Args:
|
|
message: 用户消息
|
|
|
|
Returns:
|
|
(Agent ID, 置信度)
|
|
"""
|
|
best_agent_id = None
|
|
best_score = 0.0
|
|
|
|
# 遍历所有路由规则
|
|
for rule in self.routing_rules:
|
|
score = self._calculate_rule_score(message, rule)
|
|
|
|
if score > best_score:
|
|
best_score = score
|
|
best_agent_id = rule.get("target_agent_id")
|
|
|
|
# 如果没有匹配的规则,使用默认 Agent
|
|
if not best_agent_id or best_score < 0.3:
|
|
best_agent_id = self._get_default_agent_id()
|
|
best_score = 0.5
|
|
logger.warning(f"未找到匹配规则,使用默认 Agent: {best_agent_id}")
|
|
|
|
return best_agent_id, best_score
|
|
|
|
def _calculate_rule_score(
|
|
self,
|
|
message: str,
|
|
rule: Dict[str, Any]
|
|
) -> float:
|
|
"""计算规则匹配分数
|
|
|
|
Args:
|
|
message: 用户消息
|
|
rule: 路由规则
|
|
|
|
Returns:
|
|
匹配分数 (0-1)
|
|
"""
|
|
score = 0.0
|
|
message_lower = message.lower()
|
|
|
|
# 1. 关键词匹配 (权重 0.6)
|
|
keywords = rule.get("keywords", [])
|
|
if keywords:
|
|
matched_keywords = sum(
|
|
1 for keyword in keywords
|
|
if keyword.lower() in message_lower
|
|
)
|
|
keyword_score = matched_keywords / len(keywords)
|
|
score += keyword_score * 0.6
|
|
|
|
# 2. 正则匹配 (权重 0.3)
|
|
patterns = rule.get("patterns", [])
|
|
if patterns:
|
|
matched_patterns = sum(
|
|
1 for pattern in patterns
|
|
if re.search(pattern, message, re.IGNORECASE)
|
|
)
|
|
pattern_score = matched_patterns / len(patterns)
|
|
score += pattern_score * 0.3
|
|
|
|
# 3. 排除关键词 (负分)
|
|
exclude_keywords = rule.get("exclude_keywords", [])
|
|
if exclude_keywords:
|
|
has_exclude = any(
|
|
keyword.lower() in message_lower
|
|
for keyword in exclude_keywords
|
|
)
|
|
if has_exclude:
|
|
score *= 0.5 # 减半
|
|
|
|
# 4. 最小关键词数量要求
|
|
min_keyword_count = rule.get("min_keyword_count", 0)
|
|
if keywords and min_keyword_count > 0:
|
|
matched_count = sum(
|
|
1 for keyword in keywords
|
|
if keyword.lower() in message_lower
|
|
)
|
|
if matched_count < min_keyword_count:
|
|
score *= 0.7 # 惩罚
|
|
|
|
return min(score, 1.0)
|
|
|
|
def _calculate_agent_score(
|
|
self,
|
|
message: str,
|
|
agent_id: str
|
|
) -> float:
|
|
"""计算 Agent 对消息的匹配分数
|
|
|
|
Args:
|
|
message: 用户消息
|
|
agent_id: Agent ID
|
|
|
|
Returns:
|
|
匹配分数 (0-1)
|
|
"""
|
|
# 找到该 Agent 对应的所有规则
|
|
agent_rules = [
|
|
rule for rule in self.routing_rules
|
|
if rule.get("target_agent_id") == agent_id
|
|
]
|
|
|
|
if not agent_rules:
|
|
return 0.0
|
|
|
|
# 返回最高分数
|
|
max_score = max(
|
|
self._calculate_rule_score(message, rule)
|
|
for rule in agent_rules
|
|
)
|
|
|
|
return max_score
|
|
|
|
def _extract_topic(self, message: str) -> str:
|
|
"""提取消息主题
|
|
|
|
Args:
|
|
message: 用户消息
|
|
|
|
Returns:
|
|
主题名称
|
|
"""
|
|
# 主题关键词映射
|
|
topic_keywords = {
|
|
"数学": ["数学", "方程", "计算", "求解", "x", "y", "函数", "几何"],
|
|
"物理": ["物理", "力", "速度", "加速度", "能量", "功率", "电路"],
|
|
"化学": ["化学", "方程式", "反应", "元素", "分子", "原子", "化合物"],
|
|
"语文": ["语文", "古诗", "作文", "阅读", "文言文", "诗词"],
|
|
"英语": ["英语", "单词", "语法", "翻译", "时态", "句型"],
|
|
"历史": ["历史", "朝代", "事件", "人物", "战争", "革命"],
|
|
"作业": ["作业", "批改", "检查", "评分", "反馈"],
|
|
"学习规划": ["计划", "规划", "方法", "技巧", "时间", "安排"],
|
|
"订单": ["订单", "发货", "物流", "配送", "快递"],
|
|
"退款": ["退款", "退货", "售后", "换货", "维修"],
|
|
"账户": ["账户", "密码", "登录", "注册", "绑定"],
|
|
"支付": ["支付", "付款", "充值", "余额", "优惠券"]
|
|
}
|
|
|
|
message_lower = message.lower()
|
|
|
|
# 统计每个主题的匹配度
|
|
topic_scores = {}
|
|
for topic, keywords in topic_keywords.items():
|
|
matched = sum(
|
|
1 for keyword in keywords
|
|
if keyword in message_lower
|
|
)
|
|
if matched > 0:
|
|
topic_scores[topic] = matched
|
|
|
|
# 返回匹配度最高的主题
|
|
if topic_scores:
|
|
best_topic = max(topic_scores.items(), key=lambda x: x[1])[0]
|
|
return best_topic
|
|
|
|
return "其他"
|
|
|
|
def _get_default_agent_id(self) -> str:
|
|
"""获取默认 Agent ID
|
|
|
|
Returns:
|
|
默认 Agent ID
|
|
"""
|
|
# 优先使用第一个路由规则的 Agent
|
|
if self.routing_rules:
|
|
return self.routing_rules[0].get("target_agent_id")
|
|
|
|
# 否则使用第一个子 Agent
|
|
if self.sub_agents:
|
|
return next(iter(self.sub_agents.keys()))
|
|
|
|
return "default-agent"
|