[add] Create the attribute values of the community nodes
This commit is contained in:
@@ -165,7 +165,9 @@ async def write(
|
|||||||
statement_chunk_edges=all_statement_chunk_edges,
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
connector=neo4j_connector
|
connector=neo4j_connector,
|
||||||
|
config_id=config_id,
|
||||||
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# 全量迭代最大轮数,防止不收敛
|
# 全量迭代最大轮数,防止不收敛
|
||||||
MAX_ITERATIONS = 10
|
MAX_ITERATIONS = 10
|
||||||
|
# 社区摘要核心实体数量
|
||||||
|
CORE_ENTITY_LIMIT = 5
|
||||||
|
|
||||||
|
|
||||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||||
@@ -62,9 +64,16 @@ def _weighted_vote(
|
|||||||
class LabelPropagationEngine:
|
class LabelPropagationEngine:
|
||||||
"""标签传播聚类引擎"""
|
"""标签传播聚类引擎"""
|
||||||
|
|
||||||
def __init__(self, connector: Neo4jConnector):
|
def __init__(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
):
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
self.repo = CommunityRepository(connector)
|
self.repo = CommunityRepository(connector)
|
||||||
|
self.config_id = config_id
|
||||||
|
self.llm_model_id = llm_model_id
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 公开接口
|
# 公开接口
|
||||||
@@ -155,6 +164,10 @@ class LabelPropagationEngine:
|
|||||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||||
f"{len(labels)} 个实体"
|
f"{len(labels)} 个实体"
|
||||||
)
|
)
|
||||||
|
# 为所有社区生成元数据
|
||||||
|
unique_communities = list(set(labels.values()))
|
||||||
|
for cid in unique_communities:
|
||||||
|
await self._generate_community_metadata(cid, end_user_id)
|
||||||
|
|
||||||
async def incremental_update(
|
async def incremental_update(
|
||||||
self, new_entity_ids: List[str], end_user_id: str
|
self, new_entity_ids: List[str], end_user_id: str
|
||||||
@@ -211,6 +224,7 @@ class LabelPropagationEngine:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||||
)
|
)
|
||||||
|
await self._generate_community_metadata(new_cid, end_user_id)
|
||||||
else:
|
else:
|
||||||
# 加入得票最多的社区
|
# 加入得票最多的社区
|
||||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||||
@@ -222,6 +236,7 @@ class LabelPropagationEngine:
|
|||||||
await self._evaluate_merge(
|
await self._evaluate_merge(
|
||||||
list(community_ids_in_neighbors), end_user_id
|
list(community_ids_in_neighbors), end_user_id
|
||||||
)
|
)
|
||||||
|
await self._generate_community_metadata(target_cid, end_user_id)
|
||||||
|
|
||||||
async def _evaluate_merge(
|
async def _evaluate_merge(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
@@ -354,6 +369,72 @@ class LabelPropagationEngine:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _generate_community_metadata(
|
||||||
|
self, community_id: str, end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
为社区生成并写入元数据:名称、摘要、核心实体。
|
||||||
|
|
||||||
|
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
||||||
|
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
members = await self.repo.get_community_members(community_id, end_user_id)
|
||||||
|
if not members:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 核心实体:按 activation_value 降序取 top-N
|
||||||
|
sorted_members = sorted(
|
||||||
|
members,
|
||||||
|
key=lambda m: m.get("activation_value") or 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
|
all_names = [m["name"] for m in members if m.get("name")]
|
||||||
|
|
||||||
|
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
||||||
|
summary = f"包含实体:{', '.join(all_names)}"
|
||||||
|
|
||||||
|
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
||||||
|
if self.llm_model_id:
|
||||||
|
try:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
|
entity_list_str = "、".join(all_names)
|
||||||
|
prompt = (
|
||||||
|
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
||||||
|
f"请为这组实体所代表的主题:\n"
|
||||||
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
|
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||||
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
|
)
|
||||||
|
with get_db_context() as db:
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||||
|
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
|
for line in text.strip().splitlines():
|
||||||
|
if line.startswith("名称:"):
|
||||||
|
name = line[3:].strip()
|
||||||
|
elif line.startswith("摘要:"):
|
||||||
|
summary = line[3:].strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||||
|
|
||||||
|
await self.repo.update_community_metadata(
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
name=name,
|
||||||
|
summary=summary,
|
||||||
|
core_entities=core_entities,
|
||||||
|
)
|
||||||
|
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _new_community_id() -> str:
|
def _new_community_id() -> str:
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||||
CHECK_USER_HAS_COMMUNITIES,
|
CHECK_USER_HAS_COMMUNITIES,
|
||||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||||
|
UPDATE_COMMUNITY_METADATA,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -147,3 +148,26 @@ class CommunityRepository:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"refresh_member_count failed: {e}")
|
logger.error(f"refresh_member_count failed: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
async def update_community_metadata(
|
||||||
|
self,
|
||||||
|
community_id: str,
|
||||||
|
end_user_id: str,
|
||||||
|
name: str,
|
||||||
|
summary: str,
|
||||||
|
core_entities: List[str],
|
||||||
|
) -> bool:
|
||||||
|
"""更新社区的名称、摘要和核心实体列表。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
UPDATE_COMMUNITY_METADATA,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
name=name,
|
||||||
|
summary=summary,
|
||||||
|
core_entities=core_entities,
|
||||||
|
)
|
||||||
|
return bool(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"update_community_metadata failed: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -1150,3 +1150,12 @@ WITH c, count(e) AS cnt
|
|||||||
SET c.member_count = cnt
|
SET c.member_count = cnt
|
||||||
RETURN c.community_id AS community_id, cnt AS member_count
|
RETURN c.community_id AS community_id, cnt AS member_count
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
UPDATE_COMMUNITY_METADATA = """
|
||||||
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
|
SET c.name = $name,
|
||||||
|
c.summary = $summary,
|
||||||
|
c.core_entities = $core_entities,
|
||||||
|
c.updated_at = datetime()
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
@@ -156,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
entity_edges: List[EntityEntityEdge],
|
entity_edges: List[EntityEntityEdge],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
@@ -290,12 +293,15 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
logger.info("Transaction completed. Summary: %s", summary)
|
logger.info("Transaction completed. Summary: %s", summary)
|
||||||
logger.debug("Full transaction results: %r", results)
|
logger.debug("Full transaction results: %r", results)
|
||||||
|
|
||||||
# 写入成功后,触发聚类
|
# 写入成功后,触发聚类(可通过环境变量 CLUSTERING_ENABLED=false 禁用,用于基准测试对比)
|
||||||
if entity_nodes:
|
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
|
||||||
|
if entity_nodes and clustering_enabled:
|
||||||
end_user_id = entity_nodes[0].end_user_id
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
new_entity_ids = [e.id for e in entity_nodes]
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
await _trigger_clustering(new_entity_ids, end_user_id)
|
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
|
||||||
|
elif entity_nodes and not clustering_enabled:
|
||||||
|
logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -309,6 +315,8 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
async def _trigger_clustering(
|
async def _trigger_clustering(
|
||||||
new_entity_ids: List[str],
|
new_entity_ids: List[str],
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||||
@@ -318,7 +326,7 @@ async def _trigger_clustering(
|
|||||||
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||||
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
engine = LabelPropagationEngine(connector)
|
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
|
||||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user