refactor(memory): redesign metadata extraction as async pipeline step

- Replace extract_user_metadata_task with entity-level extract_metadata_batch_task
- Add MetadataExtractionStep following ExtractionStep pattern with Jinja2 prompts
- Flatten MetadataExtractionResponse to 9-field schema (aliases, core_facts, etc.)
- Add Cypher queries for incremental metadata writeback and alias edge redirection
- Wire _extract_metadata into WritePipeline as Step 3.6 (fire-and-forget)
- Add pilot_write() to MemoryService; refactor pilot_run_service to use it
- Extract snapshot logic into WriteSnapshotRecorder
This commit is contained in:
lanceyq
2026-04-29 18:16:24 +08:00
parent 4af9b02815
commit d66d601e41
23 changed files with 1437 additions and 819 deletions

View File

@@ -1564,9 +1564,201 @@ def extract_emotion_batch_task(
_shutdown_loop_gracefully(loop)
@celery_app.task(
bind=True,
name="app.tasks.extract_metadata_batch",
max_retries=2,
default_retry_delay=30,
)
def extract_metadata_batch_task(
self,
user_entities: List[Dict[str, Any]],
llm_model_id: str,
language: str = "zh",
snapshot_dir: Optional[str] = None,
) -> Dict[str, Any]:
"""Celery task: 用户实体元数据提取 + Neo4j 回写。
在主写入流水线完成后异步执行。从用户实体的 description 中提取
结构化元数据core_facts、traits、relations 等),增量回写到 Neo4j。
Args:
user_entities: 用户实体列表,每项包含:
- entity_id: 实体 ID
- entity_name: 实体名称
- descriptions: description 文本列表
llm_model_id: LLM 模型 UUID 字符串
language: 语言 ("zh" / "en")
snapshot_dir: 可选的快照目录路径(调试模式下使用)
"""
task_id = self.request.id
total = len(user_entities)
logger.info(
f"[Metadata] 开始用户元数据提取: "
f"entities={total}, llm_model_id={llm_model_id}, "
f"language={language}, task_id={task_id}"
)
start_time = time.time()
if not user_entities:
return {"status": "SUCCESS", "total": 0, "extracted": 0, "failed": 0, "task_id": task_id}
async def _run() -> Dict[str, Any]:
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.steps.base import StepContext
from app.core.memory.storage_services.extraction_engine.steps.metadata_step import MetadataExtractionStep
from app.core.memory.storage_services.extraction_engine.steps.schema import (
MetadataStepInput,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.cypher_queries import ENTITY_METADATA_UPDATE, ENTITY_METADATA_QUERY
# Build LLM client
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_model_id)
pipeline_config = ExtractionPipelineConfig()
context = StepContext(
llm_client=llm_client,
language=language,
config=pipeline_config,
)
step = MetadataExtractionStep(context)
extracted = 0
failed = 0
snapshot_outputs: Dict[str, Any] = {} if snapshot_dir else None # type: ignore[assignment]
connector = Neo4jConnector()
try:
for entity_dict in user_entities:
entity_id = entity_dict["entity_id"]
entity_name = entity_dict.get("entity_name", "")
descriptions = entity_dict.get("descriptions", [])
if not descriptions:
logger.debug(f"[Metadata] 跳过无 description 的实体: {entity_id}")
continue
try:
# 查询已有元数据用于增量去重
existing_metadata = {}
try:
records = await connector.execute_query(
ENTITY_METADATA_QUERY, entity_id=entity_id
)
if records:
rec = records[0]
for field in (
"core_facts", "traits", "relations", "goals",
"interests", "beliefs_or_stances", "anchors", "events",
):
val = rec.get(field)
existing_metadata[field] = val if val else []
except Exception as e:
logger.warning(f"[Metadata] 查询已有元数据失败: {e}")
inp = MetadataStepInput(
entity_id=entity_id,
entity_name=entity_name,
descriptions=descriptions,
existing_metadata=existing_metadata,
)
result = await step.run(inp)
if result.has_any():
# 回写 Neo4j
await connector.execute_query(
ENTITY_METADATA_UPDATE,
entity_id=entity_id,
core_facts=result.core_facts,
traits=result.traits,
relations=result.relations,
goals=result.goals,
interests=result.interests,
beliefs_or_stances=result.beliefs_or_stances,
anchors=result.anchors,
events=result.events,
)
extracted += 1
logger.info(
f"[Metadata] 实体 {entity_name}({entity_id}) 元数据提取并回写成功"
)
else:
logger.debug(
f"[Metadata] 实体 {entity_name}({entity_id}) 无新增元数据"
)
if snapshot_outputs is not None:
snapshot_outputs[entity_id] = {
"entity_name": entity_name,
"descriptions": descriptions,
"extracted_metadata": result.model_dump(),
}
except Exception as e:
failed += 1
if snapshot_outputs is not None:
snapshot_outputs[entity_id] = {"error": str(e)}
logger.warning(
f"[Metadata] 实体 {entity_id} 元数据提取失败: {e}"
)
finally:
await connector.close()
# 快照落盘
if snapshot_outputs is not None and snapshot_dir:
try:
from pathlib import Path as _Path
import json as _json
_dir = _Path(snapshot_dir)
_dir.mkdir(parents=True, exist_ok=True)
_path = _dir / "8_metadata_outputs.json"
with open(_path, "w", encoding="utf-8") as _f:
_json.dump(snapshot_outputs, _f, ensure_ascii=False, indent=2, default=str)
logger.info(
f"[Metadata][Snapshot] 已落盘 {len(snapshot_outputs)} 条元数据结果 → {_path}"
)
except Exception as _e:
logger.warning(
f"[Metadata][Snapshot] 快照落盘失败(不影响主流程): {_e}"
)
return {"extracted": extracted, "failed": failed}
loop = None
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed = time.time() - start_time
logger.info(
f"[Metadata] 任务完成: 提取={result['extracted']}, "
f"失败={result['failed']}, 耗时={elapsed:.2f}s, task_id={task_id}"
)
return {
"status": "SUCCESS",
"total": total,
**result,
"elapsed_time": elapsed,
"task_id": task_id,
}
except Exception as e:
elapsed = time.time() - start_time
logger.error(
f"[Metadata] 任务失败: {e}, 耗时={elapsed:.2f}s",
exc_info=True,
)
raise self.retry(exc=e)
finally:
if loop:
_shutdown_loop_gracefully(loop)
# unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]:
# """Call read_service and write latest status to Redis.
# Returns status data dict that gets written to Redis.
@@ -3222,299 +3414,4 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
# ─── User Metadata Extraction Task ───────────────────────────────────────────
def _update_timestamps(existing: dict, new: dict, updated_at: dict, now: str, prefix: str = "") -> None:
"""对比新旧元数据,更新变更字段的 _updated_at 时间戳。"""
for key, new_val in new.items():
if key == "_updated_at":
continue
path = f"{prefix}.{key}" if prefix else key
old_val = existing.get(key)
if isinstance(new_val, dict) and isinstance(old_val, dict):
_update_timestamps(old_val, new_val, updated_at, now, prefix=path)
elif old_val != new_val:
updated_at[path] = now
@celery_app.task(
bind=True,
name='app.tasks.extract_user_metadata',
ignore_result=False,
max_retries=0,
acks_late=True,
time_limit=300,
soft_time_limit=240,
)
def extract_user_metadata_task(
self,
end_user_id: str,
statements: List[str],
config_id: Optional[str] = None,
language: str = "zh",
) -> Dict[str, Any]:
"""异步提取用户元数据并写入数据库。
在去重消歧完成后由编排器触发,使用独立 LLM 调用提取元数据。
LLM 配置优先使用 config_id 对应的应用配置,失败时回退到工作空间默认配置。
Args:
end_user_id: 终端用户 ID
statements: 用户相关的 statement 文本列表
config_id: 应用配置 ID可选
language: 语言类型 ("zh" 中文, "en" 英文)
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
logger.info(
f"[CELERY METADATA] Starting metadata extraction - end_user_id={end_user_id}, "
f"statements_count={len(statements)}, config_id={config_id}, language={language}"
)
async def _run() -> Dict[str, Any]:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import MetadataExtractor
from app.repositories.end_user_info_repository import EndUserInfoRepository
from app.repositories.end_user_repository import EndUserRepository
from app.services.memory_config_service import MemoryConfigService
# 1. 获取 LLM 配置(应用配置 → 工作空间配置兜底)并创建 LLM client
with get_db_context() as db:
end_user_uuid = uuid.UUID(end_user_id)
# 获取 workspace_id from end_user
end_user = EndUserRepository(db).get_by_id(end_user_uuid)
if not end_user:
return {"status": "FAILURE", "error": f"End user not found: {end_user_id}"}
workspace_id = end_user.workspace_id
config_service = MemoryConfigService(db)
memory_config = config_service.get_config_with_fallback(
memory_config_id=uuid.UUID(config_id) if config_id else None,
workspace_id=workspace_id,
)
if not memory_config:
return {"status": "FAILURE", "error": "No LLM config available (app + workspace fallback failed)"}
# 2. 创建 LLM client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
factory = MemoryClientFactory(db)
if not memory_config.llm_id:
return {"status": "FAILURE", "error": "Memory config has no LLM model configured"}
llm_client = factory.get_llm_client(memory_config.llm_id)
# 2.5 读取已有元数据和别名,传给 extractor 作为上下文
existing_metadata = None
existing_aliases = None
try:
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
if info:
if info.meta_data:
existing_metadata = info.meta_data
existing_aliases = info.aliases if info.aliases else []
logger.info(f"[CELERY METADATA] 已读取已有元数据和别名aliases={existing_aliases}")
except Exception as e:
logger.warning(f"[CELERY METADATA] 读取已有数据失败(继续无上下文提取): {e}")
# 3. 提取元数据和别名(传入已有数据作为上下文)
extractor = MetadataExtractor(llm_client=llm_client, language=language)
extract_result = await extractor.extract_metadata(
statements,
existing_metadata=existing_metadata,
existing_aliases=existing_aliases,
)
if not extract_result:
logger.info(f"[CELERY METADATA] No metadata extracted for end_user_id={end_user_id}")
return {"status": "SUCCESS", "result": "no_metadata_extracted"}
metadata_changes, aliases_to_add, aliases_to_remove = extract_result
logger.info(
f"[CELERY METADATA] LLM 元数据变更: {[c.model_dump() for c in metadata_changes]}, "
f"别名新增: {aliases_to_add}, 移除: {aliases_to_remove}"
)
from datetime import datetime as dt, timezone as tz
now = dt.now(tz.utc).isoformat()
# 过滤别名中的占位名称,执行增量增删
_PLACEHOLDER_NAMES = {"用户", "", "user", "i"}
def _filter_aliases(aliases_list):
seen = set()
result = []
for a in aliases_list:
a_stripped = a.strip()
if a_stripped and a_stripped.lower() not in _PLACEHOLDER_NAMES and a_stripped.lower() not in seen:
result.append(a_stripped)
seen.add(a_stripped.lower())
return result
filtered_add = _filter_aliases(aliases_to_add)
filtered_remove = _filter_aliases(aliases_to_remove)
remove_lower = {a.lower() for a in filtered_remove}
with get_db_context() as db:
end_user_uuid = uuid.UUID(end_user_id)
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
end_user = EndUserRepository(db).get_by_id(end_user_uuid)
if info:
# 4. 元数据增量更新(按 LLM 输出的变更操作逐条执行,所有字段均为列表类型)
if metadata_changes:
# 深拷贝,确保 SQLAlchemy 能检测到变更
import copy
existing_meta = copy.deepcopy(info.meta_data) if info.meta_data else {}
updated_at = dict(existing_meta.get("_updated_at", {}))
for change in metadata_changes:
field_path = change.field_path
action = change.action
value = change.value
if not value or not value.strip():
continue
# 定位到目标字段的父级节点
parts = field_path.split(".")
target = existing_meta
for part in parts[:-1]:
target = target.setdefault(part, {})
leaf = parts[-1]
current_list = target.get(leaf, [])
if action == "set":
if value not in current_list:
# 新值插入列表头部,保证按时间从新到旧排序
current_list.insert(0, value)
target[leaf] = current_list
logger.info(f"[CELERY METADATA] set {field_path} = {value}")
elif action == "remove":
if value in current_list:
current_list.remove(value)
target[leaf] = current_list
logger.info(f"[CELERY METADATA] remove {value} from {field_path}")
updated_at[field_path] = now
existing_meta["_updated_at"] = updated_at
# 赋值深拷贝后的新对象SQLAlchemy 会检测到字段变更并写入
info.meta_data = existing_meta
logger.info(f"[CELERY METADATA] 增量更新元数据完成: {json.dumps(existing_meta, ensure_ascii=False)}")
# 别名增量增删:(已有 - remove) + add
old_aliases = info.aliases if info.aliases else []
# 先移除
merged = [a for a in old_aliases if a.strip().lower() not in remove_lower]
# 再追加(去重)
existing_lower = {a.strip().lower() for a in merged}
for a in filtered_add:
if a.lower() not in existing_lower:
merged.append(a)
existing_lower.add(a.lower())
if merged != old_aliases:
info.aliases = merged
# other_name 更新逻辑
if merged and (
not info.other_name
or info.other_name.strip().lower() in _PLACEHOLDER_NAMES
or info.other_name.strip().lower() in remove_lower
):
info.other_name = merged[0]
if end_user and merged and (
not end_user.other_name
or end_user.other_name.strip().lower() in _PLACEHOLDER_NAMES
or end_user.other_name.strip().lower() in remove_lower
):
end_user.other_name = merged[0]
logger.info(
f"[CELERY METADATA] 别名增量更新: {old_aliases} - {filtered_remove} + {filtered_add}{merged}"
)
else:
# 没有 end_user_info 记录,创建一条
from app.models.end_user_info_model import EndUserInfo
initial_aliases = filtered_add # 新记录只有 add没有 remove
first_alias = initial_aliases[0] if initial_aliases else ""
# 从变更操作构建初始元数据(所有字段均为列表类型)
initial_meta = {}
for change in metadata_changes:
if change.action == "set" and change.value is not None and change.value.strip():
parts = change.field_path.split(".")
target = initial_meta
for part in parts[:-1]:
target = target.setdefault(part, {})
leaf = parts[-1]
current_list = target.get(leaf, [])
if change.value not in current_list:
# 新值插入列表头部,保证按时间从新到旧排序
current_list.insert(0, change.value)
target[leaf] = current_list
if first_alias or initial_meta:
new_info = EndUserInfo(
end_user_id=end_user_uuid,
other_name=first_alias or "",
aliases=initial_aliases,
meta_data=initial_meta if initial_meta else None,
)
db.add(new_info)
if end_user and first_alias and (
not end_user.other_name or end_user.other_name.strip().lower() in _PLACEHOLDER_NAMES
):
end_user.other_name = first_alias
logger.info(f"[CELERY METADATA] 创建 end_user_info: other_name={first_alias}, aliases={initial_aliases}")
else:
return {"status": "SUCCESS", "result": "no_data_to_write"}
db.commit()
# 同步 PgSQL aliases 到 Neo4j 用户实体PgSQL 为权威源)
final_aliases = info.aliases if info else initial_aliases
if final_aliases:
try:
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
neo4j_connector = Neo4jConnector()
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I']
SET e.aliases = $aliases
"""
await neo4j_connector.execute_query(
cypher, end_user_id=end_user_id, aliases=final_aliases
)
await neo4j_connector.close()
logger.info(f"[CELERY METADATA] Neo4j 用户实体 aliases 已同步: {final_aliases}")
except Exception as neo4j_err:
logger.warning(f"[CELERY METADATA] Neo4j aliases 同步失败(不影响主流程): {neo4j_err}")
return {"status": "SUCCESS", "result": "metadata_and_aliases_written"}
loop = None
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed = time.time() - start_time
result["elapsed_time"] = elapsed
result["task_id"] = self.request.id
logger.info(f"[CELERY METADATA] Task completed - elapsed={elapsed:.2f}s, result={result.get('result')}")
return result
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"[CELERY METADATA] Task failed - elapsed={elapsed:.2f}s, error={e}", exc_info=True)
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": elapsed,
"task_id": self.request.id,
}
finally:
if loop:
_shutdown_loop_gracefully(loop)
# unused task