[add] Create the attribute values of the community nodes
This commit is contained in:
@@ -165,7 +165,9 @@ async def write(
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
connector=neo4j_connector,
|
||||
config_id=config_id,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
@@ -19,6 +19,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# 全量迭代最大轮数,防止不收敛
|
||||
MAX_ITERATIONS = 10
|
||||
# 社区摘要核心实体数量
|
||||
CORE_ENTITY_LIMIT = 5
|
||||
|
||||
|
||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||
@@ -62,9 +64,16 @@ def _weighted_vote(
|
||||
class LabelPropagationEngine:
|
||||
"""标签传播聚类引擎"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.config_id = config_id
|
||||
self.llm_model_id = llm_model_id
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
@@ -155,6 +164,10 @@ class LabelPropagationEngine:
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
# 为所有社区生成元数据
|
||||
unique_communities = list(set(labels.values()))
|
||||
for cid in unique_communities:
|
||||
await self._generate_community_metadata(cid, end_user_id)
|
||||
|
||||
async def incremental_update(
|
||||
self, new_entity_ids: List[str], end_user_id: str
|
||||
@@ -211,6 +224,7 @@ class LabelPropagationEngine:
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
await self._generate_community_metadata(new_cid, end_user_id)
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
@@ -222,6 +236,7 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata(target_cid, end_user_id)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
@@ -354,6 +369,72 @@ class LabelPropagationEngine:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
为社区生成并写入元数据:名称、摘要、核心实体。
|
||||
|
||||
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
||||
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
||||
"""
|
||||
try:
|
||||
members = await self.repo.get_community_members(community_id, end_user_id)
|
||||
if not members:
|
||||
return
|
||||
|
||||
# 核心实体:按 activation_value 降序取 top-N
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
entity_list_str = "、".join(all_names)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||
|
||||
await self.repo.update_community_metadata(
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
)
|
||||
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
Reference in New Issue
Block a user