- Delete ExtractionOrchestrator (~2500 lines) and write_tools legacy path; MemoryService/WritePipeline is now the sole write path - Remove NEW_PIPELINE_ENABLED feature flag from memory_agent_service - Simplify pilot_run_service to always use PilotWritePipeline - Add dialog_at field to statement and triplet extraction prompts as the primary reference time for resolving relative temporal expressions - Rewrite relative time phrases (e.g. 昨天, 下周) into concrete dates directly in statement_text when stably resolvable from dialog_at - Rename extracat_Pruning.jinja2 to extracat_pruning.jinja2; expand few-shot examples and update memory type enum (drop NULL, add agreement/repetition/other)
279 lines
13 KiB
Python
279 lines
13 KiB
Python
"""
|
||
Pilot Run Service - 试运行服务
|
||
|
||
用于执行记忆系统的试运行流程,不保存到 Neo4j。
|
||
|
||
职责边界:
|
||
- 文本解析、语义剪枝、语义分块(预处理)
|
||
- 调用 PilotWritePipeline 执行萃取链路
|
||
- 输出结果文件
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Awaitable, Callable, Optional
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging_config import get_memory_logger, log_time
|
||
from app.core.memory.models.message_models import (
|
||
ConversationContext,
|
||
ConversationMessage,
|
||
DialogData,
|
||
)
|
||
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
||
_write_extracted_result_summary,
|
||
export_test_input_doc,
|
||
)
|
||
from app.schemas.memory_config_schema import MemoryConfig
|
||
from sqlalchemy.orm import Session
|
||
|
||
logger = get_memory_logger(__name__)
|
||
|
||
|
||
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
|
||
"""Write triplet/entity text report compatible with pipeline_help parsers."""
|
||
all_triplets = []
|
||
all_entities = []
|
||
|
||
for dialog in dialog_data_list:
|
||
for chunk in getattr(dialog, "chunks", []) or []:
|
||
for statement in getattr(chunk, "statements", []) or []:
|
||
triplet_info = getattr(statement, "triplet_extraction_info", None)
|
||
if not triplet_info:
|
||
continue
|
||
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
|
||
all_entities.extend(getattr(triplet_info, "entities", []) or [])
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
|
||
for i, triplet in enumerate(all_triplets, 1):
|
||
f.write(f"Triplet {i}:\n")
|
||
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
|
||
f.write(f" Predicate: {triplet.predicate}\n")
|
||
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
|
||
value = getattr(triplet, "value", None)
|
||
if value:
|
||
f.write(f" Value: {value}\n")
|
||
f.write("\n")
|
||
|
||
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
|
||
for i, entity in enumerate(all_entities, 1):
|
||
f.write(f"Entity {i}:\n")
|
||
f.write(f" ID: {entity.entity_idx}\n")
|
||
f.write(f" Name: {entity.name}\n")
|
||
f.write(f" Type: {entity.type}\n")
|
||
f.write(f" Description: {entity.description}\n")
|
||
f.write("\n")
|
||
|
||
|
||
async def run_pilot_extraction(
|
||
memory_config: MemoryConfig,
|
||
dialogue_text: str,
|
||
db: Session,
|
||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
|
||
language: str = "zh",
|
||
) -> None:
|
||
"""执行试运行模式的知识提取流水线。
|
||
|
||
职责:
|
||
1. 文本解析 → 语义剪枝 → 语义分块(预处理,需要 llm_client)
|
||
2. 调用 PilotWritePipeline 执行萃取链路(Pipeline 自行管理客户端)
|
||
3. 将萃取结果写入输出文件
|
||
|
||
Args:
|
||
memory_config: 从数据库加载的内存配置对象
|
||
dialogue_text: 输入的对话文本
|
||
db: 数据库会话(用于初始化预处理所需的 LLM 客户端)
|
||
progress_callback: 可选的进度回调 (stage, message, data)
|
||
language: 语言类型 ("zh" | "en")
|
||
"""
|
||
log_file = "logs/time.log"
|
||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
with open(log_file, "a", encoding="utf-8") as f:
|
||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||
|
||
pipeline_start = time.time()
|
||
|
||
try:
|
||
# ── 步骤 1: 初始化预处理所需的 LLM 客户端 ──────────────────────────
|
||
# 只用于语义剪枝和分块,PilotWritePipeline 内部会自行初始化萃取客户端
|
||
step_start = time.time()
|
||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||
factory = MemoryClientFactory(db)
|
||
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
|
||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||
|
||
# ── 步骤 2: 文本解析 ────────────────────────────────────────────────
|
||
step_start = time.time()
|
||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||
messages = [
|
||
ConversationMessage(role=r, msg=c.strip())
|
||
for r, c in matches
|
||
if c.strip()
|
||
]
|
||
if not messages:
|
||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||
|
||
dialog = DialogData(
|
||
context=ConversationContext(msgs=messages),
|
||
ref_id="pilot_dialog_1",
|
||
end_user_id=str(memory_config.workspace_id),
|
||
user_id=str(memory_config.tenant_id),
|
||
apply_id=str(memory_config.config_id),
|
||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||
)
|
||
|
||
if progress_callback:
|
||
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
|
||
|
||
# ── 步骤 2.1: 语义剪枝 ─────────────────────────────────────────────
|
||
pruned_dialogs = [dialog]
|
||
pruning_stats: dict = {"enabled": False}
|
||
|
||
if memory_config.pruning_enabled:
|
||
try:
|
||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||
SemanticPruner,
|
||
)
|
||
from app.core.memory.models.config_models import PruningConfig
|
||
|
||
config = PruningConfig(
|
||
pruning_switch=memory_config.pruning_enabled,
|
||
pruning_scene=memory_config.pruning_scene,
|
||
pruning_threshold=memory_config.pruning_threshold,
|
||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||
ontology_class_infos=memory_config.ontology_class_infos,
|
||
)
|
||
original_msgs = [{"role": m.role, "content": m.msg} for m in dialog.context.msgs]
|
||
pruned_dialogs = await SemanticPruner(config=config, llm_client=llm_client).prune_dataset([dialog])
|
||
|
||
if pruned_dialogs and pruned_dialogs[0].context:
|
||
remaining = [{"role": m.role, "content": m.msg} for m in pruned_dialogs[0].context.msgs]
|
||
# 找出被删除的消息(顺序匹配)
|
||
kept_indices: list[int] = []
|
||
ri = 0
|
||
for oi, om in enumerate(original_msgs):
|
||
if ri < len(remaining) and om == remaining[ri]:
|
||
kept_indices.append(oi)
|
||
ri += 1
|
||
deleted_messages = [
|
||
{"index": i, "role": m["role"], "content": m["content"]}
|
||
for i, m in enumerate(original_msgs)
|
||
if i not in kept_indices
|
||
]
|
||
pruning_stats = {
|
||
"enabled": True,
|
||
"scene": config.pruning_scene,
|
||
"threshold": config.pruning_threshold,
|
||
"deleted_count": len(deleted_messages),
|
||
}
|
||
logger.info(
|
||
f"[PILOT_RUN] 语义剪枝完成: {len(original_msgs)} → {len(remaining)} 条"
|
||
f"(删除 {len(deleted_messages)} 条)"
|
||
)
|
||
if progress_callback:
|
||
await progress_callback(
|
||
"text_preprocessing_result", "语义剪枝完成",
|
||
{"type": "pruning", "deleted_messages": deleted_messages},
|
||
)
|
||
else:
|
||
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
|
||
pruned_dialogs = [dialog]
|
||
|
||
except Exception as e:
|
||
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
|
||
pruned_dialogs = [dialog]
|
||
if progress_callback:
|
||
await progress_callback(
|
||
"text_preprocessing_result", "语义剪枝失败",
|
||
{"type": "pruning", "error": str(e), "fallback": "使用原始对话"},
|
||
)
|
||
|
||
# ── 步骤 2.2: 语义分块 ─────────────────────────────────────────────
|
||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||
DialogueChunker,
|
||
)
|
||
chunked_dialogs = []
|
||
for dlg in pruned_dialogs:
|
||
dlg.chunks = await DialogueChunker(memory_config.chunker_strategy, llm_client=llm_client).process_dialogue(dlg)
|
||
chunked_dialogs.append(dlg)
|
||
|
||
if progress_callback:
|
||
for dlg in chunked_dialogs:
|
||
for i, chunk in enumerate(dlg.chunks or []):
|
||
await progress_callback(
|
||
"text_preprocessing_result", f"分块 {i + 1} 处理完成",
|
||
{
|
||
"type": "chunking",
|
||
"chunk_index": i + 1,
|
||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||
"full_length": len(chunk.content),
|
||
"dialog_id": dlg.id,
|
||
"chunker_strategy": memory_config.chunker_strategy,
|
||
},
|
||
)
|
||
await progress_callback(
|
||
"text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)",
|
||
{
|
||
"total_chunks": sum(len(dlg.chunks or []) for dlg in chunked_dialogs),
|
||
"total_dialogs": len(chunked_dialogs),
|
||
"chunker_strategy": memory_config.chunker_strategy,
|
||
"pruning": pruning_stats,
|
||
},
|
||
)
|
||
|
||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||
|
||
# ── 步骤 3: 萃取(PilotWritePipeline 自行管理客户端和本体加载)──────
|
||
step_start = time.time()
|
||
logger.info("Running pilot extraction pipeline...")
|
||
|
||
if progress_callback:
|
||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||
|
||
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
|
||
|
||
pilot_result = await PilotWritePipeline(
|
||
memory_config=memory_config,
|
||
end_user_id=str(memory_config.workspace_id),
|
||
language=language,
|
||
progress_callback=progress_callback,
|
||
).run(chunked_dialogs)
|
||
|
||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||
|
||
# ── 步骤 4: 输出结果文件 ────────────────────────────────────────────
|
||
if progress_callback:
|
||
await progress_callback("generating_results", "正在生成结果...")
|
||
|
||
graph = pilot_result.graph
|
||
settings.ensure_memory_output_dir()
|
||
export_test_input_doc(
|
||
entity_nodes=graph.entity_nodes,
|
||
statement_entity_edges=graph.stmt_entity_edges,
|
||
entity_entity_edges=graph.entity_entity_edges,
|
||
)
|
||
_save_triplets_from_dialogs(
|
||
dialog_data_list=pilot_result.dialog_data_list,
|
||
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
|
||
)
|
||
_write_extracted_result_summary(
|
||
chunk_nodes=graph.chunk_nodes,
|
||
pipeline_output_dir=settings.get_memory_output_path(),
|
||
)
|
||
logger.info("Pilot run completed: stop after layer-1 dedup (no Neo4j write)")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||
raise
|
||
|
||
total_time = time.time() - pipeline_start
|
||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||
with open(log_file, "a", encoding="utf-8") as f:
|
||
f.write(f"=== Pilot Run Completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n\n")
|
||
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")
|