refactor(memory): add PilotWritePipeline and enrich extraction schema
- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write) - Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders - Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction - Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor - Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing - Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
@@ -441,21 +441,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
# 步骤 6: 计算本体覆盖率并合并到结果中
|
||||
# 步骤 6: 组装结果(试运行不做额外覆盖率后处理)
|
||||
result_data = {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
try:
|
||||
ontology_coverage = await self._compute_ontology_coverage(
|
||||
extracted_result=extracted_result,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
if ontology_coverage:
|
||||
result_data["ontology_coverage"] = ontology_coverage
|
||||
except Exception as cov_err:
|
||||
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
|
||||
|
||||
yield format_sse_message("result", result_data)
|
||||
|
||||
@@ -479,100 +470,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
async def _compute_ontology_coverage(
|
||||
self,
|
||||
extracted_result: Dict[str, Any],
|
||||
memory_config,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
|
||||
|
||||
分类规则(互斥):场景类型优先 > 通用类型 > 未匹配
|
||||
确保: 场景实体数 + 通用实体数 + 未匹配数 = 总实体数
|
||||
|
||||
Returns:
|
||||
包含三部分统计的字典,或 None(无实体数据时)
|
||||
"""
|
||||
core_entities = extracted_result.get("core_entities", [])
|
||||
if not core_entities:
|
||||
return None
|
||||
|
||||
# 1. 加载场景本体类型集合
|
||||
scene_ontology_types: set = set()
|
||||
try:
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
if memory_config.scene_id:
|
||||
class_repo = OntologyClassRepository(self.db)
|
||||
ontology_classes = class_repo.get_classes_by_scene(memory_config.scene_id)
|
||||
scene_ontology_types = {oc.class_name for oc in ontology_classes}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load scene ontology types: {e}")
|
||||
|
||||
# 2. 加载通用本体类型集合
|
||||
general_ontology_types: set = set()
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import (
|
||||
get_general_ontology_registry,
|
||||
is_general_ontology_enabled,
|
||||
)
|
||||
|
||||
if is_general_ontology_enabled():
|
||||
registry = get_general_ontology_registry()
|
||||
if registry:
|
||||
general_ontology_types = set(registry.types.keys())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load general ontology types: {e}")
|
||||
|
||||
# 3. 互斥分类:场景优先 > 通用 > 未匹配
|
||||
scene_distribution: list = []
|
||||
general_distribution: list = []
|
||||
unmatched_distribution: list = []
|
||||
scene_total = 0
|
||||
general_total = 0
|
||||
unmatched_total = 0
|
||||
|
||||
for item in core_entities:
|
||||
entity_type = item.get("type", "")
|
||||
count = item.get("count", 0)
|
||||
|
||||
if entity_type in scene_ontology_types:
|
||||
scene_distribution.append({"type": entity_type, "count": count})
|
||||
scene_total += count
|
||||
elif entity_type in general_ontology_types:
|
||||
general_distribution.append({"type": entity_type, "count": count})
|
||||
general_total += count
|
||||
else:
|
||||
unmatched_distribution.append({"type": entity_type, "count": count})
|
||||
unmatched_total += count
|
||||
|
||||
# 按数量降序排列
|
||||
scene_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||||
general_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||||
unmatched_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||||
|
||||
total_entities = scene_total + general_total + unmatched_total
|
||||
|
||||
return {
|
||||
"scene_type_distribution": {
|
||||
"type_count": len(scene_distribution),
|
||||
"entity_total": scene_total,
|
||||
"types": scene_distribution,
|
||||
},
|
||||
"general_type_distribution": {
|
||||
"type_count": len(general_distribution),
|
||||
"entity_total": general_total,
|
||||
"types": general_distribution,
|
||||
},
|
||||
"unmatched": {
|
||||
"type_count": len(unmatched_distribution),
|
||||
"entity_total": unmatched_total,
|
||||
"types": unmatched_distribution,
|
||||
},
|
||||
"total_entities": total_entities,
|
||||
"time": int(time.time() * 1000),
|
||||
}
|
||||
|
||||
|
||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||||
|
||||
|
||||
@@ -10,7 +10,9 @@ import time
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
@@ -20,9 +22,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
|
||||
ExtractionOrchestrator,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_pipeline_config,
|
||||
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
||||
_write_extracted_result_summary,
|
||||
export_test_input_doc,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -31,6 +35,42 @@ from sqlalchemy.orm import Session
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
|
||||
"""Write triplet/entity text report compatible with pipeline_help parsers."""
|
||||
all_triplets = []
|
||||
all_entities = []
|
||||
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in getattr(dialog, "chunks", []) or []:
|
||||
for statement in getattr(chunk, "statements", []) or []:
|
||||
triplet_info = getattr(statement, "triplet_extraction_info", None)
|
||||
if not triplet_info:
|
||||
continue
|
||||
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
|
||||
all_entities.extend(getattr(triplet_info, "entities", []) or [])
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
|
||||
for i, triplet in enumerate(all_triplets, 1):
|
||||
f.write(f"Triplet {i}:\n")
|
||||
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
|
||||
f.write(f" Predicate: {triplet.predicate}\n")
|
||||
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
|
||||
value = getattr(triplet, "value", None)
|
||||
if value:
|
||||
f.write(f" Value: {value}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
|
||||
for i, entity in enumerate(all_entities, 1):
|
||||
f.write(f"Entity {i}:\n")
|
||||
f.write(f" ID: {entity.entity_idx}\n")
|
||||
f.write(f" Name: {entity.name}\n")
|
||||
f.write(f" Type: {entity.type}\n")
|
||||
f.write(f" Description: {entity.description}\n")
|
||||
f.write("\n")
|
||||
|
||||
|
||||
async def run_pilot_extraction(
|
||||
memory_config: MemoryConfig,
|
||||
dialogue_text: str,
|
||||
@@ -58,7 +98,6 @@ async def run_pilot_extraction(
|
||||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
neo4j_connector = None
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
@@ -69,8 +108,6 @@ async def run_pilot_extraction(
|
||||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 解析对话文本
|
||||
@@ -242,15 +279,17 @@ async def run_pilot_extraction(
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
config = get_pipeline_config(memory_config)
|
||||
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
|
||||
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
|
||||
logger.info(
|
||||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||||
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
|
||||
use_refactored,
|
||||
)
|
||||
logger.info(
|
||||
"Initializing %s pilot pipeline...",
|
||||
"refactored" if use_refactored else "legacy",
|
||||
)
|
||||
step_start = time.time()
|
||||
|
||||
# 加载本体类型(如果配置了 scene_id),支持通用类型回退
|
||||
ontology_types = None
|
||||
@@ -266,100 +305,105 @@ async def run_pilot_extraction(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
if use_refactored:
|
||||
pilot_pipeline = PilotWritePipeline(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
pipeline_config=get_pipeline_config(memory_config),
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
# 步骤 4a: 执行重构后试运行短链路
|
||||
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
|
||||
logger.info("Running refactored pilot extraction short pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
pilot_result = await pilot_pipeline.run(chunked_dialogs)
|
||||
dialog_data_list = pilot_result.dialog_data_list
|
||||
graph = pilot_result.graph
|
||||
chunk_nodes = graph.chunk_nodes
|
||||
export_entity_nodes = graph.entity_nodes
|
||||
export_stmt_entity_edges = graph.stmt_entity_edges
|
||||
export_entity_edges = graph.entity_entity_edges
|
||||
else:
|
||||
# 步骤 4b: 执行旧试运行流水线
|
||||
logger.info("Running legacy pilot extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
_,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
_,
|
||||
_
|
||||
) = extraction_result
|
||||
neo4j_connector = Neo4jConnector()
|
||||
try:
|
||||
legacy_orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=get_pipeline_config(memory_config),
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
extraction_result = await legacy_orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
(
|
||||
_dialogue_nodes,
|
||||
chunk_nodes,
|
||||
_statement_nodes,
|
||||
entity_nodes,
|
||||
_perceptual_nodes,
|
||||
_statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
_perceptual_edges,
|
||||
_last_created_at,
|
||||
) = extraction_result
|
||||
dialog_data_list = chunked_dialogs
|
||||
export_entity_nodes = entity_nodes
|
||||
export_stmt_entity_edges = statement_entity_edges
|
||||
export_entity_edges = entity_edges
|
||||
finally:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约)
|
||||
settings.ensure_memory_output_dir()
|
||||
export_test_input_doc(
|
||||
entity_nodes=export_entity_nodes,
|
||||
statement_entity_edges=export_stmt_entity_edges,
|
||||
entity_entity_edges=export_entity_edges,
|
||||
)
|
||||
_save_triplets_from_dialogs(
|
||||
dialog_data_list=dialog_data_list,
|
||||
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
|
||||
)
|
||||
_write_extracted_result_summary(
|
||||
chunk_nodes=chunk_nodes,
|
||||
pipeline_output_dir=settings.get_memory_output_path(),
|
||||
)
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
language=language,
|
||||
)
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||
|
||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
|
||||
stats_to_cache = {
|
||||
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
|
||||
"statements_count": len(statement_nodes) if statement_nodes else 0,
|
||||
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
|
||||
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
|
||||
"temporal_count": 0, # temporal 数据在日志中,此处暂置0
|
||||
}
|
||||
await ActivityStatsCache.set_activity_stats(
|
||||
workspace_id=str(memory_config.workspace_id),
|
||||
stats=stats_to_cache,
|
||||
)
|
||||
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if neo4j_connector:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||||
|
||||
Reference in New Issue
Block a user