[changes] Batch mode for metadata creation and unified management of indexes
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from math import sqrt
|
||||
@@ -114,7 +115,7 @@ class LabelPropagationEngine:
|
||||
- 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息
|
||||
- 所有批次完成后统一 flush 和 merge
|
||||
"""
|
||||
BATCH_SIZE = 2000 # 每批实体数,可按需调整
|
||||
BATCH_SIZE = 888 # 每批实体数,可按需调整
|
||||
|
||||
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
|
||||
total_count = await self.repo.get_entity_count(end_user_id)
|
||||
@@ -203,8 +204,7 @@ class LabelPropagationEngine:
|
||||
if e.get("community_id")
|
||||
})
|
||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||
for cid in surviving_community_ids:
|
||||
await self._generate_community_metadata(cid, end_user_id)
|
||||
await self._generate_community_metadata(surviving_community_ids, end_user_id)
|
||||
|
||||
async def incremental_update(
|
||||
self, new_entity_ids: List[str], end_user_id: str
|
||||
@@ -261,7 +261,7 @@ class LabelPropagationEngine:
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
await self._generate_community_metadata(new_cid, end_user_id)
|
||||
await self._generate_community_metadata([new_cid], end_user_id)
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
@@ -273,7 +273,7 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata(target_cid, end_user_id)
|
||||
await self._generate_community_metadata([target_cid], end_user_id)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
@@ -437,89 +437,122 @@ class LabelPropagationEngine:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
||||
"""将实体列表格式化为 prompt 行,包含 name、aliases、description。"""
|
||||
lines = []
|
||||
for m in members:
|
||||
m_name = m.get("name", "")
|
||||
aliases = m.get("aliases") or []
|
||||
description = m.get("description") or ""
|
||||
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
||||
desc_str = f":{description}" if description else ""
|
||||
lines.append(f"- {m_name}{aliases_str}{desc_str}")
|
||||
return lines
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_id: str, end_user_id: str
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
为社区生成并写入元数据:名称、摘要、核心实体。
|
||||
为一个或多个社区生成并写入元数据。
|
||||
|
||||
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
||||
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
||||
NOTE: core_entities按照激活值高低排序,会造成对边缘信息检索返回消息质量不高。
|
||||
流程:
|
||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||
2. 收集所有 summary,一次性批量 embed
|
||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||
"""
|
||||
try:
|
||||
members = await self.repo.get_community_members(community_id, end_user_id)
|
||||
if not members:
|
||||
return
|
||||
if not community_ids:
|
||||
return
|
||||
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
||||
async def _build_one(cid: str):
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
return None
|
||||
|
||||
# 核心实体:按 activation_value 降序取 top-N
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
name, summary = "", ""
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
|
||||
entity_list_str = "、".join(all_names)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||
results = await asyncio.gather(
|
||||
*[_build_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
metadata_list = []
|
||||
for cid, res in zip(community_ids, results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
|
||||
elif res is not None:
|
||||
metadata_list.append(res)
|
||||
|
||||
# 生成 summary_embedding
|
||||
summary_embedding = None
|
||||
if self.embedding_model_id and summary:
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
results = await embedder.response([summary])
|
||||
summary_embedding = results[0] if results else None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
|
||||
if not metadata_list:
|
||||
return
|
||||
|
||||
# --- 阶段2:批量生成 summary_embedding ---
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
|
||||
# --- 阶段3:写入(单个 or 批量)---
|
||||
if len(metadata_list) == 1:
|
||||
m = metadata_list[0]
|
||||
result = await self.repo.update_community_metadata(
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
summary_embedding=summary_embedding,
|
||||
community_id=m["community_id"],
|
||||
end_user_id=m["end_user_id"],
|
||||
name=m["name"],
|
||||
summary=m["summary"],
|
||||
core_entities=m["core_entities"],
|
||||
summary_embedding=m["summary_embedding"],
|
||||
)
|
||||
if result:
|
||||
logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...")
|
||||
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True)
|
||||
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
||||
else:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if ok:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
|
||||
@@ -18,7 +18,6 @@ from app.core.logging_config import LoggingConfig, get_logger
|
||||
from app.core.response_utils import fail
|
||||
from app.core.models.scripts.loader import load_models
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.index_manager import ensure_indexes
|
||||
|
||||
# Initialize logging system
|
||||
LoggingConfig.setup_logging()
|
||||
@@ -62,16 +61,6 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||
|
||||
# 确保 Neo4j 索引存在(幂等,多环境安全)
|
||||
try:
|
||||
report = await ensure_indexes()
|
||||
if report["errors"]:
|
||||
logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}")
|
||||
else:
|
||||
logger.info(f"Neo4j 索引检查完成 [{report['uri']}]")
|
||||
except Exception as e:
|
||||
logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}")
|
||||
|
||||
logger.info("应用程序启动完成")
|
||||
yield
|
||||
logger.info("应用程序正在关闭")
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -257,3 +258,25 @@ class CommunityRepository:
|
||||
except Exception as e:
|
||||
logger.error(f"update_community_metadata failed: {e}")
|
||||
return False
|
||||
|
||||
async def batch_update_community_metadata(
|
||||
self,
|
||||
communities: List[Dict],
|
||||
) -> bool:
|
||||
"""批量更新多个社区的元数据。
|
||||
|
||||
Args:
|
||||
communities: 每项包含 community_id, end_user_id, name, summary,
|
||||
core_entities, summary_embedding
|
||||
"""
|
||||
if not communities:
|
||||
return True
|
||||
try:
|
||||
await self.connector.execute_query(
|
||||
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||
communities=communities,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"batch_update_community_metadata failed: {e}")
|
||||
return False
|
||||
|
||||
@@ -42,6 +42,13 @@ async def create_fulltext_indexes():
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: summariesFulltext")
|
||||
|
||||
# 创建 Community 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: communitiesFulltext")
|
||||
|
||||
print("\nFull-text indexes created successfully with BM25 support.")
|
||||
except Exception as e:
|
||||
@@ -124,6 +131,18 @@ async def create_vector_indexes():
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: dialogue_embedding_index")
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
FOR (c:Community)
|
||||
ON c.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: community_summary_embedding_index")
|
||||
|
||||
print("\nVector indexes created successfully!")
|
||||
print("\nExpected performance improvement:")
|
||||
|
||||
@@ -1136,7 +1136,8 @@ GET_COMMUNITY_MEMBERS = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||
e.name_embedding AS name_embedding
|
||||
e.name_embedding AS name_embedding,
|
||||
e.aliases AS aliases, e.description AS description
|
||||
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||
"""
|
||||
|
||||
@@ -1145,7 +1146,8 @@ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(
|
||||
RETURN c.community_id AS community_id,
|
||||
e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||
e.name_embedding AS name_embedding
|
||||
e.name_embedding AS name_embedding,
|
||||
e.aliases AS aliases, e.description AS description
|
||||
ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC
|
||||
"""
|
||||
|
||||
@@ -1171,6 +1173,17 @@ SET c.name = $name,
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
BATCH_UPDATE_COMMUNITY_METADATA = """
|
||||
UNWIND $communities AS row
|
||||
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
||||
SET c.name = row.name,
|
||||
c.summary = row.summary,
|
||||
c.core_entities = row.core_entities,
|
||||
c.summary_embedding = row.summary_embedding,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_ENTITIES_PAGE = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Neo4j 索引管理模块
|
||||
|
||||
负责检查和创建 Neo4j 全文索引与向量索引。
|
||||
支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。
|
||||
|
||||
用法:
|
||||
# 作为模块调用(应用启动时)
|
||||
from app.repositories.neo4j.index_manager import ensure_indexes
|
||||
await ensure_indexes()
|
||||
|
||||
# 作为独立脚本执行(手动建索引)
|
||||
python -m app.repositories.neo4j.index_manager
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from app.core.config import settings
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 索引定义表
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class FulltextIndexDef:
|
||||
name: str
|
||||
label: str
|
||||
properties: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexDef:
|
||||
name: str
|
||||
label: str
|
||||
property: str
|
||||
dimensions: int
|
||||
similarity: str = "cosine"
|
||||
|
||||
|
||||
# 全文索引清单(现有 + 新增 communities)
|
||||
FULLTEXT_INDEXES: List[FulltextIndexDef] = [
|
||||
FulltextIndexDef("statementsFulltext", "Statement", ["statement"]),
|
||||
FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]),
|
||||
FulltextIndexDef("chunksFulltext", "Chunk", ["content"]),
|
||||
FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]),
|
||||
FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源
|
||||
]
|
||||
|
||||
# 向量索引清单(预留 community 二期)
|
||||
VECTOR_INDEXES: List[VectorIndexDef] = [
|
||||
VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536),
|
||||
VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536),
|
||||
VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536),
|
||||
VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536),
|
||||
# 二期:社区向量索引
|
||||
VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536),
|
||||
]
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 核心检查 / 创建逻辑
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def _get_existing_indexes(connector: Neo4jConnector) -> set:
|
||||
"""查询 Neo4j 中已存在的索引名称集合"""
|
||||
rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name")
|
||||
return {row["name"] for row in rows}
|
||||
|
||||
|
||||
async def _ensure_fulltext_index(
|
||||
connector: Neo4jConnector,
|
||||
idx: FulltextIndexDef,
|
||||
existing: set,
|
||||
) -> str:
|
||||
"""检查并按需创建全文索引,返回操作状态描述"""
|
||||
if idx.name in existing:
|
||||
return f"[SKIP] 全文索引已存在: {idx.name}"
|
||||
|
||||
props = ", ".join(f"n.{p}" for p in idx.properties)
|
||||
cypher = (
|
||||
f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS '
|
||||
f'FOR (n:{idx.label}) ON EACH [{props}]'
|
||||
)
|
||||
await connector.execute_query(cypher)
|
||||
return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label} → {idx.properties})"
|
||||
|
||||
|
||||
async def _ensure_vector_index(
|
||||
connector: Neo4jConnector,
|
||||
idx: VectorIndexDef,
|
||||
existing: set,
|
||||
) -> str:
|
||||
"""检查并按需创建向量索引,返回操作状态描述"""
|
||||
if idx.name in existing:
|
||||
return f"[SKIP] 向量索引已存在: {idx.name}"
|
||||
|
||||
cypher = (
|
||||
f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS "
|
||||
f"FOR (n:{idx.label}) ON n.{idx.property} "
|
||||
f"OPTIONS {{indexConfig: {{"
|
||||
f"`vector.dimensions`: {idx.dimensions}, "
|
||||
f"`vector.similarity_function`: '{idx.similarity}'"
|
||||
f"}}}}"
|
||||
)
|
||||
await connector.execute_query(cypher)
|
||||
return (
|
||||
f"[CREATE] 向量索引已创建: {idx.name} "
|
||||
f"({idx.label}.{idx.property}, dim={idx.dimensions})"
|
||||
)
|
||||
|
||||
|
||||
async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict:
|
||||
"""
|
||||
检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。
|
||||
|
||||
Args:
|
||||
connector: 可选,传入已有连接器;为 None 时自动创建。
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"uri": 当前连接的 Neo4j URI,
|
||||
"fulltext": [操作日志列表],
|
||||
"vector": [操作日志列表],
|
||||
"errors": [错误信息列表],
|
||||
}
|
||||
"""
|
||||
own_connector = connector is None
|
||||
if own_connector:
|
||||
connector = Neo4jConnector()
|
||||
|
||||
report = {
|
||||
"uri": settings.NEO4J_URI,
|
||||
"fulltext": [],
|
||||
"vector": [],
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
try:
|
||||
# 一次性拉取所有已有索引名
|
||||
existing = await _get_existing_indexes(connector)
|
||||
logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}")
|
||||
logger.info(f"[IndexManager] 已有索引数量: {len(existing)}")
|
||||
|
||||
# 处理全文索引
|
||||
for idx in FULLTEXT_INDEXES:
|
||||
try:
|
||||
msg = await _ensure_fulltext_index(connector, idx, existing)
|
||||
report["fulltext"].append(msg)
|
||||
logger.info(f"[IndexManager] {msg}")
|
||||
except Exception as e:
|
||||
err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}"
|
||||
report["errors"].append(err)
|
||||
logger.error(f"[IndexManager] {err}")
|
||||
|
||||
# 处理向量索引
|
||||
for idx in VECTOR_INDEXES:
|
||||
try:
|
||||
msg = await _ensure_vector_index(connector, idx, existing)
|
||||
report["vector"].append(msg)
|
||||
logger.info(f"[IndexManager] {msg}")
|
||||
except Exception as e:
|
||||
err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}"
|
||||
report["errors"].append(err)
|
||||
logger.error(f"[IndexManager] {err}")
|
||||
|
||||
finally:
|
||||
if own_connector:
|
||||
await connector.close()
|
||||
|
||||
return report
|
||||
|
||||
|
||||
async def check_indexes(connector: Neo4jConnector | None = None) -> dict:
|
||||
"""
|
||||
仅检查索引状态,不创建任何索引。
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"uri": ...,
|
||||
"present": [已存在的索引名],
|
||||
"missing_fulltext": [缺失的全文索引名],
|
||||
"missing_vector": [缺失的向量索引名],
|
||||
}
|
||||
"""
|
||||
own_connector = connector is None
|
||||
if own_connector:
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
existing = await _get_existing_indexes(connector)
|
||||
missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing]
|
||||
missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing]
|
||||
|
||||
return {
|
||||
"uri": settings.NEO4J_URI,
|
||||
"present": sorted(existing),
|
||||
"missing_fulltext": missing_ft,
|
||||
"missing_vector": missing_vec,
|
||||
}
|
||||
finally:
|
||||
if own_connector:
|
||||
await connector.close()
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 独立脚本入口
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def _main():
|
||||
import sys
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Neo4j 索引管理工具")
|
||||
print(f"环境: {settings.NEO4J_URI}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 先检查
|
||||
print(">>> 检查当前索引状态...\n")
|
||||
status = await check_indexes()
|
||||
print(f" 已存在索引数: {len(status['present'])}")
|
||||
if status["missing_fulltext"]:
|
||||
print(f" 缺失全文索引: {status['missing_fulltext']}")
|
||||
if status["missing_vector"]:
|
||||
print(f" 缺失向量索引: {status['missing_vector']}")
|
||||
|
||||
if not status["missing_fulltext"] and not status["missing_vector"]:
|
||||
print("\n 所有索引均已存在,无需操作。")
|
||||
return
|
||||
|
||||
# 再创建
|
||||
print("\n>>> 开始创建缺失索引...\n")
|
||||
report = await ensure_indexes()
|
||||
|
||||
for msg in report["fulltext"] + report["vector"]:
|
||||
print(f" {msg}")
|
||||
|
||||
if report["errors"]:
|
||||
print("\n[!] 以下索引创建失败:")
|
||||
for err in report["errors"]:
|
||||
print(f" {err}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("\n 全部索引处理完成。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_main())
|
||||
Reference in New Issue
Block a user