262 lines
7.6 KiB
Python
262 lines
7.6 KiB
Python
"""会话状态管理器 - 解决多轮对话路由错乱"""
|
||
import json
|
||
from typing import Optional, Dict, Any, List
|
||
from datetime import datetime
|
||
from app.core.logging_config import get_business_logger
|
||
|
||
logger = get_business_logger()
|
||
|
||
|
||
class ConversationStateManager:
|
||
"""会话状态管理器
|
||
|
||
用于管理多轮对话中的会话状态,包括:
|
||
- 当前使用的 Agent
|
||
- 路由历史
|
||
- 主题追踪
|
||
- Agent 切换统计
|
||
"""
|
||
|
||
def __init__(self, storage_backend: Optional[Any] = None):
|
||
"""初始化状态管理器
|
||
|
||
Args:
|
||
storage_backend: 存储后端(Redis/内存等)
|
||
"""
|
||
self.storage = storage_backend or InMemoryStorage()
|
||
self.ttl = 3600 # 1小时过期
|
||
|
||
def get_state(self, conversation_id: str) -> Dict[str, Any]:
|
||
"""获取会话状态
|
||
|
||
Args:
|
||
conversation_id: 会话 ID
|
||
|
||
Returns:
|
||
会话状态字典
|
||
"""
|
||
state = self.storage.get(f"conv_state:{conversation_id}")
|
||
|
||
if not state:
|
||
logger.info(f"创建新会话状态: {conversation_id}")
|
||
return self._create_new_state(conversation_id)
|
||
|
||
return state
|
||
|
||
def update_state(
|
||
self,
|
||
conversation_id: str,
|
||
agent_id: str,
|
||
message: str,
|
||
topic: Optional[str] = None,
|
||
confidence: float = 1.0
|
||
) -> Dict[str, Any]:
|
||
"""更新会话状态
|
||
|
||
Args:
|
||
conversation_id: 会话 ID
|
||
agent_id: 当前 Agent ID
|
||
message: 用户消息
|
||
topic: 消息主题
|
||
confidence: 路由置信度
|
||
|
||
Returns:
|
||
更新后的状态
|
||
"""
|
||
state = self.get_state(conversation_id)
|
||
|
||
# 检测 Agent 切换
|
||
agent_changed = False
|
||
if state["current_agent_id"] and state["current_agent_id"] != agent_id:
|
||
agent_changed = True
|
||
state["switch_count"] += 1
|
||
state["previous_agent_id"] = state["current_agent_id"]
|
||
state["same_agent_turns"] = 0
|
||
|
||
logger.info(
|
||
"Agent 切换",
|
||
extra={
|
||
"conversation_id": conversation_id,
|
||
"from": state["current_agent_id"],
|
||
"to": agent_id,
|
||
"switch_count": state["switch_count"]
|
||
}
|
||
)
|
||
else:
|
||
state["same_agent_turns"] += 1
|
||
|
||
# 更新当前 Agent
|
||
state["current_agent_id"] = agent_id
|
||
state["last_message"] = message
|
||
state["last_topic"] = topic
|
||
state["updated_at"] = datetime.now().isoformat()
|
||
|
||
# 添加到历史
|
||
history_item = {
|
||
"message": message[:100], # 截断长消息
|
||
"agent_id": agent_id,
|
||
"topic": topic,
|
||
"confidence": confidence,
|
||
"agent_changed": agent_changed,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
state["routing_history"].append(history_item)
|
||
|
||
# 保持最近 10 条历史
|
||
if len(state["routing_history"]) > 10:
|
||
state["routing_history"] = state["routing_history"][-10:]
|
||
|
||
# 保存状态
|
||
self.storage.set(
|
||
f"conv_state:{conversation_id}",
|
||
state,
|
||
ttl=self.ttl
|
||
)
|
||
|
||
return state
|
||
|
||
def clear_state(self, conversation_id: str) -> None:
|
||
"""清除会话状态
|
||
|
||
Args:
|
||
conversation_id: 会话 ID
|
||
"""
|
||
self.storage.delete(f"conv_state:{conversation_id}")
|
||
logger.info(f"清除会话状态: {conversation_id}")
|
||
|
||
def get_routing_history(
|
||
self,
|
||
conversation_id: str,
|
||
limit: int = 10
|
||
) -> List[Dict[str, Any]]:
|
||
"""获取路由历史
|
||
|
||
Args:
|
||
conversation_id: 会话 ID
|
||
limit: 返回数量限制
|
||
|
||
Returns:
|
||
路由历史列表
|
||
"""
|
||
state = self.get_state(conversation_id)
|
||
history = state.get("routing_history", [])
|
||
return history[-limit:] if history else []
|
||
|
||
def get_statistics(self, conversation_id: str) -> Dict[str, Any]:
|
||
"""获取会话统计信息
|
||
|
||
Args:
|
||
conversation_id: 会话 ID
|
||
|
||
Returns:
|
||
统计信息
|
||
"""
|
||
state = self.get_state(conversation_id)
|
||
history = state.get("routing_history", [])
|
||
|
||
# 统计各 Agent 使用次数
|
||
agent_usage = {}
|
||
for item in history:
|
||
agent_id = item["agent_id"]
|
||
agent_usage[agent_id] = agent_usage.get(agent_id, 0) + 1
|
||
|
||
# 统计主题分布
|
||
topic_distribution = {}
|
||
for item in history:
|
||
topic = item.get("topic", "未知")
|
||
topic_distribution[topic] = topic_distribution.get(topic, 0) + 1
|
||
|
||
return {
|
||
"conversation_id": conversation_id,
|
||
"total_turns": len(history),
|
||
"switch_count": state.get("switch_count", 0),
|
||
"current_agent_id": state.get("current_agent_id"),
|
||
"same_agent_turns": state.get("same_agent_turns", 0),
|
||
"agent_usage": agent_usage,
|
||
"topic_distribution": topic_distribution,
|
||
"created_at": state.get("created_at"),
|
||
"updated_at": state.get("updated_at")
|
||
}
|
||
|
||
def _create_new_state(self, conversation_id: str) -> Dict[str, Any]:
|
||
"""创建新的会话状态
|
||
|
||
Args:
|
||
conversation_id: 会话 ID
|
||
|
||
Returns:
|
||
新的状态字典
|
||
"""
|
||
state = {
|
||
"conversation_id": conversation_id,
|
||
"current_agent_id": None,
|
||
"previous_agent_id": None,
|
||
"routing_history": [],
|
||
"last_message": None,
|
||
"last_topic": None,
|
||
"switch_count": 0,
|
||
"same_agent_turns": 0,
|
||
"created_at": datetime.now().isoformat(),
|
||
"updated_at": datetime.now().isoformat()
|
||
}
|
||
|
||
# 保存初始状态
|
||
self.storage.set(
|
||
f"conv_state:{conversation_id}",
|
||
state,
|
||
ttl=self.ttl
|
||
)
|
||
|
||
return state
|
||
|
||
|
||
class InMemoryStorage:
|
||
"""内存存储后端(用于开发和测试)"""
|
||
|
||
def __init__(self):
|
||
self._storage: Dict[str, Dict[str, Any]] = {}
|
||
|
||
def get(self, key: str) -> Optional[Dict[str, Any]]:
|
||
"""获取数据"""
|
||
return self._storage.get(key)
|
||
|
||
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
|
||
"""设置数据"""
|
||
self._storage[key] = value
|
||
|
||
def delete(self, key: str) -> None:
|
||
"""删除数据"""
|
||
if key in self._storage:
|
||
del self._storage[key]
|
||
|
||
def clear(self) -> None:
|
||
"""清空所有数据"""
|
||
self._storage.clear()
|
||
|
||
|
||
class RedisStorage:
|
||
"""Redis 存储后端(用于生产环境)"""
|
||
|
||
def __init__(self, redis_client):
|
||
"""初始化 Redis 存储
|
||
|
||
Args:
|
||
redis_client: Redis 客户端实例
|
||
"""
|
||
self.redis = redis_client
|
||
|
||
def get(self, key: str) -> Optional[Dict[str, Any]]:
|
||
"""获取数据"""
|
||
data = self.redis.get(key)
|
||
if data:
|
||
return json.loads(data)
|
||
return None
|
||
|
||
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
|
||
"""设置数据"""
|
||
self.redis.setex(key, ttl, json.dumps(value))
|
||
|
||
def delete(self, key: str) -> None:
|
||
"""删除数据"""
|
||
self.redis.delete(key)
|