Files
MemoryBear/api/app/services/conversation_state_manager.py
2025-12-15 14:09:43 +08:00

262 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""会话状态管理器 - 解决多轮对话路由错乱"""
import json
from typing import Optional, Dict, Any, List
from datetime import datetime
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ConversationStateManager:
"""会话状态管理器
用于管理多轮对话中的会话状态,包括:
- 当前使用的 Agent
- 路由历史
- 主题追踪
- Agent 切换统计
"""
def __init__(self, storage_backend: Optional[Any] = None):
"""初始化状态管理器
Args:
storage_backend: 存储后端Redis/内存等)
"""
self.storage = storage_backend or InMemoryStorage()
self.ttl = 3600 # 1小时过期
def get_state(self, conversation_id: str) -> Dict[str, Any]:
"""获取会话状态
Args:
conversation_id: 会话 ID
Returns:
会话状态字典
"""
state = self.storage.get(f"conv_state:{conversation_id}")
if not state:
logger.info(f"创建新会话状态: {conversation_id}")
return self._create_new_state(conversation_id)
return state
def update_state(
self,
conversation_id: str,
agent_id: str,
message: str,
topic: Optional[str] = None,
confidence: float = 1.0
) -> Dict[str, Any]:
"""更新会话状态
Args:
conversation_id: 会话 ID
agent_id: 当前 Agent ID
message: 用户消息
topic: 消息主题
confidence: 路由置信度
Returns:
更新后的状态
"""
state = self.get_state(conversation_id)
# 检测 Agent 切换
agent_changed = False
if state["current_agent_id"] and state["current_agent_id"] != agent_id:
agent_changed = True
state["switch_count"] += 1
state["previous_agent_id"] = state["current_agent_id"]
state["same_agent_turns"] = 0
logger.info(
"Agent 切换",
extra={
"conversation_id": conversation_id,
"from": state["current_agent_id"],
"to": agent_id,
"switch_count": state["switch_count"]
}
)
else:
state["same_agent_turns"] += 1
# 更新当前 Agent
state["current_agent_id"] = agent_id
state["last_message"] = message
state["last_topic"] = topic
state["updated_at"] = datetime.now().isoformat()
# 添加到历史
history_item = {
"message": message[:100], # 截断长消息
"agent_id": agent_id,
"topic": topic,
"confidence": confidence,
"agent_changed": agent_changed,
"timestamp": datetime.now().isoformat()
}
state["routing_history"].append(history_item)
# 保持最近 10 条历史
if len(state["routing_history"]) > 10:
state["routing_history"] = state["routing_history"][-10:]
# 保存状态
self.storage.set(
f"conv_state:{conversation_id}",
state,
ttl=self.ttl
)
return state
def clear_state(self, conversation_id: str) -> None:
"""清除会话状态
Args:
conversation_id: 会话 ID
"""
self.storage.delete(f"conv_state:{conversation_id}")
logger.info(f"清除会话状态: {conversation_id}")
def get_routing_history(
self,
conversation_id: str,
limit: int = 10
) -> List[Dict[str, Any]]:
"""获取路由历史
Args:
conversation_id: 会话 ID
limit: 返回数量限制
Returns:
路由历史列表
"""
state = self.get_state(conversation_id)
history = state.get("routing_history", [])
return history[-limit:] if history else []
def get_statistics(self, conversation_id: str) -> Dict[str, Any]:
"""获取会话统计信息
Args:
conversation_id: 会话 ID
Returns:
统计信息
"""
state = self.get_state(conversation_id)
history = state.get("routing_history", [])
# 统计各 Agent 使用次数
agent_usage = {}
for item in history:
agent_id = item["agent_id"]
agent_usage[agent_id] = agent_usage.get(agent_id, 0) + 1
# 统计主题分布
topic_distribution = {}
for item in history:
topic = item.get("topic", "未知")
topic_distribution[topic] = topic_distribution.get(topic, 0) + 1
return {
"conversation_id": conversation_id,
"total_turns": len(history),
"switch_count": state.get("switch_count", 0),
"current_agent_id": state.get("current_agent_id"),
"same_agent_turns": state.get("same_agent_turns", 0),
"agent_usage": agent_usage,
"topic_distribution": topic_distribution,
"created_at": state.get("created_at"),
"updated_at": state.get("updated_at")
}
def _create_new_state(self, conversation_id: str) -> Dict[str, Any]:
"""创建新的会话状态
Args:
conversation_id: 会话 ID
Returns:
新的状态字典
"""
state = {
"conversation_id": conversation_id,
"current_agent_id": None,
"previous_agent_id": None,
"routing_history": [],
"last_message": None,
"last_topic": None,
"switch_count": 0,
"same_agent_turns": 0,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat()
}
# 保存初始状态
self.storage.set(
f"conv_state:{conversation_id}",
state,
ttl=self.ttl
)
return state
class InMemoryStorage:
"""内存存储后端(用于开发和测试)"""
def __init__(self):
self._storage: Dict[str, Dict[str, Any]] = {}
def get(self, key: str) -> Optional[Dict[str, Any]]:
"""获取数据"""
return self._storage.get(key)
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
"""设置数据"""
self._storage[key] = value
def delete(self, key: str) -> None:
"""删除数据"""
if key in self._storage:
del self._storage[key]
def clear(self) -> None:
"""清空所有数据"""
self._storage.clear()
class RedisStorage:
"""Redis 存储后端(用于生产环境)"""
def __init__(self, redis_client):
"""初始化 Redis 存储
Args:
redis_client: Redis 客户端实例
"""
self.redis = redis_client
def get(self, key: str) -> Optional[Dict[str, Any]]:
"""获取数据"""
data = self.redis.get(key)
if data:
return json.loads(data)
return None
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
"""设置数据"""
self.redis.setex(key, ttl, json.dumps(value))
def delete(self, key: str) -> None:
"""删除数据"""
self.redis.delete(key)