refactor(memory): extract memory count sync utility
- Add shared utility for syncing end user memory_count from Neo4j
This commit is contained in:
@@ -20,6 +20,7 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem
|
|||||||
memory_summary_generation
|
memory_summary_generation
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.core.memory.utils.log.logging_utils import log_time
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
|
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
@@ -313,8 +314,27 @@ async def write(
|
|||||||
except Exception as cache_err:
|
except Exception as cache_err:
|
||||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
#同步neo4j记忆节点总数到pgsql,end_user表的memory_count字段
|
# 同步 Neo4j 记忆节点总数到 PostgreSQL end_users.memory_count
|
||||||
await _sync_memory_count_after_write(end_user_id)
|
if end_user_id:
|
||||||
|
try:
|
||||||
|
memory_count_connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
node_count = await sync_end_user_memory_count_from_neo4j(
|
||||||
|
end_user_id,
|
||||||
|
memory_count_connector,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await memory_count_connector.close()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[MemoryCount] 写入后同步 memory_count: "
|
||||||
|
f"end_user_id={end_user_id}, count={node_count}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Close LLM/Embedder underlying httpx clients to prevent
|
# Close LLM/Embedder underlying httpx clients to prevent
|
||||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||||
@@ -335,48 +355,3 @@ async def write(
|
|||||||
logger.info("=== Pipeline Complete ===")
|
logger.info("=== Pipeline Complete ===")
|
||||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
async def _sync_memory_count_after_write(end_user_id: str) -> None:
|
|
||||||
"""
|
|
||||||
记忆写入完成后,查 Neo4j 全量节点数,绝对值同步到 PostgreSQL end_user 表的 memory_count 字段
|
|
||||||
|
|
||||||
不使用增量累加:
|
|
||||||
- Neo4j 写入使用 MERGE 语义,节点列表长度不等于新增节点数。
|
|
||||||
- 重试或重复写入可能匹配已有节点。
|
|
||||||
- 绝对值覆盖可以避免越加越大的计数漂移。
|
|
||||||
"""
|
|
||||||
if not end_user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.models.end_user_model import EndUser
|
|
||||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
|
||||||
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
result = await connector.execute_query(
|
|
||||||
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
|
||||||
end_user_ids=[end_user_id],
|
|
||||||
)
|
|
||||||
node_count = int(result[0]["total"]) if result else 0
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
with get_db_context() as db:
|
|
||||||
db.query(EndUser).filter(
|
|
||||||
EndUser.id == uuid.UUID(end_user_id)
|
|
||||||
).update(
|
|
||||||
{"memory_count": node_count},
|
|
||||||
synchronize_session=False,
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[MemoryCount] 写入后同步 memory_count: "
|
|
||||||
f"end_user_id={end_user_id}, count={node_count}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
@@ -20,6 +20,7 @@ from uuid import UUID
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||||
|
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
|
||||||
@@ -145,8 +146,22 @@ class ForgettingScheduler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
||||||
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
|
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
|
||||||
await self._sync_memory_count_to_db(end_user_id)
|
if end_user_id:
|
||||||
|
try:
|
||||||
|
node_count = await sync_end_user_memory_count_from_neo4j(
|
||||||
|
end_user_id,
|
||||||
|
self.connector,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[MemoryCount] 遗忘后同步 memory_count: "
|
||||||
|
f"end_user_id={end_user_id}, count={node_count}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return report
|
return report
|
||||||
|
|
||||||
# 步骤3:按激活值排序(激活值最低的优先)
|
# 步骤3:按激活值排序(激活值最低的优先)
|
||||||
@@ -303,8 +318,22 @@ class ForgettingScheduler:
|
|||||||
f"({reduction_rate:.2%}), "
|
f"({reduction_rate:.2%}), "
|
||||||
f"耗时 {duration:.2f} 秒"
|
f"耗时 {duration:.2f} 秒"
|
||||||
)
|
)
|
||||||
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
|
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
|
||||||
await self._sync_memory_count_to_db(end_user_id)
|
if end_user_id:
|
||||||
|
try:
|
||||||
|
node_count = await sync_end_user_memory_count_from_neo4j(
|
||||||
|
end_user_id,
|
||||||
|
self.connector,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[MemoryCount] 遗忘后同步 memory_count: "
|
||||||
|
f"end_user_id={end_user_id}, count={node_count}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return report
|
return report
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -352,48 +381,3 @@ class ForgettingScheduler:
|
|||||||
if results:
|
if results:
|
||||||
return results[0]['total']
|
return results[0]['total']
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def _sync_memory_count_to_db(
|
|
||||||
self,
|
|
||||||
end_user_id: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
遗忘周期结束后,用 SEARCH_FOR_ALL_BATCH 口径查全量节点数,
|
|
||||||
同步到 PostgreSQL end_users.memory_count。
|
|
||||||
|
|
||||||
不复用 _count_knowledge_nodes:
|
|
||||||
- _count_knowledge_nodes 只统计 Statement、ExtractedEntity、MemorySummary。
|
|
||||||
- 宿主列表需要统计该 end_user_id 下全部 Neo4j 节点。
|
|
||||||
"""
|
|
||||||
if not end_user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.db import get_db_context
|
|
||||||
from app.models.end_user_model import EndUser
|
|
||||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
|
||||||
|
|
||||||
result = await self.connector.execute_query(
|
|
||||||
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
|
||||||
end_user_ids=[end_user_id],
|
|
||||||
)
|
|
||||||
node_count = int(result[0]["total"]) if result else 0
|
|
||||||
|
|
||||||
with get_db_context() as db:
|
|
||||||
db.query(EndUser).filter(
|
|
||||||
EndUser.id == UUID(end_user_id)
|
|
||||||
).update(
|
|
||||||
{"memory_count": node_count},
|
|
||||||
synchronize_session=False,
|
|
||||||
)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[MemoryCount] 遗忘后同步 memory_count: "
|
|
||||||
f"end_user_id={end_user_id}, count={node_count}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|||||||
36
api/app/core/memory/utils/memory_count_utils.py
Normal file
36
api/app/core/memory/utils/memory_count_utils.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
|
||||||
|
async def sync_end_user_memory_count_from_neo4j(
|
||||||
|
end_user_id: str,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Sync one end user's Neo4j memory node count to PostgreSQL.
|
||||||
|
|
||||||
|
The caller owns the Neo4j connector lifecycle.
|
||||||
|
"""
|
||||||
|
if not end_user_id:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
result = await connector.execute_query(
|
||||||
|
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||||
|
end_user_ids=[end_user_id],
|
||||||
|
)
|
||||||
|
node_count = int(result[0]["total"]) if result else 0
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
db.query(EndUser).filter(
|
||||||
|
EndUser.id == UUID(end_user_id)
|
||||||
|
).update(
|
||||||
|
{"memory_count": node_count},
|
||||||
|
synchronize_session=False,
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return node_count
|
||||||
Reference in New Issue
Block a user