diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 9b0be9c8..1dcc73b2 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -20,6 +20,7 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time +from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes @@ -313,8 +314,27 @@ async def write( except Exception as cache_err: logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) - #同步neo4j记忆节点总数到pgsql,end_user表的memory_count字段 - await _sync_memory_count_after_write(end_user_id) + # 同步 Neo4j 记忆节点总数到 PostgreSQL end_users.memory_count + if end_user_id: + try: + memory_count_connector = Neo4jConnector() + try: + node_count = await sync_end_user_memory_count_from_neo4j( + end_user_id, + memory_count_connector, + ) + finally: + await memory_count_connector.close() + + logger.info( + f"[MemoryCount] 写入后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) # Close LLM/Embedder underlying httpx clients to prevent # 'RuntimeError: Event loop is closed' during garbage collection @@ -335,48 +355,3 @@ async def write( logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") - -async def _sync_memory_count_after_write(end_user_id: str) -> None: - """ - 记忆写入完成后,查 Neo4j 全量节点数,绝对值同步到 PostgreSQL end_user 表的 memory_count 字段 - - 不使用增量累加: - - Neo4j 写入使用 MERGE 语义,节点列表长度不等于新增节点数。 - - 重试或重复写入可能匹配已有节点。 - - 绝对值覆盖可以避免越加越大的计数漂移。 - """ - if not end_user_id: - return - - try: - from app.models.end_user_model import EndUser - from app.repositories.memory_config_repository import MemoryConfigRepository - - connector = Neo4jConnector() - try: - result = await connector.execute_query( - MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, - end_user_ids=[end_user_id], - ) - node_count = int(result[0]["total"]) if result else 0 - finally: - await connector.close() - - with get_db_context() as db: - db.query(EndUser).filter( - EndUser.id == uuid.UUID(end_user_id) - ).update( - {"memory_count": node_count}, - synchronize_session=False, - ) - db.commit() - - logger.info( - f"[MemoryCount] 写入后同步 memory_count: " - f"end_user_id={end_user_id}, count={node_count}" - ) - except Exception as e: - logger.warning( - f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}", - exc_info=True, - ) \ No newline at end of file diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py index cad6a4db..39c9eed6 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py @@ -20,6 +20,7 @@ from uuid import UUID from datetime import datetime from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy +from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -145,8 +146,22 @@ class ForgettingScheduler: } logger.info("没有可遗忘的节点对,遗忘周期结束") - # 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段 - await self._sync_memory_count_to_db(end_user_id) + # 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count + if end_user_id: + try: + node_count = await sync_end_user_memory_count_from_neo4j( + end_user_id, + self.connector, + ) + logger.info( + f"[MemoryCount] 遗忘后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) return report # 步骤3:按激活值排序(激活值最低的优先) @@ -303,8 +318,22 @@ class ForgettingScheduler: f"({reduction_rate:.2%}), " f"耗时 {duration:.2f} 秒" ) - # 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段 - await self._sync_memory_count_to_db(end_user_id) + # 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count + if end_user_id: + try: + node_count = await sync_end_user_memory_count_from_neo4j( + end_user_id, + self.connector, + ) + logger.info( + f"[MemoryCount] 遗忘后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) return report except Exception as e: @@ -352,48 +381,3 @@ class ForgettingScheduler: if results: return results[0]['total'] return 0 - - async def _sync_memory_count_to_db( - self, - end_user_id: Optional[str] = None, - ) -> None: - """ - 遗忘周期结束后,用 SEARCH_FOR_ALL_BATCH 口径查全量节点数, - 同步到 PostgreSQL end_users.memory_count。 - - 不复用 _count_knowledge_nodes: - - _count_knowledge_nodes 只统计 Statement、ExtractedEntity、MemorySummary。 - - 宿主列表需要统计该 end_user_id 下全部 Neo4j 节点。 - """ - if not end_user_id: - return - - try: - from app.db import get_db_context - from app.models.end_user_model import EndUser - from app.repositories.memory_config_repository import MemoryConfigRepository - - result = await self.connector.execute_query( - MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, - end_user_ids=[end_user_id], - ) - node_count = int(result[0]["total"]) if result else 0 - - with get_db_context() as db: - db.query(EndUser).filter( - EndUser.id == UUID(end_user_id) - ).update( - {"memory_count": node_count}, - synchronize_session=False, - ) - db.commit() - - logger.info( - f"[MemoryCount] 遗忘后同步 memory_count: " - f"end_user_id={end_user_id}, count={node_count}" - ) - except Exception as e: - logger.warning( - f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}", - exc_info=True, - ) diff --git a/api/app/core/memory/utils/memory_count_utils.py b/api/app/core/memory/utils/memory_count_utils.py new file mode 100644 index 00000000..316cb635 --- /dev/null +++ b/api/app/core/memory/utils/memory_count_utils.py @@ -0,0 +1,36 @@ +from uuid import UUID + +from app.db import get_db_context +from app.models.end_user_model import EndUser +from app.repositories.memory_config_repository import MemoryConfigRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + +async def sync_end_user_memory_count_from_neo4j( + end_user_id: str, + connector: Neo4jConnector, +) -> int: + """ + Sync one end user's Neo4j memory node count to PostgreSQL. + + The caller owns the Neo4j connector lifecycle. + """ + if not end_user_id: + return 0 + + result = await connector.execute_query( + MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, + end_user_ids=[end_user_id], + ) + node_count = int(result[0]["total"]) if result else 0 + + with get_db_context() as db: + db.query(EndUser).filter( + EndUser.id == UUID(end_user_id) + ).update( + {"memory_count": node_count}, + synchronize_session=False, + ) + db.commit() + + return node_count