Files
MemoryBear/api/app/services/pilot_run_service.py
lanceyq d66d601e41 refactor(memory): redesign metadata extraction as async pipeline step
- Replace extract_user_metadata_task with entity-level extract_metadata_batch_task
- Add MetadataExtractionStep following ExtractionStep pattern with Jinja2 prompts
- Flatten MetadataExtractionResponse to 9-field schema (aliases, core_facts, etc.)
- Add Cypher queries for incremental metadata writeback and alias edge redirection
- Wire _extract_metadata into WritePipeline as Step 3.6 (fire-and-forget)
- Add pilot_write() to MemoryService; refactor pilot_run_service to use it
- Extract snapshot logic into WriteSnapshotRecorder
2026-05-08 11:27:51 +08:00

416 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.config import settings
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
"""Write triplet/entity text report compatible with pipeline_help parsers."""
all_triplets = []
all_entities = []
for dialog in dialog_data_list:
for chunk in getattr(dialog, "chunks", []) or []:
for statement in getattr(chunk, "statements", []) or []:
triplet_info = getattr(statement, "triplet_extraction_info", None)
if not triplet_info:
continue
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
all_entities.extend(getattr(triplet_info, "entities", []) or [])
with open(output_path, "w", encoding="utf-8") as f:
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
for i, triplet in enumerate(all_triplets, 1):
f.write(f"Triplet {i}:\n")
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
f.write(f" Predicate: {triplet.predicate}\n")
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
value = getattr(triplet, "value", None)
if value:
f.write(f" Value: {value}\n")
f.write("\n")
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
for i, entity in enumerate(all_entities, 1):
f.write(f"Entity {i}:\n")
f.write(f" ID: {entity.entity_idx}\n")
f.write(f" Name: {entity.name}\n")
f.write(f" Type: {entity.type}\n")
f.write(f" Description: {entity.description}\n")
f.write("\n")
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
language: str = "zh",
) -> None:
"""
执行试运行模式的知识提取流水线。
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
db: 数据库会话
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
end_user_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
# ========== 步骤 2.1: 语义剪枝 ==========
pruned_dialogs = [dialog]
deleted_messages = [] # 记录被删除的消息
pruning_stats = None # 保存剪枝统计信息,用于最终汇总
if memory_config.pruning_enabled:
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
from app.core.memory.models.config_models import PruningConfig
# 构建剪枝配置
pruning_config_dict = {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
"ontology_class_infos": memory_config.ontology_class_infos,
}
config = PruningConfig(**pruning_config_dict)
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
# 记录剪枝前的消息(用于对比)
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
original_msg_count = len(original_messages)
# 执行剪枝
pruner = SemanticPruner(config=config, llm_client=llm_client)
pruned_dialogs = await pruner.prune_dataset([dialog])
# 计算剪枝结果并找出被删除的消息
if pruned_dialogs and pruned_dialogs[0].context:
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
remaining_msg_count = len(remaining_messages)
deleted_msg_count = original_msg_count - remaining_msg_count
# 找出被删除的消息(基于索引精确匹配)
# 为剩余消息创建带索引的列表,用于精确追踪
remaining_with_index = []
remaining_idx = 0
for orig_idx, orig_msg in enumerate(original_messages):
if remaining_idx < len(remaining_messages) and \
orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \
orig_msg["content"] == remaining_messages[remaining_idx]["content"]:
remaining_with_index.append(orig_idx)
remaining_idx += 1
# 找出未在保留列表中的消息索引
deleted_messages = [
{"index": idx, "role": msg["role"], "content": msg["content"]}
for idx, msg in enumerate(original_messages)
if idx not in remaining_with_index
]
# 保存剪枝统计信息用于最终汇总只保留deleted_count
pruning_stats = {
"enabled": True,
"scene": config.pruning_scene,
"threshold": config.pruning_threshold,
"deleted_count": deleted_msg_count,
}
# 输出剪枝结果(显示删除的消息详情)
pruning_result = {
"type": "pruning",
"deleted_messages": deleted_messages,
}
logger.info(
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
)
if progress_callback:
await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result)
else:
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
pruned_dialogs = [dialog]
except Exception as e:
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
pruned_dialogs = [dialog]
if progress_callback:
error_result = {
"type": "pruning",
"error": str(e),
"fallback": "使用原始对话"
}
await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result)
else:
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
pruning_stats = {
"enabled": False,
}
# ========== 步骤 2.2: 语义分块 ==========
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=pruned_dialogs,
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
if hasattr(dlg, 'chunks') and dlg.chunks:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"type": "chunking",
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 构建预处理完成总结(包含剪枝统计)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
# 添加剪枝统计信息(始终包含 pruning 字段,确保前端不会因字段缺失报错)
preprocessing_summary["pruning"] = pruning_stats if pruning_stats else {
"enabled": memory_config.pruning_enabled,
"deleted_count": 0,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
logger.info(
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
use_refactored,
)
logger.info(
"Initializing %s pilot pipeline...",
"refactored" if use_refactored else "legacy",
)
step_start = time.time()
# 加载本体类型(如果配置了 scene_id支持通用类型回退
ontology_types = None
try:
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_with_fallback
ontology_types = load_ontology_types_with_fallback(
scene_id=memory_config.scene_id,
workspace_id=memory_config.workspace_id,
db=db,
enable_general_fallback=True
)
except Exception as e:
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
if use_refactored:
from app.core.memory.memory_service import MemoryService
memory_service = MemoryService(
memory_config=memory_config,
end_user_id=str(memory_config.workspace_id),
)
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
# 步骤 4a: 执行重构后试运行短链路
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
logger.info("Running refactored pilot extraction short pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
pilot_result = await memory_service.pilot_write(
chunked_dialogs=chunked_dialogs,
language=language,
progress_callback=progress_callback,
)
dialog_data_list = pilot_result.dialog_data_list
graph = pilot_result.graph
chunk_nodes = graph.chunk_nodes
export_entity_nodes = graph.entity_nodes
export_stmt_entity_edges = graph.stmt_entity_edges
export_entity_edges = graph.entity_entity_edges
else:
# 步骤 4b: 执行旧试运行流水线
logger.info("Running legacy pilot extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
neo4j_connector = Neo4jConnector()
try:
legacy_orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
extraction_result = await legacy_orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
(
_dialogue_nodes,
chunk_nodes,
_statement_nodes,
entity_nodes,
_perceptual_nodes,
_statement_chunk_edges,
statement_entity_edges,
entity_edges,
_perceptual_edges,
_last_created_at,
) = extraction_result
dialog_data_list = chunked_dialogs
export_entity_nodes = entity_nodes
export_stmt_entity_edges = statement_entity_edges
export_entity_edges = entity_edges
finally:
try:
await neo4j_connector.close()
except Exception:
pass
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约)
settings.ensure_memory_output_dir()
export_test_input_doc(
entity_nodes=export_entity_nodes,
statement_entity_edges=export_stmt_entity_edges,
entity_entity_edges=export_entity_edges,
)
_save_triplets_from_dialogs(
dialog_data_list=dialog_data_list,
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
)
_write_extracted_result_summary(
chunk_nodes=chunk_nodes,
pipeline_output_dir=settings.get_memory_output_path(),
)
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")