refactor(memory): add PilotWritePipeline and enrich extraction schema

- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write)
- Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders
- Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction
- Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor
- Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing
- Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
lanceyq
2026-04-27 18:15:46 +08:00
parent b0ddd12cc6
commit 2355536b44
23 changed files with 806 additions and 1070 deletions

View File

@@ -10,7 +10,9 @@ import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.config import settings
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
@@ -20,9 +22,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
@@ -31,6 +35,42 @@ from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
"""Write triplet/entity text report compatible with pipeline_help parsers."""
all_triplets = []
all_entities = []
for dialog in dialog_data_list:
for chunk in getattr(dialog, "chunks", []) or []:
for statement in getattr(chunk, "statements", []) or []:
triplet_info = getattr(statement, "triplet_extraction_info", None)
if not triplet_info:
continue
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
all_entities.extend(getattr(triplet_info, "entities", []) or [])
with open(output_path, "w", encoding="utf-8") as f:
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
for i, triplet in enumerate(all_triplets, 1):
f.write(f"Triplet {i}:\n")
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
f.write(f" Predicate: {triplet.predicate}\n")
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
value = getattr(triplet, "value", None)
if value:
f.write(f" Value: {value}\n")
f.write("\n")
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
for i, entity in enumerate(all_entities, 1):
f.write(f"Entity {i}:\n")
f.write(f" ID: {entity.entity_idx}\n")
f.write(f" Name: {entity.name}\n")
f.write(f" Type: {entity.type}\n")
f.write(f" Description: {entity.description}\n")
f.write("\n")
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
@@ -58,7 +98,6 @@ async def run_pilot_extraction(
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
@@ -69,8 +108,6 @@ async def run_pilot_extraction(
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
@@ -242,15 +279,17 @@ async def run_pilot_extraction(
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
use_refactored,
)
logger.info(
"Initializing %s pilot pipeline...",
"refactored" if use_refactored else "legacy",
)
step_start = time.time()
# 加载本体类型(如果配置了 scene_id支持通用类型回退
ontology_types = None
@@ -266,100 +305,105 @@ async def run_pilot_extraction(
except Exception as e:
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
if use_refactored:
pilot_pipeline = PilotWritePipeline(
llm_client=llm_client,
embedder_client=embedder_client,
pipeline_config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4a: 执行重构后试运行短链路
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
logger.info("Running refactored pilot extraction short pipeline...")
step_start = time.time()
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
pilot_result = await pilot_pipeline.run(chunked_dialogs)
dialog_data_list = pilot_result.dialog_data_list
graph = pilot_result.graph
chunk_nodes = graph.chunk_nodes
export_entity_nodes = graph.entity_nodes
export_stmt_entity_edges = graph.stmt_entity_edges
export_entity_edges = graph.entity_entity_edges
else:
# 步骤 4b: 执行旧试运行流水线
logger.info("Running legacy pilot extraction pipeline...")
step_start = time.time()
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
_,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
_,
_
) = extraction_result
neo4j_connector = Neo4jConnector()
try:
legacy_orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
extraction_result = await legacy_orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
(
_dialogue_nodes,
chunk_nodes,
_statement_nodes,
entity_nodes,
_perceptual_nodes,
_statement_chunk_edges,
statement_entity_edges,
entity_edges,
_perceptual_edges,
_last_created_at,
) = extraction_result
dialog_data_list = chunked_dialogs
export_entity_nodes = entity_nodes
export_stmt_entity_edges = statement_entity_edges
export_entity_edges = entity_edges
finally:
try:
await neo4j_connector.close()
except Exception:
pass
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致
try:
logger.info("Generating memory summaries...")
step_start = time.time()
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约
settings.ensure_memory_output_dir()
export_test_input_doc(
entity_nodes=export_entity_nodes,
statement_entity_edges=export_stmt_entity_edges,
entity_entity_edges=export_entity_edges,
)
_save_triplets_from_dialogs(
dialog_data_list=dialog_data_list,
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
)
_write_extracted_result_summary(
chunk_nodes=chunk_nodes,
pipeline_output_dir=settings.get_memory_output_path(),
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
language=language,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
# 将提取统计写入 Redis按 workspace_id 存储
try:
from app.cache.memory.activity_stats_cache import ActivityStatsCache
stats_to_cache = {
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
"statements_count": len(statement_nodes) if statement_nodes else 0,
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
"temporal_count": 0, # temporal 数据在日志中此处暂置0
}
await ActivityStatsCache.set_activity_stats(
workspace_id=str(memory_config.workspace_id),
stats=stats_to_cache,
)
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
except Exception as cache_err:
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)