[Changes]

This commit is contained in:
lanceyq
2026-03-16 14:05:12 +08:00
parent 6d79db8ba3
commit f32d92b9d0
5 changed files with 49 additions and 25 deletions

View File

@@ -176,24 +176,24 @@ class SearchService:
r.get("id") for r in community_results if r.get("id") r.get("id") for r in community_results if r.get("id")
] ]
if community_ids and end_user_id: if community_ids and end_user_id:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
expand_connector = Neo4jConnector()
try: try:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
expand_result = await search_graph_community_expand( expand_result = await search_graph_community_expand(
connector=connector, connector=expand_connector,
community_ids=community_ids, community_ids=community_ids,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=10, limit=10,
) )
await connector.close()
expanded_stmts = expand_result.get("expanded_statements", []) expanded_stmts = expand_result.get("expanded_statements", [])
if expanded_stmts: if expanded_stmts:
# 展开的 statements 插入 communities 之后、statements 之前
answer_list.extend(expanded_stmts) answer_list.extend(expanded_stmts)
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements") logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
except Exception as e: except Exception as e:
logger.warning(f"社区展开检索失败,跳过: {e}") logger.warning(f"社区展开检索失败,跳过: {e}")
finally:
await expand_connector.close()
# Extract clean content from all results # Extract clean content from all results
content_list = [ content_list = [

View File

@@ -178,13 +178,6 @@ async def write(
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
) )
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break break
else: else:
logger.warning("Failed to save some data to Neo4j") logger.warning("Failed to save some data to Neo4j")

View File

@@ -116,23 +116,19 @@ class LabelPropagationEngine:
""" """
BATCH_SIZE = 2000 # 每批实体数,可按需调整 BATCH_SIZE = 2000 # 每批实体数,可按需调整
# 先查总数,决定批次数 # 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
total_entities = await self.repo.get_all_entities(end_user_id) total_count = await self.repo.get_entity_count(end_user_id)
if not total_entities: if not total_count:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return return
total_count = len(total_entities) all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体," logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体,"
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}") f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}")
# labels 跨批次共享:先用全量数据初始化(只存 id内存极小 # labels 跨批次共享:只存 id→community_id内存极小
labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities} labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
# embeddings 也跨批次共享(每个向量 ~6KB10万实体约 600MB这是不可避免的 del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
# 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻
# 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding
# 所以只需要当前批次实体的 embedding不需要全量
del total_entities # 释放全量列表,后续按批次加载
for batch_start in range(0, total_count, BATCH_SIZE): for batch_start in range(0, total_count, BATCH_SIZE):
batch_entities = await self.repo.get_entities_page( batch_entities = await self.repo.get_entities_page(

View File

@@ -13,6 +13,8 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_LEAVE_ALL_COMMUNITIES, ENTITY_LEAVE_ALL_COMMUNITIES,
GET_ENTITY_NEIGHBORS, GET_ENTITY_NEIGHBORS,
GET_ALL_ENTITIES_FOR_USER, GET_ALL_ENTITIES_FOR_USER,
GET_ENTITY_COUNT_FOR_USER,
GET_ALL_ENTITY_IDS_FOR_USER,
GET_ENTITIES_PAGE, GET_ENTITIES_PAGE,
GET_COMMUNITY_MEMBERS, GET_COMMUNITY_MEMBERS,
GET_ALL_COMMUNITY_MEMBERS_BATCH, GET_ALL_COMMUNITY_MEMBERS_BATCH,
@@ -21,7 +23,6 @@ from app.repositories.neo4j.cypher_queries import (
CHECK_USER_HAS_COMMUNITIES, CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA, UPDATE_COMMUNITY_METADATA,
UPDATE_COMMUNITY_METADATA,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -113,6 +114,30 @@ class CommunityRepository:
logger.error(f"get_all_entities failed: {e}") logger.error(f"get_all_entities failed: {e}")
return [] return []
async def get_entity_count(self, end_user_id: str) -> int:
"""仅返回用户实体总数,不加载实体数据。"""
try:
result = await self.connector.execute_query(
GET_ENTITY_COUNT_FOR_USER,
end_user_id=end_user_id,
)
return result[0]["entity_count"] if result else 0
except Exception as e:
logger.error(f"get_entity_count failed: {e}")
return 0
async def get_all_entity_ids(self, end_user_id: str) -> List[str]:
"""仅返回用户所有实体 ID 列表,不加载 embedding 等大字段。"""
try:
result = await self.connector.execute_query(
GET_ALL_ENTITY_IDS_FOR_USER,
end_user_id=end_user_id,
)
return [r["id"] for r in result]
except Exception as e:
logger.error(f"get_all_entity_ids failed: {e}")
return []
async def get_entities_page( async def get_entities_page(
self, end_user_id: str, skip: int, limit: int self, end_user_id: str, skip: int, limit: int
) -> List[Dict]: ) -> List[Dict]:

View File

@@ -1122,6 +1122,16 @@ RETURN e.id AS id,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
""" """
GET_ENTITY_COUNT_FOR_USER = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
RETURN count(e) AS entity_count
"""
GET_ALL_ENTITY_IDS_FOR_USER = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
RETURN e.id AS id
"""
GET_COMMUNITY_MEMBERS = """ GET_COMMUNITY_MEMBERS = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,