""" Pilot Run Service - 试运行服务 用于执行记忆系统的试运行流程,不保存到 Neo4j。 """ import os import re import time from datetime import datetime from typing import Awaitable, Callable, Optional from app.core.logging_config import get_memory_logger, log_time from app.core.memory.models.message_models import ( ConversationContext, ConversationMessage, DialogData, ) from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( ExtractionOrchestrator, get_chunked_dialogs_from_preprocessed, ) from app.core.memory.utils.config.config_utils import ( get_pipeline_config, ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig from sqlalchemy.orm import Session logger = get_memory_logger(__name__) async def run_pilot_extraction( memory_config: MemoryConfig, dialogue_text: str, db: Session, progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None, language: str = "zh", ) -> None: """ 执行试运行模式的知识提取流水线。 Args: memory_config: 从数据库加载的内存配置对象 dialogue_text: 输入的对话文本 db: 数据库会话 progress_callback: 可选的进度回调函数 - 参数1 (stage): 当前处理阶段标识符 - 参数2 (message): 人类可读的进度消息 - 参数3 (data): 可选的附加数据字典 language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ log_file = "logs/time.log" os.makedirs(os.path.dirname(log_file), exist_ok=True) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(log_file, "a", encoding="utf-8") as f: f.write(f"\n=== Pilot Run Started: {timestamp} ===\n") pipeline_start = time.time() neo4j_connector = None try: # 步骤 1: 初始化客户端 logger.info("Initializing clients...") step_start = time.time() client_factory = MemoryClientFactory(db) llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id)) embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id)) neo4j_connector = Neo4jConnector() log_time("Client Initialization", time.time() - step_start, log_file) # 步骤 2: 解析对话文本 logger.info("Parsing dialogue text...") step_start = time.time() # 解析对话文本,支持 "用户:" 和 "AI:" 格式 pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)" matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL) messages = [ ConversationMessage(role=r, msg=c.strip()) for r, c in matches if c.strip() ] # 如果没有匹配到格式化的对话,将整个文本作为用户消息 if not messages: messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())] context = ConversationContext(msgs=messages) dialog = DialogData( context=context, ref_id="pilot_dialog_1", end_user_id=str(memory_config.workspace_id), user_id=str(memory_config.tenant_id), apply_id=str(memory_config.config_id), metadata={"source": "pilot_run", "input_type": "frontend_text"}, ) if progress_callback: await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...") # ========== 步骤 2.1: 语义剪枝 ========== pruned_dialogs = [dialog] deleted_messages = [] # 记录被删除的消息 pruning_stats = None # 保存剪枝统计信息,用于最终汇总 if memory_config.pruning_enabled: try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) from app.core.memory.models.config_models import PruningConfig # 构建剪枝配置 pruning_config_dict = { "pruning_switch": memory_config.pruning_enabled, "pruning_scene": memory_config.pruning_scene, "pruning_threshold": memory_config.pruning_threshold, "scene_id": str(memory_config.scene_id) if memory_config.scene_id else None, "ontology_class_infos": memory_config.ontology_class_infos, } config = PruningConfig(**pruning_config_dict) logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}") # 记录剪枝前的消息(用于对比) original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs] original_msg_count = len(original_messages) # 执行剪枝 pruner = SemanticPruner(config=config, llm_client=llm_client) pruned_dialogs = await pruner.prune_dataset([dialog]) # 计算剪枝结果并找出被删除的消息 if pruned_dialogs and pruned_dialogs[0].context: remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs] remaining_msg_count = len(remaining_messages) deleted_msg_count = original_msg_count - remaining_msg_count # 找出被删除的消息(基于索引精确匹配) # 为剩余消息创建带索引的列表,用于精确追踪 remaining_with_index = [] remaining_idx = 0 for orig_idx, orig_msg in enumerate(original_messages): if remaining_idx < len(remaining_messages) and \ orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \ orig_msg["content"] == remaining_messages[remaining_idx]["content"]: remaining_with_index.append(orig_idx) remaining_idx += 1 # 找出未在保留列表中的消息索引 deleted_messages = [ {"index": idx, "role": msg["role"], "content": msg["content"]} for idx, msg in enumerate(original_messages) if idx not in remaining_with_index ] # 保存剪枝统计信息(用于最终汇总,只保留deleted_count) pruning_stats = { "enabled": True, "scene": config.pruning_scene, "threshold": config.pruning_threshold, "deleted_count": deleted_msg_count, } # 输出剪枝结果(显示删除的消息详情) pruning_result = { "type": "pruning", "deleted_messages": deleted_messages, } logger.info( f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> " f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)" ) if progress_callback: await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result) else: logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话") pruned_dialogs = [dialog] except Exception as e: logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True) pruned_dialogs = [dialog] if progress_callback: error_result = { "type": "pruning", "error": str(e), "fallback": "使用原始对话" } await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result) else: logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过") pruning_stats = { "enabled": False, } # ========== 步骤 2.2: 语义分块 ========== chunked_dialogs = await get_chunked_dialogs_from_preprocessed( data=pruned_dialogs, chunker_strategy=memory_config.chunker_strategy, llm_client=llm_client, ) remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0 logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning") # 进度回调:输出每个分块的结果 if progress_callback: for dlg in chunked_dialogs: if hasattr(dlg, 'chunks') and dlg.chunks: for i, chunk in enumerate(dlg.chunks): chunk_result = { "type": "chunking", "chunk_index": i + 1, "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, "full_length": len(chunk.content), "dialog_id": dlg.id, "chunker_strategy": memory_config.chunker_strategy, } await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) # 构建预处理完成总结(包含剪枝统计) preprocessing_summary = { "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } # 添加剪枝统计信息(始终包含 pruning 字段,确保前端不会因字段缺失报错) preprocessing_summary["pruning"] = pruning_stats if pruning_stats else { "enabled": memory_config.pruning_enabled, "deleted_count": 0, } await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary) log_time("Data Loading & Chunking", time.time() - step_start, log_file) # 步骤 3: 初始化流水线编排器 logger.info("Initializing extraction orchestrator...") step_start = time.time() config = get_pipeline_config(memory_config) logger.info( f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, " f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}" ) # 加载本体类型(如果配置了 scene_id),支持通用类型回退 ontology_types = None try: from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_with_fallback ontology_types = load_ontology_types_with_fallback( scene_id=memory_config.scene_id, workspace_id=memory_config.workspace_id, db=db, enable_general_fallback=True ) except Exception as e: logger.warning(f"Failed to load ontology types: {e}", exc_info=True) orchestrator = ExtractionOrchestrator( llm_client=llm_client, embedder_client=embedder_client, connector=neo4j_connector, config=config, progress_callback=progress_callback, embedding_id=str(memory_config.embedding_model_id), language=language, ontology_types=ontology_types, ) log_time("Orchestrator Initialization", time.time() - step_start, log_file) # 步骤 4: 执行知识提取流水线 logger.info("Running extraction pipeline...") step_start = time.time() if progress_callback: await progress_callback("knowledge_extraction", "正在知识抽取...") extraction_result = await orchestrator.run( dialog_data_list=chunked_dialogs, is_pilot_run=True, ) # 解包 extraction_result tuple (与 main.py 保持一致) ( dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes, statement_chunk_edges, statement_entity_edges, entity_edges, ) = extraction_result log_time("Extraction Pipeline", time.time() - step_start, log_file) if progress_callback: await progress_callback("generating_results", "正在生成结果...") # 步骤 5: 生成记忆摘要(与 main.py 保持一致) try: logger.info("Generating memory summaries...") step_start = time.time() from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( memory_summary_generation, ) summaries = await memory_summary_generation( chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language, ) log_time("Memory Summary Generation", time.time() - step_start, log_file) except Exception as e: logger.error(f"Memory summary step failed: {e}", exc_info=True) logger.info("Pilot run completed: Skipping Neo4j save") # 将提取统计写入 Redis,按 workspace_id 存储 try: from app.cache.memory.activity_stats_cache import ActivityStatsCache stats_to_cache = { "chunk_count": len(chunk_nodes) if chunk_nodes else 0, "statements_count": len(statement_nodes) if statement_nodes else 0, "triplet_entities_count": len(entity_nodes) if entity_nodes else 0, "triplet_relations_count": len(entity_edges) if entity_edges else 0, "temporal_count": 0, # temporal 数据在日志中,此处暂置0 } await ActivityStatsCache.set_activity_stats( workspace_id=str(memory_config.workspace_id), stats=stats_to_cache, ) logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") except Exception as cache_err: logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) except Exception as e: logger.error(f"Pilot run failed: {e}", exc_info=True) raise finally: if neo4j_connector: try: await neo4j_connector.close() except Exception: pass total_time = time.time() - pipeline_start log_time("TOTAL PILOT RUN TIME", total_time, log_file) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(log_file, "a", encoding="utf-8") as f: f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n") logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")