diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 2be18c97..c9346c16 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -176,24 +176,24 @@ class SearchService: r.get("id") for r in community_results if r.get("id") ] if community_ids and end_user_id: + from app.repositories.neo4j.graph_search import search_graph_community_expand + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + expand_connector = Neo4jConnector() try: - from app.repositories.neo4j.graph_search import search_graph_community_expand - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - connector = Neo4jConnector() expand_result = await search_graph_community_expand( - connector=connector, + connector=expand_connector, community_ids=community_ids, end_user_id=end_user_id, limit=10, ) - await connector.close() expanded_stmts = expand_result.get("expanded_statements", []) if expanded_stmts: - # 展开的 statements 插入 communities 之后、statements 之前 answer_list.extend(expanded_stmts) logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements") except Exception as e: logger.warning(f"社区展开检索失败,跳过: {e}") + finally: + await expand_connector.close() # Extract clean content from all results content_list = [ diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index df8752d6..02aa1b44 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -178,13 +178,6 @@ async def write( llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, ) - # 写入成功后,异步触发聚类(不阻塞写入响应) - schedule_clustering_after_write( - all_entity_nodes, - config_id=config_id, - llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, - ) break else: logger.warning("Failed to save some data to Neo4j") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index b4a16734..46a7b8f3 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -116,23 +116,19 @@ class LabelPropagationEngine: """ BATCH_SIZE = 2000 # 每批实体数,可按需调整 - # 先查总数,决定批次数 - total_entities = await self.repo.get_all_entities(end_user_id) - if not total_entities: + # 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段 + total_count = await self.repo.get_entity_count(end_user_id) + if not total_count: logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") return - total_count = len(total_entities) + all_entity_ids = await self.repo.get_all_entity_ids(end_user_id) logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体," f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批") - # labels 跨批次共享:先用全量数据初始化(只存 id,内存极小) - labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities} - # embeddings 也跨批次共享(每个向量 ~6KB,10万实体约 600MB,这是不可避免的) - # 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻 - # 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding, - # 所以只需要当前批次实体的 embedding,不需要全量 - del total_entities # 释放全量列表,后续按批次加载 + # labels 跨批次共享:只存 id→community_id,内存极小 + labels: Dict[str, str] = {eid: eid for eid in all_entity_ids} + del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据 for batch_start in range(0, total_count, BATCH_SIZE): batch_entities = await self.repo.get_entities_page( diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index bf7fde1d..267ced4f 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -13,6 +13,8 @@ from app.repositories.neo4j.cypher_queries import ( ENTITY_LEAVE_ALL_COMMUNITIES, GET_ENTITY_NEIGHBORS, GET_ALL_ENTITIES_FOR_USER, + GET_ENTITY_COUNT_FOR_USER, + GET_ALL_ENTITY_IDS_FOR_USER, GET_ENTITIES_PAGE, GET_COMMUNITY_MEMBERS, GET_ALL_COMMUNITY_MEMBERS_BATCH, @@ -21,7 +23,6 @@ from app.repositories.neo4j.cypher_queries import ( CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_METADATA, - UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -113,6 +114,30 @@ class CommunityRepository: logger.error(f"get_all_entities failed: {e}") return [] + async def get_entity_count(self, end_user_id: str) -> int: + """仅返回用户实体总数,不加载实体数据。""" + try: + result = await self.connector.execute_query( + GET_ENTITY_COUNT_FOR_USER, + end_user_id=end_user_id, + ) + return result[0]["entity_count"] if result else 0 + except Exception as e: + logger.error(f"get_entity_count failed: {e}") + return 0 + + async def get_all_entity_ids(self, end_user_id: str) -> List[str]: + """仅返回用户所有实体 ID 列表,不加载 embedding 等大字段。""" + try: + result = await self.connector.execute_query( + GET_ALL_ENTITY_IDS_FOR_USER, + end_user_id=end_user_id, + ) + return [r["id"] for r in result] + except Exception as e: + logger.error(f"get_all_entity_ids failed: {e}") + return [] + async def get_entities_page( self, end_user_id: str, skip: int, limit: int ) -> List[Dict]: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 339adb43..01f2fa6a 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1122,6 +1122,16 @@ RETURN e.id AS id, CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id """ +GET_ENTITY_COUNT_FOR_USER = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +RETURN count(e) AS entity_count +""" + +GET_ALL_ENTITY_IDS_FOR_USER = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +RETURN e.id AS id +""" + GET_COMMUNITY_MEMBERS = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,