[changes] Community Clustering Retrieval Module
This commit is contained in:
@@ -120,7 +120,7 @@ class SearchService:
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
@@ -146,8 +146,8 @@ class SearchService:
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
@@ -157,13 +157,43 @@ class SearchService:
|
||||
else:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# 对命中的 community 节点展开其成员 statements
|
||||
if "communities" in include:
|
||||
community_results = (
|
||||
answer.get('reranked_results', {}).get('communities', [])
|
||||
if search_type == "hybrid"
|
||||
else answer.get('communities', [])
|
||||
)
|
||||
community_ids = [
|
||||
r.get("id") for r in community_results if r.get("id")
|
||||
]
|
||||
if community_ids and end_user_id:
|
||||
try:
|
||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
expand_result = await search_graph_community_expand(
|
||||
connector=connector,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
limit=10,
|
||||
)
|
||||
await connector.close()
|
||||
expanded_stmts = expand_result.get("expanded_statements", [])
|
||||
if expanded_stmts:
|
||||
# 展开的 statements 插入 communities 之后、statements 之前
|
||||
answer_list.extend(expanded_stmts)
|
||||
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
|
||||
except Exception as e:
|
||||
logger.warning(f"社区展开检索失败,跳过: {e}")
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
@@ -165,10 +165,17 @@ async def write(
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||
schedule_clustering_after_write(
|
||||
all_entity_nodes,
|
||||
config_id=config_id,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
|
||||
@@ -238,7 +238,7 @@ def rerank_with_activation(
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
@@ -281,21 +281,23 @@ def rerank_with_activation(
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
||||
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||
raw_act_norm = item.get("normalized_activation_value")
|
||||
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
# 存储激活度分数供第二阶段使用
|
||||
item["activation_score"] = act_norm
|
||||
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||
item["activation_score"] = act_norm # 可能为 None
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
|
||||
@@ -20,6 +20,9 @@ logger = logging.getLogger(__name__)
|
||||
# 全量迭代最大轮数,防止不收敛
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
# 社区核心实体取 top-N 数量
|
||||
CORE_ENTITY_LIMIT = 10
|
||||
|
||||
|
||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
|
||||
@@ -62,9 +65,18 @@ def _weighted_vote(
|
||||
class LabelPropagationEngine:
|
||||
"""标签传播聚类引擎"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.config_id = config_id
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
@@ -94,58 +106,110 @@ class LabelPropagationEngine:
|
||||
|
||||
async def full_clustering(self, end_user_id: str) -> None:
|
||||
"""
|
||||
全量标签传播初始化。
|
||||
全量标签传播初始化(分批处理,控制内存峰值)。
|
||||
|
||||
1. 拉取所有实体,初始化每个实体为独立社区
|
||||
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
|
||||
3. 直到标签不再变化或达到 MAX_ITERATIONS
|
||||
4. 将最终标签写入 Neo4j
|
||||
策略:
|
||||
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
|
||||
- labels 字典跨批次共享(只存 id→community_id,内存极小)
|
||||
- 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息
|
||||
- 所有批次完成后统一 flush 和 merge
|
||||
"""
|
||||
entities = await self.repo.get_all_entities(end_user_id)
|
||||
if not entities:
|
||||
BATCH_SIZE = 2000 # 每批实体数,可按需调整
|
||||
|
||||
# 先查总数,决定批次数
|
||||
total_entities = await self.repo.get_all_entities(end_user_id)
|
||||
if not total_entities:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||
return
|
||||
|
||||
# 初始化:每个实体持有自己 id 作为社区标签
|
||||
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
|
||||
embeddings: Dict[str, Optional[List[float]]] = {
|
||||
e["id"]: e.get("name_embedding") for e in entities
|
||||
}
|
||||
total_count = len(total_entities)
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体,"
|
||||
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批")
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
changed = 0
|
||||
# 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
|
||||
for entity in entities:
|
||||
eid = entity["id"]
|
||||
neighbors = await self.repo.get_entity_neighbors(eid, end_user_id)
|
||||
# labels 跨批次共享:先用全量数据初始化(只存 id,内存极小)
|
||||
labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities}
|
||||
# embeddings 也跨批次共享(每个向量 ~6KB,10万实体约 600MB,这是不可避免的)
|
||||
# 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻
|
||||
# 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding,
|
||||
# 所以只需要当前批次实体的 embedding,不需要全量
|
||||
del total_entities # 释放全量列表,后续按批次加载
|
||||
|
||||
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
|
||||
enriched = []
|
||||
for nb in neighbors:
|
||||
nb_copy = dict(nb)
|
||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||
enriched.append(nb_copy)
|
||||
|
||||
new_label = _weighted_vote(enriched, embeddings.get(eid))
|
||||
if new_label and new_label != labels[eid]:
|
||||
labels[eid] = new_label
|
||||
changed += 1
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS},"
|
||||
f"标签变化数: {changed}"
|
||||
for batch_start in range(0, total_count, BATCH_SIZE):
|
||||
batch_entities = await self.repo.get_entities_page(
|
||||
end_user_id, skip=batch_start, limit=BATCH_SIZE
|
||||
)
|
||||
if changed == 0:
|
||||
logger.info("[Clustering] 标签已收敛,提前结束迭代")
|
||||
if not batch_entities:
|
||||
break
|
||||
|
||||
# 将最终标签写入 Neo4j
|
||||
batch_ids = [e["id"] for e in batch_entities]
|
||||
batch_embeddings: Dict[str, Optional[List[float]]] = {
|
||||
e["id"]: e.get("name_embedding") for e in batch_entities
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:"
|
||||
f"加载 {len(batch_entities)} 个实体的邻居图..."
|
||||
)
|
||||
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
|
||||
batch_ids, end_user_id
|
||||
)
|
||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
changed = 0
|
||||
for entity in batch_entities:
|
||||
eid = entity["id"]
|
||||
neighbors = neighbors_cache.get(eid, [])
|
||||
|
||||
# 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值)
|
||||
enriched = []
|
||||
for nb in neighbors:
|
||||
nb_copy = dict(nb)
|
||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||
enriched.append(nb_copy)
|
||||
|
||||
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
|
||||
if new_label and new_label != labels[eid]:
|
||||
labels[eid] = new_label
|
||||
changed += 1
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
|
||||
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
|
||||
)
|
||||
if changed == 0:
|
||||
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
|
||||
break
|
||||
|
||||
# 释放本批次的大对象
|
||||
del neighbors_cache, batch_embeddings, batch_entities
|
||||
|
||||
# 所有批次完成,统一写入 Neo4j
|
||||
await self._flush_labels(labels, end_user_id)
|
||||
pre_merge_count = len(set(labels.values()))
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区,"
|
||||
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体,开始后处理合并"
|
||||
)
|
||||
|
||||
all_community_ids = list(set(labels.values()))
|
||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
|
||||
# 查询存活社区并生成元数据
|
||||
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||
surviving_community_ids = list({
|
||||
e.get("community_id") for e in surviving_communities
|
||||
if e.get("community_id")
|
||||
})
|
||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||
for cid in surviving_community_ids:
|
||||
await self._generate_community_metadata(cid, end_user_id)
|
||||
|
||||
async def incremental_update(
|
||||
self, new_entity_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
@@ -306,6 +370,90 @@ class LabelPropagationEngine:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
为社区生成并写入元数据:名称、摘要、核心实体。
|
||||
|
||||
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
||||
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
||||
NOTE: core_entities按照激活值高低排序,会造成对边缘信息检索返回消息质量不高。
|
||||
"""
|
||||
try:
|
||||
members = await self.repo.get_community_members(community_id, end_user_id)
|
||||
if not members:
|
||||
return
|
||||
|
||||
# 核心实体:按 activation_value 降序取 top-N
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
entity_list_str = "、".join(all_names)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||
|
||||
# 生成 summary_embedding
|
||||
summary_embedding = None
|
||||
if self.embedding_model_id and summary:
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
results = await embedder.response([summary])
|
||||
summary_embedding = results[0] if results else None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
|
||||
|
||||
result = await self.repo.update_community_metadata(
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
summary_embedding=summary_embedding,
|
||||
)
|
||||
if result:
|
||||
logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@@ -18,6 +18,7 @@ from app.core.logging_config import LoggingConfig, get_logger
|
||||
from app.core.response_utils import fail
|
||||
from app.core.models.scripts.loader import load_models
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.index_manager import ensure_indexes
|
||||
|
||||
# Initialize logging system
|
||||
LoggingConfig.setup_logging()
|
||||
@@ -61,9 +62,18 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||
|
||||
# 确保 Neo4j 索引存在(幂等,多环境安全)
|
||||
try:
|
||||
report = await ensure_indexes()
|
||||
if report["errors"]:
|
||||
logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}")
|
||||
else:
|
||||
logger.info(f"Neo4j 索引检查完成 [{report['uri']}]")
|
||||
except Exception as e:
|
||||
logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}")
|
||||
|
||||
logger.info("应用程序启动完成")
|
||||
yield
|
||||
# 应用关闭事件
|
||||
logger.info("应用程序正在关闭")
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,14 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||
GET_ENTITY_NEIGHBORS,
|
||||
GET_ALL_ENTITIES_FOR_USER,
|
||||
GET_ENTITIES_PAGE,
|
||||
GET_COMMUNITY_MEMBERS,
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -87,6 +92,41 @@ class CommunityRepository:
|
||||
logger.error(f"get_all_entities failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_entities_page(
|
||||
self, end_user_id: str, skip: int, limit: int
|
||||
) -> List[Dict]:
|
||||
"""分页拉取实体,用于全量聚类分批处理。"""
|
||||
try:
|
||||
return await self.connector.execute_query(
|
||||
GET_ENTITIES_PAGE,
|
||||
end_user_id=end_user_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"get_entities_page failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_entity_neighbors_for_ids(
|
||||
self, entity_ids: List[str], end_user_id: str
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。"""
|
||||
try:
|
||||
rows = await self.connector.execute_query(
|
||||
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
||||
entity_ids=entity_ids,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result: Dict[str, List[Dict]] = {}
|
||||
for row in rows:
|
||||
eid = row["entity_id"]
|
||||
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
|
||||
result.setdefault(eid, []).append(neighbor)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_entity_neighbors_for_ids failed: {e}")
|
||||
return {}
|
||||
|
||||
async def get_community_members(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> List[Dict]:
|
||||
@@ -127,3 +167,28 @@ class CommunityRepository:
|
||||
except Exception as e:
|
||||
logger.error(f"refresh_member_count failed: {e}")
|
||||
return 0
|
||||
|
||||
async def update_community_metadata(
|
||||
self,
|
||||
community_id: str,
|
||||
end_user_id: str,
|
||||
name: str,
|
||||
summary: str,
|
||||
core_entities: List[str],
|
||||
summary_embedding: Optional[List[float]] = None,
|
||||
) -> bool:
|
||||
"""更新社区的名称、摘要、核心实体列表和摘要向量。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
summary_embedding=summary_embedding,
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"update_community_metadata failed: {e}")
|
||||
return False
|
||||
|
||||
@@ -1139,6 +1139,15 @@ RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||
"""
|
||||
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN c.community_id AS community_id,
|
||||
e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||
e.name_embedding AS name_embedding
|
||||
ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC
|
||||
"""
|
||||
|
||||
CHECK_USER_HAS_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
RETURN count(c) AS community_count
|
||||
@@ -1150,3 +1159,128 @@ WITH c, count(e) AS cnt
|
||||
SET c.member_count = cnt
|
||||
RETURN c.community_id AS community_id, cnt AS member_count
|
||||
"""
|
||||
|
||||
UPDATE_COMMUNITY_METADATA = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
SET c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.summary_embedding = $summary_embedding,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_ENTITIES_PAGE = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.name_embedding AS name_embedding,
|
||||
e.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
ORDER BY e.id
|
||||
SKIP $skip LIMIT $limit
|
||||
"""
|
||||
|
||||
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """
|
||||
// 批量拉取指定实体列表的邻居(用于分批全量聚类)
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE e.id IN $entity_ids
|
||||
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE nb2.id <> e.id
|
||||
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||
UNWIND all_neighbors AS nb
|
||||
WITH e, nb WHERE nb IS NOT NULL
|
||||
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN DISTINCT
|
||||
e.id AS entity_id,
|
||||
nb.id AS id,
|
||||
nb.name AS name,
|
||||
nb.name_embedding AS name_embedding,
|
||||
nb.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
||||
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源一:直接关系邻居
|
||||
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源二:同 Statement 共现邻居
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE nb2.id <> e.id
|
||||
|
||||
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||
UNWIND all_neighbors AS nb
|
||||
WITH e, nb WHERE nb IS NOT NULL
|
||||
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN DISTINCT
|
||||
e.id AS entity_id,
|
||||
nb.id AS id,
|
||||
nb.name AS name,
|
||||
nb.name_embedding AS name_embedding,
|
||||
nb.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
|
||||
# Community keyword search: matches name or summary via fulltext index
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
|
||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
RETURN c.community_id AS id,
|
||||
c.name AS name,
|
||||
c.summary AS content,
|
||||
c.core_entities AS core_entities,
|
||||
c.member_count AS member_count,
|
||||
c.end_user_id AS end_user_id,
|
||||
c.updated_at AS updated_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# Community 向量检索 ──────────────────────────────────────────────────
|
||||
# Community embedding-based search: cosine similarity on Community.summary_embedding
|
||||
COMMUNITY_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS c, score
|
||||
WHERE c.summary_embedding IS NOT NULL
|
||||
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||
RETURN c.community_id AS id,
|
||||
c.name AS name,
|
||||
c.summary AS content,
|
||||
c.core_entities AS core_entities,
|
||||
c.member_count AS member_count,
|
||||
c.end_user_id AS end_user_id,
|
||||
c.updated_at AS updated_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# Community 展开检索 ──────────────────────────────────────────────────
|
||||
# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索)
|
||||
EXPAND_COMMUNITY_STATEMENTS = """
|
||||
MATCH (c:Community {community_id: $community_id})
|
||||
MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c)
|
||||
MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
WHERE s.end_user_id = $end_user_id
|
||||
RETURN s.statement AS statement,
|
||||
s.id AS id,
|
||||
s.end_user_id AS end_user_id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
e.name AS source_entity,
|
||||
c.name AS community_name
|
||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -156,10 +156,13 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
connector: Neo4jConnector,
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
||||
schedule_clustering_after_write() 显式触发。
|
||||
|
||||
Args:
|
||||
dialogue_nodes: List of DialogueNode objects to save
|
||||
chunk_nodes: List of ChunkNode objects to save
|
||||
@@ -290,13 +293,6 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
logger.info("Transaction completed. Summary: %s", summary)
|
||||
logger.debug("Full transaction results: %r", results)
|
||||
|
||||
# 写入成功后,触发聚类
|
||||
if entity_nodes:
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
await _trigger_clustering(new_entity_ids, end_user_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -306,9 +302,38 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
return False
|
||||
|
||||
|
||||
def schedule_clustering_after_write(
|
||||
entity_nodes: List,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
写入 Neo4j 成功后,调度后台聚类任务。
|
||||
|
||||
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
||||
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
||||
"""
|
||||
if not entity_nodes:
|
||||
return
|
||||
|
||||
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
|
||||
if not clustering_enabled:
|
||||
logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发")
|
||||
return
|
||||
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||
@@ -318,7 +343,7 @@ async def _trigger_clustering(
|
||||
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||
connector = Neo4jConnector()
|
||||
engine = LabelPropagationEngine(connector)
|
||||
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
COMMUNITY_EMBEDDING_SEARCH,
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
EXPAND_COMMUNITY_STATEMENTS,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
@@ -285,6 +288,15 @@ async def search_graph(
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("communities")
|
||||
|
||||
# Execute all queries in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
@@ -396,6 +408,16 @@ async def search_graph_by_embedding(
|
||||
))
|
||||
task_keys.append("summaries")
|
||||
|
||||
# Communities (向量语义匹配)
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
COMMUNITY_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("communities")
|
||||
|
||||
# Execute all queries in parallel
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
@@ -408,6 +430,7 @@ async def search_graph_by_embedding(
|
||||
"chunks": [],
|
||||
"entities": [],
|
||||
"summaries": [],
|
||||
"communities": [],
|
||||
}
|
||||
|
||||
for key, result in zip(task_keys, task_results):
|
||||
@@ -661,6 +684,62 @@ async def search_graph_by_chunk_id(
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
async def search_graph_community_expand(
|
||||
connector: Neo4jConnector,
|
||||
community_ids: List[str],
|
||||
end_user_id: str,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
三期:社区展开检索 —— 主题 → 细节两级检索。
|
||||
|
||||
命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体,
|
||||
再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点,
|
||||
按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
community_ids: 已命中的社区 ID 列表
|
||||
end_user_id: 用户 ID,用于数据隔离
|
||||
limit: 每个社区最多返回的 Statement 数量
|
||||
|
||||
Returns:
|
||||
{"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]}
|
||||
"""
|
||||
if not community_ids or not end_user_id:
|
||||
return {"expanded_statements": []}
|
||||
|
||||
tasks = [
|
||||
connector.execute_query(
|
||||
EXPAND_COMMUNITY_STATEMENTS,
|
||||
community_id=cid,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
for cid in community_ids
|
||||
]
|
||||
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
expanded: List[Dict[str, Any]] = []
|
||||
for cid, result in zip(community_ids, task_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"社区展开检索失败 community_id={cid}: {result}")
|
||||
else:
|
||||
expanded.extend(result)
|
||||
|
||||
# 按 activation_value 全局排序后去重
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
expanded.sort(
|
||||
key=lambda x: float(x.get("activation_value") or 0),
|
||||
reverse=True,
|
||||
)
|
||||
expanded = _deduplicate_results(expanded)
|
||||
|
||||
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
|
||||
return {"expanded_statements": expanded}
|
||||
|
||||
|
||||
async def search_graph_by_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
254
api/app/repositories/neo4j/index_manager.py
Normal file
254
api/app/repositories/neo4j/index_manager.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j 索引管理模块
|
||||
|
||||
负责检查和创建 Neo4j 全文索引与向量索引。
|
||||
支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。
|
||||
|
||||
用法:
|
||||
# 作为模块调用(应用启动时)
|
||||
from app.repositories.neo4j.index_manager import ensure_indexes
|
||||
await ensure_indexes()
|
||||
|
||||
# 作为独立脚本执行(手动建索引)
|
||||
python -m app.repositories.neo4j.index_manager
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from app.core.config import settings
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 索引定义表
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class FulltextIndexDef:
|
||||
name: str
|
||||
label: str
|
||||
properties: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexDef:
|
||||
name: str
|
||||
label: str
|
||||
property: str
|
||||
dimensions: int
|
||||
similarity: str = "cosine"
|
||||
|
||||
|
||||
# 全文索引清单(现有 + 新增 communities)
|
||||
FULLTEXT_INDEXES: List[FulltextIndexDef] = [
|
||||
FulltextIndexDef("statementsFulltext", "Statement", ["statement"]),
|
||||
FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]),
|
||||
FulltextIndexDef("chunksFulltext", "Chunk", ["content"]),
|
||||
FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]),
|
||||
FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源
|
||||
]
|
||||
|
||||
# 向量索引清单(预留 community 二期)
|
||||
VECTOR_INDEXES: List[VectorIndexDef] = [
|
||||
VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536),
|
||||
VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536),
|
||||
VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536),
|
||||
VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536),
|
||||
# 二期:社区向量索引
|
||||
VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536),
|
||||
]
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 核心检查 / 创建逻辑
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def _get_existing_indexes(connector: Neo4jConnector) -> set:
|
||||
"""查询 Neo4j 中已存在的索引名称集合"""
|
||||
rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name")
|
||||
return {row["name"] for row in rows}
|
||||
|
||||
|
||||
async def _ensure_fulltext_index(
|
||||
connector: Neo4jConnector,
|
||||
idx: FulltextIndexDef,
|
||||
existing: set,
|
||||
) -> str:
|
||||
"""检查并按需创建全文索引,返回操作状态描述"""
|
||||
if idx.name in existing:
|
||||
return f"[SKIP] 全文索引已存在: {idx.name}"
|
||||
|
||||
props = ", ".join(f"n.{p}" for p in idx.properties)
|
||||
cypher = (
|
||||
f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS '
|
||||
f'FOR (n:{idx.label}) ON EACH [{props}]'
|
||||
)
|
||||
await connector.execute_query(cypher)
|
||||
return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label} → {idx.properties})"
|
||||
|
||||
|
||||
async def _ensure_vector_index(
|
||||
connector: Neo4jConnector,
|
||||
idx: VectorIndexDef,
|
||||
existing: set,
|
||||
) -> str:
|
||||
"""检查并按需创建向量索引,返回操作状态描述"""
|
||||
if idx.name in existing:
|
||||
return f"[SKIP] 向量索引已存在: {idx.name}"
|
||||
|
||||
cypher = (
|
||||
f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS "
|
||||
f"FOR (n:{idx.label}) ON n.{idx.property} "
|
||||
f"OPTIONS {{indexConfig: {{"
|
||||
f"`vector.dimensions`: {idx.dimensions}, "
|
||||
f"`vector.similarity_function`: '{idx.similarity}'"
|
||||
f"}}}}"
|
||||
)
|
||||
await connector.execute_query(cypher)
|
||||
return (
|
||||
f"[CREATE] 向量索引已创建: {idx.name} "
|
||||
f"({idx.label}.{idx.property}, dim={idx.dimensions})"
|
||||
)
|
||||
|
||||
|
||||
async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict:
|
||||
"""
|
||||
检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。
|
||||
|
||||
Args:
|
||||
connector: 可选,传入已有连接器;为 None 时自动创建。
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"uri": 当前连接的 Neo4j URI,
|
||||
"fulltext": [操作日志列表],
|
||||
"vector": [操作日志列表],
|
||||
"errors": [错误信息列表],
|
||||
}
|
||||
"""
|
||||
own_connector = connector is None
|
||||
if own_connector:
|
||||
connector = Neo4jConnector()
|
||||
|
||||
report = {
|
||||
"uri": settings.NEO4J_URI,
|
||||
"fulltext": [],
|
||||
"vector": [],
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
try:
|
||||
# 一次性拉取所有已有索引名
|
||||
existing = await _get_existing_indexes(connector)
|
||||
logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}")
|
||||
logger.info(f"[IndexManager] 已有索引数量: {len(existing)}")
|
||||
|
||||
# 处理全文索引
|
||||
for idx in FULLTEXT_INDEXES:
|
||||
try:
|
||||
msg = await _ensure_fulltext_index(connector, idx, existing)
|
||||
report["fulltext"].append(msg)
|
||||
logger.info(f"[IndexManager] {msg}")
|
||||
except Exception as e:
|
||||
err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}"
|
||||
report["errors"].append(err)
|
||||
logger.error(f"[IndexManager] {err}")
|
||||
|
||||
# 处理向量索引
|
||||
for idx in VECTOR_INDEXES:
|
||||
try:
|
||||
msg = await _ensure_vector_index(connector, idx, existing)
|
||||
report["vector"].append(msg)
|
||||
logger.info(f"[IndexManager] {msg}")
|
||||
except Exception as e:
|
||||
err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}"
|
||||
report["errors"].append(err)
|
||||
logger.error(f"[IndexManager] {err}")
|
||||
|
||||
finally:
|
||||
if own_connector:
|
||||
await connector.close()
|
||||
|
||||
return report
|
||||
|
||||
|
||||
async def check_indexes(connector: Neo4jConnector | None = None) -> dict:
|
||||
"""
|
||||
仅检查索引状态,不创建任何索引。
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"uri": ...,
|
||||
"present": [已存在的索引名],
|
||||
"missing_fulltext": [缺失的全文索引名],
|
||||
"missing_vector": [缺失的向量索引名],
|
||||
}
|
||||
"""
|
||||
own_connector = connector is None
|
||||
if own_connector:
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
existing = await _get_existing_indexes(connector)
|
||||
missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing]
|
||||
missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing]
|
||||
|
||||
return {
|
||||
"uri": settings.NEO4J_URI,
|
||||
"present": sorted(existing),
|
||||
"missing_fulltext": missing_ft,
|
||||
"missing_vector": missing_vec,
|
||||
}
|
||||
finally:
|
||||
if own_connector:
|
||||
await connector.close()
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 独立脚本入口
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def _main():
|
||||
import sys
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Neo4j 索引管理工具")
|
||||
print(f"环境: {settings.NEO4J_URI}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 先检查
|
||||
print(">>> 检查当前索引状态...\n")
|
||||
status = await check_indexes()
|
||||
print(f" 已存在索引数: {len(status['present'])}")
|
||||
if status["missing_fulltext"]:
|
||||
print(f" 缺失全文索引: {status['missing_fulltext']}")
|
||||
if status["missing_vector"]:
|
||||
print(f" 缺失向量索引: {status['missing_vector']}")
|
||||
|
||||
if not status["missing_fulltext"] and not status["missing_vector"]:
|
||||
print("\n 所有索引均已存在,无需操作。")
|
||||
return
|
||||
|
||||
# 再创建
|
||||
print("\n>>> 开始创建缺失索引...\n")
|
||||
report = await ensure_indexes()
|
||||
|
||||
for msg in report["fulltext"] + report["vector"]:
|
||||
print(f" {msg}")
|
||||
|
||||
if report["errors"]:
|
||||
print("\n[!] 以下索引创建失败:")
|
||||
for err in report["errors"]:
|
||||
print(f" {err}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("\n 全部索引处理完成。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_main())
|
||||
388
api/app/tasks.py
388
api/app/tasks.py
@@ -2416,3 +2416,391 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_implicit_emotions_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
# 触发型任务标识,区别于 periodic_tasks 队列中的定时任务
|
||||
triggered=True,
|
||||
)
|
||||
def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。
|
||||
存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
|
||||
with get_db_context() as db:
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
existing = repo.get_by_end_user_id(end_user_id)
|
||||
if existing is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无记录,开始初始化")
|
||||
implicit_ok = False
|
||||
emotion_ok = False
|
||||
try:
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||
)
|
||||
implicit_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id, db=db, language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||
)
|
||||
emotion_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||
|
||||
if implicit_ok or emotion_ok:
|
||||
initialized += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 初始化异常: {e}")
|
||||
|
||||
logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_interest_distribution_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
|
||||
默认生成中文(zh)兴趣分布数据。
|
||||
|
||||
Args:
|
||||
self: task object
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
language = "zh"
|
||||
|
||||
service = MemoryAgentService()
|
||||
|
||||
with get_db_context() as db:
|
||||
for end_user_id in end_user_ids:
|
||||
# 存在性检查:缓存有数据则跳过
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
if cached is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
|
||||
try:
|
||||
result = await service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=5,
|
||||
language=language,
|
||||
)
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
expire=INTEREST_CACHE_EXPIRE,
|
||||
)
|
||||
initialized += 1
|
||||
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
|
||||
|
||||
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_perceptual_memory",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_api_config: dict,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""
|
||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
||||
|
||||
This task generates or updates the user's perceptual memory
|
||||
in the backend databases. It is intended to be executed asynchronously
|
||||
via Celery.
|
||||
|
||||
Args:
|
||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
||||
model_api_config (ModelInfo): API configuration for the model
|
||||
used to generate perceptual memory.
|
||||
file_type (str): The file type
|
||||
file_url (url): The url of file
|
||||
file_message (dict): The file message containing details about the file
|
||||
to be processed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
||||
set_asyncio_event_loop()
|
||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
||||
model_info = ModelInfo(**model_api_config)
|
||||
with get_db_context() as db:
|
||||
memory_perceptual_service = MemoryPerceptualService(db)
|
||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
||||
end_user_id,
|
||||
model_info,
|
||||
file_type,
|
||||
file_url,
|
||||
file_message,
|
||||
))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 社区聚类补全任务(触发型)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_community_clustering_for_users",
|
||||
bind=True,
|
||||
ignore_result=False,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900,
|
||||
)
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户 ID 列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
repo = CommunityRepository(connector)
|
||||
|
||||
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
||||
user_llm_map: Dict[str, Optional[str]] = {}
|
||||
user_embedding_map: Dict[str, Optional[str]] = {}
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
batch_configs = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
for uid, cfg_info in batch_configs.items():
|
||||
config_id = cfg_info.get("memory_config_id")
|
||||
if config_id:
|
||||
try:
|
||||
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
||||
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
||||
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
else:
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
try:
|
||||
# 已有社区节点则跳过
|
||||
has_communities = await repo.has_communities(end_user_id)
|
||||
if has_communities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
|
||||
continue
|
||||
|
||||
# 检查是否有 ExtractedEntity 节点
|
||||
entities = await repo.get_all_entities(end_user_id)
|
||||
if not entities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
||||
continue
|
||||
|
||||
# 每个用户使用自己的 llm_model_id
|
||||
llm_model_id = user_llm_map.get(end_user_id)
|
||||
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||
engine = LabelPropagationEngine(
|
||||
connector=connector,
|
||||
llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
)
|
||||
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||
await engine.full_clustering(end_user_id)
|
||||
initialized += 1
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
logger.info(
|
||||
f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}"
|
||||
)
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
loop = set_asyncio_event_loop()
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
Submodule redbear-mem-benchmark updated: 8494e82498...89053e48e9
Reference in New Issue
Block a user