diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 22030278..b3707083 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -165,7 +165,9 @@ async def write( statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, - connector=neo4j_connector + connector=neo4j_connector, + config_id=config_id, + llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, ) if success: logger.info("Successfully saved all data to Neo4j") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index cb6e5804..251d4fea 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -19,6 +19,8 @@ logger = logging.getLogger(__name__) # 全量迭代最大轮数,防止不收敛 MAX_ITERATIONS = 10 +# 社区摘要核心实体数量 +CORE_ENTITY_LIMIT = 5 def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: @@ -62,9 +64,16 @@ def _weighted_vote( class LabelPropagationEngine: """标签传播聚类引擎""" - def __init__(self, connector: Neo4jConnector): + def __init__( + self, + connector: Neo4jConnector, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, + ): self.connector = connector self.repo = CommunityRepository(connector) + self.config_id = config_id + self.llm_model_id = llm_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -155,6 +164,10 @@ class LabelPropagationEngine: f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) + # 为所有社区生成元数据 + unique_communities = list(set(labels.values())) + for cid in unique_communities: + await self._generate_community_metadata(cid, end_user_id) async def incremental_update( self, new_entity_ids: List[str], end_user_id: str @@ -211,6 +224,7 @@ class LabelPropagationEngine: logger.debug( f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" ) + await self._generate_community_metadata(new_cid, end_user_id) else: # 加入得票最多的社区 await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) @@ -222,6 +236,7 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) + await self._generate_community_metadata(target_cid, end_user_id) async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -354,6 +369,72 @@ class LabelPropagationEngine: except Exception: return None + async def _generate_community_metadata( + self, community_id: str, end_user_id: str + ) -> None: + """ + 为社区生成并写入元数据:名称、摘要、核心实体。 + + - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) + - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 + """ + try: + members = await self.repo.get_community_members(community_id, end_user_id) + if not members: + return + + # 核心实体:按 activation_value 降序取 top-N + sorted_members = sorted( + members, + key=lambda m: m.get("activation_value") or 0, + reverse=True, + ) + core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] + all_names = [m["name"] for m in members if m.get("name")] + + name = "、".join(core_entities[:3]) if core_entities else community_id[:8] + summary = f"包含实体:{', '.join(all_names)}" + + # 若有 LLM 配置,调用 LLM 生成更好的名称和摘要 + if self.llm_model_id: + try: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + entity_list_str = "、".join(all_names) + prompt = ( + f"以下是一组语义相关的实体:{entity_list_str}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过50个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) + + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() + except Exception as e: + logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") + + await self.repo.update_community_metadata( + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + ) + logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}") + except Exception as e: + logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}") + @staticmethod def _new_community_id() -> str: return str(uuid.uuid4()) diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 2a1f4f2b..6c5c7618 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -17,6 +17,7 @@ from app.repositories.neo4j.cypher_queries import ( GET_ALL_COMMUNITY_MEMBERS_BATCH, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, + UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -147,3 +148,26 @@ class CommunityRepository: except Exception as e: logger.error(f"refresh_member_count failed: {e}") return 0 + + async def update_community_metadata( + self, + community_id: str, + end_user_id: str, + name: str, + summary: str, + core_entities: List[str], + ) -> bool: + """更新社区的名称、摘要和核心实体列表。""" + try: + result = await self.connector.execute_query( + UPDATE_COMMUNITY_METADATA, + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + ) + return bool(result) + except Exception as e: + logger.error(f"update_community_metadata failed: {e}") + return False diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 84889d65..b270ed64 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1150,3 +1150,12 @@ WITH c, count(e) AS cnt SET c.member_count = cnt RETURN c.community_id AS community_id, cnt AS member_count """ + +UPDATE_COMMUNITY_METADATA = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +SET c.name = $name, + c.summary = $summary, + c.core_entities = $core_entities, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index a94bc23b..2ef9bafc 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,5 +1,6 @@ import asyncio -from typing import List +import os +from typing import List, Optional # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -156,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j( entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + connector: Neo4jConnector, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -290,12 +293,15 @@ async def save_dialog_and_statements_to_neo4j( logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) - # 写入成功后,触发聚类 - if entity_nodes: + # 写入成功后,触发聚类(可通过环境变量 CLUSTERING_ENABLED=false 禁用,用于基准测试对比) + clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false" + if entity_nodes and clustering_enabled: end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - await _trigger_clustering(new_entity_ids, end_user_id) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) + elif entity_nodes and not clustering_enabled: + logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发") return True @@ -309,6 +315,8 @@ async def save_dialog_and_statements_to_neo4j( async def _trigger_clustering( new_entity_ids: List[str], end_user_id: str, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 @@ -318,7 +326,7 @@ async def _trigger_clustering( from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") connector = Neo4jConnector() - engine = LabelPropagationEngine(connector) + engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") except Exception as e: