Files
MemoryBear/api/app/services/pilot_run_service.py
lanceyq 9dc9b7aee7 refactor(memory): remove legacy extraction pipeline and add dialog_at temporal grounding
- Delete ExtractionOrchestrator (~2500 lines) and write_tools legacy path;
  MemoryService/WritePipeline is now the sole write path
- Remove NEW_PIPELINE_ENABLED feature flag from memory_agent_service
- Simplify pilot_run_service to always use PilotWritePipeline
- Add dialog_at field to statement and triplet extraction prompts as the
  primary reference time for resolving relative temporal expressions
- Rewrite relative time phrases (e.g. 昨天, 下周) into concrete dates
  directly in statement_text when stably resolvable from dialog_at
- Rename extracat_Pruning.jinja2 to extracat_pruning.jinja2; expand
  few-shot examples and update memory type enum (drop NULL, add
  agreement/repetition/other)
2026-05-08 11:28:24 +08:00

279 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
职责边界:
- 文本解析、语义剪枝、语义分块(预处理)
- 调用 PilotWritePipeline 执行萃取链路
- 输出结果文件
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.config import settings
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
"""Write triplet/entity text report compatible with pipeline_help parsers."""
all_triplets = []
all_entities = []
for dialog in dialog_data_list:
for chunk in getattr(dialog, "chunks", []) or []:
for statement in getattr(chunk, "statements", []) or []:
triplet_info = getattr(statement, "triplet_extraction_info", None)
if not triplet_info:
continue
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
all_entities.extend(getattr(triplet_info, "entities", []) or [])
with open(output_path, "w", encoding="utf-8") as f:
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
for i, triplet in enumerate(all_triplets, 1):
f.write(f"Triplet {i}:\n")
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
f.write(f" Predicate: {triplet.predicate}\n")
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
value = getattr(triplet, "value", None)
if value:
f.write(f" Value: {value}\n")
f.write("\n")
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
for i, entity in enumerate(all_entities, 1):
f.write(f"Entity {i}:\n")
f.write(f" ID: {entity.entity_idx}\n")
f.write(f" Name: {entity.name}\n")
f.write(f" Type: {entity.type}\n")
f.write(f" Description: {entity.description}\n")
f.write("\n")
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
language: str = "zh",
) -> None:
"""执行试运行模式的知识提取流水线。
职责:
1. 文本解析 → 语义剪枝 → 语义分块(预处理,需要 llm_client
2. 调用 PilotWritePipeline 执行萃取链路Pipeline 自行管理客户端)
3. 将萃取结果写入输出文件
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
db: 数据库会话(用于初始化预处理所需的 LLM 客户端)
progress_callback: 可选的进度回调 (stage, message, data)
language: 语言类型 ("zh" | "en")
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
try:
# ── 步骤 1: 初始化预处理所需的 LLM 客户端 ──────────────────────────
# 只用于语义剪枝和分块PilotWritePipeline 内部会自行初始化萃取客户端
step_start = time.time()
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
log_time("Client Initialization", time.time() - step_start, log_file)
# ── 步骤 2: 文本解析 ────────────────────────────────────────────────
step_start = time.time()
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
dialog = DialogData(
context=ConversationContext(msgs=messages),
ref_id="pilot_dialog_1",
end_user_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
# ── 步骤 2.1: 语义剪枝 ─────────────────────────────────────────────
pruned_dialogs = [dialog]
pruning_stats: dict = {"enabled": False}
if memory_config.pruning_enabled:
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
from app.core.memory.models.config_models import PruningConfig
config = PruningConfig(
pruning_switch=memory_config.pruning_enabled,
pruning_scene=memory_config.pruning_scene,
pruning_threshold=memory_config.pruning_threshold,
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
ontology_class_infos=memory_config.ontology_class_infos,
)
original_msgs = [{"role": m.role, "content": m.msg} for m in dialog.context.msgs]
pruned_dialogs = await SemanticPruner(config=config, llm_client=llm_client).prune_dataset([dialog])
if pruned_dialogs and pruned_dialogs[0].context:
remaining = [{"role": m.role, "content": m.msg} for m in pruned_dialogs[0].context.msgs]
# 找出被删除的消息(顺序匹配)
kept_indices: list[int] = []
ri = 0
for oi, om in enumerate(original_msgs):
if ri < len(remaining) and om == remaining[ri]:
kept_indices.append(oi)
ri += 1
deleted_messages = [
{"index": i, "role": m["role"], "content": m["content"]}
for i, m in enumerate(original_msgs)
if i not in kept_indices
]
pruning_stats = {
"enabled": True,
"scene": config.pruning_scene,
"threshold": config.pruning_threshold,
"deleted_count": len(deleted_messages),
}
logger.info(
f"[PILOT_RUN] 语义剪枝完成: {len(original_msgs)}{len(remaining)}"
f"(删除 {len(deleted_messages)} 条)"
)
if progress_callback:
await progress_callback(
"text_preprocessing_result", "语义剪枝完成",
{"type": "pruning", "deleted_messages": deleted_messages},
)
else:
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
pruned_dialogs = [dialog]
except Exception as e:
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
pruned_dialogs = [dialog]
if progress_callback:
await progress_callback(
"text_preprocessing_result", "语义剪枝失败",
{"type": "pruning", "error": str(e), "fallback": "使用原始对话"},
)
# ── 步骤 2.2: 语义分块 ─────────────────────────────────────────────
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
chunked_dialogs = []
for dlg in pruned_dialogs:
dlg.chunks = await DialogueChunker(memory_config.chunker_strategy, llm_client=llm_client).process_dialogue(dlg)
chunked_dialogs.append(dlg)
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks or []):
await progress_callback(
"text_preprocessing_result", f"分块 {i + 1} 处理完成",
{
"type": "chunking",
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
},
)
await progress_callback(
"text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)",
{
"total_chunks": sum(len(dlg.chunks or []) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
"pruning": pruning_stats,
},
)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# ── 步骤 3: 萃取PilotWritePipeline 自行管理客户端和本体加载)──────
step_start = time.time()
logger.info("Running pilot extraction pipeline...")
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
pilot_result = await PilotWritePipeline(
memory_config=memory_config,
end_user_id=str(memory_config.workspace_id),
language=language,
progress_callback=progress_callback,
).run(chunked_dialogs)
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# ── 步骤 4: 输出结果文件 ────────────────────────────────────────────
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
graph = pilot_result.graph
settings.ensure_memory_output_dir()
export_test_input_doc(
entity_nodes=graph.entity_nodes,
statement_entity_edges=graph.stmt_entity_edges,
entity_entity_edges=graph.entity_entity_edges,
)
_save_triplets_from_dialogs(
dialog_data_list=pilot_result.dialog_data_list,
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
)
_write_extracted_result_summary(
chunk_nodes=graph.chunk_nodes,
pipeline_output_dir=settings.get_memory_output_path(),
)
logger.info("Pilot run completed: stop after layer-1 dedup (no Neo4j write)")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")