[add] Create the attribute values of the community nodes

This commit is contained in:
lanceyq
2026-03-12 20:27:50 +08:00
parent 744ba31ba6
commit 6d8b1aede4
5 changed files with 132 additions and 8 deletions

View File

@@ -165,7 +165,9 @@ async def write(
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
connector=neo4j_connector,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
)
if success:
logger.info("Successfully saved all data to Neo4j")

View File

@@ -19,6 +19,8 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10
# 社区摘要核心实体数量
CORE_ENTITY_LIMIT = 5
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
@@ -62,9 +64,16 @@ def _weighted_vote(
class LabelPropagationEngine:
"""标签传播聚类引擎"""
def __init__(self, connector: Neo4jConnector):
def __init__(
self,
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -155,6 +164,10 @@ class LabelPropagationEngine:
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体"
)
# 为所有社区生成元数据
unique_communities = list(set(labels.values()))
for cid in unique_communities:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str
@@ -211,6 +224,7 @@ class LabelPropagationEngine:
logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
)
await self._generate_community_metadata(new_cid, end_user_id)
else:
# 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -222,6 +236,7 @@ class LabelPropagationEngine:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
await self._generate_community_metadata(target_cid, end_user_id)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
@@ -354,6 +369,72 @@ class LabelPropagationEngine:
except Exception:
return None
async def _generate_community_metadata(
self, community_id: str, end_user_id: str
) -> None:
"""
为社区生成并写入元数据:名称、摘要、核心实体。
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
"""
try:
members = await self.repo.get_community_members(community_id, end_user_id)
if not members:
return
# 核心实体:按 activation_value 降序取 top-N
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else community_id[:8]
summary = f"包含实体:{', '.join(all_names)}"
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
if self.llm_model_id:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
entity_list_str = "".join(all_names)
prompt = (
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
)
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
except Exception as e:
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
@staticmethod
def _new_community_id() -> str:
return str(uuid.uuid4())

View File

@@ -17,6 +17,7 @@ from app.repositories.neo4j.cypher_queries import (
GET_ALL_COMMUNITY_MEMBERS_BATCH,
CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA,
)
logger = logging.getLogger(__name__)
@@ -147,3 +148,26 @@ class CommunityRepository:
except Exception as e:
logger.error(f"refresh_member_count failed: {e}")
return 0
async def update_community_metadata(
self,
community_id: str,
end_user_id: str,
name: str,
summary: str,
core_entities: List[str],
) -> bool:
"""更新社区的名称、摘要和核心实体列表。"""
try:
result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA,
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
)
return bool(result)
except Exception as e:
logger.error(f"update_community_metadata failed: {e}")
return False

View File

@@ -1150,3 +1150,12 @@ WITH c, count(e) AS cnt
SET c.member_count = cnt
RETURN c.community_id AS community_id, cnt AS member_count
"""
UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name,
c.summary = $summary,
c.core_entities = $core_entities,
c.updated_at = datetime()
RETURN c.community_id AS community_id
"""

View File

@@ -1,5 +1,6 @@
import asyncio
from typing import List
import os
from typing import List, Optional
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -156,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j(
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -290,12 +293,15 @@ async def save_dialog_and_statements_to_neo4j(
logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results)
# 写入成功后,触发聚类
if entity_nodes:
# 写入成功后,触发聚类(可通过环境变量 CLUSTERING_ENABLED=false 禁用,用于基准测试对比)
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
if entity_nodes and clustering_enabled:
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
await _trigger_clustering(new_entity_ids, end_user_id)
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
elif entity_nodes and not clustering_enabled:
logger.info("[Clustering] 聚类已禁用CLUSTERING_ENABLED=false跳过聚类触发")
return True
@@ -309,6 +315,8 @@ async def save_dialog_and_statements_to_neo4j(
async def _trigger_clustering(
new_entity_ids: List[str],
end_user_id: str,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
) -> None:
"""
聚类触发函数,自动判断全量初始化还是增量更新。
@@ -318,7 +326,7 @@ async def _trigger_clustering(
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
connector = Neo4jConnector()
engine = LabelPropagationEngine(connector)
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}")
except Exception as e: