- Replace extract_user_metadata_task with entity-level extract_metadata_batch_task - Add MetadataExtractionStep following ExtractionStep pattern with Jinja2 prompts - Flatten MetadataExtractionResponse to 9-field schema (aliases, core_facts, etc.) - Add Cypher queries for incremental metadata writeback and alias edge redirection - Wire _extract_metadata into WritePipeline as Step 3.6 (fire-and-forget) - Add pilot_write() to MemoryService; refactor pilot_run_service to use it - Extract snapshot logic into WriteSnapshotRecorder
416 lines
18 KiB
Python
416 lines
18 KiB
Python
"""
|
||
Pilot Run Service - 试运行服务
|
||
|
||
用于执行记忆系统的试运行流程,不保存到 Neo4j。
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Awaitable, Callable, Optional
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging_config import get_memory_logger, log_time
|
||
from app.core.memory.models.message_models import (
|
||
ConversationContext,
|
||
ConversationMessage,
|
||
DialogData,
|
||
)
|
||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||
ExtractionOrchestrator,
|
||
get_chunked_dialogs_from_preprocessed,
|
||
)
|
||
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
||
_write_extracted_result_summary,
|
||
export_test_input_doc,
|
||
)
|
||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||
from app.schemas.memory_config_schema import MemoryConfig
|
||
from sqlalchemy.orm import Session
|
||
|
||
logger = get_memory_logger(__name__)
|
||
|
||
|
||
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
|
||
"""Write triplet/entity text report compatible with pipeline_help parsers."""
|
||
all_triplets = []
|
||
all_entities = []
|
||
|
||
for dialog in dialog_data_list:
|
||
for chunk in getattr(dialog, "chunks", []) or []:
|
||
for statement in getattr(chunk, "statements", []) or []:
|
||
triplet_info = getattr(statement, "triplet_extraction_info", None)
|
||
if not triplet_info:
|
||
continue
|
||
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
|
||
all_entities.extend(getattr(triplet_info, "entities", []) or [])
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
|
||
for i, triplet in enumerate(all_triplets, 1):
|
||
f.write(f"Triplet {i}:\n")
|
||
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
|
||
f.write(f" Predicate: {triplet.predicate}\n")
|
||
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
|
||
value = getattr(triplet, "value", None)
|
||
if value:
|
||
f.write(f" Value: {value}\n")
|
||
f.write("\n")
|
||
|
||
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
|
||
for i, entity in enumerate(all_entities, 1):
|
||
f.write(f"Entity {i}:\n")
|
||
f.write(f" ID: {entity.entity_idx}\n")
|
||
f.write(f" Name: {entity.name}\n")
|
||
f.write(f" Type: {entity.type}\n")
|
||
f.write(f" Description: {entity.description}\n")
|
||
f.write("\n")
|
||
|
||
|
||
async def run_pilot_extraction(
|
||
memory_config: MemoryConfig,
|
||
dialogue_text: str,
|
||
db: Session,
|
||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
|
||
language: str = "zh",
|
||
) -> None:
|
||
"""
|
||
执行试运行模式的知识提取流水线。
|
||
|
||
Args:
|
||
memory_config: 从数据库加载的内存配置对象
|
||
dialogue_text: 输入的对话文本
|
||
db: 数据库会话
|
||
progress_callback: 可选的进度回调函数
|
||
- 参数1 (stage): 当前处理阶段标识符
|
||
- 参数2 (message): 人类可读的进度消息
|
||
- 参数3 (data): 可选的附加数据字典
|
||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||
"""
|
||
log_file = "logs/time.log"
|
||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
with open(log_file, "a", encoding="utf-8") as f:
|
||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||
|
||
pipeline_start = time.time()
|
||
|
||
try:
|
||
# 步骤 1: 初始化客户端
|
||
logger.info("Initializing clients...")
|
||
step_start = time.time()
|
||
|
||
client_factory = MemoryClientFactory(db)
|
||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||
|
||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||
|
||
# 步骤 2: 解析对话文本
|
||
logger.info("Parsing dialogue text...")
|
||
step_start = time.time()
|
||
|
||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||
messages = [
|
||
ConversationMessage(role=r, msg=c.strip())
|
||
for r, c in matches
|
||
if c.strip()
|
||
]
|
||
|
||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||
if not messages:
|
||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||
|
||
context = ConversationContext(msgs=messages)
|
||
dialog = DialogData(
|
||
context=context,
|
||
ref_id="pilot_dialog_1",
|
||
end_user_id=str(memory_config.workspace_id),
|
||
user_id=str(memory_config.tenant_id),
|
||
apply_id=str(memory_config.config_id),
|
||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||
)
|
||
|
||
if progress_callback:
|
||
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
|
||
|
||
# ========== 步骤 2.1: 语义剪枝 ==========
|
||
pruned_dialogs = [dialog]
|
||
deleted_messages = [] # 记录被删除的消息
|
||
pruning_stats = None # 保存剪枝统计信息,用于最终汇总
|
||
|
||
if memory_config.pruning_enabled:
|
||
try:
|
||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||
SemanticPruner,
|
||
)
|
||
from app.core.memory.models.config_models import PruningConfig
|
||
|
||
# 构建剪枝配置
|
||
pruning_config_dict = {
|
||
"pruning_switch": memory_config.pruning_enabled,
|
||
"pruning_scene": memory_config.pruning_scene,
|
||
"pruning_threshold": memory_config.pruning_threshold,
|
||
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
|
||
"ontology_class_infos": memory_config.ontology_class_infos,
|
||
}
|
||
config = PruningConfig(**pruning_config_dict)
|
||
|
||
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||
|
||
# 记录剪枝前的消息(用于对比)
|
||
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
|
||
original_msg_count = len(original_messages)
|
||
|
||
# 执行剪枝
|
||
pruner = SemanticPruner(config=config, llm_client=llm_client)
|
||
pruned_dialogs = await pruner.prune_dataset([dialog])
|
||
|
||
# 计算剪枝结果并找出被删除的消息
|
||
if pruned_dialogs and pruned_dialogs[0].context:
|
||
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
|
||
remaining_msg_count = len(remaining_messages)
|
||
deleted_msg_count = original_msg_count - remaining_msg_count
|
||
|
||
# 找出被删除的消息(基于索引精确匹配)
|
||
# 为剩余消息创建带索引的列表,用于精确追踪
|
||
remaining_with_index = []
|
||
remaining_idx = 0
|
||
for orig_idx, orig_msg in enumerate(original_messages):
|
||
if remaining_idx < len(remaining_messages) and \
|
||
orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \
|
||
orig_msg["content"] == remaining_messages[remaining_idx]["content"]:
|
||
remaining_with_index.append(orig_idx)
|
||
remaining_idx += 1
|
||
|
||
# 找出未在保留列表中的消息索引
|
||
deleted_messages = [
|
||
{"index": idx, "role": msg["role"], "content": msg["content"]}
|
||
for idx, msg in enumerate(original_messages)
|
||
if idx not in remaining_with_index
|
||
]
|
||
|
||
# 保存剪枝统计信息(用于最终汇总,只保留deleted_count)
|
||
pruning_stats = {
|
||
"enabled": True,
|
||
"scene": config.pruning_scene,
|
||
"threshold": config.pruning_threshold,
|
||
"deleted_count": deleted_msg_count,
|
||
}
|
||
|
||
# 输出剪枝结果(显示删除的消息详情)
|
||
pruning_result = {
|
||
"type": "pruning",
|
||
"deleted_messages": deleted_messages,
|
||
}
|
||
|
||
logger.info(
|
||
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
|
||
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
|
||
)
|
||
|
||
if progress_callback:
|
||
await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result)
|
||
else:
|
||
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
|
||
pruned_dialogs = [dialog]
|
||
|
||
except Exception as e:
|
||
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
|
||
pruned_dialogs = [dialog]
|
||
if progress_callback:
|
||
error_result = {
|
||
"type": "pruning",
|
||
"error": str(e),
|
||
"fallback": "使用原始对话"
|
||
}
|
||
await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result)
|
||
else:
|
||
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
|
||
pruning_stats = {
|
||
"enabled": False,
|
||
}
|
||
|
||
# ========== 步骤 2.2: 语义分块 ==========
|
||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||
data=pruned_dialogs,
|
||
chunker_strategy=memory_config.chunker_strategy,
|
||
llm_client=llm_client,
|
||
)
|
||
|
||
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
|
||
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
|
||
|
||
# 进度回调:输出每个分块的结果
|
||
if progress_callback:
|
||
for dlg in chunked_dialogs:
|
||
if hasattr(dlg, 'chunks') and dlg.chunks:
|
||
for i, chunk in enumerate(dlg.chunks):
|
||
chunk_result = {
|
||
"type": "chunking",
|
||
"chunk_index": i + 1,
|
||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||
"full_length": len(chunk.content),
|
||
"dialog_id": dlg.id,
|
||
"chunker_strategy": memory_config.chunker_strategy,
|
||
}
|
||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||
|
||
# 构建预处理完成总结(包含剪枝统计)
|
||
preprocessing_summary = {
|
||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks),
|
||
"total_dialogs": len(chunked_dialogs),
|
||
"chunker_strategy": memory_config.chunker_strategy,
|
||
}
|
||
|
||
# 添加剪枝统计信息(始终包含 pruning 字段,确保前端不会因字段缺失报错)
|
||
preprocessing_summary["pruning"] = pruning_stats if pruning_stats else {
|
||
"enabled": memory_config.pruning_enabled,
|
||
"deleted_count": 0,
|
||
}
|
||
|
||
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
|
||
|
||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||
|
||
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
|
||
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
|
||
logger.info(
|
||
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
|
||
use_refactored,
|
||
)
|
||
logger.info(
|
||
"Initializing %s pilot pipeline...",
|
||
"refactored" if use_refactored else "legacy",
|
||
)
|
||
step_start = time.time()
|
||
|
||
# 加载本体类型(如果配置了 scene_id),支持通用类型回退
|
||
ontology_types = None
|
||
try:
|
||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_with_fallback
|
||
|
||
ontology_types = load_ontology_types_with_fallback(
|
||
scene_id=memory_config.scene_id,
|
||
workspace_id=memory_config.workspace_id,
|
||
db=db,
|
||
enable_general_fallback=True
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
|
||
|
||
if use_refactored:
|
||
from app.core.memory.memory_service import MemoryService
|
||
|
||
memory_service = MemoryService(
|
||
memory_config=memory_config,
|
||
end_user_id=str(memory_config.workspace_id),
|
||
)
|
||
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
|
||
|
||
# 步骤 4a: 执行重构后试运行短链路
|
||
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
|
||
logger.info("Running refactored pilot extraction short pipeline...")
|
||
step_start = time.time()
|
||
|
||
if progress_callback:
|
||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||
|
||
pilot_result = await memory_service.pilot_write(
|
||
chunked_dialogs=chunked_dialogs,
|
||
language=language,
|
||
progress_callback=progress_callback,
|
||
)
|
||
dialog_data_list = pilot_result.dialog_data_list
|
||
graph = pilot_result.graph
|
||
chunk_nodes = graph.chunk_nodes
|
||
export_entity_nodes = graph.entity_nodes
|
||
export_stmt_entity_edges = graph.stmt_entity_edges
|
||
export_entity_edges = graph.entity_entity_edges
|
||
else:
|
||
# 步骤 4b: 执行旧试运行流水线
|
||
logger.info("Running legacy pilot extraction pipeline...")
|
||
step_start = time.time()
|
||
|
||
if progress_callback:
|
||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||
|
||
neo4j_connector = Neo4jConnector()
|
||
try:
|
||
legacy_orchestrator = ExtractionOrchestrator(
|
||
llm_client=llm_client,
|
||
embedder_client=embedder_client,
|
||
connector=neo4j_connector,
|
||
config=get_pipeline_config(memory_config),
|
||
progress_callback=progress_callback,
|
||
embedding_id=str(memory_config.embedding_model_id),
|
||
language=language,
|
||
ontology_types=ontology_types,
|
||
)
|
||
extraction_result = await legacy_orchestrator.run(
|
||
dialog_data_list=chunked_dialogs,
|
||
is_pilot_run=True,
|
||
)
|
||
(
|
||
_dialogue_nodes,
|
||
chunk_nodes,
|
||
_statement_nodes,
|
||
entity_nodes,
|
||
_perceptual_nodes,
|
||
_statement_chunk_edges,
|
||
statement_entity_edges,
|
||
entity_edges,
|
||
_perceptual_edges,
|
||
_last_created_at,
|
||
) = extraction_result
|
||
dialog_data_list = chunked_dialogs
|
||
export_entity_nodes = entity_nodes
|
||
export_stmt_entity_edges = statement_entity_edges
|
||
export_entity_edges = entity_edges
|
||
finally:
|
||
try:
|
||
await neo4j_connector.close()
|
||
except Exception:
|
||
pass
|
||
|
||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||
|
||
if progress_callback:
|
||
await progress_callback("generating_results", "正在生成结果...")
|
||
|
||
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约)
|
||
settings.ensure_memory_output_dir()
|
||
export_test_input_doc(
|
||
entity_nodes=export_entity_nodes,
|
||
statement_entity_edges=export_stmt_entity_edges,
|
||
entity_entity_edges=export_entity_edges,
|
||
)
|
||
_save_triplets_from_dialogs(
|
||
dialog_data_list=dialog_data_list,
|
||
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
|
||
)
|
||
_write_extracted_result_summary(
|
||
chunk_nodes=chunk_nodes,
|
||
pipeline_output_dir=settings.get_memory_output_path(),
|
||
)
|
||
|
||
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||
raise
|
||
|
||
total_time = time.time() - pipeline_start
|
||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
with open(log_file, "a", encoding="utf-8") as f:
|
||
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
|
||
|
||
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")
|