Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
"""对话仓储模块
本模块提供对话节点的数据访问功能。
Classes:
DialogRepository: 对话仓储管理DialogueNode的CRUD操作
"""
from typing import List, Optional, Dict
from datetime import datetime
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.core.memory.models.graph_models import DialogueNode
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
class DialogRepository(BaseNeo4jRepository[DialogueNode]):
"""对话仓储
管理对话节点的创建、查询、更新和删除操作。
提供按group_id、user_id、ref_id等条件查询对话的方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签,固定为"Dialogue"
"""
def __init__(self, connector: Neo4jConnector):
"""初始化对话仓储
Args:
connector: Neo4j连接器实例
"""
super().__init__(connector, "Dialogue")
def _map_to_entity(self, node_data: Dict) -> DialogueNode:
"""将节点数据映射为对话实体
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
DialogueNode: 对话实体对象
"""
# 从查询结果中提取节点数据
n = node_data.get('n', node_data)
# 处理datetime字段
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
return DialogueNode(**n)
async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据group_id查询对话
Args:
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"group_id": group_id}, limit=limit)
async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]:
"""根据user_id查询对话
Args:
user_id: 用户ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"user_id": user_id}, limit=limit)
async def find_by_ref_id(self, ref_id: str) -> Optional[DialogueNode]:
"""根据ref_id查询对话
ref_id是外部对话系统的引用ID通常是唯一的。
Args:
ref_id: 引用ID
Returns:
Optional[DialogueNode]: 找到的对话如果不存在则返回None
"""
results = await self.find({"ref_id": ref_id}, limit=1)
return results[0] if results else None
async def find_by_group_and_user(
self,
group_id: str,
user_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据group_id和user_id查询对话
Args:
group_id: 组ID
user_id: 用户ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"group_id": group_id, "user_id": user_id},
limit=limit
)
async def find_recent_dialogs(
self,
group_id: str,
days: int = 7,
limit: int = 100
) -> List[DialogueNode]:
"""查询最近的对话
Args:
group_id: 组ID
days: 查询最近多少天的对话
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表,按创建时间倒序排列
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
days=days,
limit=limit
)
return [self._map_to_entity(r) for r in results]
async def find_by_config_id(
self,
config_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id查询对话
Args:
config_id: 配置ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find({"config_id": config_id}, limit=limit)
async def find_by_config_and_group(
self,
config_id: str,
group_id: str,
limit: int = 100
) -> List[DialogueNode]:
"""根据config_id和group_id查询对话
支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。
Args:
config_id: 配置ID
group_id: 组ID
limit: 返回结果的最大数量
Returns:
List[DialogueNode]: 对话列表
"""
return await self.find(
{"config_id": config_id, "group_id": group_id},
limit=limit
)