Initial commit
This commit is contained in:
175
app/repositories/neo4j/base_neo4j_repository.py
Normal file
175
app/repositories/neo4j/base_neo4j_repository.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j仓储基类模块
|
||||
|
||||
本模块提供Neo4j仓储的基类实现,封装了通用的Neo4j节点操作。
|
||||
|
||||
Classes:
|
||||
BaseNeo4jRepository: Neo4j仓储基类,实现通用的CRUD操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, TypeVar
|
||||
from app.repositories.base_repository import BaseRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class BaseNeo4jRepository(BaseRepository[T]):
|
||||
"""Neo4j仓储基类 - 实现通用的Neo4j节点操作
|
||||
|
||||
这个基类封装了Neo4j节点的通用CRUD操作,子类只需要实现
|
||||
特定的映射逻辑和业务查询方法。
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签(如"Dialogue", "Statement"等)
|
||||
|
||||
Type Parameters:
|
||||
T: 实体类型,通常是Pydantic模型
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector, node_label: str):
|
||||
"""初始化Neo4j仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
node_label: 节点标签,用于Cypher查询
|
||||
"""
|
||||
self.connector = connector
|
||||
self.node_label = node_label
|
||||
|
||||
async def create(self, entity: T) -> T:
|
||||
"""创建节点
|
||||
|
||||
将实体对象转换为Neo4j节点并保存到数据库。
|
||||
|
||||
Args:
|
||||
entity: 要创建的实体对象
|
||||
|
||||
Returns:
|
||||
T: 创建后的实体对象
|
||||
|
||||
Example:
|
||||
>>> dialog = DialogueNode(id="123", name="对话1", ...)
|
||||
>>> created = await repository.create(dialog)
|
||||
"""
|
||||
query = f"""
|
||||
CREATE (n:{self.node_label} $props)
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.connector.execute_query(
|
||||
query,
|
||||
props=entity.model_dump()
|
||||
)
|
||||
return entity
|
||||
|
||||
async def get_by_id(self, entity_id: str) -> Optional[T]:
|
||||
"""根据ID获取节点
|
||||
|
||||
Args:
|
||||
entity_id: 节点ID
|
||||
|
||||
Returns:
|
||||
Optional[T]: 找到的实体对象,如果不存在则返回None
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
RETURN n
|
||||
"""
|
||||
result = await self.connector.execute_query(query, id=entity_id)
|
||||
if result:
|
||||
return self._map_to_entity(result[0])
|
||||
return None
|
||||
|
||||
async def update(self, entity: T) -> T:
|
||||
"""更新节点
|
||||
|
||||
更新现有节点的属性。使用SET +=语法合并属性。
|
||||
|
||||
Args:
|
||||
entity: 要更新的实体对象(必须包含id字段)
|
||||
|
||||
Returns:
|
||||
T: 更新后的实体对象
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
SET n += $props
|
||||
RETURN n
|
||||
"""
|
||||
await self.connector.execute_query(
|
||||
query,
|
||||
id=entity.id,
|
||||
props=entity.model_dump()
|
||||
)
|
||||
return entity
|
||||
|
||||
async def delete(self, entity_id: str) -> bool:
|
||||
"""删除节点
|
||||
|
||||
删除指定ID的节点。使用DETACH DELETE同时删除相关的边。
|
||||
|
||||
Args:
|
||||
entity_id: 要删除的节点ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,否则返回False
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label} {{id: $id}})
|
||||
DETACH DELETE n
|
||||
RETURN count(n) as deleted
|
||||
"""
|
||||
result = await self.connector.execute_query(query, id=entity_id)
|
||||
return result[0]['deleted'] > 0 if result else False
|
||||
|
||||
async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]:
|
||||
"""查询节点
|
||||
|
||||
根据过滤条件查询节点列表。
|
||||
|
||||
Args:
|
||||
filters: 查询条件字典,键为属性名,值为期望的值
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[T]: 符合条件的实体列表
|
||||
|
||||
Example:
|
||||
>>> results = await repository.find(
|
||||
... {"group_id": "group_123", "user_id": "user_456"},
|
||||
... limit=50
|
||||
... )
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = [f"n.{key} = ${key}" for key in filters.keys()]
|
||||
where_str = " AND ".join(where_clauses) if where_clauses else "1=1"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE {where_str}
|
||||
RETURN n
|
||||
LIMIT $limit
|
||||
"""
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
limit=limit,
|
||||
**filters
|
||||
)
|
||||
return [self._map_to_entity(r) for r in results]
|
||||
|
||||
def _map_to_entity(self, node_data: Dict) -> T:
|
||||
"""将节点数据映射为实体对象
|
||||
|
||||
这是一个抽象方法,子类必须实现具体的映射逻辑。
|
||||
|
||||
Args:
|
||||
node_data: 从Neo4j查询返回的节点数据字典
|
||||
|
||||
Returns:
|
||||
T: 映射后的实体对象
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 如果子类未实现此方法
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _map_to_entity method")
|
||||
Reference in New Issue
Block a user