refactor(memory): extract memory count sync utility

- Add shared utility for syncing end user memory_count from Neo4j
This commit is contained in:
miao
2026-04-29 18:35:49 +08:00
parent f86c023477
commit 80902eb79a
3 changed files with 91 additions and 96 deletions

View File

@@ -20,6 +20,7 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem
memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
@@ -313,8 +314,27 @@ async def write(
except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
#同步neo4j记忆节点总数到pgsqlend_user表的memory_count字段
await _sync_memory_count_after_write(end_user_id)
# 同步 Neo4j 记忆节点总数到 PostgreSQL end_users.memory_count
if end_user_id:
try:
memory_count_connector = Neo4jConnector()
try:
node_count = await sync_end_user_memory_count_from_neo4j(
end_user_id,
memory_count_connector,
)
finally:
await memory_count_connector.close()
logger.info(
f"[MemoryCount] 写入后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)
# Close LLM/Embedder underlying httpx clients to prevent
# 'RuntimeError: Event loop is closed' during garbage collection
@@ -335,48 +355,3 @@ async def write(
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
async def _sync_memory_count_after_write(end_user_id: str) -> None:
"""
记忆写入完成后,查 Neo4j 全量节点数,绝对值同步到 PostgreSQL end_user 表的 memory_count 字段
不使用增量累加:
- Neo4j 写入使用 MERGE 语义,节点列表长度不等于新增节点数。
- 重试或重复写入可能匹配已有节点。
- 绝对值覆盖可以避免越加越大的计数漂移。
"""
if not end_user_id:
return
try:
from app.models.end_user_model import EndUser
from app.repositories.memory_config_repository import MemoryConfigRepository
connector = Neo4jConnector()
try:
result = await connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=[end_user_id],
)
node_count = int(result[0]["total"]) if result else 0
finally:
await connector.close()
with get_db_context() as db:
db.query(EndUser).filter(
EndUser.id == uuid.UUID(end_user_id)
).update(
{"memory_count": node_count},
synchronize_session=False,
)
db.commit()
logger.info(
f"[MemoryCount] 写入后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)

View File

@@ -20,6 +20,7 @@ from uuid import UUID
from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -145,8 +146,22 @@ class ForgettingScheduler:
}
logger.info("没有可遗忘的节点对,遗忘周期结束")
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
await self._sync_memory_count_to_db(end_user_id)
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
if end_user_id:
try:
node_count = await sync_end_user_memory_count_from_neo4j(
end_user_id,
self.connector,
)
logger.info(
f"[MemoryCount] 遗忘后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)
return report
# 步骤3按激活值排序激活值最低的优先
@@ -303,8 +318,22 @@ class ForgettingScheduler:
f"({reduction_rate:.2%}), "
f"耗时 {duration:.2f}"
)
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
await self._sync_memory_count_to_db(end_user_id)
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
if end_user_id:
try:
node_count = await sync_end_user_memory_count_from_neo4j(
end_user_id,
self.connector,
)
logger.info(
f"[MemoryCount] 遗忘后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)
return report
except Exception as e:
@@ -352,48 +381,3 @@ class ForgettingScheduler:
if results:
return results[0]['total']
return 0
async def _sync_memory_count_to_db(
self,
end_user_id: Optional[str] = None,
) -> None:
"""
遗忘周期结束后,用 SEARCH_FOR_ALL_BATCH 口径查全量节点数,
同步到 PostgreSQL end_users.memory_count。
不复用 _count_knowledge_nodes
- _count_knowledge_nodes 只统计 Statement、ExtractedEntity、MemorySummary。
- 宿主列表需要统计该 end_user_id 下全部 Neo4j 节点。
"""
if not end_user_id:
return
try:
from app.db import get_db_context
from app.models.end_user_model import EndUser
from app.repositories.memory_config_repository import MemoryConfigRepository
result = await self.connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=[end_user_id],
)
node_count = int(result[0]["total"]) if result else 0
with get_db_context() as db:
db.query(EndUser).filter(
EndUser.id == UUID(end_user_id)
).update(
{"memory_count": node_count},
synchronize_session=False,
)
db.commit()
logger.info(
f"[MemoryCount] 遗忘后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)

View File

@@ -0,0 +1,36 @@
from uuid import UUID
from app.db import get_db_context
from app.models.end_user_model import EndUser
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def sync_end_user_memory_count_from_neo4j(
end_user_id: str,
connector: Neo4jConnector,
) -> int:
"""
Sync one end user's Neo4j memory node count to PostgreSQL.
The caller owns the Neo4j connector lifecycle.
"""
if not end_user_id:
return 0
result = await connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=[end_user_id],
)
node_count = int(result[0]["total"]) if result else 0
with get_db_context() as db:
db.query(EndUser).filter(
EndUser.id == UUID(end_user_id)
).update(
{"memory_count": node_count},
synchronize_session=False,
)
db.commit()
return node_count