diff --git a/api/app/core/config.py b/api/app/core/config.py index 56a07f3f..615f5d98 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -272,6 +272,12 @@ class Settings: MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") + # Pilot run pipeline switch: + # true -> use refactored PilotWritePipeline + # false -> use legacy ExtractionOrchestrator pipeline + PILOT_RUN_USE_REFACTORED_PIPELINE: bool = ( + os.getenv("PILOT_RUN_USE_REFACTORED_PIPELINE", "true").lower() == "true" + ) # Tool Management Configuration TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 3cd1fa0a..1180f367 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -9,7 +9,8 @@ async def get_chunked_dialogs( end_user_id: str = "group_1", messages: list = None, ref_id: str = "", - config_id: str = None + config_id: str = None, + snapshot=None, ) -> List[DialogData]: """Generate chunks from structured messages using the specified chunker strategy. @@ -19,6 +20,7 @@ async def get_chunked_dialogs( messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier config_id: Configuration ID for processing (used to load pruning config) + snapshot: Optional PipelineSnapshot instance for saving pruning output Returns: List of DialogData objects with generated chunks @@ -93,7 +95,7 @@ async def get_chunked_dialogs( llm_client = factory.get_llm_client_from_config(memory_config) # 执行剪枝 - 使用 prune_dataset 支持消息级剪枝 - pruner = SemanticPruner(config=pruning_config, llm_client=llm_client) + pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot) original_msg_count = len(dialog_data.context.msgs) # 使用 prune_dataset 而不是 prune_dialog diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 473e9189..d4eff79e 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -184,7 +184,8 @@ async def write( "entities": [ { "entity_idx": e.entity_idx, "name": e.name, - "type": e.type, "description": e.description, + "type": e.type, "type_description": getattr(e, "type_description", ""), + "description": e.description, "is_explicit_memory": getattr(e, "is_explicit_memory", False), } for e in s.triplet_extraction_info.entities @@ -193,6 +194,7 @@ async def write( { "subject_name": t.subject_name, "subject_id": t.subject_id, "predicate": t.predicate, + "predicate_description": getattr(t, "predicate_description", ""), "object_name": t.object_name, "object_id": t.object_id, } for t in s.triplet_extraction_info.triplets @@ -206,13 +208,13 @@ async def write( "chunk_nodes_count": len(all_chunk_nodes), "statement_nodes_count": len(all_statement_nodes), "entity_nodes": [ - {"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description} + {"id": e.id, "name": e.name, "entity_type": e.entity_type, "type_description": e.type_description, "description": e.description} for e in all_entity_nodes ], "entity_entity_edges": [ { "source": e.source, "target": e.target, - "relation_type": e.relation_type, "statement": e.statement, + "relation_type": e.relation_type, "relation_type_description": e.relation_type_description, "statement": e.statement, } for e in all_entity_entity_edges ], diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 6e34421c..2248ce05 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -162,6 +162,7 @@ class EntityEntityEdge(Edge): invalid_at: Optional end date of temporal validity """ relation_type: str = Field(..., description="Relation type as defined in ontology") + relation_type_description: str = Field(default="", description="Chinese definition of the relation type from ontology") relation_value: Optional[str] = Field(None, description="Value of the relation") statement: str = Field(..., description='The statement of the edge.') source_statement_id: str = Field(..., description="Statement where this relationship was extracted") @@ -413,6 +414,7 @@ class ExtractedEntityNode(Node): entity_idx: int = Field(..., description="Unique identifier for the entity") statement_id: str = Field(..., description="Statement this entity was extracted from") entity_type: str = Field(..., description="Type of the entity") + type_description: str = Field(default="", description="Chinese definition of the entity type from ontology") description: str = Field(..., description="Entity description") example: str = Field( default="", diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 67d274c7..64d25601 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -96,6 +96,10 @@ class Statement(BaseModel): emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name") # Reference resolution has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references") + has_emotional_state: bool = Field( + False, + description="Whether the statement reflects user's emotional state", + ) class ConversationContext(BaseModel): diff --git a/api/app/core/memory/models/triplet_models.py b/api/app/core/memory/models/triplet_models.py index df7ee14b..fbedc978 100644 --- a/api/app/core/memory/models/triplet_models.py +++ b/api/app/core/memory/models/triplet_models.py @@ -37,6 +37,7 @@ class Entity(BaseModel): name: str = Field(..., description="Name of the entity") name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name") type: str = Field(..., description="Type/category of the entity") + type_description: str = Field(default="", description="Chinese definition of the entity type from ontology") description: str = Field(..., description="Description of the entity") example: str = Field( default="", @@ -79,6 +80,7 @@ class Triplet(BaseModel): subject_name: str = Field(..., description="Name of the subject entity") subject_id: int = Field(..., description="ID of the subject entity") predicate: str = Field(..., description="Relationship/predicate between subject and object") + predicate_description: str = Field(default="", description="Chinese definition of the predicate from ontology") object_name: str = Field(..., description="Name of the object entity") object_id: int = Field(..., description="ID of the object entity") value: Optional[str] = Field(None, description="Additional value or context") diff --git a/api/app/core/memory/pipelines/__init__.py b/api/app/core/memory/pipelines/__init__.py index 8da9b28d..6471d9c1 100644 --- a/api/app/core/memory/pipelines/__init__.py +++ b/api/app/core/memory/pipelines/__init__.py @@ -14,13 +14,31 @@ def __getattr__(name): WritePipeline, WriteResult, ) + _exports = { "WritePipeline": WritePipeline, "ExtractionResult": ExtractionResult, "WriteResult": WriteResult, } return _exports[name] + if name in ("PilotWritePipeline", "PilotWriteResult"): + from app.core.memory.pipelines.pilot_write_pipeline import ( + PilotWritePipeline, + PilotWriteResult, + ) + + _exports = { + "PilotWritePipeline": PilotWritePipeline, + "PilotWriteResult": PilotWriteResult, + } + return _exports[name] raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -__all__ = ["WritePipeline", "ExtractionResult", "WriteResult"] +__all__ = [ + "WritePipeline", + "ExtractionResult", + "WriteResult", + "PilotWritePipeline", + "PilotWriteResult", +] diff --git a/api/app/core/memory/pipelines/pilot_write_pipeline.py b/api/app/core/memory/pipelines/pilot_write_pipeline.py new file mode 100644 index 00000000..0465b66e --- /dev/null +++ b/api/app/core/memory/pipelines/pilot_write_pipeline.py @@ -0,0 +1,108 @@ +"""PilotWritePipeline — 试运行专用萃取流水线。 + +职责边界: +- 只执行“萃取相关”链路:statement -> triplet -> graph_build -> 第一层去重消歧 +- 不负责 Neo4j 写入、聚类、摘要、缓存更新 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, List, Optional + +from app.core.memory.models.message_models import DialogData +from app.core.memory.models.variate_config import ExtractionPipelineConfig +from app.core.memory.storage_services.extraction_engine.dedup_step import ( + DedupResult, + run_dedup, +) +from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import ( + NewExtractionOrchestrator, +) +from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import ( + GraphBuildResult, + build_graph_nodes_and_edges, +) + + +@dataclass +class PilotWriteResult: + """试运行流水线输出。""" + + dialog_data_list: List[DialogData] + graph: GraphBuildResult + dedup: DedupResult + + @property + def stats(self) -> Dict[str, int]: + return { + "chunk_count": len(self.graph.chunk_nodes), + "statement_count": len(self.graph.statement_nodes), + "entity_count_before_dedup": len(self.graph.entity_nodes), + "entity_count_after_dedup": len(self.dedup.entity_nodes), + "relation_count_before_dedup": len(self.graph.entity_entity_edges), + "relation_count_after_dedup": len(self.dedup.entity_entity_edges), + } + + +class PilotWritePipeline: + """重构后试运行专用流水线。""" + + def __init__( + self, + llm_client: Any, + embedder_client: Any, + pipeline_config: ExtractionPipelineConfig, + embedding_id: Optional[str], + language: str = "zh", + ontology_types: Any = None, + progress_callback: Optional[ + Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]] + ] = None, + ) -> None: + self.llm_client = llm_client + self.embedder_client = embedder_client + self.pipeline_config = pipeline_config + self.embedding_id = embedding_id + self.language = language + self.ontology_types = ontology_types + self.progress_callback = progress_callback + + async def run(self, dialog_data_list: List[DialogData]) -> PilotWriteResult: + """执行试运行萃取链路。""" + orchestrator = NewExtractionOrchestrator( + llm_client=self.llm_client, + embedder_client=self.embedder_client, + config=self.pipeline_config, + embedding_id=self.embedding_id, + ontology_types=self.ontology_types, + language=self.language, + is_pilot_run=True, + progress_callback=self.progress_callback, + ) + extracted_dialogs = await orchestrator.run(dialog_data_list) + + graph = await build_graph_nodes_and_edges( + dialog_data_list=extracted_dialogs, + embedder_client=self.embedder_client, + progress_callback=self.progress_callback, + ) + + dedup = await run_dedup( + entity_nodes=graph.entity_nodes, + statement_entity_edges=graph.stmt_entity_edges, + entity_entity_edges=graph.entity_entity_edges, + dialog_data_list=extracted_dialogs, + pipeline_config=self.pipeline_config, + connector=None, # pilot: no layer-2 db dedup + llm_client=self.llm_client, + is_pilot_run=True, + progress_callback=self.progress_callback, + ) + + return PilotWriteResult( + dialog_data_list=extracted_dialogs, + graph=graph, + dedup=dedup, + ) + diff --git a/api/app/core/memory/pipelines/write_pipeline.py b/api/app/core/memory/pipelines/write_pipeline.py index 180a70cf..a68798db 100644 --- a/api/app/core/memory/pipelines/write_pipeline.py +++ b/api/app/core/memory/pipelines/write_pipeline.py @@ -180,7 +180,11 @@ class WritePipeline: self._init_clients() self._init_neo4j_connector() - # Step 1: 预处理 - 消息分块 + AI消息语义剪枝(暂无实现) + # 初始化 Snapshot(提前创建,供预处理阶段的剪枝使用) + from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot + self._snapshot = PipelineSnapshot("new") + + # Step 1: 预处理 - 消息分块 + AI消息语义剪枝 step_start = time.time() chunked_dialogs = await self._preprocess(messages, ref_id) chunks_count = sum(len(d.chunks) for d in chunked_dialogs) @@ -220,7 +224,7 @@ class WritePipeline: ) # Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在) - self._extract_emotion(getattr(self, "_emotion_statements", [])) + await self._extract_emotion(getattr(self, "_emotion_statements", [])) # Step 4: 聚类 - 增量更新社区(异步,不阻塞) step_start = time.time() @@ -266,7 +270,7 @@ class WritePipeline: async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]: """ - 预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。 + 预处理:消息校验 → AI消息语义剪枝 → 对话分块。 委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。 get_dialogs.py 内部已包含: @@ -276,12 +280,15 @@ class WritePipeline: """ from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs + snapshot = getattr(self, "_snapshot", None) + return await get_chunked_dialogs( chunker_strategy=self.memory_config.chunker_strategy, end_user_id=self.end_user_id, messages=messages, ref_id=ref_id, config_id=str(self.memory_config.config_id), + snapshot=snapshot, ) # ────────────────────────────────────────────── @@ -321,7 +328,9 @@ class WritePipeline: pipeline_config = get_pipeline_config(self.memory_config) ontology_types = self._load_ontology_types() - snapshot = PipelineSnapshot("new") + # 复用 run() 中已创建的 snapshot(剪枝阶段已使用同一实例) + snapshot = getattr(self, "_snapshot", None) or PipelineSnapshot("new") + self._snapshot = snapshot # ── 新编排器:LLM 萃取 + 数据赋值 ── new_orchestrator = NewExtractionOrchestrator( @@ -589,11 +598,15 @@ class WritePipeline: # fire-and-forget 提交 Celery 任务,不阻塞主流程 # ────────────────────────────────────────────── - def _extract_emotion(self, emotion_statements: list) -> None: + async def _extract_emotion(self, emotion_statements: list) -> None: """提交异步情绪提取 Celery 任务。 从编排器收集的 user statement 列表中提取情绪, 异步回写到 Neo4j Statement 节点。失败不影响主流程。 + + 在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径 + 通过 snapshot_dir 透传给 Celery 任务;worker 端在完成 LLM 抽取后, + 将结果落盘到 /4_emotion_outputs.json,避免主进程重复调用 LLM。 """ if not emotion_statements: return @@ -607,6 +620,14 @@ class WritePipeline: logger.warning("[Emotion] 无法提交情绪提取任务:llm_model_id 为空") return + # 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘 + snapshot = getattr(self, "_snapshot", None) + snapshot_dir = ( + snapshot.directory + if snapshot is not None and getattr(snapshot, "enabled", False) + else None + ) + try: from app.celery_app import celery_app @@ -616,12 +637,14 @@ class WritePipeline: "statements": emotion_statements, "llm_model_id": llm_model_id, "language": self.language, + "snapshot_dir": snapshot_dir, }, ) logger.info( f"[Emotion] 异步情绪提取任务已提交 - " f"task_id={result.id}, " f"statement_count={len(emotion_statements)}, " + f"snapshot_dir={snapshot_dir}, " f"source=async" ) except Exception as e: @@ -629,6 +652,7 @@ class WritePipeline: f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}", exc_info=True, ) + # ────────────────────────────────────────────── # Step 5: 摘要 # (+ entity_description)+ meta_data部分在此提取 diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 5390197a..4933c286 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -1,952 +1,450 @@ """ -语义剪枝器 - 在预处理与分块之间过滤与场景不相关内容 +Assistant 消息语义剪枝器 功能: -- 对话级一次性抽取判定相关性 -- 仅对"不相关对话"的消息按比例删除 -- 重要信息(时间、编号、金额、联系方式、地址等)优先保留 -- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化 +- 将对话拆分为 User-Assistant 消息对 +- 对每个消息对,调用 LLM 从 Assistant 消息中提取记忆摘要 +- 若 Assistant 消息无记忆价值(hint=NULL),则删除该 Assistant 消息 +- 若有记忆价值,用压缩后的 assistant_memory_hint 替换原始冗长回复 +- User 消息始终保留,不做任何修改 +- 支持并发 LLM 调用、LRU 缓存、重试与降级 """ import asyncio -import logging -import os import hashlib import json -import re +import logging from collections import OrderedDict -from datetime import datetime -from typing import List, Optional, Dict, Tuple, Set +from typing import List, Optional, Dict + from pydantic import BaseModel, Field -from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext +from app.core.memory.models.message_models import ( + DialogData, + ConversationMessage, + ConversationContext, +) from app.core.memory.models.config_models import PruningConfig -from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering -from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import ( - SceneConfigRegistry, - ScenePatterns +from app.core.memory.utils.prompt.prompt_utils import ( + prompt_env, + log_prompt_rendering, + log_template_rendering, ) logger = logging.getLogger(__name__) def message_has_files(message: "ConversationMessage") -> bool: - """检查消息是否包含文件。 - - Args: - message: 待检查的消息对象 - - Returns: - bool: 如果消息包含文件则返回 True,否则返回 False - """ + """检查消息是否包含文件。""" return message.files and len(message.files) > 0 -class DialogExtractionResponse(BaseModel): - """对话级一次性抽取的结构化返回,用于加速剪枝。 +class AssistantPruningResponse(BaseModel): + """LLM 对单个 User-Assistant 消息对的剪枝结果。 - - is_related:对话与场景的相关性判定。 - - times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。 - - preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。 - - scene_unrelated_snippets:与当前场景无关且无语义关联的消息片段(原文截取), - 用于高阈值阶段精准删除跨场景内容。 + - assistant_memory_hint: 从 Assistant 消息中提取的极短辅助摘要,无价值时为 "NULL" + - assistant_memory_type: 摘要类型枚举,无价值时为 "NULL" """ - is_related: bool = Field(...) - times: List[str] = Field(default_factory=list) - ids: List[str] = Field(default_factory=list) - amounts: List[str] = Field(default_factory=list) - contacts: List[str] = Field(default_factory=list) - addresses: List[str] = Field(default_factory=list) - keywords: List[str] = Field(default_factory=list) - preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留") - scene_unrelated_snippets: List[str] = Field(default_factory=list,description="与当前场景无关且无语义关联的消息原文片段,高阈值阶段用于精准删除跨场景内容") - -class MessageImportanceResponse(BaseModel): - """消息重要性批量判断的结构化返回(用于LLM语义判断)。 - - - importance_scores: 消息索引到重要性分数的映射 (0-10分) - - reasons: 可选的判断理由 - """ - importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射") - reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由") - - -class QAPair(BaseModel): - """问答对模型,用于识别和保护对话中的问答结构。""" - question_idx: int = Field(..., description="问题消息的索引") - answer_idx: int = Field(..., description="答案消息的索引") - confidence: float = Field(default=1.0, description="问答对的置信度(0-1)") + assistant_memory_hint: str = Field( + ..., description="从 Assistant 消息提取的记忆摘要,或 'NULL'" + ) + assistant_memory_type: str = Field( + ..., + description="comfort | suggestion | recommendation | warning | instruction | NULL", + ) class SemanticPruner: - """语义剪枝:在预处理与分块之间过滤与场景不相关内容。 + """Assistant 消息语义剪枝器。 - 采用对话级一次性抽取判定相关性;仅对"不相关对话"的消息按比例删除, - 重要信息(时间、编号、金额、联系方式、地址等)优先保留。 + 将对话拆分为 User-Assistant 消息对,通过 LLM 判断 Assistant 消息的记忆价值: + - 有价值:用压缩摘要替换原始 Assistant 消息 + - 无价值(NULL):删除该 Assistant 消息 + - User 消息始终保留 """ - def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5): - # 如果没有提供config,使用默认配置 + def __init__( + self, + config: Optional[PruningConfig] = None, + llm_client=None, + language: str = "zh", + max_concurrent: int = 5, + snapshot=None, + ): if config is None: - # 使用默认的剪枝配置 config = PruningConfig( - pruning_switch=False, # 默认关闭剪枝,保持向后兼容 + pruning_switch=False, pruning_scene="education", - pruning_threshold=0.5 + pruning_threshold=0.5, ) - + self.config = config self.llm_client = llm_client - self.language = language # 保存语言配置 - self.max_concurrent = max_concurrent # 新增:最大并发数 - - # 详细日志配置:限制逐条消息日志的数量 - self._detailed_prune_logging = True # 是否启用详细日志 - self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志 - - # 加载统一填充词库 - self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene) - - # 本体类型列表:直接使用 ontology_class_infos(name + description) - self._ontology_class_infos = getattr(self.config, "ontology_class_infos", None) or [] - # _ontology_classes 仅用于日志统计 - self._ontology_classes = [info.class_name for info in self._ontology_class_infos] - - self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}") - if self._ontology_class_infos: - self._log(f"[剪枝-初始化] 注入本体类型({len(self._ontology_class_infos)}个): {self._ontology_classes}") - else: - self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词") - - # Load Jinja2 template + self.language = language + self.max_concurrent = max_concurrent + self._snapshot = snapshot # PipelineSnapshot 实例,用于输出剪枝快照 + + # 加载 Jinja2 模板 self.template = prompt_env.get_template("extracat_Pruning.jinja2") - - # 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存 - self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict() - self._cache_max_size = 1000 # 缓存大小限制 - - # 运行日志:收集关键终端输出,便于写入 JSON + + # LRU 缓存:避免对相同消息对重复调用 LLM + self._cache: OrderedDict[str, AssistantPruningResponse] = OrderedDict() + self._cache_max_size = 1000 + + # Snapshot 数据收集:每个消息对的 input + gold + self._snapshot_records: List[Dict] = [] + + # 运行日志 self.run_logs: List[str] = [] - # _is_important_message 和 _importance_score 已移除: - # 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。 - # LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。 - - def _is_filler_message(self, message: ConversationMessage) -> bool: - """检测典型寒暄/口头禅/确认类短消息。 - - 判断顺序: - 1. 空消息 - 2. 场景特定填充词库精确匹配 - 3. 常见寒暄精确匹配 - 4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了") - 5. 纯表情/标点 - """ - t = message.msg.strip() - if not t: - return True - - # 检查是否在场景特定填充词库中(精确匹配) - if t in self.scene_config.filler_phrases: - return True - - # 常见寒暄和问候(精确匹配,避免误删) - common_greetings = { - "在吗", "在不在", "在呢", "在的", - "你好", "您好", "hello", "hi", - "拜拜", "再见", "拜", "88", "bye", - "好的", "好", "行", "可以", "嗯", "哦", "啊", - "是的", "对", "对的", "没错", "是啊", - "哈哈", "呵呵", "嘿嘿", "嗯嗯" - } - if t in common_greetings: - return True - - # 组合寒暄模式:短消息(≤15字)且完全由寒暄成分构成 - # 策略:将消息拆分后,每个片段都能在填充词库或常见寒暄中找到,则整体为填充 - if len(t) <= 15: - # 确认+称呼/感谢组合,如"好的谢谢"、"明白了"、"知道了谢谢" - _confirm_prefixes = {"好的", "好", "嗯", "嗯嗯", "哦", "明白", "明白了", "知道了", "了解", "收到", "没问题"} - _thanks_suffixes = {"谢谢", "谢谢你", "谢谢您", "多谢", "感谢", "谢了"} - _greeting_suffixes = {"你好", "您好", "老师好", "同学好", "大家好"} - _greeting_prefixes = {"同学", "老师", "您好", "你好"} - _close_patterns = { - "没有了", "没事了", "没问题了", "好了", "行了", "可以了", - "不用了", "不需要了", "就这样", "就这样吧", "那就这样", - } - _polite_responses = { - "不客气", "不用谢", "没关系", "没事", "应该的", "这是我应该做的", - } - - # 规则1:确认词 + 感谢词(如"好的谢谢"、"嗯谢谢") - for cp in _confirm_prefixes: - for ts in _thanks_suffixes: - if t == cp + ts or t == cp + "," + ts or t == cp + "," + ts: - return True - - # 规则2:称呼前缀 + 问候(如"同学你好"、"老师好") - for gp in _greeting_prefixes: - for gs in _greeting_suffixes: - if t == gp + gs or t.startswith(gp) and t.endswith("好"): - return True - - # 规则3:结束语 + 感谢(如"没有了,谢谢老师"、"没有了谢谢") - for cp in _close_patterns: - if t.startswith(cp): - remainder = t[len(cp):].lstrip(",,、 ") - if not remainder or any(remainder.startswith(ts) for ts in _thanks_suffixes): - return True - - # 规则4:礼貌回应(如"不客气,祝你考试顺利"——前缀是礼貌词,后半是祝福套话) - for pr in _polite_responses: - if t.startswith(pr): - remainder = t[len(pr):].lstrip(",,、 ") - # 后半是祝福/套话(不含实质信息) - if not remainder or re.match(r"^(祝|希望|期待|加油|顺利|好好|保重)", remainder): - return True - - # 规则5:纯确认词加"了"后缀(如"明白了"、"知道了"、"好了") - _confirm_base = {"明白", "知道", "了解", "收到", "好", "行", "可以", "没问题"} - for cb in _confirm_base: - if t == cb + "了" or t == cb + "了。" or t == cb + "了!": - return True - - # 检查是否为纯表情符号(方括号包裹) - if re.fullmatch(r"(\[[^\]]+\])+", t): - return True - - # 纯标点符号 - if re.fullmatch(r"[。!?,.!?…·\s]+", t): - return True - - return False - - async def _batch_evaluate_importance_with_llm( - self, - messages: List[ConversationMessage], - context: str = "" - ) -> Dict[int, int]: - """使用LLM批量评估消息的重要性(语义层面)。 - - Args: - messages: 消息列表 - context: 对话上下文(可选) - - Returns: - 消息索引到重要性分数(0-10)的映射 - """ - if not self.llm_client or not messages: - return {} - - # 构建批量评估的提示词 - msg_list = [] - for idx, msg in enumerate(messages): - msg_list.append(f"{idx}. {msg.msg}") - - msg_text = "\n".join(msg_list) - - prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分): -- 0-2分:无意义的寒暄、口头禅、纯表情 -- 3-5分:一般性对话,有一定信息量但不关键 -- 6-8分:包含重要信息(时间、地点、人物、事件等) -- 9-10分:关键决策、承诺、重要数据 - -对话上下文: -{context if context else "无"} - -待评估的消息: -{msg_text} - -请以JSON格式返回,格式为: -{{ - "importance_scores": {{ - "0": 分数, - "1": 分数, - ... - }} -}} -""" - - try: - messages_for_llm = [ - {"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"}, - {"role": "user", "content": prompt} - ] - - response = await self.llm_client.response_structured( - messages_for_llm, - MessageImportanceResponse - ) - - # 转换字符串键为整数键 - return {int(k): v for k, v in response.importance_scores.items()} - except Exception as e: - self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}") - return {} - - def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]: - """识别对话中的问答对,用于保护问答结构的完整性。 - - 改进版:使用场景特定的问句关键词,并排除寒暄类问句 - - Args: - messages: 消息列表 - - Returns: - 问答对列表 - """ - qa_pairs = [] - - # 寒暄类问句,不应该被保护(这些不是真正的问答) - greeting_questions = { - "在吗", "在不在", "你好吗", "怎么样", "好吗", - "有空吗", "忙吗", "睡了吗", "起床了吗" - } - - for i in range(len(messages) - 1): - current_msg = messages[i].msg.strip() - next_msg = messages[i + 1].msg.strip() - - # 排除寒暄类问句 - if current_msg in greeting_questions: - continue - - # 使用场景特定的问句关键词,但要求更严格 - is_question = False - - # 1. 以问号结尾 - if current_msg.endswith("?") or current_msg.endswith("?"): - is_question = True - # 2. 包含实质性问句关键词(排除"吗"这种太宽泛的) - elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时"]): - is_question = True - - if is_question and next_msg: - # 检查下一条消息是否像答案(不是另一个问句,也不是寒暄) - is_answer = not (next_msg.endswith("?") or next_msg.endswith("?")) - - # 排除寒暄类回复 - greeting_answers = {"你好", "您好", "在呢", "在的", "嗯", "哦", "好的"} - if next_msg in greeting_answers: - is_answer = False - - if is_answer: - qa_pairs.append(QAPair( - question_idx=i, - answer_idx=i + 1, - confidence=0.8 # 基于规则的置信度 - )) - - return qa_pairs - - def _get_protected_indices( - self, - messages: List[ConversationMessage], - qa_pairs: List[QAPair], - window_size: int = 2 - ) -> Set[int]: - """获取需要保护的消息索引集合(问答对+上下文窗口)。 - - Args: - messages: 消息列表 - qa_pairs: 问答对列表 - window_size: 上下文窗口大小(前后各保留几条消息) - - Returns: - 需要保护的消息索引集合 - """ - protected = set() - - for qa_pair in qa_pairs: - # 保护问答对本身 - protected.add(qa_pair.question_idx) - protected.add(qa_pair.answer_idx) - - # 保护上下文窗口 - for offset in range(-window_size, window_size + 1): - q_idx = qa_pair.question_idx + offset - a_idx = qa_pair.answer_idx + offset - - if 0 <= q_idx < len(messages): - protected.add(q_idx) - if 0 <= a_idx < len(messages): - protected.add(a_idx) - - return protected - - async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse: - """对话级一次性抽取:从整段对话中提取重要信息并判定相关性。 - - 改进版: - - LRU缓存管理 - - 重试机制 - - 降级策略 - """ - # 缓存命中则直接返回(场景+内容作为键) - cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest() - - # LRU缓存:如果命中,移到末尾(最近使用) - if cache_key in self._dialog_extract_cache: - self._dialog_extract_cache.move_to_end(cache_key) - return self._dialog_extract_cache[cache_key] - - # LRU缓存大小限制:超过限制时删除最旧的条目 - if len(self._dialog_extract_cache) >= self._cache_max_size: - # 删除最旧的条目(OrderedDict的第一个) - oldest_key = next(iter(self._dialog_extract_cache)) - del self._dialog_extract_cache[oldest_key] - self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目") - - rendered = self.template.render( - pruning_scene=self.config.pruning_scene, - ontology_class_infos=self._ontology_class_infos, - dialog_text=dialog_text, - language=self.language - ) - log_template_rendering("extracat_Pruning.jinja2", { - "pruning_scene": self.config.pruning_scene, - "ontology_class_infos_count": len(self._ontology_class_infos), - "language": self.language - }) - log_prompt_rendering("pruning-extract", rendered) - - # 强制使用 LLM - if not self.llm_client: - raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。") - - messages = [ - {"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"}, - {"role": "user", "content": rendered}, - ] - - # 重试机制 - max_retries = 3 - for attempt in range(max_retries): - try: - ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) - self._dialog_extract_cache[cache_key] = ex - return ex - except Exception as e: - if attempt < max_retries - 1: - self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}") - await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避 - continue - else: - # 降级策略:标记为相关,避免误删 - self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)") - fallback_response = DialogExtractionResponse( - is_related=True, - times=[], - ids=[], - amounts=[], - contacts=[], - addresses=[], - keywords=[] - ) - return fallback_response - - def _get_pruning_mode(self) -> str: - """根据 pruning_threshold 返回当前剪枝阶段。 - - - 低阈值 [0.0, 0.3):conservative 只删填充,保留所有实质内容 - - 中阈值 [0.3, 0.6):semantic 保留场景相关 + 有语义关联的内容,删除无关联内容 - - 高阈值 [0.6, 0.9]:strict 只保留场景相关内容,跨场景内容可被删除 - """ - t = float(self.config.pruning_threshold) - if t < 0.3: - return "conservative" - elif t < 0.6: - return "semantic" - else: - return "strict" - - def _apply_related_dialog_pruning( - self, - msgs: List[ConversationMessage], - extraction: "DialogExtractionResponse", - dialog_label: str, - pruning_mode: str, - ) -> List[ConversationMessage]: - """相关对话统一剪枝入口,消除 prune_dialog / prune_dataset 中的重复逻辑。 - - - conservative:只删填充 - - semantic / strict:场景感知剪枝 - """ - if pruning_mode == "conservative": - preserve_tokens = self._build_preserve_tokens(extraction) - return self._prune_fillers_only(msgs, preserve_tokens, dialog_label) - else: - return self._prune_with_scene_filter(msgs, extraction, dialog_label, pruning_mode) - - def _prune_fillers_only( - self, - msgs: List[ConversationMessage], - preserve_tokens: List[str], - dialog_label: str, - ) -> List[ConversationMessage]: - """相关对话专用:只删填充消息,LLM 保护消息和实质内容一律保留。 - - 不受 pruning_threshold 约束,删多少算多少(填充有多少删多少)。 - 至少保留 1 条消息。 - 注意:填充检测优先于 preserve_tokens 保护——填充消息本身无信息价值, - 即使 LLM 误将其关键词放入 preserve_tokens 也应删除。 - """ - to_delete_ids: set = set() - for m in msgs: - # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 - if message_has_files(m): - self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}") - continue - - # 填充检测优先:先判断是否为填充,再看 LLM 保护 - if self._is_filler_message(m): - to_delete_ids.add(id(m)) - self._log(f" [填充] '{m.msg[:40]}' → 删除") - continue - if self._msg_matches_tokens(m, preserve_tokens): - self._log(f" [保护] '{m.msg[:40]}' → LLM保护,跳过") - - kept = [m for m in msgs if id(m) not in to_delete_ids] - if not kept and msgs: - kept = [msgs[0]] - - deleted = len(msgs) - len(kept) self._log( - f"[剪枝-相关] {dialog_label} 总消息={len(msgs)} " - f"填充删除={deleted} 保留={len(kept)}" - ) - return kept - - def _prune_with_scene_filter( - self, - msgs: List[ConversationMessage], - extraction: "DialogExtractionResponse", - dialog_label: str, - mode: str, - ) -> List[ConversationMessage]: - """场景感知剪枝,供 semantic / strict 两个阈值档位调用。 - - 本函数体现剪枝系统的三层递进逻辑: - - 第一层(conservative,阈值 < 0.3): - 不进入本函数,由 _prune_fillers_only 处理。 - 保留标准:只问"有没有信息量",填充消息(嗯/好的/哈哈等)删除,其余一律保留。 - - 第二层(semantic,阈值 [0.3, 0.6)): - 保留标准:内容价值优先,场景相关性是参考而非唯一标准。 - - 填充消息 → 删除(最高优先级) - - 场景相关消息 → 保留 - - 场景无关消息 → 有两次豁免机会: - 1. 命中 scene_preserve_tokens(LLM 标记的关键词/时间/金额等)→ 保留 - 2. 含情感词(感觉/压力/开心等)→ 保留(情感内容有记忆价值) - 3. 两次豁免均未命中 → 删除 - - 第三层(strict,阈值 [0.6, 0.9]): - 保留标准:场景相关性优先,无任何豁免。 - - 填充消息 → 删除(最高优先级) - - 场景相关消息 → 保留 - - 场景无关消息 → 直接删除,preserve_keywords 和情感词在此模式下均不生效 - - 至少保留 1 条消息(兜底取第一条)。 - """ - # strict 模式收窄保护范围:只保护结构化关键信息(时间/编号/金额/联系方式/地址), - # 不保护 keywords / preserve_keywords,让场景过滤能删掉更多内容。 - # semantic 模式完整保护:包含 LLM 抽取的所有重要片段(含 keywords 和 preserve_keywords)。 - if mode == "strict": - scene_preserve_tokens = ( - extraction.times + extraction.ids + extraction.amounts + - extraction.contacts + extraction.addresses - ) - else: - scene_preserve_tokens = self._build_preserve_tokens(extraction) - - unrelated_snippets = extraction.scene_unrelated_snippets or [] - - to_delete_ids: set = set() - for m in msgs: - msg_text = m.msg.strip() - - # 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断 - if message_has_files(m): - self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}") - continue - - # 第一优先级:填充消息无论模式直接删除,不参与后续场景判断 - if self._is_filler_message(m): - to_delete_ids.add(id(m)) - self._log(f" [填充] '{msg_text[:40]}' → 删除") - continue - - # 双向包含匹配:处理 LLM 返回片段与原始消息文本长度不完全一致的情况 - is_scene_unrelated = any( - snip and (snip in msg_text or msg_text in snip) - for snip in unrelated_snippets - ) - - if is_scene_unrelated: - if mode == "strict": - # strict:场景无关直接删除,不做任何豁免 - # 场景相关性是唯一裁决标准,preserve_keywords 在此模式下不生效 - to_delete_ids.add(id(m)) - self._log(f" [场景无关-严格] '{msg_text[:40]}' → 删除") - elif mode == "semantic": - # semantic:场景无关但有内容价值 → 保留 - # 豁免第一层:命中 scene_preserve_tokens(关键词/结构化信息保护) - if self._msg_matches_tokens(m, scene_preserve_tokens): - self._log(f" [保护] '{msg_text[:40]}' → 场景关键词保护,保留") - else: - # 豁免第二层:含情感词,认为有情境记忆价值,即使场景无关也保留 - has_contextual_emotion = any( - word in msg_text - for word in ["感觉", "觉得", "心情", "开心", "难过", "高兴", "沮丧", - "喜欢", "讨厌", "爱", "恨", "担心", "害怕", "兴奋", - "压力", "累", "疲惫", "烦", "焦虑", "委屈", "感动"] - ) - if not has_contextual_emotion: - to_delete_ids.add(id(m)) - self._log(f" [场景无关-语义] '{msg_text[:40]}' → 删除(无情感关联)") - else: - self._log(f" [场景关联-保留] '{msg_text[:40]}' → 有情感关联,保留") - else: - # 不在 scene_unrelated_snippets 中 → 场景相关,直接保留 - if self._msg_matches_tokens(m, scene_preserve_tokens): - self._log(f" [保护] '{msg_text[:40]}' → LLM保护,跳过") - # else: 普通场景相关消息,保留,不输出日志 - - kept = [m for m in msgs if id(m) not in to_delete_ids] - if not kept and msgs: - kept = [msgs[0]] - - deleted = len(msgs) - len(kept) - self._log( - f"[剪枝-{mode}] {dialog_label} 总消息={len(msgs)} " - f"删除={deleted} 保留={len(kept)}" - ) - return kept - - def _build_preserve_tokens(self, extraction: "DialogExtractionResponse") -> List[str]: - """统一构建 preserve_tokens,合并 LLM 抽取的所有重要片段。""" - return ( - extraction.times + extraction.ids + extraction.amounts + - extraction.contacts + extraction.addresses + extraction.keywords + - extraction.preserve_keywords + f"[剪枝-初始化] 场景={self.config.pruning_scene}, " + f"语言={self.language}, 开关={self.config.pruning_switch}" ) - def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool: - """判断消息是否包含任意抽取到的重要片段。""" - if not tokens: - return False - t = message.msg - return any(tok and (tok in t) for tok in tokens) + # ────────────────────────────────────────────── + # 公开接口(保持与旧版兼容) + # ────────────────────────────────────────────── async def prune_dialog(self, dialog: DialogData) -> DialogData: - """单对话剪枝:使用一次性对话抽取,避免逐条消息 LLM 调用。 - - 流程: - - 对整段对话进行抽取与相关性判定;若相关则不剪; - - 若不相关:用抽取到的重要片段 + 简单启发识别重要消息,按比例删除不相关消息,优先删除不重要,再删除重要(但重要最多按比例)。 - - 删除策略:不重要消息按出现顺序删除(确定性、无随机)。 - """ + """单对话剪枝入口。""" if not self.config.pruning_switch: return dialog - proportion = float(self.config.pruning_threshold) - extraction = await self._extract_dialog_important(dialog.content) - pruning_mode = self._get_pruning_mode() - self._log(f"[剪枝-模式] 阈值={proportion} → 模式={pruning_mode}") - - if extraction.is_related: - kept = self._apply_related_dialog_pruning( - dialog.context.msgs, extraction, f"对话ID={dialog.id}", pruning_mode - ) - dialog.context = ConversationContext(msgs=kept) - return dialog - - # 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容 - preserve_tokens = self._build_preserve_tokens(extraction) msgs = dialog.context.msgs + kept = await self._prune_messages(msgs, f"对话ID={dialog.id}") + dialog.context = ConversationContext(msgs=kept) - # 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护) - filler_ids: set = set() - deletable: List[ConversationMessage] = [] + # 保存剪枝快照 + self._save_snapshot() - for m in msgs: - if self._msg_matches_tokens(m, preserve_tokens): - pass # 保护消息:不加入任何桶,不会被删除 - elif self._is_filler_message(m): - filler_ids.add(id(m)) - else: - deletable.append(m) - - # 计算删除目标 - total_unrel = len(msgs) - delete_target = int(total_unrel * proportion) - if proportion > 0 and total_unrel > 0 and delete_target == 0: - delete_target = 1 - max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1)) - delete_target = min(delete_target, max_deletable) - - # 优先删填充,再删其他可删消息(按出现顺序) - to_delete_ids: set = set() - for m in msgs: - if len(to_delete_ids) >= delete_target: - break - if id(m) in filler_ids: - to_delete_ids.add(id(m)) - for m in deletable: - if len(to_delete_ids) >= delete_target: - break - to_delete_ids.add(id(m)) - - kept_msgs = [m for m in msgs if id(m) not in to_delete_ids] - if not kept_msgs and msgs: - kept_msgs = [msgs[0]] - - deleted_total = len(msgs) - len(kept_msgs) - protected_count = len(msgs) - len(filler_ids) - len(deletable) - self._log( - f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} " - f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) " - f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}" - ) - - dialog.context = ConversationContext(msgs=kept_msgs) return dialog async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]: - """数据集层面:全局消息级剪枝,保留所有对话。 - - 改进版: - - 消息级独立判断,每条消息根据场景规则独立评估 - - 问答对保护已注释(暂不启用,留作观察) - - 优化删除策略:填充消息 → 不重要消息 → 低分重要消息 - - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留 - - 保证每段对话至少保留1条消息,不会删除整段对话 - """ - # 如果剪枝功能关闭,直接返回原始数据集 + """数据集层面剪枝入口,逐对话处理。""" if not self.config.pruning_switch: return dialogs - # 阈值保护:最高0.9 - proportion = float(self.config.pruning_threshold) - if proportion > 0.9: - logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") - proportion = 0.9 - if proportion < 0.0: - proportion = 0.0 - self._log( - f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断" + f"[剪枝-数据集] 对话总数={len(dialogs)}, " + f"场景={self.config.pruning_scene}, " + f"开关={self.config.pruning_switch}" ) - pruning_mode = self._get_pruning_mode() - self._log(f"[剪枝-数据集] 阈值={proportion} → 剪枝阶段={pruning_mode}") - result: List[DialogData] = [] - total_original_msgs = 0 - total_deleted_msgs = 0 + total_original = 0 + total_deleted = 0 - # 统计对象:直接收集结构化数据,无需事后正则解析 stats = { "scene": self.config.pruning_scene, "dialog_total": len(dialogs), - "deletion_ratio": proportion, "enabled": self.config.pruning_switch, - "pruning_mode": pruning_mode, - "related_count": 0, - "unrelated_count": 0, - "related_indices": [], - "unrelated_indices": [], "total_deleted_messages": 0, "remaining_dialogs": 0, "dialogs": [], } - # 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息) - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse: - async with semaphore: - try: - return await self._extract_dialog_important(dd.content) - except Exception as e: - self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}") - return DialogExtractionResponse(is_related=True) - - extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs] - extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks) - - for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)): + for d_idx, dd in enumerate(dialogs): msgs = dd.context.msgs original_count = len(msgs) - total_original_msgs += original_count + total_original += original_count - # 相关对话:根据阶段决定处理力度 - if extraction.is_related: - stats["related_count"] += 1 - stats["related_indices"].append(d_idx + 1) - kept = self._apply_related_dialog_pruning( - msgs, extraction, f"对话 {d_idx+1}", pruning_mode - ) - deleted_count = original_count - len(kept) - total_deleted_msgs += deleted_count - dd.context.msgs = kept - result.append(dd) - stats["dialogs"].append({ - "index": d_idx + 1, - "is_related": True, - "total_messages": original_count, - "deleted": deleted_count, - "kept": len(kept), - }) - continue + kept = await self._prune_messages(msgs, f"对话 {d_idx + 1}") - stats["unrelated_count"] += 1 - stats["unrelated_indices"].append(d_idx + 1) + deleted_count = original_count - len(kept) + total_deleted += deleted_count - # 从 LLM 抽取结果中获取所有需要保留的 token - preserve_tokens = self._build_preserve_tokens(extraction) - - # 判断是否需要详细日志 - should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog - if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog: - self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志") - - if extraction.preserve_keywords: - self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}") - - # 消息级分类:LLM保护 / 填充 / 其他可删 - llm_protected_msgs = [] # LLM 保护消息(preserve_tokens 命中):绝对不可删除 - filler_msgs = [] # 填充消息(优先删除) - deletable_msgs = [] # 其余消息(按比例删除) - - for idx, m in enumerate(msgs): - msg_text = m.msg.strip() - - # 最高优先级保护:带有文件的消息一律保留,不参与分类 - if message_has_files(m): - self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}") - llm_protected_msgs.append((idx, m)) # 放入保护列表 - continue - - if self._msg_matches_tokens(m, preserve_tokens): - llm_protected_msgs.append((idx, m)) - if should_log_details or idx < self._max_debug_msgs_per_dialog: - self._log(f" [{idx}] '{msg_text[:30]}...' → 保护(LLM,不可删)") - elif self._is_filler_message(m): - filler_msgs.append((idx, m)) - if should_log_details or idx < self._max_debug_msgs_per_dialog: - self._log(f" [{idx}] '{msg_text[:30]}...' → 填充") - else: - deletable_msgs.append((idx, m)) - if should_log_details or idx < self._max_debug_msgs_per_dialog: - self._log(f" [{idx}] '{msg_text[:30]}...' → 可删") - - # important_msgs 仅用于日志统计 - important_msgs = llm_protected_msgs - - # 计算删除配额 - delete_target = int(original_count * proportion) - if proportion > 0 and original_count > 0 and delete_target == 0: - delete_target = 1 - - # 确保至少保留1条消息 - max_deletable = max(0, original_count - 1) - delete_target = min(delete_target, max_deletable) - - # 删除策略:优先删填充消息,再按出现顺序删其余可删消息 - to_delete_indices = set() - deleted_details = [] - - # 第一步:删除填充消息 - for idx, msg in filler_msgs: - if len(to_delete_indices) >= delete_target: - break - to_delete_indices.add(idx) - deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'") - - # 第二步:如果还需要删除,按出现顺序删可删消息 - for idx, msg in deletable_msgs: - if len(to_delete_indices) >= delete_target: - break - to_delete_indices.add(idx) - deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'") - - # 执行删除 - kept_msgs = [] - for idx, m in enumerate(msgs): - if idx not in to_delete_indices: - kept_msgs.append(m) - - # 确保至少保留1条 - if not kept_msgs and msgs: - kept_msgs = [msgs[0]] - - dd.context.msgs = kept_msgs - deleted_count = original_count - len(kept_msgs) - total_deleted_msgs += deleted_count - - # 输出删除详情 - if deleted_details: - self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:") - for detail in deleted_details: - self._log(f" {detail}") - - # ========== 问答对统计(已注释) ========== - # qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else "" - # ======================================== - - self._log( - f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} " - f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) " - f"删除={deleted_count} 保留={len(kept_msgs)}" - ) + dd.context = ConversationContext(msgs=kept) + result.append(dd) stats["dialogs"].append({ "index": d_idx + 1, - "is_related": False, "total_messages": original_count, - "protected": len(important_msgs), - "fillers": len(filler_msgs), - "deletable": len(deletable_msgs), "deleted": deleted_count, - "kept": len(kept_msgs), + "kept": len(kept), }) - result.append(dd) - - # 补全统计对象 - stats["total_deleted_messages"] = total_deleted_msgs + stats["total_deleted_messages"] = total_deleted stats["remaining_dialogs"] = len(result) - self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") - self._log(f"[剪枝-数据集] 相关对话数={stats['related_count']} 不相关对话数={stats['unrelated_count']}") - self._log(f"[剪枝-数据集] 总删除 {total_deleted_msgs} 条") + self._log(f"[剪枝-数据集] 总消息={total_original}, 删除={total_deleted}") - # 直接序列化统计对象,无需正则解析 + # 保存统计日志 + self._save_stats(stats) + + # 保存剪枝快照到 PipelineSnapshot + self._save_snapshot() + + if not result: + logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据") + return dialogs + + return result + + # ────────────────────────────────────────────── + # 核心剪枝逻辑 + # ────────────────────────────────────────────── + + async def _prune_messages( + self, msgs: List[ConversationMessage], label: str + ) -> List[ConversationMessage]: + """对消息列表执行 Assistant 剪枝。 + + 流程: + 1. 扫描消息,配对 User-Assistant 对 + 2. 对每个消息对并发调用 LLM 提取 assistant_memory_hint + 3. hint="NULL" → 删除 Assistant 消息 + 4. hint 非 NULL → 用压缩摘要替换 Assistant 原始消息 + 5. User 消息、带文件的消息、非配对消息原样保留 + """ + if not msgs: + return msgs + + # 第一步:识别 User-Assistant 消息对 + pairs = self._pair_user_assistant(msgs) + # pairs: List[(user_idx, assistant_idx)] + + # 第二步:并发调用 LLM 处理每个消息对 + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def process_pair(user_idx: int, asst_idx: int): + async with semaphore: + user_msg = msgs[user_idx] + asst_msg = msgs[asst_idx] + + # 构建 snapshot 的 input 部分 + input_record = { + "msgs": [ + {"role": "User", "msg": user_msg.msg}, + {"role": "Assistant", "msg": asst_msg.msg}, + ] + } + + # 带文件的 Assistant 消息不剪枝 + if message_has_files(asst_msg): + self._log( + f" [{label}] 索引{asst_idx} 带文件,跳过剪枝" + ) + self._snapshot_records.append({ + "input": input_record, + "gold": { + "assistant_memory_hint": asst_msg.msg, + "assistant_memory_type": "skipped (has files)", + }, + }) + return asst_idx, asst_msg.msg, False + + result = await self._extract_assistant_hint(user_msg, asst_msg) + + # 收集 snapshot 记录 + self._snapshot_records.append({ + "input": input_record, + "gold": { + "assistant_memory_hint": result.assistant_memory_hint, + "assistant_memory_type": result.assistant_memory_type, + }, + }) + + if result.assistant_memory_hint == "NULL": + self._log( + f" [{label}] 索引{asst_idx} → NULL,删除 " + f"('{asst_msg.msg[:40]}')" + ) + return asst_idx, None, True # 标记删除 + else: + self._log( + f" [{label}] 索引{asst_idx} → " + f"type={result.assistant_memory_type}, " + f"hint='{result.assistant_memory_hint[:50]}'" + ) + return asst_idx, result.assistant_memory_hint, False + + tasks = [process_pair(u, a) for u, a in pairs] + pair_results = await asyncio.gather(*tasks) + + # 构建替换/删除映射 + # asst_idx → (new_msg_text | None) + asst_actions: Dict[int, Optional[str]] = {} + for asst_idx, new_text, should_delete in pair_results: + if should_delete: + asst_actions[asst_idx] = None + else: + asst_actions[asst_idx] = new_text + + # 第三步:构建最终消息列表 + kept: List[ConversationMessage] = [] + for idx, m in enumerate(msgs): + if idx in asst_actions: + new_text = asst_actions[idx] + if new_text is None: + # 删除该 Assistant 消息 + continue + else: + # 用压缩摘要替换原始消息 + kept.append(ConversationMessage( + role=m.role, + msg=new_text, + files=m.files, + )) + else: + # User 消息、未配对的消息原样保留 + kept.append(m) + + # 兜底:至少保留 1 条消息 + if not kept and msgs: + kept = [msgs[0]] + + deleted = len(msgs) - len(kept) + self._log( + f"[剪枝] {label} 总消息={len(msgs)}, " + f"配对数={len(pairs)}, 删除={deleted}, 保留={len(kept)}" + ) + return kept + + def _pair_user_assistant( + self, msgs: List[ConversationMessage] + ) -> List[tuple]: + """将消息列表中相邻的 User-Assistant 配对。 + + 规则: + - 遍历消息,遇到 role=user 时记录索引 + - 紧接着的 role=assistant 消息与之配对 + - 连续多条 user 消息只取最后一条作为上下文 + - 未配对的 assistant 消息(如对话开头就是 assistant)不处理 + """ + pairs = [] + last_user_idx = None + + for idx, m in enumerate(msgs): + if m.role == "user": + last_user_idx = idx + elif m.role == "assistant" and last_user_idx is not None: + pairs.append((last_user_idx, idx)) + last_user_idx = None # 一个 user 只配一个 assistant + + return pairs + + # ────────────────────────────────────────────── + # LLM 调用 + # ────────────────────────────────────────────── + + async def _extract_assistant_hint( + self, + user_msg: ConversationMessage, + asst_msg: ConversationMessage, + ) -> AssistantPruningResponse: + """调用 LLM 从 User-Assistant 消息对中提取 Assistant 记忆摘要。 + + 使用 extracat_Pruning.jinja2 模板,输入格式: + {"msgs": [{"role": "User", "msg": "..."}, {"role": "Assistant", "msg": "..."}]} + """ + # 构建模板输入 + dialog_text = json.dumps( + { + "msgs": [ + {"role": "User", "msg": user_msg.msg}, + {"role": "Assistant", "msg": asst_msg.msg}, + ] + }, + ensure_ascii=False, + ) + + # 缓存检查 + cache_key = hashlib.sha1(dialog_text.encode("utf-8")).hexdigest() + if cache_key in self._cache: + self._cache.move_to_end(cache_key) + return self._cache[cache_key] + + # LRU 淘汰 + if len(self._cache) >= self._cache_max_size: + oldest = next(iter(self._cache)) + del self._cache[oldest] + + # 渲染模板 + rendered = self.template.render(dialog_text=dialog_text) + log_template_rendering("extracat_Pruning.jinja2", { + "language": self.language, + }) + log_prompt_rendering("pruning-assistant-hint", rendered) + + if not self.llm_client: + raise RuntimeError("llm_client 未配置;请配置 LLM 以进行 Assistant 剪枝。") + + messages = [ + { + "role": "system", + "content": "你是一个面向记忆存储的辅助信息提取器,只输出严格 JSON。", + }, + {"role": "user", "content": rendered}, + ] + + # 重试机制 + max_retries = 3 + for attempt in range(max_retries): + try: + result = await self.llm_client.response_structured( + messages, AssistantPruningResponse + ) + self._cache[cache_key] = result + return result + except Exception as e: + if attempt < max_retries - 1: + self._log( + f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试: " + f"{str(e)[:100]}" + ) + await asyncio.sleep(0.5 * (attempt + 1)) + else: + # 降级:保留原始消息,不剪枝 + self._log( + f"[剪枝-LLM] {max_retries} 次失败,降级保留原始消息" + ) + return AssistantPruningResponse( + assistant_memory_hint=asst_msg.msg, + assistant_memory_type="NULL", + ) + + # ────────────────────────────────────────────── + # 工具方法 + # ────────────────────────────────────────────── + + def _save_stats(self, stats: dict) -> None: + """保存剪枝统计到文件。""" try: from app.core.config import settings + settings.ensure_memory_output_dir() log_output_path = settings.get_memory_output_path("pruned_terminal.json") with open(log_output_path, "w", encoding="utf-8") as f: json.dump(stats, f, ensure_ascii=False, indent=2) except Exception as e: - self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}") + self._log(f"[剪枝] 保存统计日志失败:{e}") - # Safety: avoid empty dataset - if not result: - logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") - return dialogs + def _save_snapshot(self) -> None: + """将剪枝结果保存到 PipelineSnapshot(1_assistant_pruning.json)。 - return result + 输出格式:每个 User-Assistant 消息对一条记录,包含: + - input.msgs: 原始消息对 [{role, msg}, {role, msg}] + - gold.assistant_memory_hint: LLM 提取的记忆摘要 + - gold.assistant_memory_type: 摘要类型枚举 + """ + if not self._snapshot or not self._snapshot_records: + return + + try: + self._snapshot.save_stage("1_assistant_pruning", self._snapshot_records) + self._log( + f"[剪枝-快照] 已保存 {len(self._snapshot_records)} 条记录 " + f"到 1_assistant_pruning.json" + ) + except Exception as e: + self._log(f"[剪枝-快照] 保存失败: {e}") def _log(self, msg: str) -> None: - """记录日志并打印到终端。""" + """记录日志。""" try: self.run_logs.append(msg) except Exception: pass logger.debug(msg) - - diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 75fc87d2..049c265f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1264,6 +1264,7 @@ class ExtractionOrchestrator: entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx statement_id=statement.id, # 添加必需的 statement_id 字段 entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type + type_description=getattr(entity, 'type_description', ''), description=getattr(entity, 'description', ''), # 添加必需的 description 字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 @@ -1306,6 +1307,7 @@ class ExtractionOrchestrator: source=subject_entity_id, target=object_entity_id, relation_type=triplet.predicate, + relation_type_description=getattr(triplet, 'predicate_description', ''), statement=statement.statement, source_statement_id=statement.id, end_user_id=dialog_data.end_user_id, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index 76a48c58..5c1a9125 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -12,16 +12,21 @@ from app.core.memory.utils.data.ontology import ( TemporalInfo, ) from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt -from pydantic import BaseModel, Field, field_validator +from pydantic import AliasChoices, BaseModel, Field, field_validator logger = logging.getLogger(__name__) class ExtractedStatement(BaseModel): """Schema for extracted statement from LLM""" - statement: str = Field(..., description="The extracted statement text") + statement: str = Field( + ..., + validation_alias=AliasChoices("statement", "statement_text"), + description="The extracted statement text", + ) statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION") temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL") - relevence: str = Field(..., description="RELEVANT or IRRELEVANT") + # New prompt no longer outputs relevence; keep backward-compatible default. + relevence: str = Field("RELEVANT", description="RELEVANT or IRRELEVANT") has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references") class StatementExtractionResponse(BaseModel): @@ -41,7 +46,7 @@ class StatementExtractionResponse(BaseModel): valid_statements = [] filtered_count = 0 for i, stmt in enumerate(v): - if isinstance(stmt, dict) and stmt.get('statement'): + if isinstance(stmt, dict) and (stmt.get("statement") or stmt.get("statement_text")): valid_statements.append(stmt) elif isinstance(stmt, dict): # Log which statement was filtered @@ -96,6 +101,11 @@ class StatementExtractor: """ chunk_content = chunk.content chunk_speaker = self._get_speaker_from_chunk(chunk) + logger.info( + "[LegacyStatementExtractor] chunk_id=%s content_len=%d", + getattr(chunk, "id", ""), + len(chunk_content or ""), + ) if not chunk_content or len(chunk_content.strip()) < 5: logger.warning(f"Chunk {chunk.id} content too short or empty, skipping") @@ -108,7 +118,18 @@ class StatementExtractor: granularity=self.config.statement_granularity, include_dialogue_context=self.config.include_dialogue_context, dialogue_content=dialogue_content, - max_dialogue_chars=self.config.max_dialogue_context_chars + max_dialogue_chars=self.config.max_dialogue_context_chars, + input_json={ + "chunk_id": getattr(chunk, "id", ""), + "end_user_id": end_user_id or "", + "target_content": chunk_content, + "target_message_date": datetime.now().isoformat(), + "supporting_context": { + "msgs": [ + {"role": "context", "msg": dialogue_content} + ] if dialogue_content else [] + }, + }, ) # Simple system message diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index ea355ca1..8eb01cff 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -4,7 +4,7 @@ from typing import List, Dict, Optional from app.core.logging_config import get_memory_logger from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt -from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 +from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS from app.core.memory.models.triplet_models import TripletExtractionResponse from app.core.memory.models.message_models import DialogData, Statement from app.core.memory.models.ontology_extraction_models import OntologyTypeList @@ -73,15 +73,9 @@ class TripletExtractor: try: # Get structured response from LLM response = await self.llm_client.response_structured(messages, TripletExtractionResponse) - # Filter triplets to only allowed predicates from ontology - # 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系 - allowed_predicates = {p.value for p in Predicate} - filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates] - # 仅保留predicate ∈ Predicate 的三元组,其余全部剔除 - # Create new triplets with statement_id set during creation updated_triplets = [] - for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组 + for triplet in response.triplets: updated_triplet = triplet.model_copy(update={"statement_id": statement.id}) updated_triplets.append(updated_triplet) diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py index ea8c2812..4649a17e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py @@ -300,6 +300,33 @@ class NewExtractionOrchestrator: "embedding_output": None, } + if self.progress_callback: + statements_count = sum( + len(stmts) + for chunk_stmts in all_stmt_results.values() + for stmts in chunk_stmts.values() + ) + entities_count = sum( + len(t_out.entities) + for stmt_triplets in all_triplet_results.values() + for t_out in stmt_triplets.values() + ) + triplets_count = sum( + len(t_out.triplets) + for stmt_triplets in all_triplet_results.values() + for t_out in stmt_triplets.values() + ) + await self.progress_callback( + "knowledge_extraction_complete", + "知识抽取完成", + { + "entities_count": entities_count, + "statements_count": statements_count, + "temporal_ranges_count": 0, + "triplets_count": triplets_count, + }, + ) + logger.info("Pilot extraction complete") return dialog_data_list @@ -467,6 +494,11 @@ class NewExtractionOrchestrator: else None ) for chunk in dialog.chunks: + # 仅对 speaker="user" 的 chunk 进行陈述句抽取;assistant 内容交给 + # 上游预处理/剪枝阶段处理,避免浪费 LLM 调用。 + chunk_speaker = getattr(chunk, "speaker", "user") + if chunk_speaker != "user": + continue inp = StatementStepInput( chunk_id=chunk.id, end_user_id=dialog.end_user_id, @@ -478,7 +510,7 @@ class NewExtractionOrchestrator: ) tasks.append(self.statement_step.run(inp)) task_meta.append( - (dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx) + (dialog.id, chunk.id, chunk_speaker, ctx) ) results = await asyncio.gather(*tasks, return_exceptions=True) @@ -499,6 +531,15 @@ class NewExtractionOrchestrator: for s in stmts: s.speaker = speaker stmt_map[dialog_id][chunk_id] = stmts + if self.progress_callback: + # Frontend consumes knowledge_extraction_result with data.statement. + # Emit one event per statement to keep payload contract simple. + for s in stmts: + await self.progress_callback( + "knowledge_extraction_result", + "知识抽取中", + {"statement": s.statement_text}, + ) return stmt_map @@ -520,6 +561,11 @@ class NewExtractionOrchestrator: chunk_stmts = all_stmt_results.get(dialog.id, {}) for _chunk_id, stmts in chunk_stmts.items(): for stmt in stmts: + # 防御性过滤:三元组抽取仅针对 user statement。 + # 上游 _extract_all_statements 已过滤 chunk.speaker,此处再做 + # 一次 statement.speaker 的二次校验,防止外部注入或 legacy 数据脱漏。 + if getattr(stmt, "speaker", "user") != "user": + continue inp = self._convert_to_triplet_input(stmt, ctx) tasks.append(self.triplet_step.run(inp)) task_meta.append((dialog.id, stmt.statement_id)) @@ -541,6 +587,24 @@ class NewExtractionOrchestrator: triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output() else: triplet_map[dialog_id][stmt_id] = result + if self.progress_callback: + await self.progress_callback( + "extract_triplet_result", + f"statement {stmt_id} 提取完成", + { + "statement_id": stmt_id, + "triplet_count": len(result.triplets), + "entity_count": len(result.entities), + "triplets": [ + { + "subject_name": t.subject_name, + "predicate": t.predicate, + "object_name": t.object_name, + } + for t in result.triplets[:5] + ], + }, + ) return triplet_map @@ -842,6 +906,8 @@ class NewExtractionOrchestrator: temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL), # relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT, temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at), + has_unsolved_reference=stmt_out.has_unsolved_reference, + has_emotional_state=stmt_out.has_emotional_state, triplet_extraction_info=triplet_info, statement_embedding=stmt_embedding, **emotion_kwargs, diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/graph_build_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/graph_build_step.py index 5da5f5ab..f329c98d 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/graph_build_step.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/graph_build_step.py @@ -250,6 +250,7 @@ async def build_graph_nodes_and_edges( entity_idx=entity.entity_idx, statement_id=statement.id, entity_type=getattr(entity, "type", "unknown"), + type_description=getattr(entity, "type_description", ""), description=getattr(entity, "description", ""), example=getattr(entity, "example", ""), connect_strength=( @@ -296,6 +297,7 @@ async def build_graph_nodes_and_edges( source=subject_entity_id, target=object_entity_id, relation_type=triplet.predicate, + relation_type_description=getattr(triplet, "predicate_description", ""), statement=statement.statement, source_statement_id=statement.id, end_user_id=dialog_data.end_user_id, diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py b/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py index 8b0ae643..eacab0b6 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py @@ -46,6 +46,7 @@ class StatementStepOutput(BaseModel): temporal_type: str # STATIC / DYNAMIC / ATEMPORAL # relevance: str # RELEVANT / IRRELEVANT speaker: str # "user" / "assistant" + has_emotional_state: bool = False # Whether statement reflects user's emotional state valid_at: str # ISO 8601 or "NULL" invalid_at: str # ISO 8601 or "NULL" has_unsolved_reference: bool = False # Whether the statement has unresolved references @@ -72,6 +73,7 @@ class EntityItem(BaseModel): entity_idx: int name: str type: str + type_description: str = "" description: str is_explicit_memory: bool = False @@ -82,6 +84,7 @@ class TripletItem(BaseModel): subject_name: str subject_id: int predicate: str + predicate_description: str = "" object_name: str object_id: int diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py index a0c76b68..25f13e24 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py @@ -34,6 +34,10 @@ class _ExtractedStatement(BaseModel): statement_type: str = Field(..., description="FACT / OPINION / OTHER") temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL") # relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT") + has_emotional_state: bool = Field( + False, + description="Whether the statement reflects user's emotional state", + ) valid_at: str = Field("NULL", description="ISO 8601 or NULL") invalid_at: str = Field("NULL", description="ISO 8601 or NULL") has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references") @@ -155,6 +159,7 @@ class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementS temporal_type=stmt.temporal_type.strip().upper(), # relevance=stmt.relevance.strip().upper(), speaker="user", # default; orchestrator overrides from chunk metadata + has_emotional_state=getattr(stmt, "has_emotional_state", False), valid_at=stmt.valid_at or "NULL", invalid_at=stmt.invalid_at or "NULL", has_unsolved_reference=getattr(stmt, "has_unsolved_reference", False), diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/triplet_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/triplet_step.py index af143a62..9f8953b8 100644 --- a/api/app/core/memory/storage_services/extraction_engine/steps/triplet_step.py +++ b/api/app/core/memory/storage_services/extraction_engine/steps/triplet_step.py @@ -112,6 +112,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput]) subject_name=t.subject_name, subject_id=t.subject_id, predicate=t.predicate, + predicate_description=getattr(t, "predicate_description", ""), object_name=t.object_name, object_id=t.object_id, ) @@ -123,6 +124,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput]) entity_idx=e.entity_idx, name=e.name, type=e.type, + type_description=getattr(e, "type_description", ""), description=e.description, is_explicit_memory=getattr(e, "is_explicit_memory", False), ) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index f5c58dbe..247da0a9 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -92,6 +92,7 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity THEN entity.expired_at ELSE e.expired_at END, e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END, e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END, + e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END, e.description = CASE WHEN entity.description IS NOT NULL AND entity.description <> '' AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description)) @@ -147,6 +148,7 @@ MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id}) // Avoid duplicate edges across runs for the same endpoints MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object) SET r.predicate = rel.predicate, + r.predicate_description = rel.predicate_description, r.statement_id = rel.statement_id, r.value = rel.value, r.statement = rel.statement, diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 56feece2..6f0e03a5 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -44,6 +44,7 @@ async def save_entities_and_relationships( 'source_id': edge.source, 'target_id': edge.target, 'predicate': edge.relation_type, + 'predicate_description': edge.relation_type_description, 'statement_id': edge.source_statement_id, 'value': edge.relation_value, 'statement': edge.statement, @@ -297,6 +298,7 @@ async def save_dialog_and_statements_to_neo4j( 'source_id': edge.source, 'target_id': edge.target, 'predicate': edge.relation_type, + 'predicate_description': edge.relation_type_description, 'statement_id': edge.source_statement_id, 'value': edge.relation_value, 'statement': edge.statement, diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 132370b6..0282fa5a 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -441,21 +441,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) with open(result_path, "r", encoding="utf-8") as rf: extracted_result = json.load(rf) - # 步骤 6: 计算本体覆盖率并合并到结果中 + # 步骤 6: 组装结果(试运行不做额外覆盖率后处理) result_data = { "config_id": cid, "time_log": os.path.join(project_root, "logs", "time.log"), "extracted_result": extracted_result, } - try: - ontology_coverage = await self._compute_ontology_coverage( - extracted_result=extracted_result, - memory_config=memory_config, - ) - if ontology_coverage: - result_data["ontology_coverage"] = ontology_coverage - except Exception as cov_err: - logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True) yield format_sse_message("result", result_data) @@ -479,100 +470,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "time": int(time.time() * 1000) }) - async def _compute_ontology_coverage( - self, - extracted_result: Dict[str, Any], - memory_config, - ) -> Optional[Dict[str, Any]]: - """根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。 - - 分类规则(互斥):场景类型优先 > 通用类型 > 未匹配 - 确保: 场景实体数 + 通用实体数 + 未匹配数 = 总实体数 - - Returns: - 包含三部分统计的字典,或 None(无实体数据时) - """ - core_entities = extracted_result.get("core_entities", []) - if not core_entities: - return None - - # 1. 加载场景本体类型集合 - scene_ontology_types: set = set() - try: - from app.repositories.ontology_class_repository import OntologyClassRepository - - if memory_config.scene_id: - class_repo = OntologyClassRepository(self.db) - ontology_classes = class_repo.get_classes_by_scene(memory_config.scene_id) - scene_ontology_types = {oc.class_name for oc in ontology_classes} - except Exception as e: - logger.warning(f"Failed to load scene ontology types: {e}") - - # 2. 加载通用本体类型集合 - general_ontology_types: set = set() - try: - from app.core.memory.ontology_services.ontology_type_loader import ( - get_general_ontology_registry, - is_general_ontology_enabled, - ) - - if is_general_ontology_enabled(): - registry = get_general_ontology_registry() - if registry: - general_ontology_types = set(registry.types.keys()) - except Exception as e: - logger.warning(f"Failed to load general ontology types: {e}") - - # 3. 互斥分类:场景优先 > 通用 > 未匹配 - scene_distribution: list = [] - general_distribution: list = [] - unmatched_distribution: list = [] - scene_total = 0 - general_total = 0 - unmatched_total = 0 - - for item in core_entities: - entity_type = item.get("type", "") - count = item.get("count", 0) - - if entity_type in scene_ontology_types: - scene_distribution.append({"type": entity_type, "count": count}) - scene_total += count - elif entity_type in general_ontology_types: - general_distribution.append({"type": entity_type, "count": count}) - general_total += count - else: - unmatched_distribution.append({"type": entity_type, "count": count}) - unmatched_total += count - - # 按数量降序排列 - scene_distribution.sort(key=lambda x: x["count"], reverse=True) - general_distribution.sort(key=lambda x: x["count"], reverse=True) - unmatched_distribution.sort(key=lambda x: x["count"], reverse=True) - - total_entities = scene_total + general_total + unmatched_total - - return { - "scene_type_distribution": { - "type_count": len(scene_distribution), - "entity_total": scene_total, - "types": scene_distribution, - }, - "general_type_distribution": { - "type_count": len(general_distribution), - "entity_total": general_total, - "types": general_distribution, - }, - "unmatched": { - "type_count": len(unmatched_distribution), - "entity_total": unmatched_total, - "types": unmatched_distribution, - }, - "total_entities": total_entities, - "time": int(time.time() * 1000), - } - - # -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- # Ensure env for connector (e.g., NEO4J_PASSWORD) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 4617946b..5c7da40e 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -10,7 +10,9 @@ import time from datetime import datetime from typing import Awaitable, Callable, Optional +from app.core.config import settings from app.core.logging_config import get_memory_logger, log_time +from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline from app.core.memory.models.message_models import ( ConversationContext, ConversationMessage, @@ -20,9 +22,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator ExtractionOrchestrator, get_chunked_dialogs_from_preprocessed, ) -from app.core.memory.utils.config.config_utils import ( - get_pipeline_config, +from app.core.memory.storage_services.extraction_engine.pipeline_help import ( + _write_extracted_result_summary, + export_test_input_doc, ) +from app.core.memory.utils.config.config_utils import get_pipeline_config from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -31,6 +35,42 @@ from sqlalchemy.orm import Session logger = get_memory_logger(__name__) +def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None: + """Write triplet/entity text report compatible with pipeline_help parsers.""" + all_triplets = [] + all_entities = [] + + for dialog in dialog_data_list: + for chunk in getattr(dialog, "chunks", []) or []: + for statement in getattr(chunk, "statements", []) or []: + triplet_info = getattr(statement, "triplet_extraction_info", None) + if not triplet_info: + continue + all_triplets.extend(getattr(triplet_info, "triplets", []) or []) + all_entities.extend(getattr(triplet_info, "entities", []) or []) + + with open(output_path, "w", encoding="utf-8") as f: + f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n") + for i, triplet in enumerate(all_triplets, 1): + f.write(f"Triplet {i}:\n") + f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n") + f.write(f" Predicate: {triplet.predicate}\n") + f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n") + value = getattr(triplet, "value", None) + if value: + f.write(f" Value: {value}\n") + f.write("\n") + + f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n") + for i, entity in enumerate(all_entities, 1): + f.write(f"Entity {i}:\n") + f.write(f" ID: {entity.entity_idx}\n") + f.write(f" Name: {entity.name}\n") + f.write(f" Type: {entity.type}\n") + f.write(f" Description: {entity.description}\n") + f.write("\n") + + async def run_pilot_extraction( memory_config: MemoryConfig, dialogue_text: str, @@ -58,7 +98,6 @@ async def run_pilot_extraction( f.write(f"\n=== Pilot Run Started: {timestamp} ===\n") pipeline_start = time.time() - neo4j_connector = None try: # 步骤 1: 初始化客户端 @@ -69,8 +108,6 @@ async def run_pilot_extraction( llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id)) embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id)) - neo4j_connector = Neo4jConnector() - log_time("Client Initialization", time.time() - step_start, log_file) # 步骤 2: 解析对话文本 @@ -242,15 +279,17 @@ async def run_pilot_extraction( log_time("Data Loading & Chunking", time.time() - step_start, log_file) - # 步骤 3: 初始化流水线编排器 - logger.info("Initializing extraction orchestrator...") - step_start = time.time() - - config = get_pipeline_config(memory_config) + # 步骤 3: 初始化并选择试运行流水线(环境变量可切换) + use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE) logger.info( - f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, " - f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}" + "Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s", + use_refactored, ) + logger.info( + "Initializing %s pilot pipeline...", + "refactored" if use_refactored else "legacy", + ) + step_start = time.time() # 加载本体类型(如果配置了 scene_id),支持通用类型回退 ontology_types = None @@ -266,100 +305,105 @@ async def run_pilot_extraction( except Exception as e: logger.warning(f"Failed to load ontology types: {e}", exc_info=True) - orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=neo4j_connector, - config=config, - progress_callback=progress_callback, - embedding_id=str(memory_config.embedding_model_id), - language=language, - ontology_types=ontology_types, - ) + if use_refactored: + pilot_pipeline = PilotWritePipeline( + llm_client=llm_client, + embedder_client=embedder_client, + pipeline_config=get_pipeline_config(memory_config), + progress_callback=progress_callback, + embedding_id=str(memory_config.embedding_model_id), + language=language, + ontology_types=ontology_types, + ) + log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file) - log_time("Orchestrator Initialization", time.time() - step_start, log_file) + # 步骤 4a: 执行重构后试运行短链路 + # statement -> triplet -> graph_build -> 第一层去重消歧(结束) + logger.info("Running refactored pilot extraction short pipeline...") + step_start = time.time() - # 步骤 4: 执行知识提取流水线 - logger.info("Running extraction pipeline...") - step_start = time.time() + if progress_callback: + await progress_callback("knowledge_extraction", "正在知识抽取...") - if progress_callback: - await progress_callback("knowledge_extraction", "正在知识抽取...") + pilot_result = await pilot_pipeline.run(chunked_dialogs) + dialog_data_list = pilot_result.dialog_data_list + graph = pilot_result.graph + chunk_nodes = graph.chunk_nodes + export_entity_nodes = graph.entity_nodes + export_stmt_entity_edges = graph.stmt_entity_edges + export_entity_edges = graph.entity_entity_edges + else: + # 步骤 4b: 执行旧试运行流水线 + logger.info("Running legacy pilot extraction pipeline...") + step_start = time.time() - extraction_result = await orchestrator.run( - dialog_data_list=chunked_dialogs, - is_pilot_run=True, - ) + if progress_callback: + await progress_callback("knowledge_extraction", "正在知识抽取...") - # 解包 extraction_result tuple (与 main.py 保持一致) - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - _, - statement_chunk_edges, - statement_entity_edges, - entity_edges, - _, - _ - ) = extraction_result + neo4j_connector = Neo4jConnector() + try: + legacy_orchestrator = ExtractionOrchestrator( + llm_client=llm_client, + embedder_client=embedder_client, + connector=neo4j_connector, + config=get_pipeline_config(memory_config), + progress_callback=progress_callback, + embedding_id=str(memory_config.embedding_model_id), + language=language, + ontology_types=ontology_types, + ) + extraction_result = await legacy_orchestrator.run( + dialog_data_list=chunked_dialogs, + is_pilot_run=True, + ) + ( + _dialogue_nodes, + chunk_nodes, + _statement_nodes, + entity_nodes, + _perceptual_nodes, + _statement_chunk_edges, + statement_entity_edges, + entity_edges, + _perceptual_edges, + _last_created_at, + ) = extraction_result + dialog_data_list = chunked_dialogs + export_entity_nodes = entity_nodes + export_stmt_entity_edges = statement_entity_edges + export_entity_edges = entity_edges + finally: + try: + await neo4j_connector.close() + except Exception: + pass log_time("Extraction Pipeline", time.time() - step_start, log_file) if progress_callback: await progress_callback("generating_results", "正在生成结果...") - # 步骤 5: 生成记忆摘要(与 main.py 保持一致) - try: - logger.info("Generating memory summaries...") - step_start = time.time() + # 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约) + settings.ensure_memory_output_dir() + export_test_input_doc( + entity_nodes=export_entity_nodes, + statement_entity_edges=export_stmt_entity_edges, + entity_entity_edges=export_entity_edges, + ) + _save_triplets_from_dialogs( + dialog_data_list=dialog_data_list, + output_path=settings.get_memory_output_path("extracted_triplets.txt"), + ) + _write_extracted_result_summary( + chunk_nodes=chunk_nodes, + pipeline_output_dir=settings.get_memory_output_path(), + ) - from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - memory_summary_generation, - ) - - summaries = await memory_summary_generation( - chunked_dialogs, - llm_client=llm_client, - embedder_client=embedder_client, - language=language, - ) - - log_time("Memory Summary Generation", time.time() - step_start, log_file) - except Exception as e: - logger.error(f"Memory summary step failed: {e}", exc_info=True) - - logger.info("Pilot run completed: Skipping Neo4j save") - - # 将提取统计写入 Redis,按 workspace_id 存储 - try: - from app.cache.memory.activity_stats_cache import ActivityStatsCache - - stats_to_cache = { - "chunk_count": len(chunk_nodes) if chunk_nodes else 0, - "statements_count": len(statement_nodes) if statement_nodes else 0, - "triplet_entities_count": len(entity_nodes) if entity_nodes else 0, - "triplet_relations_count": len(entity_edges) if entity_edges else 0, - "temporal_count": 0, # temporal 数据在日志中,此处暂置0 - } - await ActivityStatsCache.set_activity_stats( - workspace_id=str(memory_config.workspace_id), - stats=stats_to_cache, - ) - logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") - except Exception as cache_err: - logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)") except Exception as e: logger.error(f"Pilot run failed: {e}", exc_info=True) raise - finally: - if neo4j_connector: - try: - await neo4j_connector.close() - except Exception: - pass total_time = time.time() - pipeline_start log_time("TOTAL PILOT RUN TIME", total_time, log_file) diff --git a/api/app/tasks.py b/api/app/tasks.py index b7de2fd2..54ebe80f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1382,6 +1382,7 @@ def extract_emotion_batch_task( llm_model_id: str, language: str = "zh", emotion_config: Optional[Dict[str, Any]] = None, + snapshot_dir: Optional[str] = None, ) -> Dict[str, Any]: """Celery task: batch emotion extraction + Neo4j backfill. @@ -1395,6 +1396,10 @@ def extract_emotion_batch_task( language: Language code ("zh" / "en"). emotion_config: Optional dict with emotion step config overrides (emotion_extract_keywords, emotion_enable_subject). + snapshot_dir: Optional absolute path of the current run's snapshot directory. + When provided (only in debug mode), emotion outputs will be + dumped to /4_emotion_outputs.json for offline + comparison between the legacy / new pipelines. """ task_id = self.request.id total = len(statements) @@ -1445,6 +1450,8 @@ def extract_emotion_batch_task( extracted = 0 failed = 0 update_items = [] + # 快照用:收集每条 statement 的 EmotionStepOutput(仅当 snapshot_dir 非空时使用) + snapshot_outputs: Dict[str, Any] = {} if snapshot_dir else None # type: ignore[assignment] async def _extract_one(stmt_dict: Dict[str, str]): nonlocal extracted, failed @@ -1461,6 +1468,8 @@ def extract_emotion_batch_task( "emotion_intensity": result.emotion_intensity, "emotion_keywords": result.emotion_keywords, }) + if snapshot_outputs is not None: + snapshot_outputs[stmt_dict["statement_id"]] = result.model_dump() extracted += 1 logger.debug( f"[Emotion] 单条提取完成: stmt={stmt_dict['statement_id']}, " @@ -1468,12 +1477,33 @@ def extract_emotion_batch_task( ) except Exception as e: failed += 1 + if snapshot_outputs is not None: + snapshot_outputs[stmt_dict["statement_id"]] = {"error": str(e)} logger.warning( f"[Emotion] 单条提取失败 stmt={stmt_dict['statement_id']}: {e}" ) await asyncio.gather(*[_extract_one(s) for s in statements]) + # 快照落盘(worker 端):不影响 Neo4j 写入流程,失败只打日志 + if snapshot_outputs is not None: + try: + from pathlib import Path as _Path + import json as _json + + _dir = _Path(snapshot_dir) + _dir.mkdir(parents=True, exist_ok=True) + _path = _dir / "4_emotion_outputs.json" + with open(_path, "w", encoding="utf-8") as _f: + _json.dump(snapshot_outputs, _f, ensure_ascii=False, indent=2, default=str) + logger.info( + f"[Emotion][Snapshot] 已落盘 {len(snapshot_outputs)} 条情绪结果 → {_path}" + ) + except Exception as _e: + logger.warning( + f"[Emotion][Snapshot] 快照落盘失败(不影响主流程): {_e}" + ) + # Batch update Neo4j via write transaction if update_items: connector = Neo4jConnector()