Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

View File

@@ -0,0 +1,261 @@
"""会话状态管理器 - 解决多轮对话路由错乱"""
import json
from typing import Optional, Dict, Any, List
from datetime import datetime
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ConversationStateManager:
"""会话状态管理器
用于管理多轮对话中的会话状态,包括:
- 当前使用的 Agent
- 路由历史
- 主题追踪
- Agent 切换统计
"""
def __init__(self, storage_backend: Optional[Any] = None):
"""初始化状态管理器
Args:
storage_backend: 存储后端Redis/内存等)
"""
self.storage = storage_backend or InMemoryStorage()
self.ttl = 3600 # 1小时过期
def get_state(self, conversation_id: str) -> Dict[str, Any]:
"""获取会话状态
Args:
conversation_id: 会话 ID
Returns:
会话状态字典
"""
state = self.storage.get(f"conv_state:{conversation_id}")
if not state:
logger.info(f"创建新会话状态: {conversation_id}")
return self._create_new_state(conversation_id)
return state
def update_state(
self,
conversation_id: str,
agent_id: str,
message: str,
topic: Optional[str] = None,
confidence: float = 1.0
) -> Dict[str, Any]:
"""更新会话状态
Args:
conversation_id: 会话 ID
agent_id: 当前 Agent ID
message: 用户消息
topic: 消息主题
confidence: 路由置信度
Returns:
更新后的状态
"""
state = self.get_state(conversation_id)
# 检测 Agent 切换
agent_changed = False
if state["current_agent_id"] and state["current_agent_id"] != agent_id:
agent_changed = True
state["switch_count"] += 1
state["previous_agent_id"] = state["current_agent_id"]
state["same_agent_turns"] = 0
logger.info(
f"Agent 切换",
extra={
"conversation_id": conversation_id,
"from": state["current_agent_id"],
"to": agent_id,
"switch_count": state["switch_count"]
}
)
else:
state["same_agent_turns"] += 1
# 更新当前 Agent
state["current_agent_id"] = agent_id
state["last_message"] = message
state["last_topic"] = topic
state["updated_at"] = datetime.now().isoformat()
# 添加到历史
history_item = {
"message": message[:100], # 截断长消息
"agent_id": agent_id,
"topic": topic,
"confidence": confidence,
"agent_changed": agent_changed,
"timestamp": datetime.now().isoformat()
}
state["routing_history"].append(history_item)
# 保持最近 10 条历史
if len(state["routing_history"]) > 10:
state["routing_history"] = state["routing_history"][-10:]
# 保存状态
self.storage.set(
f"conv_state:{conversation_id}",
state,
ttl=self.ttl
)
return state
def clear_state(self, conversation_id: str) -> None:
"""清除会话状态
Args:
conversation_id: 会话 ID
"""
self.storage.delete(f"conv_state:{conversation_id}")
logger.info(f"清除会话状态: {conversation_id}")
def get_routing_history(
self,
conversation_id: str,
limit: int = 10
) -> List[Dict[str, Any]]:
"""获取路由历史
Args:
conversation_id: 会话 ID
limit: 返回数量限制
Returns:
路由历史列表
"""
state = self.get_state(conversation_id)
history = state.get("routing_history", [])
return history[-limit:] if history else []
def get_statistics(self, conversation_id: str) -> Dict[str, Any]:
"""获取会话统计信息
Args:
conversation_id: 会话 ID
Returns:
统计信息
"""
state = self.get_state(conversation_id)
history = state.get("routing_history", [])
# 统计各 Agent 使用次数
agent_usage = {}
for item in history:
agent_id = item["agent_id"]
agent_usage[agent_id] = agent_usage.get(agent_id, 0) + 1
# 统计主题分布
topic_distribution = {}
for item in history:
topic = item.get("topic", "未知")
topic_distribution[topic] = topic_distribution.get(topic, 0) + 1
return {
"conversation_id": conversation_id,
"total_turns": len(history),
"switch_count": state.get("switch_count", 0),
"current_agent_id": state.get("current_agent_id"),
"same_agent_turns": state.get("same_agent_turns", 0),
"agent_usage": agent_usage,
"topic_distribution": topic_distribution,
"created_at": state.get("created_at"),
"updated_at": state.get("updated_at")
}
def _create_new_state(self, conversation_id: str) -> Dict[str, Any]:
"""创建新的会话状态
Args:
conversation_id: 会话 ID
Returns:
新的状态字典
"""
state = {
"conversation_id": conversation_id,
"current_agent_id": None,
"previous_agent_id": None,
"routing_history": [],
"last_message": None,
"last_topic": None,
"switch_count": 0,
"same_agent_turns": 0,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat()
}
# 保存初始状态
self.storage.set(
f"conv_state:{conversation_id}",
state,
ttl=self.ttl
)
return state
class InMemoryStorage:
"""内存存储后端(用于开发和测试)"""
def __init__(self):
self._storage: Dict[str, Dict[str, Any]] = {}
def get(self, key: str) -> Optional[Dict[str, Any]]:
"""获取数据"""
return self._storage.get(key)
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
"""设置数据"""
self._storage[key] = value
def delete(self, key: str) -> None:
"""删除数据"""
if key in self._storage:
del self._storage[key]
def clear(self) -> None:
"""清空所有数据"""
self._storage.clear()
class RedisStorage:
"""Redis 存储后端(用于生产环境)"""
def __init__(self, redis_client):
"""初始化 Redis 存储
Args:
redis_client: Redis 客户端实例
"""
self.redis = redis_client
def get(self, key: str) -> Optional[Dict[str, Any]]:
"""获取数据"""
data = self.redis.get(key)
if data:
return json.loads(data)
return None
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
"""设置数据"""
self.redis.setex(key, ttl, json.dumps(value))
def delete(self, key: str) -> None:
"""删除数据"""
self.redis.delete(key)