refactor(memory): remove legacy extraction pipeline and add dialog_at temporal grounding

- Delete ExtractionOrchestrator (~2500 lines) and write_tools legacy path;
  MemoryService/WritePipeline is now the sole write path
- Remove NEW_PIPELINE_ENABLED feature flag from memory_agent_service
- Simplify pilot_run_service to always use PilotWritePipeline
- Add dialog_at field to statement and triplet extraction prompts as the
  primary reference time for resolving relative temporal expressions
- Rewrite relative time phrases (e.g. 昨天, 下周) into concrete dates
  directly in statement_text when stably resolvable from dialog_at
- Rename extracat_Pruning.jinja2 to extracat_pruning.jinja2; expand
  few-shot examples and update memory type enum (drop NULL, add
  agreement/repetition/other)
This commit is contained in:
lanceyq
2026-04-30 15:58:08 +08:00
parent cf389bb978
commit 9dc9b7aee7
16 changed files with 386 additions and 3327 deletions

View File

@@ -34,7 +34,6 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.memory_service import MemoryService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -447,32 +446,20 @@ class MemoryAgentService:
memory_config,
language: Language | str,
) -> None:
"""根据 NEW_PIPELINE_ENABLED 选择新旧流水线写入 Neo4j。"""
# 统一转换为 dict下游流水线期望 list[dict]
"""使用新流水线MemoryService → WritePipeline写入 Neo4j。"""
messages_dict = [
msg if isinstance(msg, dict) else msg.model_dump(exclude_none=True)
for msg in messages
]
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
if use_new_pipeline:
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
result = await service.write(
messages=messages_dict, language=language, ref_id='',
)
logger.info(
f"[NewPipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, "
f"extraction={result.extraction}"
)
else:
await write_neo4j(
end_user_id=end_user_id,
messages=messages_dict,
memory_config=memory_config,
ref_id='',
language=language,
)
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
result = await service.write(
messages=messages_dict, language=language, ref_id='',
)
logger.info(
f"[WritePipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, "
f"extraction={result.extraction}"
)
async def _invalidate_interest_cache(self, end_user_id: str) -> None:
"""写入完成后失效兴趣分布缓存。"""

View File

@@ -2,6 +2,11 @@
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
职责边界:
- 文本解析、语义剪枝、语义分块(预处理)
- 调用 PilotWritePipeline 执行萃取链路
- 输出结果文件
"""
import os
@@ -17,17 +22,10 @@ from app.core.memory.models.message_models import (
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
@@ -77,18 +75,19 @@ async def run_pilot_extraction(
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
language: str = "zh",
) -> None:
"""
执行试运行模式的知识提取流水线。
"""执行试运行模式的知识提取流水线。
职责:
1. 文本解析 → 语义剪枝 → 语义分块(预处理,需要 llm_client
2. 调用 PilotWritePipeline 执行萃取链路Pipeline 自行管理客户端)
3. 将萃取结果写入输出文件
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
db: 数据库会话
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
db: 数据库会话(用于初始化预处理所需的 LLM 客户端)
progress_callback: 可选的进度回调 (stage, message, data)
language: 语言类型 ("zh" | "en")
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
@@ -99,21 +98,16 @@ async def run_pilot_extraction(
pipeline_start = time.time()
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
# ── 步骤 1: 初始化预处理所需的 LLM 客户端 ──────────────────────────
# 只用于语义剪枝和分块PilotWritePipeline 内部会自行初始化萃取客户端
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
# ── 步骤 2: 文本解析 ────────────────────────────────────────────────
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
@@ -121,14 +115,11 @@ async def run_pilot_extraction(
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
context=ConversationContext(msgs=messages),
ref_id="pilot_dialog_1",
end_user_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
@@ -139,267 +130,142 @@ async def run_pilot_extraction(
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
# ========== 步骤 2.1: 语义剪枝 ==========
# ── 步骤 2.1: 语义剪枝 ─────────────────────────────────────────────
pruned_dialogs = [dialog]
deleted_messages = [] # 记录被删除的消息
pruning_stats = None # 保存剪枝统计信息,用于最终汇总
pruning_stats: dict = {"enabled": False}
if memory_config.pruning_enabled:
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
from app.core.memory.models.config_models import PruningConfig
# 构建剪枝配置
pruning_config_dict = {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
"ontology_class_infos": memory_config.ontology_class_infos,
}
config = PruningConfig(**pruning_config_dict)
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
# 记录剪枝前的消息(用于对比)
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
original_msg_count = len(original_messages)
# 执行剪枝
pruner = SemanticPruner(config=config, llm_client=llm_client)
pruned_dialogs = await pruner.prune_dataset([dialog])
# 计算剪枝结果并找出被删除的消息
config = PruningConfig(
pruning_switch=memory_config.pruning_enabled,
pruning_scene=memory_config.pruning_scene,
pruning_threshold=memory_config.pruning_threshold,
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
ontology_class_infos=memory_config.ontology_class_infos,
)
original_msgs = [{"role": m.role, "content": m.msg} for m in dialog.context.msgs]
pruned_dialogs = await SemanticPruner(config=config, llm_client=llm_client).prune_dataset([dialog])
if pruned_dialogs and pruned_dialogs[0].context:
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
remaining_msg_count = len(remaining_messages)
deleted_msg_count = original_msg_count - remaining_msg_count
# 找出被删除的消息(基于索引精确匹配)
# 为剩余消息创建带索引的列表,用于精确追踪
remaining_with_index = []
remaining_idx = 0
for orig_idx, orig_msg in enumerate(original_messages):
if remaining_idx < len(remaining_messages) and \
orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \
orig_msg["content"] == remaining_messages[remaining_idx]["content"]:
remaining_with_index.append(orig_idx)
remaining_idx += 1
# 找出未在保留列表中的消息索引
remaining = [{"role": m.role, "content": m.msg} for m in pruned_dialogs[0].context.msgs]
# 找出被删除的消息(顺序匹配)
kept_indices: list[int] = []
ri = 0
for oi, om in enumerate(original_msgs):
if ri < len(remaining) and om == remaining[ri]:
kept_indices.append(oi)
ri += 1
deleted_messages = [
{"index": idx, "role": msg["role"], "content": msg["content"]}
for idx, msg in enumerate(original_messages)
if idx not in remaining_with_index
{"index": i, "role": m["role"], "content": m["content"]}
for i, m in enumerate(original_msgs)
if i not in kept_indices
]
# 保存剪枝统计信息用于最终汇总只保留deleted_count
pruning_stats = {
"enabled": True,
"scene": config.pruning_scene,
"threshold": config.pruning_threshold,
"deleted_count": deleted_msg_count,
"deleted_count": len(deleted_messages),
}
# 输出剪枝结果(显示删除的消息详情)
pruning_result = {
"type": "pruning",
"deleted_messages": deleted_messages,
}
logger.info(
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
f"[PILOT_RUN] 语义剪枝完成: {len(original_msgs)}{len(remaining)}"
f"(删除 {len(deleted_messages)} 条)"
)
if progress_callback:
await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result)
await progress_callback(
"text_preprocessing_result", "语义剪枝完成",
{"type": "pruning", "deleted_messages": deleted_messages},
)
else:
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
pruned_dialogs = [dialog]
except Exception as e:
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
pruned_dialogs = [dialog]
if progress_callback:
error_result = {
"type": "pruning",
"error": str(e),
"fallback": "使用原始对话"
}
await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result)
else:
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
pruning_stats = {
"enabled": False,
}
await progress_callback(
"text_preprocessing_result", "语义剪枝失败",
{"type": "pruning", "error": str(e), "fallback": "使用原始对话"},
)
# ========== 步骤 2.2: 语义分块 ==========
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=pruned_dialogs,
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
# ── 步骤 2.2: 语义分块 ─────────────────────────────────────────────
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
chunked_dialogs = []
for dlg in pruned_dialogs:
dlg.chunks = await DialogueChunker(memory_config.chunker_strategy, llm_client=llm_client).process_dialogue(dlg)
chunked_dialogs.append(dlg)
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
if hasattr(dlg, 'chunks') and dlg.chunks:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
for i, chunk in enumerate(dlg.chunks or []):
await progress_callback(
"text_preprocessing_result", f"分块 {i + 1} 处理完成",
{
"type": "chunking",
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 构建预处理完成总结(包含剪枝统计)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
# 添加剪枝统计信息(始终包含 pruning 字段,确保前端不会因字段缺失报错)
preprocessing_summary["pruning"] = pruning_stats if pruning_stats else {
"enabled": memory_config.pruning_enabled,
"deleted_count": 0,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
},
)
await progress_callback(
"text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)",
{
"total_chunks": sum(len(dlg.chunks or []) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
"pruning": pruning_stats,
},
)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
logger.info(
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
use_refactored,
)
logger.info(
"Initializing %s pilot pipeline...",
"refactored" if use_refactored else "legacy",
)
# ── 步骤 3: 萃取PilotWritePipeline 自行管理客户端和本体加载)──────
step_start = time.time()
logger.info("Running pilot extraction pipeline...")
# 加载本体类型(如果配置了 scene_id支持通用类型回退
ontology_types = None
try:
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_with_fallback
ontology_types = load_ontology_types_with_fallback(
scene_id=memory_config.scene_id,
workspace_id=memory_config.workspace_id,
db=db,
enable_general_fallback=True
)
except Exception as e:
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
if use_refactored:
from app.core.memory.memory_service import MemoryService
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
memory_service = MemoryService(
memory_config=memory_config,
end_user_id=str(memory_config.workspace_id),
)
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
# 步骤 4a: 执行重构后试运行短链路
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
logger.info("Running refactored pilot extraction short pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
pilot_result = await memory_service.pilot_write(
chunked_dialogs=chunked_dialogs,
language=language,
progress_callback=progress_callback,
)
dialog_data_list = pilot_result.dialog_data_list
graph = pilot_result.graph
chunk_nodes = graph.chunk_nodes
export_entity_nodes = graph.entity_nodes
export_stmt_entity_edges = graph.stmt_entity_edges
export_entity_edges = graph.entity_entity_edges
else:
# 步骤 4b: 执行旧试运行流水线
logger.info("Running legacy pilot extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
neo4j_connector = Neo4jConnector()
try:
legacy_orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
extraction_result = await legacy_orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
(
_dialogue_nodes,
chunk_nodes,
_statement_nodes,
entity_nodes,
_perceptual_nodes,
_statement_chunk_edges,
statement_entity_edges,
entity_edges,
_perceptual_edges,
_last_created_at,
) = extraction_result
dialog_data_list = chunked_dialogs
export_entity_nodes = entity_nodes
export_stmt_entity_edges = statement_entity_edges
export_entity_edges = entity_edges
finally:
try:
await neo4j_connector.close()
except Exception:
pass
pilot_result = await PilotWritePipeline(
memory_config=memory_config,
end_user_id=str(memory_config.workspace_id),
language=language,
progress_callback=progress_callback,
).run(chunked_dialogs)
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# ── 步骤 4: 输出结果文件 ────────────────────────────────────────────
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约)
graph = pilot_result.graph
settings.ensure_memory_output_dir()
export_test_input_doc(
entity_nodes=export_entity_nodes,
statement_entity_edges=export_stmt_entity_edges,
entity_entity_edges=export_entity_edges,
entity_nodes=graph.entity_nodes,
statement_entity_edges=graph.stmt_entity_edges,
entity_entity_edges=graph.entity_entity_edges,
)
_save_triplets_from_dialogs(
dialog_data_list=dialog_data_list,
dialog_data_list=pilot_result.dialog_data_list,
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
)
_write_extracted_result_summary(
chunk_nodes=chunk_nodes,
chunk_nodes=graph.chunk_nodes,
pipeline_output_dir=settings.get_memory_output_path(),
)
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
logger.info("Pilot run completed: stop after layer-1 dedup (no Neo4j write)")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
@@ -407,9 +273,6 @@ async def run_pilot_extraction(
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
f.write(f"=== Pilot Run Completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")