Merge branch 'develop' into release/v0.2.7
This commit is contained in:
@@ -2,7 +2,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy import and_, desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
@@ -127,6 +127,17 @@ class MemoryPerceptualRepository:
|
||||
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_url(
|
||||
self,
|
||||
file_url: str
|
||||
) -> list[MemoryPerceptualModel]:
|
||||
try:
|
||||
stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
except Exception:
|
||||
db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}")
|
||||
raise
|
||||
|
||||
def get_by_type(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
|
||||
194
api/app/repositories/neo4j/community_repository.py
Normal file
194
api/app/repositories/neo4j/community_repository.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Community 节点仓库
|
||||
|
||||
管理 Neo4j 中 Community 节点及 BELONGS_TO_COMMUNITY 边的 CRUD 操作。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
COMMUNITY_NODE_UPSERT,
|
||||
ENTITY_JOIN_COMMUNITY,
|
||||
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||
GET_ENTITY_NEIGHBORS,
|
||||
GET_ALL_ENTITIES_FOR_USER,
|
||||
GET_COMMUNITY_MEMBERS,
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityRepository:
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
self.connector = connector
|
||||
|
||||
async def upsert_community(
|
||||
self, community_id: str, end_user_id: str, member_count: int = 0
|
||||
) -> Optional[str]:
|
||||
"""创建或更新 Community 节点,返回 community_id。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
COMMUNITY_NODE_UPSERT,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
member_count=member_count,
|
||||
)
|
||||
return result[0]["community_id"] if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"upsert_community failed: {e}")
|
||||
return None
|
||||
|
||||
async def assign_entity_to_community(
|
||||
self, entity_id: str, community_id: str, end_user_id: str
|
||||
) -> bool:
|
||||
"""将实体关联到社区(先解除旧关联,再建立新关联)。"""
|
||||
try:
|
||||
await self.connector.execute_query(
|
||||
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||
entity_id=entity_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result = await self.connector.execute_query(
|
||||
ENTITY_JOIN_COMMUNITY,
|
||||
entity_id=entity_id,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"assign_entity_to_community failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_entity_neighbors(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> List[Dict]:
|
||||
"""查询实体的直接邻居及其社区归属。"""
|
||||
try:
|
||||
return await self.connector.execute_query(
|
||||
GET_ENTITY_NEIGHBORS,
|
||||
entity_id=entity_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"get_entity_neighbors failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_entity_neighbors_batch(
|
||||
self, end_user_id: str
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""一次性批量拉取该用户下所有实体的邻居,返回 {entity_id: [neighbors]} 字典。
|
||||
用于全量聚类预加载,避免每个实体单独查询。"""
|
||||
try:
|
||||
rows = await self.connector.execute_query(
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result: Dict[str, List[Dict]] = {}
|
||||
for row in rows:
|
||||
eid = row["entity_id"]
|
||||
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
|
||||
result.setdefault(eid, []).append(neighbor)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_entity_neighbors_batch failed: {e}")
|
||||
return {}
|
||||
|
||||
async def get_all_entities(self, end_user_id: str) -> List[Dict]:
|
||||
"""拉取某用户下所有实体及其当前社区归属。"""
|
||||
try:
|
||||
return await self.connector.execute_query(
|
||||
GET_ALL_ENTITIES_FOR_USER,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_entities failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_community_members(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> List[Dict]:
|
||||
"""查询社区成员列表。"""
|
||||
try:
|
||||
return await self.connector.execute_query(
|
||||
GET_COMMUNITY_MEMBERS,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"get_community_members failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_community_members_batch(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""批量查询多个社区的成员,返回 {community_id: [members]} 字典。"""
|
||||
try:
|
||||
rows = await self.connector.execute_query(
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result: Dict[str, List[Dict]] = {}
|
||||
for row in rows:
|
||||
cid = row["community_id"]
|
||||
result.setdefault(cid, []).append(row)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_community_members_batch failed: {e}")
|
||||
return {}
|
||||
|
||||
async def has_communities(self, end_user_id: str) -> bool:
|
||||
"""检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return result[0]["community_count"] > 0 if result else False
|
||||
except Exception as e:
|
||||
logger.error(f"has_communities failed: {e}")
|
||||
return False
|
||||
|
||||
async def refresh_member_count(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> int:
|
||||
"""重新统计并更新社区成员数,返回最新数量。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return result[0]["member_count"] if result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"refresh_member_count failed: {e}")
|
||||
return 0
|
||||
|
||||
async def update_community_metadata(
|
||||
self,
|
||||
community_id: str,
|
||||
end_user_id: str,
|
||||
name: str,
|
||||
summary: str,
|
||||
core_entities: List[str],
|
||||
) -> bool:
|
||||
"""更新社区的名称、摘要和核心实体列表。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"update_community_metadata failed: {e}")
|
||||
return False
|
||||
@@ -1058,4 +1058,147 @@ Graph_Node_query = """
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||
# ============================================================
|
||||
|
||||
# ─── Community 聚类相关 Cypher 模板 ───────────────────────────────────────────
|
||||
|
||||
COMMUNITY_NODE_UPSERT = """
|
||||
MERGE (c:Community {community_id: $community_id})
|
||||
SET c.end_user_id = $end_user_id,
|
||||
c.member_count = $member_count,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
ENTITY_JOIN_COMMUNITY = """
|
||||
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c)
|
||||
SET c.updated_at = datetime()
|
||||
RETURN e.id AS entity_id, c.community_id AS community_id
|
||||
"""
|
||||
|
||||
ENTITY_LEAVE_ALL_COMMUNITIES = """
|
||||
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||
MATCH (e)-[r:BELONGS_TO_COMMUNITY]->(:Community)
|
||||
DELETE r
|
||||
"""
|
||||
|
||||
GET_ENTITY_NEIGHBORS = """
|
||||
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||
|
||||
// 来源一:直接关系邻居(EXTRACTED_RELATIONSHIP 边)
|
||||
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源二:同 Statement 共现邻居(REFERENCES_ENTITY 边)
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE nb2.id <> e.id
|
||||
|
||||
WITH collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||
UNWIND all_neighbors AS nb
|
||||
WITH nb WHERE nb IS NOT NULL
|
||||
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN DISTINCT
|
||||
nb.id AS id,
|
||||
nb.name AS name,
|
||||
nb.name_embedding AS name_embedding,
|
||||
nb.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
GET_ALL_ENTITIES_FOR_USER = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.name_embedding AS name_embedding,
|
||||
e.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
GET_COMMUNITY_MEMBERS = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||
e.name_embedding AS name_embedding
|
||||
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||
"""
|
||||
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
WHERE c.community_id IN $community_ids
|
||||
RETURN c.community_id AS community_id,
|
||||
e.id AS id,
|
||||
e.name_embedding AS name_embedding,
|
||||
e.activation_value AS activation_value
|
||||
"""
|
||||
|
||||
CHECK_USER_HAS_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
RETURN count(c) AS community_count
|
||||
"""
|
||||
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||
WITH c, count(e) AS cnt
|
||||
SET c.member_count = cnt
|
||||
RETURN c.community_id AS community_id, cnt AS member_count
|
||||
"""
|
||||
|
||||
UPDATE_COMMUNITY_METADATA = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
SET c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
||||
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源一:直接关系邻居
|
||||
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源二:同 Statement 共现邻居
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE nb2.id <> e.id
|
||||
|
||||
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||
UNWIND all_neighbors AS nb
|
||||
WITH e, nb WHERE nb IS NOT NULL
|
||||
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN DISTINCT
|
||||
e.id AS entity_id,
|
||||
nb.id AS id,
|
||||
nb.name AS name,
|
||||
nb.name_embedding AS name_embedding,
|
||||
nb.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
GET_COMMUNITY_GRAPH_DATA = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c)
|
||||
OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
RETURN
|
||||
elementId(c) AS c_id,
|
||||
properties(c) AS c_props,
|
||||
elementId(e) AS e_id,
|
||||
properties(e) AS e_props,
|
||||
elementId(b) AS b_id,
|
||||
elementId(e2) AS e2_id,
|
||||
properties(e2) AS e2_props,
|
||||
elementId(r) AS r_id,
|
||||
type(r) AS r_type,
|
||||
properties(r) AS r_props,
|
||||
startNode(r) = e AS r_from_e
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import List
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -155,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
connector: Neo4jConnector,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
@@ -288,6 +292,10 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
}
|
||||
logger.info("Transaction completed. Summary: %s", summary)
|
||||
logger.debug("Full transaction results: %r", results)
|
||||
|
||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||
schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -295,3 +303,55 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
print(f"Neo4j integration error: {e}")
|
||||
print("Continuing without database storage...")
|
||||
return False
|
||||
|
||||
|
||||
def schedule_clustering_after_write(
|
||||
entity_nodes: List,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
写入 Neo4j 成功后,调度后台聚类任务。
|
||||
|
||||
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
||||
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
||||
"""
|
||||
if not entity_nodes:
|
||||
return
|
||||
|
||||
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
|
||||
if not clustering_enabled:
|
||||
logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发")
|
||||
return
|
||||
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||
"""
|
||||
connector = None
|
||||
try:
|
||||
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||
connector = Neo4jConnector()
|
||||
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
|
||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 聚类触发失败: {e}", exc_info=True)
|
||||
finally:
|
||||
if connector:
|
||||
try:
|
||||
await connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user