diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 37376093..21257f2e 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -7,6 +7,7 @@ - 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居 """ +import asyncio import logging import uuid from math import sqrt @@ -114,7 +115,7 @@ class LabelPropagationEngine: - 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息 - 所有批次完成后统一 flush 和 merge """ - BATCH_SIZE = 2000 # 每批实体数,可按需调整 + BATCH_SIZE = 888 # 每批实体数,可按需调整 # 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段 total_count = await self.repo.get_entity_count(end_user_id) @@ -203,8 +204,7 @@ class LabelPropagationEngine: if e.get("community_id") }) logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") - for cid in surviving_community_ids: - await self._generate_community_metadata(cid, end_user_id) + await self._generate_community_metadata(surviving_community_ids, end_user_id) async def incremental_update( self, new_entity_ids: List[str], end_user_id: str @@ -261,7 +261,7 @@ class LabelPropagationEngine: logger.debug( f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" ) - await self._generate_community_metadata(new_cid, end_user_id) + await self._generate_community_metadata([new_cid], end_user_id) else: # 加入得票最多的社区 await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) @@ -273,7 +273,7 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) - await self._generate_community_metadata(target_cid, end_user_id) + await self._generate_community_metadata([target_cid], end_user_id) async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -437,89 +437,122 @@ class LabelPropagationEngine: except Exception: return None + @staticmethod + def _build_entity_lines(members: List[Dict]) -> List[str]: + """将实体列表格式化为 prompt 行,包含 name、aliases、description。""" + lines = [] + for m in members: + m_name = m.get("name", "") + aliases = m.get("aliases") or [] + description = m.get("description") or "" + aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else "" + desc_str = f":{description}" if description else "" + lines.append(f"- {m_name}{aliases_str}{desc_str}") + return lines + async def _generate_community_metadata( - self, community_id: str, end_user_id: str + self, community_ids: List[str], end_user_id: str ) -> None: """ - 为社区生成并写入元数据:名称、摘要、核心实体。 + 为一个或多个社区生成并写入元数据。 - - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) - - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 - NOTE: core_entities按照激活值高低排序,会造成对边缘信息检索返回消息质量不高。 + 流程: + 1. 逐个社区调 LLM 生成 name / summary(串行) + 2. 收集所有 summary,一次性批量 embed + 3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata """ - try: - members = await self.repo.get_community_members(community_id, end_user_id) - if not members: - return + if not community_ids: + return + + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + # --- 阶段1:并发调 LLM 生成每个社区的 name / summary --- + async def _build_one(cid: str): + members = await self.repo.get_community_members(cid, end_user_id) + if not members: + return None - # 核心实体:按 activation_value 降序取 top-N sorted_members = sorted( members, key=lambda m: m.get("activation_value") or 0, reverse=True, ) core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] - all_names = [m["name"] for m in members if m.get("name")] - name = "、".join(core_entities[:3]) if core_entities else community_id[:8] - summary = f"包含实体:{', '.join(all_names)}" + entity_list_str = "\n".join(self._build_entity_lines(members)) + prompt = ( + f"以下是一组语义相关的实体:\n{entity_list_str}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过50个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) - # 若有 LLM 配置,调用 LLM 生成更好的名称和摘要 - if self.llm_model_id: - try: - from app.db import get_db_context - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + name, summary = "", "" + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() - entity_list_str = "、".join(all_names) - prompt = ( - f"以下是一组语义相关的实体:{entity_list_str}\n\n" - f"请为这组实体所代表的主题:\n" - f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过50个字)\n\n" - f"严格按以下格式输出,不要有其他内容:\n" - f"名称:<名称>\n摘要:<摘要>" - ) - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(self.llm_model_id) - response = await llm_client.chat([{"role": "user", "content": prompt}]) - text = response.content if hasattr(response, "content") else str(response) + return { + "community_id": cid, + "end_user_id": end_user_id, + "name": name, + "summary": summary, + "core_entities": core_entities, + "summary_embedding": None, + } - for line in text.strip().splitlines(): - if line.startswith("名称:"): - name = line[3:].strip() - elif line.startswith("摘要:"): - summary = line[3:].strip() - except Exception as e: - logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") + results = await asyncio.gather( + *[_build_one(cid) for cid in community_ids], + return_exceptions=True, + ) + metadata_list = [] + for cid, res in zip(community_ids, results): + if isinstance(res, Exception): + logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res) + elif res is not None: + metadata_list.append(res) - # 生成 summary_embedding - summary_embedding = None - if self.embedding_model_id and summary: - try: - from app.db import get_db_context - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - with get_db_context() as db: - embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) - results = await embedder.response([summary]) - summary_embedding = results[0] if results else None - except Exception as e: - logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}") + if not metadata_list: + return + # --- 阶段2:批量生成 summary_embedding --- + summaries = [m["summary"] for m in metadata_list] + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + embeddings = await embedder.response(summaries) + for i, meta in enumerate(metadata_list): + meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + + # --- 阶段3:写入(单个 or 批量)--- + if len(metadata_list) == 1: + m = metadata_list[0] result = await self.repo.update_community_metadata( - community_id=community_id, - end_user_id=end_user_id, - name=name, - summary=summary, - core_entities=core_entities, - summary_embedding=summary_embedding, + community_id=m["community_id"], + end_user_id=m["end_user_id"], + name=m["name"], + summary=m["summary"], + core_entities=m["core_entities"], + summary_embedding=m["summary_embedding"], ) if result: - logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...") + logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...") else: - logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False") - except Exception as e: - logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True) + logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False") + else: + ok = await self.repo.batch_update_community_metadata(metadata_list) + if ok: + logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") + else: + logger.warning(f"[Clustering] 批量写入社区元数据失败") @staticmethod def _new_community_id() -> str: diff --git a/api/app/main.py b/api/app/main.py index 5314b8b6..c794f48a 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -18,7 +18,6 @@ from app.core.logging_config import LoggingConfig, get_logger from app.core.response_utils import fail from app.core.models.scripts.loader import load_models from app.db import get_db_context -from app.repositories.neo4j.index_manager import ensure_indexes # Initialize logging system LoggingConfig.setup_logging() @@ -62,16 +61,6 @@ async def lifespan(app: FastAPI): else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") - # 确保 Neo4j 索引存在(幂等,多环境安全) - try: - report = await ensure_indexes() - if report["errors"]: - logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}") - else: - logger.info(f"Neo4j 索引检查完成 [{report['uri']}]") - except Exception as e: - logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}") - logger.info("应用程序启动完成") yield logger.info("应用程序正在关闭") diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 267ced4f..f9c4bd92 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -23,6 +23,7 @@ from app.repositories.neo4j.cypher_queries import ( CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_METADATA, + BATCH_UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -257,3 +258,25 @@ class CommunityRepository: except Exception as e: logger.error(f"update_community_metadata failed: {e}") return False + + async def batch_update_community_metadata( + self, + communities: List[Dict], + ) -> bool: + """批量更新多个社区的元数据。 + + Args: + communities: 每项包含 community_id, end_user_id, name, summary, + core_entities, summary_embedding + """ + if not communities: + return True + try: + await self.connector.execute_query( + BATCH_UPDATE_COMMUNITY_METADATA, + communities=communities, + ) + return True + except Exception as e: + logger.error(f"batch_update_community_metadata failed: {e}") + return False diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index 55dead1b..29f60fdd 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -42,6 +42,13 @@ async def create_fulltext_indexes(): OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } """) print("✓ Created: summariesFulltext") + + # 创建 Community 索引 + await connector.execute_query(""" + CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] + OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } + """) + print("✓ Created: communitiesFulltext") print("\nFull-text indexes created successfully with BM25 support.") except Exception as e: @@ -124,6 +131,18 @@ async def create_vector_indexes(): }} """) print("✓ Created: dialogue_embedding_index") + + # Community summary embedding index + await connector.execute_query(""" + CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS + FOR (c:Community) + ON c.summary_embedding + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """) + print("✓ Created: community_summary_embedding_index") print("\nVector indexes created successfully!") print("\nExpected performance improvement:") diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 01f2fa6a..16a26b3b 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1136,7 +1136,8 @@ GET_COMMUNITY_MEMBERS = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, e.importance_score AS importance_score, e.activation_value AS activation_value, - e.name_embedding AS name_embedding + e.name_embedding AS name_embedding, + e.aliases AS aliases, e.description AS description ORDER BY coalesce(e.activation_value, 0) DESC """ @@ -1145,7 +1146,8 @@ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->( RETURN c.community_id AS community_id, e.id AS id, e.name AS name, e.entity_type AS entity_type, e.importance_score AS importance_score, e.activation_value AS activation_value, - e.name_embedding AS name_embedding + e.name_embedding AS name_embedding, + e.aliases AS aliases, e.description AS description ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC """ @@ -1171,6 +1173,17 @@ SET c.name = $name, RETURN c.community_id AS community_id """ +BATCH_UPDATE_COMMUNITY_METADATA = """ +UNWIND $communities AS row +MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id}) +SET c.name = row.name, + c.summary = row.summary, + c.core_entities = row.core_entities, + c.summary_embedding = row.summary_embedding, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" + GET_ENTITIES_PAGE = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community) diff --git a/api/app/repositories/neo4j/index_manager.py b/api/app/repositories/neo4j/index_manager.py deleted file mode 100644 index a1ab6689..00000000 --- a/api/app/repositories/neo4j/index_manager.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -"""Neo4j 索引管理模块 - -负责检查和创建 Neo4j 全文索引与向量索引。 -支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。 - -用法: - # 作为模块调用(应用启动时) - from app.repositories.neo4j.index_manager import ensure_indexes - await ensure_indexes() - - # 作为独立脚本执行(手动建索引) - python -m app.repositories.neo4j.index_manager -""" - -import asyncio -import logging -from dataclasses import dataclass -from typing import List - -from app.core.config import settings -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -logger = logging.getLogger(__name__) - - -# ───────────────────────────────────────────────────────────── -# 索引定义表 -# ───────────────────────────────────────────────────────────── - -@dataclass -class FulltextIndexDef: - name: str - label: str - properties: List[str] - - -@dataclass -class VectorIndexDef: - name: str - label: str - property: str - dimensions: int - similarity: str = "cosine" - - -# 全文索引清单(现有 + 新增 communities) -FULLTEXT_INDEXES: List[FulltextIndexDef] = [ - FulltextIndexDef("statementsFulltext", "Statement", ["statement"]), - FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]), - FulltextIndexDef("chunksFulltext", "Chunk", ["content"]), - FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]), - FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源 -] - -# 向量索引清单(预留 community 二期) -VECTOR_INDEXES: List[VectorIndexDef] = [ - VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536), - VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536), - VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536), - VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536), - # 二期:社区向量索引 - VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536), -] - - -# ───────────────────────────────────────────────────────────── -# 核心检查 / 创建逻辑 -# ───────────────────────────────────────────────────────────── - -async def _get_existing_indexes(connector: Neo4jConnector) -> set: - """查询 Neo4j 中已存在的索引名称集合""" - rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name") - return {row["name"] for row in rows} - - -async def _ensure_fulltext_index( - connector: Neo4jConnector, - idx: FulltextIndexDef, - existing: set, -) -> str: - """检查并按需创建全文索引,返回操作状态描述""" - if idx.name in existing: - return f"[SKIP] 全文索引已存在: {idx.name}" - - props = ", ".join(f"n.{p}" for p in idx.properties) - cypher = ( - f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS ' - f'FOR (n:{idx.label}) ON EACH [{props}]' - ) - await connector.execute_query(cypher) - return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label} → {idx.properties})" - - -async def _ensure_vector_index( - connector: Neo4jConnector, - idx: VectorIndexDef, - existing: set, -) -> str: - """检查并按需创建向量索引,返回操作状态描述""" - if idx.name in existing: - return f"[SKIP] 向量索引已存在: {idx.name}" - - cypher = ( - f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS " - f"FOR (n:{idx.label}) ON n.{idx.property} " - f"OPTIONS {{indexConfig: {{" - f"`vector.dimensions`: {idx.dimensions}, " - f"`vector.similarity_function`: '{idx.similarity}'" - f"}}}}" - ) - await connector.execute_query(cypher) - return ( - f"[CREATE] 向量索引已创建: {idx.name} " - f"({idx.label}.{idx.property}, dim={idx.dimensions})" - ) - - -async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict: - """ - 检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。 - - Args: - connector: 可选,传入已有连接器;为 None 时自动创建。 - - Returns: - dict: { - "uri": 当前连接的 Neo4j URI, - "fulltext": [操作日志列表], - "vector": [操作日志列表], - "errors": [错误信息列表], - } - """ - own_connector = connector is None - if own_connector: - connector = Neo4jConnector() - - report = { - "uri": settings.NEO4J_URI, - "fulltext": [], - "vector": [], - "errors": [], - } - - try: - # 一次性拉取所有已有索引名 - existing = await _get_existing_indexes(connector) - logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}") - logger.info(f"[IndexManager] 已有索引数量: {len(existing)}") - - # 处理全文索引 - for idx in FULLTEXT_INDEXES: - try: - msg = await _ensure_fulltext_index(connector, idx, existing) - report["fulltext"].append(msg) - logger.info(f"[IndexManager] {msg}") - except Exception as e: - err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}" - report["errors"].append(err) - logger.error(f"[IndexManager] {err}") - - # 处理向量索引 - for idx in VECTOR_INDEXES: - try: - msg = await _ensure_vector_index(connector, idx, existing) - report["vector"].append(msg) - logger.info(f"[IndexManager] {msg}") - except Exception as e: - err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}" - report["errors"].append(err) - logger.error(f"[IndexManager] {err}") - - finally: - if own_connector: - await connector.close() - - return report - - -async def check_indexes(connector: Neo4jConnector | None = None) -> dict: - """ - 仅检查索引状态,不创建任何索引。 - - Returns: - dict: { - "uri": ..., - "present": [已存在的索引名], - "missing_fulltext": [缺失的全文索引名], - "missing_vector": [缺失的向量索引名], - } - """ - own_connector = connector is None - if own_connector: - connector = Neo4jConnector() - - try: - existing = await _get_existing_indexes(connector) - missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing] - missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing] - - return { - "uri": settings.NEO4J_URI, - "present": sorted(existing), - "missing_fulltext": missing_ft, - "missing_vector": missing_vec, - } - finally: - if own_connector: - await connector.close() - - -# ───────────────────────────────────────────────────────────── -# 独立脚本入口 -# ───────────────────────────────────────────────────────────── - -async def _main(): - import sys - - print(f"\n{'='*60}") - print(f"Neo4j 索引管理工具") - print(f"环境: {settings.NEO4J_URI}") - print(f"{'='*60}\n") - - # 先检查 - print(">>> 检查当前索引状态...\n") - status = await check_indexes() - print(f" 已存在索引数: {len(status['present'])}") - if status["missing_fulltext"]: - print(f" 缺失全文索引: {status['missing_fulltext']}") - if status["missing_vector"]: - print(f" 缺失向量索引: {status['missing_vector']}") - - if not status["missing_fulltext"] and not status["missing_vector"]: - print("\n 所有索引均已存在,无需操作。") - return - - # 再创建 - print("\n>>> 开始创建缺失索引...\n") - report = await ensure_indexes() - - for msg in report["fulltext"] + report["vector"]: - print(f" {msg}") - - if report["errors"]: - print("\n[!] 以下索引创建失败:") - for err in report["errors"]: - print(f" {err}") - sys.exit(1) - else: - print("\n 全部索引处理完成。") - - -if __name__ == "__main__": - asyncio.run(_main())