[Changes]

This commit is contained in:
lanceyq
2026-03-16 14:05:12 +08:00
parent 6d79db8ba3
commit f32d92b9d0
5 changed files with 49 additions and 25 deletions

View File

@@ -176,24 +176,24 @@ class SearchService:
r.get("id") for r in community_results if r.get("id")
]
if community_ids and end_user_id:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
expand_connector = Neo4jConnector()
try:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
expand_result = await search_graph_community_expand(
connector=connector,
connector=expand_connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
await connector.close()
expanded_stmts = expand_result.get("expanded_statements", [])
if expanded_stmts:
# 展开的 statements 插入 communities 之后、statements 之前
answer_list.extend(expanded_stmts)
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
except Exception as e:
logger.warning(f"社区展开检索失败,跳过: {e}")
finally:
await expand_connector.close()
# Extract clean content from all results
content_list = [

View File

@@ -178,13 +178,6 @@ async def write(
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break
else:
logger.warning("Failed to save some data to Neo4j")

View File

@@ -116,23 +116,19 @@ class LabelPropagationEngine:
"""
BATCH_SIZE = 2000 # 每批实体数,可按需调整
# 先查总数,决定批次数
total_entities = await self.repo.get_all_entities(end_user_id)
if not total_entities:
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
total_count = await self.repo.get_entity_count(end_user_id)
if not total_count:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return
total_count = len(total_entities)
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体,"
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}")
# labels 跨批次共享:先用全量数据初始化(只存 id内存极小
labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities}
# embeddings 也跨批次共享(每个向量 ~6KB10万实体约 600MB这是不可避免的
# 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻
# 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding
# 所以只需要当前批次实体的 embedding不需要全量
del total_entities # 释放全量列表,后续按批次加载
# labels 跨批次共享:只存 id→community_id内存极小
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
for batch_start in range(0, total_count, BATCH_SIZE):
batch_entities = await self.repo.get_entities_page(

View File

@@ -13,6 +13,8 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_LEAVE_ALL_COMMUNITIES,
GET_ENTITY_NEIGHBORS,
GET_ALL_ENTITIES_FOR_USER,
GET_ENTITY_COUNT_FOR_USER,
GET_ALL_ENTITY_IDS_FOR_USER,
GET_ENTITIES_PAGE,
GET_COMMUNITY_MEMBERS,
GET_ALL_COMMUNITY_MEMBERS_BATCH,
@@ -21,7 +23,6 @@ from app.repositories.neo4j.cypher_queries import (
CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA,
UPDATE_COMMUNITY_METADATA,
)
logger = logging.getLogger(__name__)
@@ -113,6 +114,30 @@ class CommunityRepository:
logger.error(f"get_all_entities failed: {e}")
return []
async def get_entity_count(self, end_user_id: str) -> int:
"""仅返回用户实体总数,不加载实体数据。"""
try:
result = await self.connector.execute_query(
GET_ENTITY_COUNT_FOR_USER,
end_user_id=end_user_id,
)
return result[0]["entity_count"] if result else 0
except Exception as e:
logger.error(f"get_entity_count failed: {e}")
return 0
async def get_all_entity_ids(self, end_user_id: str) -> List[str]:
"""仅返回用户所有实体 ID 列表,不加载 embedding 等大字段。"""
try:
result = await self.connector.execute_query(
GET_ALL_ENTITY_IDS_FOR_USER,
end_user_id=end_user_id,
)
return [r["id"] for r in result]
except Exception as e:
logger.error(f"get_all_entity_ids failed: {e}")
return []
async def get_entities_page(
self, end_user_id: str, skip: int, limit: int
) -> List[Dict]:

View File

@@ -1122,6 +1122,16 @@ RETURN e.id AS id,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
"""
GET_ENTITY_COUNT_FOR_USER = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
RETURN count(e) AS entity_count
"""
GET_ALL_ENTITY_IDS_FOR_USER = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
RETURN e.id AS id
"""
GET_COMMUNITY_MEMBERS = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,