refactor(memory): add PilotWritePipeline and enrich extraction schema

- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write)
- Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders
- Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction
- Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor
- Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing
- Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
lanceyq
2026-04-27 18:15:46 +08:00
parent b0ddd12cc6
commit 2355536b44
23 changed files with 806 additions and 1070 deletions

View File

@@ -441,21 +441,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
# 步骤 6: 计算本体覆盖率并合并到结果中
# 步骤 6: 组装结果(试运行不做额外覆盖率后处理)
result_data = {
"config_id": cid,
"time_log": os.path.join(project_root, "logs", "time.log"),
"extracted_result": extracted_result,
}
try:
ontology_coverage = await self._compute_ontology_coverage(
extracted_result=extracted_result,
memory_config=memory_config,
)
if ontology_coverage:
result_data["ontology_coverage"] = ontology_coverage
except Exception as cov_err:
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
yield format_sse_message("result", result_data)
@@ -479,100 +470,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000)
})
async def _compute_ontology_coverage(
self,
extracted_result: Dict[str, Any],
memory_config,
) -> Optional[Dict[str, Any]]:
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
分类规则(互斥):场景类型优先 > 通用类型 > 未匹配
确保: 场景实体数 + 通用实体数 + 未匹配数 = 总实体数
Returns:
包含三部分统计的字典,或 None无实体数据时
"""
core_entities = extracted_result.get("core_entities", [])
if not core_entities:
return None
# 1. 加载场景本体类型集合
scene_ontology_types: set = set()
try:
from app.repositories.ontology_class_repository import OntologyClassRepository
if memory_config.scene_id:
class_repo = OntologyClassRepository(self.db)
ontology_classes = class_repo.get_classes_by_scene(memory_config.scene_id)
scene_ontology_types = {oc.class_name for oc in ontology_classes}
except Exception as e:
logger.warning(f"Failed to load scene ontology types: {e}")
# 2. 加载通用本体类型集合
general_ontology_types: set = set()
try:
from app.core.memory.ontology_services.ontology_type_loader import (
get_general_ontology_registry,
is_general_ontology_enabled,
)
if is_general_ontology_enabled():
registry = get_general_ontology_registry()
if registry:
general_ontology_types = set(registry.types.keys())
except Exception as e:
logger.warning(f"Failed to load general ontology types: {e}")
# 3. 互斥分类:场景优先 > 通用 > 未匹配
scene_distribution: list = []
general_distribution: list = []
unmatched_distribution: list = []
scene_total = 0
general_total = 0
unmatched_total = 0
for item in core_entities:
entity_type = item.get("type", "")
count = item.get("count", 0)
if entity_type in scene_ontology_types:
scene_distribution.append({"type": entity_type, "count": count})
scene_total += count
elif entity_type in general_ontology_types:
general_distribution.append({"type": entity_type, "count": count})
general_total += count
else:
unmatched_distribution.append({"type": entity_type, "count": count})
unmatched_total += count
# 按数量降序排列
scene_distribution.sort(key=lambda x: x["count"], reverse=True)
general_distribution.sort(key=lambda x: x["count"], reverse=True)
unmatched_distribution.sort(key=lambda x: x["count"], reverse=True)
total_entities = scene_total + general_total + unmatched_total
return {
"scene_type_distribution": {
"type_count": len(scene_distribution),
"entity_total": scene_total,
"types": scene_distribution,
},
"general_type_distribution": {
"type_count": len(general_distribution),
"entity_total": general_total,
"types": general_distribution,
},
"unmatched": {
"type_count": len(unmatched_distribution),
"entity_total": unmatched_total,
"types": unmatched_distribution,
},
"total_entities": total_entities,
"time": int(time.time() * 1000),
}
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD)

View File

@@ -10,7 +10,9 @@ import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.config import settings
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
@@ -20,9 +22,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
@@ -31,6 +35,42 @@ from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
"""Write triplet/entity text report compatible with pipeline_help parsers."""
all_triplets = []
all_entities = []
for dialog in dialog_data_list:
for chunk in getattr(dialog, "chunks", []) or []:
for statement in getattr(chunk, "statements", []) or []:
triplet_info = getattr(statement, "triplet_extraction_info", None)
if not triplet_info:
continue
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
all_entities.extend(getattr(triplet_info, "entities", []) or [])
with open(output_path, "w", encoding="utf-8") as f:
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
for i, triplet in enumerate(all_triplets, 1):
f.write(f"Triplet {i}:\n")
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
f.write(f" Predicate: {triplet.predicate}\n")
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
value = getattr(triplet, "value", None)
if value:
f.write(f" Value: {value}\n")
f.write("\n")
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
for i, entity in enumerate(all_entities, 1):
f.write(f"Entity {i}:\n")
f.write(f" ID: {entity.entity_idx}\n")
f.write(f" Name: {entity.name}\n")
f.write(f" Type: {entity.type}\n")
f.write(f" Description: {entity.description}\n")
f.write("\n")
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
@@ -58,7 +98,6 @@ async def run_pilot_extraction(
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
@@ -69,8 +108,6 @@ async def run_pilot_extraction(
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
@@ -242,15 +279,17 @@ async def run_pilot_extraction(
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
use_refactored,
)
logger.info(
"Initializing %s pilot pipeline...",
"refactored" if use_refactored else "legacy",
)
step_start = time.time()
# 加载本体类型(如果配置了 scene_id支持通用类型回退
ontology_types = None
@@ -266,100 +305,105 @@ async def run_pilot_extraction(
except Exception as e:
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
if use_refactored:
pilot_pipeline = PilotWritePipeline(
llm_client=llm_client,
embedder_client=embedder_client,
pipeline_config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4a: 执行重构后试运行短链路
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
logger.info("Running refactored pilot extraction short pipeline...")
step_start = time.time()
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
pilot_result = await pilot_pipeline.run(chunked_dialogs)
dialog_data_list = pilot_result.dialog_data_list
graph = pilot_result.graph
chunk_nodes = graph.chunk_nodes
export_entity_nodes = graph.entity_nodes
export_stmt_entity_edges = graph.stmt_entity_edges
export_entity_edges = graph.entity_entity_edges
else:
# 步骤 4b: 执行旧试运行流水线
logger.info("Running legacy pilot extraction pipeline...")
step_start = time.time()
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
_,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
_,
_
) = extraction_result
neo4j_connector = Neo4jConnector()
try:
legacy_orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
extraction_result = await legacy_orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
(
_dialogue_nodes,
chunk_nodes,
_statement_nodes,
entity_nodes,
_perceptual_nodes,
_statement_chunk_edges,
statement_entity_edges,
entity_edges,
_perceptual_edges,
_last_created_at,
) = extraction_result
dialog_data_list = chunked_dialogs
export_entity_nodes = entity_nodes
export_stmt_entity_edges = statement_entity_edges
export_entity_edges = entity_edges
finally:
try:
await neo4j_connector.close()
except Exception:
pass
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致
try:
logger.info("Generating memory summaries...")
step_start = time.time()
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约
settings.ensure_memory_output_dir()
export_test_input_doc(
entity_nodes=export_entity_nodes,
statement_entity_edges=export_stmt_entity_edges,
entity_entity_edges=export_entity_edges,
)
_save_triplets_from_dialogs(
dialog_data_list=dialog_data_list,
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
)
_write_extracted_result_summary(
chunk_nodes=chunk_nodes,
pipeline_output_dir=settings.get_memory_output_path(),
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
language=language,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
# 将提取统计写入 Redis按 workspace_id 存储
try:
from app.cache.memory.activity_stats_cache import ActivityStatsCache
stats_to_cache = {
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
"statements_count": len(statement_nodes) if statement_nodes else 0,
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
"temporal_count": 0, # temporal 数据在日志中此处暂置0
}
await ActivityStatsCache.set_activity_stats(
workspace_id=str(memory_config.workspace_id),
stats=stats_to_cache,
)
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
except Exception as cache_err:
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)