[add] Create the attribute values of the community nodes

This commit is contained in:
lanceyq
2026-03-12 20:27:50 +08:00
parent 744ba31ba6
commit 6d8b1aede4
5 changed files with 132 additions and 8 deletions

View File

@@ -165,7 +165,9 @@ async def write(
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
connector=neo4j_connector,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
)
if success:
logger.info("Successfully saved all data to Neo4j")

View File

@@ -19,6 +19,8 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10
# 社区摘要核心实体数量
CORE_ENTITY_LIMIT = 5
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
@@ -62,9 +64,16 @@ def _weighted_vote(
class LabelPropagationEngine:
"""标签传播聚类引擎"""
def __init__(self, connector: Neo4jConnector):
def __init__(
self,
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -155,6 +164,10 @@ class LabelPropagationEngine:
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体"
)
# 为所有社区生成元数据
unique_communities = list(set(labels.values()))
for cid in unique_communities:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str
@@ -211,6 +224,7 @@ class LabelPropagationEngine:
logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
)
await self._generate_community_metadata(new_cid, end_user_id)
else:
# 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -222,6 +236,7 @@ class LabelPropagationEngine:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
await self._generate_community_metadata(target_cid, end_user_id)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
@@ -354,6 +369,72 @@ class LabelPropagationEngine:
except Exception:
return None
async def _generate_community_metadata(
self, community_id: str, end_user_id: str
) -> None:
"""
为社区生成并写入元数据:名称、摘要、核心实体。
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
"""
try:
members = await self.repo.get_community_members(community_id, end_user_id)
if not members:
return
# 核心实体:按 activation_value 降序取 top-N
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else community_id[:8]
summary = f"包含实体:{', '.join(all_names)}"
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
if self.llm_model_id:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
entity_list_str = "".join(all_names)
prompt = (
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
)
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
except Exception as e:
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
@staticmethod
def _new_community_id() -> str:
return str(uuid.uuid4())