Files
MemoryBear/app/repositories/neo4j/base_neo4j_repository.py
2025-11-30 18:22:17 +08:00

176 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""Neo4j仓储基类模块
本模块提供Neo4j仓储的基类实现封装了通用的Neo4j节点操作。
Classes:
BaseNeo4jRepository: Neo4j仓储基类实现通用的CRUD操作
"""
from typing import List, Optional, Dict, Any, TypeVar
from app.repositories.base_repository import BaseRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
T = TypeVar('T')
class BaseNeo4jRepository(BaseRepository[T]):
"""Neo4j仓储基类 - 实现通用的Neo4j节点操作
这个基类封装了Neo4j节点的通用CRUD操作子类只需要实现
特定的映射逻辑和业务查询方法。
Attributes:
connector: Neo4j连接器实例
node_label: 节点标签(如"Dialogue", "Statement"等)
Type Parameters:
T: 实体类型通常是Pydantic模型
"""
def __init__(self, connector: Neo4jConnector, node_label: str):
"""初始化Neo4j仓储
Args:
connector: Neo4j连接器实例
node_label: 节点标签用于Cypher查询
"""
self.connector = connector
self.node_label = node_label
async def create(self, entity: T) -> T:
"""创建节点
将实体对象转换为Neo4j节点并保存到数据库。
Args:
entity: 要创建的实体对象
Returns:
T: 创建后的实体对象
Example:
>>> dialog = DialogueNode(id="123", name="对话1", ...)
>>> created = await repository.create(dialog)
"""
query = f"""
CREATE (n:{self.node_label} $props)
RETURN n
"""
result = await self.connector.execute_query(
query,
props=entity.model_dump()
)
return entity
async def get_by_id(self, entity_id: str) -> Optional[T]:
"""根据ID获取节点
Args:
entity_id: 节点ID
Returns:
Optional[T]: 找到的实体对象如果不存在则返回None
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
RETURN n
"""
result = await self.connector.execute_query(query, id=entity_id)
if result:
return self._map_to_entity(result[0])
return None
async def update(self, entity: T) -> T:
"""更新节点
更新现有节点的属性。使用SET +=语法合并属性。
Args:
entity: 要更新的实体对象必须包含id字段
Returns:
T: 更新后的实体对象
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
SET n += $props
RETURN n
"""
await self.connector.execute_query(
query,
id=entity.id,
props=entity.model_dump()
)
return entity
async def delete(self, entity_id: str) -> bool:
"""删除节点
删除指定ID的节点。使用DETACH DELETE同时删除相关的边。
Args:
entity_id: 要删除的节点ID
Returns:
bool: 删除成功返回True否则返回False
"""
query = f"""
MATCH (n:{self.node_label} {{id: $id}})
DETACH DELETE n
RETURN count(n) as deleted
"""
result = await self.connector.execute_query(query, id=entity_id)
return result[0]['deleted'] > 0 if result else False
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
"""查询节点
根据过滤条件查询节点列表。
Args:
filters: 查询条件字典,键为属性名,值为期望的值
limit: 返回结果的最大数量
Returns:
List[T]: 符合条件的实体列表
Example:
>>> results = await repository.find(
... {"group_id": "group_123", "user_id": "user_456"},
... limit=50
... )
"""
# 构建查询条件
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
query = f"""
MATCH (n:{self.node_label})
WHERE {where_str}
RETURN n
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
limit=limit,
**filters
)
return [self._map_to_entity(r) for r in results]
def _map_to_entity(self, node_data: Dict) -> T:
"""将节点数据映射为实体对象
这是一个抽象方法,子类必须实现具体的映射逻辑。
Args:
node_data: 从Neo4j查询返回的节点数据字典
Returns:
T: 映射后的实体对象
Raises:
NotImplementedError: 如果子类未实现此方法
"""
raise NotImplementedError("Subclasses must implement _map_to_entity method")