[add] Create the attribute values of the community nodes

This commit is contained in:
lanceyq
2026-03-12 20:27:50 +08:00
parent 744ba31ba6
commit 6d8b1aede4
5 changed files with 132 additions and 8 deletions

View File

@@ -165,7 +165,9 @@ async def write(
statement_chunk_edges=all_statement_chunk_edges, statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges, statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges, entity_edges=all_entity_entity_edges,
connector=neo4j_connector connector=neo4j_connector,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
) )
if success: if success:
logger.info("Successfully saved all data to Neo4j") logger.info("Successfully saved all data to Neo4j")

View File

@@ -19,6 +19,8 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛 # 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10 MAX_ITERATIONS = 10
# 社区摘要核心实体数量
CORE_ENTITY_LIMIT = 5
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
@@ -62,9 +64,16 @@ def _weighted_vote(
class LabelPropagationEngine: class LabelPropagationEngine:
"""标签传播聚类引擎""" """标签传播聚类引擎"""
def __init__(self, connector: Neo4jConnector): def __init__(
self,
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
):
self.connector = connector self.connector = connector
self.repo = CommunityRepository(connector) self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# 公开接口 # 公开接口
@@ -155,6 +164,10 @@ class LabelPropagationEngine:
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体" f"{len(labels)} 个实体"
) )
# 为所有社区生成元数据
unique_communities = list(set(labels.values()))
for cid in unique_communities:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update( async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str self, new_entity_ids: List[str], end_user_id: str
@@ -211,6 +224,7 @@ class LabelPropagationEngine:
logger.debug( logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
) )
await self._generate_community_metadata(new_cid, end_user_id)
else: else:
# 加入得票最多的社区 # 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -222,6 +236,7 @@ class LabelPropagationEngine:
await self._evaluate_merge( await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id list(community_ids_in_neighbors), end_user_id
) )
await self._generate_community_metadata(target_cid, end_user_id)
async def _evaluate_merge( async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str self, community_ids: List[str], end_user_id: str
@@ -354,6 +369,72 @@ class LabelPropagationEngine:
except Exception: except Exception:
return None return None
async def _generate_community_metadata(
self, community_id: str, end_user_id: str
) -> None:
"""
为社区生成并写入元数据:名称、摘要、核心实体。
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
"""
try:
members = await self.repo.get_community_members(community_id, end_user_id)
if not members:
return
# 核心实体:按 activation_value 降序取 top-N
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else community_id[:8]
summary = f"包含实体:{', '.join(all_names)}"
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
if self.llm_model_id:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
entity_list_str = "".join(all_names)
prompt = (
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
)
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
except Exception as e:
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
@staticmethod @staticmethod
def _new_community_id() -> str: def _new_community_id() -> str:
return str(uuid.uuid4()) return str(uuid.uuid4())

View File

@@ -17,6 +17,7 @@ from app.repositories.neo4j.cypher_queries import (
GET_ALL_COMMUNITY_MEMBERS_BATCH, GET_ALL_COMMUNITY_MEMBERS_BATCH,
CHECK_USER_HAS_COMMUNITIES, CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -147,3 +148,26 @@ class CommunityRepository:
except Exception as e: except Exception as e:
logger.error(f"refresh_member_count failed: {e}") logger.error(f"refresh_member_count failed: {e}")
return 0 return 0
async def update_community_metadata(
self,
community_id: str,
end_user_id: str,
name: str,
summary: str,
core_entities: List[str],
) -> bool:
"""更新社区的名称、摘要和核心实体列表。"""
try:
result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA,
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
)
return bool(result)
except Exception as e:
logger.error(f"update_community_metadata failed: {e}")
return False

View File

@@ -1150,3 +1150,12 @@ WITH c, count(e) AS cnt
SET c.member_count = cnt SET c.member_count = cnt
RETURN c.community_id AS community_id, cnt AS member_count RETURN c.community_id AS community_id, cnt AS member_count
""" """
UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name,
c.summary = $summary,
c.core_entities = $core_entities,
c.updated_at = datetime()
RETURN c.community_id AS community_id
"""

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
from typing import List import os
from typing import List, Optional
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -156,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j(
entity_edges: List[EntityEntityEdge], entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -290,12 +293,15 @@ async def save_dialog_and_statements_to_neo4j(
logger.info("Transaction completed. Summary: %s", summary) logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results) logger.debug("Full transaction results: %r", results)
# 写入成功后,触发聚类 # 写入成功后,触发聚类(可通过环境变量 CLUSTERING_ENABLED=false 禁用,用于基准测试对比)
if entity_nodes: clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
if entity_nodes and clustering_enabled:
end_user_id = entity_nodes[0].end_user_id end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes] new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
await _trigger_clustering(new_entity_ids, end_user_id) asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
elif entity_nodes and not clustering_enabled:
logger.info("[Clustering] 聚类已禁用CLUSTERING_ENABLED=false跳过聚类触发")
return True return True
@@ -309,6 +315,8 @@ async def save_dialog_and_statements_to_neo4j(
async def _trigger_clustering( async def _trigger_clustering(
new_entity_ids: List[str], new_entity_ids: List[str],
end_user_id: str, end_user_id: str,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
聚类触发函数,自动判断全量初始化还是增量更新。 聚类触发函数,自动判断全量初始化还是增量更新。
@@ -318,7 +326,7 @@ async def _trigger_clustering(
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
connector = Neo4jConnector() connector = Neo4jConnector()
engine = LabelPropagationEngine(connector) engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}") logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}")
except Exception as e: except Exception as e: