Files
MemoryBear/api/app/repositories/neo4j/neo4j_connector.py
lanceyq 34387e1f23 fix(memory): add coalesce defaults for activation_value and related node properties
- Add coalesce fallbacks for importance_score, activation_value, and
access_count in statement, entity, weak entity, strong entity, and
memory summary MERGE queries to prevent null values on new nodes
- Set activation_value default to coalesce(importance_score, 0.5) for
consistency with the forgetting/activation scoring logic
- Suppress Neo4j UNRECOGNIZED property key notifications in driver init
since missing keys like last_access_time and activation_value are
expected for newly created nodes
2026-04-17 13:54:12 +08:00

175 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""Neo4j连接器模块
本模块提供Neo4j图数据库的连接和查询功能。
从 app/core/memory/src/database/neo4j_connector.py 迁移而来。
Classes:
Neo4jConnector: Neo4j数据库连接器提供异步查询接口
"""
from typing import Any, List, Dict
import logging
from neo4j import AsyncGraphDatabase, basic_auth
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
from app.core.config import settings
logger = logging.getLogger(__name__)
def _convert_neo4j_types(value: Any) -> Any:
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
if isinstance(value, Neo4jDateTime):
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
if isinstance(value, Neo4jDate):
return value.iso_format()
if isinstance(value, Neo4jTime):
return value.iso_format()
if isinstance(value, Neo4jDuration):
return str(value)
if isinstance(value, dict):
return {k: _convert_neo4j_types(v) for k, v in value.items()}
if isinstance(value, list):
return [_convert_neo4j_types(item) for item in value]
return value
class Neo4jConnector:
"""Neo4j数据库连接器
提供与Neo4j图数据库的连接和查询功能。
使用异步驱动程序以支持高并发操作。
Attributes:
driver: Neo4j异步驱动程序实例
Methods:
close: 关闭数据库连接
execute_query: 执行Cypher查询
delete_group: 删除指定组的所有数据
"""
def __init__(self):
"""初始化Neo4j连接器
从配置文件和环境变量中读取连接信息。
Raises:
RuntimeError: 如果NEO4J_PASSWORD环境变量未设置
"""
# 从全局配置和环境变量获取 Neo4j 配置
uri = settings.NEO4J_URI
username = settings.NEO4J_USERNAME
password = settings.NEO4J_PASSWORD
if not password:
raise RuntimeError(
"NEO4J_PASSWORD is not set. Create a .env with NEO4J_PASSWORD or export it before running."
)
self.driver = AsyncGraphDatabase.driver(
uri,
auth=basic_auth(username, password),
# 抑制属性键不存在的 UNRECOGNIZED 分类通知警告(如 01N52
# last_access_time 等属性在节点被检索命中后才写入,
# activation_value 在新节点创建后可能尚未被计算,
# 全新数据库或清空数据后这些属性键不存在是正常业务行为
notifications_disabled_classifications=["UNRECOGNIZED"],
)
async def close(self):
"""关闭数据库连接
释放数据库连接资源。应在应用程序关闭时调用。
"""
await self.driver.close()
async def execute_query(self, cypher: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询
Args:
cypher: Cypher查询语句
json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询
Returns:
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
Example:
"""
result = await self.driver.execute_query(
cypher,
database="neo4j",
**kwargs
)
records, summary, keys = result
if json_format:
return [_convert_neo4j_types(record.data()) for record in records]
else:
return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作
提供显式事务支持,确保操作的原子性。
如果事务函数抛出异常,所有更改将自动回滚。
Args:
transaction_func: 事务函数,接收 tx 参数并执行查询
**kwargs: 传递给事务函数的额外参数
Returns:
Any: 事务函数的返回值
Example:
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_write(transaction_func, **kwargs)
async def execute_read_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在读事务中执行操作
提供显式事务支持用于读操作。
Args:
transaction_func: 事务函数,接收 tx 参数并执行查询
**kwargs: 传递给事务函数的额外参数
Returns:
Any: 事务函数的返回值
Example:
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_read(transaction_func, **kwargs)
async def delete_group(self, end_user_id: str):
"""删除指定组的所有数据
删除所有属于指定end_user_id的节点和边。
这是一个危险操作,会永久删除数据。
Args:
end_user_id: 要删除的组ID
Example:
Group group_123 deleted.
"""
# 删除节点DETACH DELETE会同时删除相关的边
await self.driver.execute_query(
"MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n",
database="neo4j",
end_user_id=end_user_id
)
# 删除独立的边(如果有的话)
await self.driver.execute_query(
"MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r",
database="neo4j",
end_user_id=end_user_id
)
print(f"Group {end_user_id} deleted.")