diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 4c667061..3cd1fa0a 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -1,7 +1,4 @@ -import os -import json from typing import List -from datetime import datetime from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage @@ -34,6 +31,7 @@ async def get_chunked_dialogs( conversation_messages = [] +# step1: 消息格式校验 role:user、assistant。content for idx, msg in enumerate(messages): if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields") @@ -59,7 +57,7 @@ async def get_chunked_dialogs( config_id=config_id ) - # 语义剪枝步骤(在分块之前) +# step2: 语义剪枝步骤(在分块之前) try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner from app.core.memory.models.config_models import PruningConfig @@ -116,6 +114,7 @@ async def get_chunked_dialogs( except Exception as e: logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True) +# step3: 分块 chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = extracted_chunks diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 3b0ea1ee..473e9189 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -147,7 +147,85 @@ async def write( all_perceptual_edges, all_dedup_details, ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) + +# region TODO 乐力齐 重构流水线切换至生产环境稳定后,移除快照对比代码 + # ── Snapshot: 旧流水线萃取结果(按 phase2_step_io_schema_v1.md 格式) ── + from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot + snapshot = PipelineSnapshot("legacy") + # Statement 输出(从 dialog_data_list 中提取) + stmt_snapshot = [] + for d in all_dedup_details: + if not hasattr(d, "chunks"): + continue + for c in d.chunks: + for s in c.statements: + stmt_snapshot.append({ + "statement_id": s.id, + "statement_text": s.statement, + "statement_type": str(getattr(s, "stmt_type", "")), + "temporal_type": str(getattr(s, "temporal_info", "")), + "relevance": str(getattr(s, "relevence_info", "RELEVANT")), + "speaker": getattr(s, "speaker", "user") or "user", + "valid_at": s.temporal_validity.valid_at if s.temporal_validity else "NULL", + "invalid_at": s.temporal_validity.invalid_at if s.temporal_validity else "NULL", + }) + snapshot.save_stage("2_statement_outputs", stmt_snapshot) + + # Triplet 输出(从 dialog_data_list 中提取) + triplet_snapshot = {} + for d in all_dedup_details: + if not hasattr(d, "chunks"): + continue + for c in d.chunks: + for s in c.statements: + if s.triplet_extraction_info: + triplet_snapshot[s.id] = { + "entities": [ + { + "entity_idx": e.entity_idx, "name": e.name, + "type": e.type, "description": e.description, + "is_explicit_memory": getattr(e, "is_explicit_memory", False), + } + for e in s.triplet_extraction_info.entities + ], + "triplets": [ + { + "subject_name": t.subject_name, "subject_id": t.subject_id, + "predicate": t.predicate, + "object_name": t.object_name, "object_id": t.object_id, + } + for t in s.triplet_extraction_info.triplets + ], + } + snapshot.save_stage("3_triplet_outputs", triplet_snapshot) + + # 图节点和边(去重后) + snapshot.save_stage("6_nodes_edges_after_dedup", { + "dialogue_nodes_count": len(all_dialogue_nodes), + "chunk_nodes_count": len(all_chunk_nodes), + "statement_nodes_count": len(all_statement_nodes), + "entity_nodes": [ + {"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description} + for e in all_entity_nodes + ], + "entity_entity_edges": [ + { + "source": e.source, "target": e.target, + "relation_type": e.relation_type, "statement": e.statement, + } + for e in all_entity_entity_edges + ], + }) + snapshot.save_summary({ + "dialogue_count": len(all_dialogue_nodes), + "chunk_count": len(all_chunk_nodes), + "statement_count": len(all_statement_nodes), + "entity_count": len(all_entity_nodes), + "relation_count": len(all_entity_entity_edges), + }) +# endregion + log_time("Extraction Pipeline", time.time() - step_start, log_file) # Step 3: Save all data to Neo4j database diff --git a/api/app/core/memory/models/variate_config.py b/api/app/core/memory/models/variate_config.py index 24abd39c..a743a554 100644 --- a/api/app/core/memory/models/variate_config.py +++ b/api/app/core/memory/models/variate_config.py @@ -149,3 +149,5 @@ class ExtractionPipelineConfig(BaseModel): temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig) deduplication: DedupConfig = Field(default_factory=DedupConfig) forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig) + # 情绪引擎(旁路模块,SidecarStepFactory 通过此字段判断是否启用) + emotion_enabled: bool = Field(default=False, description="是否启用情绪提取旁路") diff --git a/api/app/core/memory/pipelines/write_pipeline.py b/api/app/core/memory/pipelines/write_pipeline.py index 194ecdeb..cc30df7d 100644 --- a/api/app/core/memory/pipelines/write_pipeline.py +++ b/api/app/core/memory/pipelines/write_pipeline.py @@ -12,20 +12,33 @@ WritePipeline — 记忆写入流水线 依赖方向:Facade → Pipeline → Engine → Repository(单向,不允许反向调用) """ + from __future__ import annotations import asyncio import logging import time import uuid -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional +from pydantic import BaseModel, Field, ConfigDict + if TYPE_CHECKING: - from app.core.memory.models.graph_models import ExtractedEntityNode from app.core.memory.models.message_models import DialogData from app.schemas.memory_config_schema import MemoryConfig +from app.core.memory.models.graph_models import ( + ChunkNode, + DialogueNode, + EntityEntityEdge, + ExtractedEntityNode, + PerceptualEdge, + PerceptualNode, + StatementChunkEdge, + StatementEntityEdge, + StatementNode, +) + logger = logging.getLogger(__name__) @@ -34,36 +47,40 @@ logger = logging.getLogger(__name__) # ────────────────────────────────────────────── -@dataclass -class ExtractionResult: - """萃取步骤的结构化输出,替代 ExtractionOrchestrator.run() 返回的裸元组。 +class ExtractionResult(BaseModel): + """萃取 + 图构建 + 去重消歧后的结构化输出。 - 字段与 ExtractionOrchestrator.run() 的 10 元素返回值一一对应: - [0] dialogue_nodes → self.dialogue_nodes - [1] chunk_nodes → self.chunk_nodes - [2] statement_nodes → self.statement_nodes - [3] entity_nodes → self.entity_nodes - [4] perceptual_nodes → self.perceptual_nodes - [5] stmt_chunk_edges → self.stmt_chunk_edges - [6] stmt_entity_edges → self.stmt_entity_edges - [7] entity_entity_edges → self.entity_entity_edges - [8] perceptual_edges → self.perceptual_edges - [9] dialog_data_list → self.dialog_data_list + 作为 Pipeline 层的阶段间数据载体,确保下游步骤(_store、_cluster) + 接收到的图节点和边结构完整、类型正确。 - 注意:字段类型使用 List[Any] 而非具体的 graph_models 类型, - 避免在模块加载时触发循环依赖。Pipeline 只做数据传递,不检查具体类型。 + 字段对应 ExtractionOrchestrator 产出的图节点/边: + dialogue_nodes — 对话节点 + chunk_nodes — 分块节点 + statement_nodes — 陈述句节点 + entity_nodes — 实体节点(去重消歧后) + perceptual_nodes — 感知节点 + stmt_chunk_edges — 陈述句 → 分块 边 + stmt_entity_edges — 陈述句 → 实体 边 + entity_entity_edges — 实体 → 实体 边(去重消歧后) + perceptual_edges — 感知 → 分块 边 + dialog_data_list — 原始 DialogData(供摘要阶段使用) """ - dialogue_nodes: List[Any] - chunk_nodes: List[Any] - statement_nodes: List[Any] - entity_nodes: List[Any] - perceptual_nodes: List[Any] - stmt_chunk_edges: List[Any] - stmt_entity_edges: List[Any] - entity_entity_edges: List[Any] - perceptual_edges: List[Any] - dialog_data_list: List[Any] + model_config = ConfigDict(arbitrary_types_allowed=True) + + dialogue_nodes: List[DialogueNode] + chunk_nodes: List[ChunkNode] + statement_nodes: List[StatementNode] + entity_nodes: List[ExtractedEntityNode] + perceptual_nodes: List[PerceptualNode] + stmt_chunk_edges: List[StatementChunkEdge] + stmt_entity_edges: List[StatementEntityEdge] + entity_entity_edges: List[EntityEntityEdge] + perceptual_edges: List[PerceptualEdge] + dialog_data_list: List[Any] = Field( + default_factory=list, + description="原始 DialogData 列表,类型为 Any 以避免循环依赖", + ) @property def stats(self) -> Dict[str, int]: @@ -78,8 +95,7 @@ class ExtractionResult: } -@dataclass -class WriteResult: +class WriteResult(BaseModel): """写入流水线的最终输出,返回给 MemoryService / MemoryAgentService""" status: str # "success" | "pilot_complete" | "failed" @@ -114,7 +130,7 @@ class WritePipeline: memory_config: 不可变的记忆配置对象(从数据库加载) end_user_id: 终端用户 ID language: 语言 ("zh" | "en") - progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None] + progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None] 供pilot run使用 """ self.memory_config = memory_config self.end_user_id = end_user_id @@ -145,7 +161,7 @@ class WritePipeline: is_pilot_run: 试运行模式(只萃取不写入) Returns: - WriteResult 包含状态和统计信息 + WriteResult 包含状态和统计信息 """ if not ref_id: ref_id = uuid.uuid4().hex @@ -164,7 +180,7 @@ class WritePipeline: self._init_clients() self._init_neo4j_connector() - # Step 1: 预处理 - 消息分块 + # Step 1: 预处理 - 消息分块 + AI消息语义剪枝(暂无实现) step_start = time.time() chunked_dialogs = await self._preprocess(messages, ref_id) chunks_count = sum(len(d.chunks) for d in chunked_dialogs) @@ -175,9 +191,7 @@ class WritePipeline: # Step 2: 萃取 - 知识提取 step_start = time.time() - extraction_result = await self._extract( - chunked_dialogs, is_pilot_run - ) + extraction_result = await self._extract(chunked_dialogs, is_pilot_run) stats = extraction_result.stats logger.info( f"[WritePipeline] [2/5] 萃取:知识提取 " @@ -190,9 +204,7 @@ class WritePipeline: # 试运行模式到此结束 if is_pilot_run: elapsed = time.time() - pipeline_start - logger.info( - f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s" - ) + logger.info(f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s") return WriteResult( status="pilot_complete", extraction=extraction_result.stats, @@ -227,9 +239,7 @@ class WritePipeline: await self._update_stats_cache(extraction_result) elapsed = time.time() - pipeline_start - logger.info( - f"[WritePipeline] 完成 ✔ {elapsed:.2f}s" - ) + logger.info(f"[WritePipeline] 完成 ✔ {elapsed:.2f}s") return WriteResult( status="success", extraction=extraction_result.stats, @@ -251,16 +261,14 @@ class WritePipeline: # Step 1: 预处理 # ────────────────────────────────────────────── - async def _preprocess( - self, messages: List[dict], ref_id: str - ) -> List[DialogData]: + async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]: """ - 预处理:消息校验 → 语义剪枝 → 对话分块。 + 预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。 委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。 get_dialogs.py 内部已包含: - 消息格式校验(role/content 必填) - - 语义剪枝(根据 config 中 pruning_enabled 决定) + - AI消息语义剪枝(根据 config 中 pruning_enabled 决定) - DialogueChunker 分块 """ from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs @@ -283,56 +291,187 @@ class WritePipeline: is_pilot_run: bool, ) -> ExtractionResult: """ - 萃取:初始化引擎 → 执行知识提取 → 返回结构化结果。 + 萃取:初始化引擎 → 执行知识提取 → 构建图节点/边 → 去重 → 返回结构化结果。 - ExtractionOrchestrator 作为萃取引擎被调用, - Pipeline 不关心引擎内部的并行策略和提取细节。 + 使用 NewExtractionOrchestrator(ExtractionStep 范式)完成 LLM 萃取, + 然后通过独立的 graph_build_step 和 dedup_step 完成图构建和去重, + 不依赖旧编排器 ExtractionOrchestrator。 + + 执行流程: + 1. NewExtractionOrchestrator.run() → 萃取并赋值到 DialogData + 2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边 + 3. run_dedup() → 两阶段去重消歧 """ - from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - ExtractionOrchestrator, + from app.core.memory.storage_services.extraction_engine.dedup_step import ( + run_dedup, ) + from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import ( + build_graph_nodes_and_edges, + ) + from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import ( + NewExtractionOrchestrator, + ) + from app.core.memory.utils.config.config_utils import get_pipeline_config + from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot pipeline_config = get_pipeline_config(self.memory_config) ontology_types = self._load_ontology_types() - orchestrator = ExtractionOrchestrator( + snapshot = PipelineSnapshot("new") + + # ── 新编排器:LLM 萃取 + 数据赋值 ── + new_orchestrator = NewExtractionOrchestrator( llm_client=self._llm_client, embedder_client=self._embedder_client, - connector=self._neo4j_connector, config=pipeline_config, embedding_id=str(self.memory_config.embedding_model_id), - language=self.language, ontology_types=ontology_types, + language=self.language, + is_pilot_run=is_pilot_run, + progress_callback=self.progress_callback, + ) + # step1: 执行知识提取 + dialog_data_list = await new_orchestrator.run(chunked_dialogs) + + # ── Snapshot: 各阶段萃取结果 ── TODO 乐力齐 重构流水线切换生产环境稳定后修改 + stage_outputs = new_orchestrator.last_stage_outputs + if stage_outputs: + stmt_results = stage_outputs.get("statement_results", {}) + stmt_snapshot = [] + for _did, chunk_stmts in stmt_results.items(): + for _cid, stmts in chunk_stmts.items(): + for s in stmts: + stmt_snapshot.append(s.model_dump()) + snapshot.save_stage("2_statement_outputs", stmt_snapshot) + + triplet_results = stage_outputs.get("triplet_results", {}) + triplet_snapshot = {} + for _did, stmt_triplets in triplet_results.items(): + for stmt_id, t_out in stmt_triplets.items(): + triplet_snapshot[stmt_id] = t_out.model_dump() + snapshot.save_stage("3_triplet_outputs", triplet_snapshot) + + emotion_results = stage_outputs.get("emotion_results", {}) + emotion_snapshot = {} + for stmt_id, emo in emotion_results.items(): + if hasattr(emo, "model_dump"): + emotion_snapshot[stmt_id] = emo.model_dump() + snapshot.save_stage("4_emotion_outputs", emotion_snapshot) + + emb_output = stage_outputs.get("embedding_output") + if emb_output and hasattr(emb_output, "model_dump"): + emb_data = emb_output.model_dump() + for key in ( + "statement_embeddings", + "chunk_embeddings", + "entity_embeddings", + ): + if key in emb_data and isinstance(emb_data[key], dict): + emb_data[key] = { + k: v[:5] if isinstance(v, list) else v + for k, v in emb_data[key].items() + } + if "dialog_embeddings" in emb_data and isinstance( + emb_data["dialog_embeddings"], list + ): + emb_data["dialog_embeddings"] = [ + v[:5] if isinstance(v, list) else v + for v in emb_data["dialog_embeddings"] + ] + snapshot.save_stage("5_embedding_outputs", emb_data) + + # step2: 构建图节点和边 + graph = await build_graph_nodes_and_edges( + dialog_data_list=dialog_data_list, + embedder_client=self._embedder_client, progress_callback=self.progress_callback, ) - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - perceptual_nodes, - stmt_chunk_edges, - stmt_entity_edges, - entity_entity_edges, - perceptual_edges, - dialog_data_list, - ) = await orchestrator.run(chunked_dialogs, is_pilot_run=is_pilot_run) + # region Snapshot: 图节点和边(去重前)Snapshot有关的内容在重构流水线切换生产环境之后修改 + snapshot.save_stage( + "6_nodes_edges_before_dedup", + { + "dialogue_nodes_count": len(graph.dialogue_nodes), + "chunk_nodes_count": len(graph.chunk_nodes), + "statement_nodes_count": len(graph.statement_nodes), + "entity_nodes": [ + { + "id": e.id, + "name": e.name, + "entity_type": e.entity_type, + "description": e.description, + } + for e in graph.entity_nodes + ], + "entity_entity_edges": [ + { + "source": e.source, + "target": e.target, + "relation_type": e.relation_type, + "statement": e.statement, + } + for e in graph.entity_entity_edges + ], + "stmt_entity_edges_count": len(graph.stmt_entity_edges), + }, + ) - return ExtractionResult( - dialogue_nodes=dialogue_nodes, - chunk_nodes=chunk_nodes, - statement_nodes=statement_nodes, - entity_nodes=entity_nodes, - perceptual_nodes=perceptual_nodes, - stmt_chunk_edges=stmt_chunk_edges, - stmt_entity_edges=stmt_entity_edges, - entity_entity_edges=entity_entity_edges, - perceptual_edges=perceptual_edges, + # step3: 两阶段去重消歧 + dedup_result = await run_dedup( + entity_nodes=graph.entity_nodes, + statement_entity_edges=graph.stmt_entity_edges, + entity_entity_edges=graph.entity_entity_edges, + dialog_data_list=dialog_data_list, + pipeline_config=pipeline_config, + connector=self._neo4j_connector, + llm_client=self._llm_client, + is_pilot_run=is_pilot_run, + progress_callback=self.progress_callback, + ) + + # Snapshot: 去重后 + snapshot.save_stage( + "7_after_dedup", + { + "entity_nodes": [ + { + "id": e.id, + "name": e.name, + "entity_type": e.entity_type, + "description": e.description, + } + for e in dedup_result.entity_nodes + ], + "entity_entity_edges": [ + { + "source": e.source, + "target": e.target, + "relation_type": e.relation_type, + "statement": e.statement, + } + for e in dedup_result.entity_entity_edges + ], + }, + ) + + # step4: 构造最终结果 + result = ExtractionResult( + dialogue_nodes=graph.dialogue_nodes, + chunk_nodes=graph.chunk_nodes, + statement_nodes=graph.statement_nodes, + entity_nodes=dedup_result.entity_nodes, + perceptual_nodes=graph.perceptual_nodes, + stmt_chunk_edges=graph.stmt_chunk_edges, + stmt_entity_edges=dedup_result.statement_entity_edges, + entity_entity_edges=dedup_result.entity_entity_edges, + perceptual_edges=graph.perceptual_edges, dialog_data_list=dialog_data_list, ) + snapshot.save_summary(result.stats) # TODO 乐力齐 snapshot需要改 + return result + # ────────────────────────────────────────────── # Step 3: 存储 # ────────────────────────────────────────────── @@ -379,14 +518,10 @@ class WritePipeline: ) await asyncio.sleep(1 * (attempt + 1)) else: - logger.error( - f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败" - ) + logger.error(f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败") except Exception as e: if self._is_deadlock(e) and attempt < max_retries - 1: - logger.warning( - f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})" - ) + logger.warning(f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})") await asyncio.sleep(1 * (attempt + 1)) else: raise @@ -401,6 +536,10 @@ class WritePipeline: 聚类不阻塞主写入流程,失败不影响写入结果。 通过 Celery 异步执行,由 LabelPropagationEngine 完成实际计算。 + + 注意:ExtractionResult.entity_nodes 已经是经过 _extract() 中 + 两阶段去重消歧(_run_dedup_and_write_summary)后的结果, + 聚类直接基于去重后的实体 ID 执行。 """ if not result.entity_nodes: return @@ -428,7 +567,9 @@ class WritePipeline: ) logger.info( f"[Clustering] 增量聚类任务已提交 - " - f"task_id={task.id}, entity_count={len(new_entity_ids)}" + f"task_id={task.id}, " + f"entity_count={len(new_entity_ids)}, " + f"source=dedup" ) except Exception as e: logger.error( @@ -438,9 +579,9 @@ class WritePipeline: # ────────────────────────────────────────────── # Step 5: 摘要 - # (+ entity_description) + # (+ entity_description)+ meta_data部分在此提取 # ────────────────────────────────────────────── - +# TODO 乐力齐 需要做成异步celery任务 async def _summarize(self, chunked_dialogs: List[DialogData]) -> None: """ 摘要:生成情景记忆摘要 → 写入 Neo4j。 @@ -467,9 +608,7 @@ class WritePipeline: ms_connector = Neo4jConnector() try: await add_memory_summary_nodes(summaries, ms_connector) - await add_memory_summary_statement_edges( - summaries, ms_connector - ) + await add_memory_summary_statement_edges(summaries, ms_connector) finally: try: await ms_connector.close() @@ -494,9 +633,7 @@ class WritePipeline: with get_db_context() as db: factory = MemoryClientFactory(db) - self._llm_client = factory.get_llm_client_from_config( - self.memory_config - ) + self._llm_client = factory.get_llm_client_from_config(self.memory_config) self._embedder_client = factory.get_embedder_client_from_config( self.memory_config ) @@ -564,10 +701,8 @@ class WritePipeline: if entity_nodes: eu_id = entity_nodes[0].end_user_id if eu_id: - neo4j_assistant_aliases = ( - await fetch_neo4j_assistant_aliases( - self._neo4j_connector, eu_id - ) + neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases( + self._neo4j_connector, eu_id ) clean_cross_role_aliases( entity_nodes, @@ -586,9 +721,7 @@ class WritePipeline: msg = str(e).lower() return "deadlockdetected" in msg or "deadlock" in msg - async def _update_stats_cache( - self, result: ExtractionResult - ) -> None: + async def _update_stats_cache(self, result: ExtractionResult) -> None: """ 将提取统计写入 Redis 活动缓存,按 workspace_id 存储。 失败不中断主流程。 @@ -614,9 +747,7 @@ class WritePipeline: f"workspace_id={self.memory_config.workspace_id}" ) except Exception as e: - logger.warning( - f"写入活动统计缓存失败(不影响主流程): {e}" - ) + logger.warning(f"写入活动统计缓存失败(不影响主流程): {e}") async def _cleanup(self) -> None: """ @@ -634,16 +765,14 @@ class WritePipeline: # 防止 'RuntimeError: Event loop is closed' 在垃圾回收时触发 for client_obj in (self._llm_client, self._embedder_client): try: - underlying = getattr( - client_obj, "client", None - ) or getattr(client_obj, "model", None) + underlying = getattr(client_obj, "client", None) or getattr( + client_obj, "model", None + ) if underlying is None: continue inner = getattr(underlying, "_model", underlying) http_client = getattr(inner, "async_client", None) - if http_client is not None and hasattr( - http_client, "aclose" - ): + if http_client is not None and hasattr(http_client, "aclose"): await http_client.aclose() except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/dedup_step.py b/api/app/core/memory/storage_services/extraction_engine/dedup_step.py new file mode 100644 index 00000000..69718803 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/dedup_step.py @@ -0,0 +1,506 @@ +"""Independent deduplication module for the extraction pipeline. + +Extracts dedup logic from ExtractionOrchestrator into standalone functions +so the orchestrator stays thin and dedup can be tested/evolved independently. + +The module exposes: + - ``DedupResult`` — structured output of the dedup process + - ``run_dedup()`` — async entry point called by WritePipeline + - Helper functions migrated from ExtractionOrchestrator: + ``save_dedup_details``, ``analyze_entity_merges``, + ``analyze_entity_disambiguation``, ``send_dedup_progress_callback``, + ``parse_dedup_report`` +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple + +from app.core.memory.models.graph_models import ( + EntityEntityEdge, + ExtractedEntityNode, + StatementEntityEdge, +) +from app.core.memory.models.message_models import DialogData +from app.core.memory.models.variate_config import ExtractionPipelineConfig +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# DedupResult dataclass (Requirement 10.2) +# --------------------------------------------------------------------------- + +@dataclass +class DedupResult: + """Structured output of the two-stage entity deduplication process. + + Attributes: + entity_nodes: Deduplicated entity node list. + statement_entity_edges: Deduplicated statement-entity edges. + entity_entity_edges: Deduplicated entity-entity edges. + dedup_details: Raw detail dict returned by the first-layer dedup. + merge_records: Parsed merge records (exact / fuzzy / LLM). + disamb_records: Parsed disambiguation records. + """ + + entity_nodes: List[ExtractedEntityNode] + statement_entity_edges: List[StatementEntityEdge] + entity_entity_edges: List[EntityEntityEdge] + dedup_details: Dict[str, Any] = field(default_factory=dict) + merge_records: List[Dict[str, Any]] = field(default_factory=list) + disamb_records: List[Dict[str, Any]] = field(default_factory=list) + + @property + def stats(self) -> Dict[str, int]: + """Summary statistics for the dedup run.""" + return { + "entity_count": len(self.entity_nodes), + "merge_count": len(self.merge_records), + "disamb_count": len(self.disamb_records), + } + + +# --------------------------------------------------------------------------- +# Migrated helpers (from ExtractionOrchestrator) — Requirement 10.4 +# --------------------------------------------------------------------------- + + +def save_dedup_details( + dedup_details: Dict[str, Any], + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode], +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, str]]: + """Parse raw *dedup_details* into structured merge / disamb records. + + Returns: + (merge_records, disamb_records, id_redirect_map) + """ + merge_records: List[Dict[str, Any]] = [] + disamb_records: List[Dict[str, Any]] = [] + id_redirect_map: Dict[str, str] = {} + + try: + id_redirect_map = dedup_details.get("id_redirect", {}) + + # --- exact-match merges --- + exact_merge_map = dedup_details.get("exact_merge_map", {}) + for _key, info in exact_merge_map.items(): + merged_ids = info.get("merged_ids", set()) + if merged_ids: + merge_records.append({ + "type": "精确匹配", + "canonical_id": info.get("canonical_id"), + "entity_name": info.get("name"), + "entity_type": info.get("entity_type"), + "merged_count": len(merged_ids), + "merged_ids": list(merged_ids), + }) + + # --- fuzzy-match merges --- + for record in dedup_details.get("fuzzy_merge_records", []): + try: + match = re.search( + r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)", + record, + ) + if match: + merge_records.append({ + "type": "模糊匹配", + "canonical_id": match.group(1), + "entity_name": match.group(3), + "entity_type": match.group(4), + "merged_count": 1, + "merged_ids": [match.group(5)], + }) + except Exception as e: + logger.debug("解析模糊匹配记录失败: %s, 错误: %s", record, e) + + # --- LLM-based merges --- + for record in dedup_details.get("llm_decision_records", []): + if "[LLM去重]" in str(record): + try: + match = re.search( + r"同名类型相似 ([^(]+)(([^)]+))\|([^(]+)(([^)]+))", + record, + ) + if match: + merge_records.append({ + "type": "LLM去重", + "entity_name": match.group(1), + "entity_type": f"{match.group(2)}|{match.group(4)}", + "merged_count": 1, + "merged_ids": [], + }) + except Exception as e: + logger.debug("解析LLM去重记录失败: %s, 错误: %s", record, e) + + # --- disambiguation records --- + for record in dedup_details.get("disamb_records", []): + if "[DISAMB阻断]" in str(record): + try: + content = str(record).replace("[DISAMB阻断]", "").strip() + match = re.search( + r"([^(]+)(([^)]+))\|([^(]+)(([^)]+))", content + ) + if match: + entity1_name = match.group(1).strip() + entity1_type = match.group(2) + entity2_type = match.group(4) + + conf_match = re.search(r"conf=([0-9.]+)", str(record)) + confidence = conf_match.group(1) if conf_match else "unknown" + + reason_match = re.search(r"reason=([^|]+)", str(record)) + reason = reason_match.group(1).strip() if reason_match else "" + + disamb_records.append({ + "entity_name": entity1_name, + "disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}", + "confidence": confidence, + "reason": (reason[:100] + "...") if len(reason) > 100 else reason, + }) + except Exception as e: + logger.debug("解析消歧记录失败: %s, 错误: %s", record, e) + + logger.info( + "保存去重消歧记录:%d 个合并记录,%d 个消歧记录", + len(merge_records), + len(disamb_records), + ) + except Exception as e: + logger.error("保存去重消歧详情失败: %s", e, exc_info=True) + + return merge_records, disamb_records, id_redirect_map + + +def analyze_entity_merges( + merge_records: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return merge info sorted by merged_count (descending).""" + if not merge_records: + return [] + sorted_records = sorted( + merge_records, key=lambda x: x.get("merged_count", 0), reverse=True + ) + return [ + { + "main_entity_name": r.get("entity_name", "未知实体"), + "merged_count": r.get("merged_count", 1), + } + for r in sorted_records + ] + + +def analyze_entity_disambiguation( + disamb_records: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return disambiguation records (pass-through).""" + return disamb_records if disamb_records else [] + + +def parse_dedup_report( + merge_records: List[Dict[str, Any]], + disamb_records: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Build a summary report dict from parsed records.""" + try: + dedup_examples: List[Dict[str, Any]] = [] + disamb_examples: List[Dict[str, Any]] = [] + total_merges = 0 + total_disambiguations = 0 + + for record in merge_records: + merge_count = record.get("merged_count", 0) + total_merges += merge_count + dedup_examples.append({ + "type": record.get("type", "未知"), + "entity_name": record.get("entity_name", "未知实体"), + "entity_type": record.get("entity_type", "未知类型"), + "merge_count": merge_count, + "description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个", + }) + + for record in disamb_records: + total_disambiguations += 1 + disamb_type = record.get("disamb_type", "") + entity_name = record.get("entity_name", "未知实体") + disamb_examples.append({ + "entity1_name": entity_name, + "entity1_type": ( + disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() + if "vs" in disamb_type + else "未知" + ), + "entity2_name": entity_name, + "entity2_type": ( + disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知" + ), + "description": f"{entity_name},消歧区分成功", + }) + + return { + "dedup_examples": dedup_examples[:5], + "disamb_examples": disamb_examples[:5], + "total_merges": total_merges, + "total_disambiguations": total_disambiguations, + } + except Exception as e: + logger.error("获取去重报告失败: %s", e, exc_info=True) + return { + "dedup_examples": [], + "disamb_examples": [], + "total_merges": 0, + "total_disambiguations": 0, + } + + +async def send_dedup_progress_callback( + progress_callback: Callable, + merge_records: List[Dict[str, Any]], + disamb_records: List[Dict[str, Any]], + original_entities: int, + final_entities: int, + original_stmt_edges: int, + final_stmt_edges: int, + original_ent_edges: int, + final_ent_edges: int, +) -> None: + """Send dedup completion progress via *progress_callback*.""" + try: + dedup_details = parse_dedup_report(merge_records, disamb_records) + + entities_reduced = original_entities - final_entities + stmt_edges_reduced = original_stmt_edges - final_stmt_edges + ent_edges_reduced = original_ent_edges - final_ent_edges + + dedup_stats = { + "entities": { + "original_count": original_entities, + "final_count": final_entities, + "reduced_count": entities_reduced, + "reduction_rate": ( + round(entities_reduced / original_entities * 100, 1) + if original_entities > 0 + else 0 + ), + }, + "statement_entity_edges": { + "original_count": original_stmt_edges, + "final_count": final_stmt_edges, + "reduced_count": stmt_edges_reduced, + }, + "entity_entity_edges": { + "original_count": original_ent_edges, + "final_count": final_ent_edges, + "reduced_count": ent_edges_reduced, + }, + "dedup_examples": dedup_details.get("dedup_examples", []), + "disamb_examples": dedup_details.get("disamb_examples", []), + "summary": { + "total_merges": dedup_details.get("total_merges", 0), + "total_disambiguations": dedup_details.get("total_disambiguations", 0), + }, + } + + await progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats) + except Exception as e: + logger.error("发送去重消歧进度回调失败: %s", e, exc_info=True) + try: + basic_stats = { + "entities": { + "original_count": original_entities, + "final_count": final_entities, + "reduced_count": original_entities - final_entities, + }, + "summary": f"实体去重合并{original_entities - final_entities}个", + } + await progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats) + except Exception as e2: + logger.error("发送基本去重统计失败: %s", e2, exc_info=True) + + +# --------------------------------------------------------------------------- +# run_dedup — main entry point (Requirements 10.1, 10.3) +# --------------------------------------------------------------------------- + + +async def run_dedup( + entity_nodes: List[ExtractedEntityNode], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + pipeline_config: ExtractionPipelineConfig, + connector: Optional[Neo4jConnector] = None, + llm_client: Optional[Any] = None, + is_pilot_run: bool = False, + progress_callback: Optional[Callable] = None, +) -> DedupResult: + """Two-stage entity deduplication and disambiguation. + + Full mode: + Layer 1 — exact / fuzzy / LLM matching + Layer 2 — Neo4j joint dedup + cross-role alias cleaning + + Pilot-run mode: + Layer 1 only (skip Neo4j layer 2 and alias cleaning). + + Args: + entity_nodes: Pre-dedup entity nodes. + statement_entity_edges: Pre-dedup statement-entity edges. + entity_entity_edges: Pre-dedup entity-entity edges. + dialog_data_list: Source dialogue data (used to detect end_user_id). + pipeline_config: Pipeline configuration (contains DedupConfig). + connector: Optional Neo4j connector for layer-2 dedup. + llm_client: Optional LLM client for LLM-based dedup decisions. + is_pilot_run: When True, only execute layer-1 dedup. + progress_callback: Optional async callable for progress reporting. + + Returns: + A ``DedupResult`` with deduplicated nodes, edges, and statistics. + """ + logger.info("开始两阶段实体去重和消歧") + + if progress_callback: + await progress_callback("deduplication", "正在去重消歧...") + + logger.info( + "去重前: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边", + len(entity_nodes), + len(statement_entity_edges), + len(entity_entity_edges), + ) + + original_entity_count = len(entity_nodes) + original_stmt_edge_count = len(statement_entity_edges) + original_ent_edge_count = len(entity_entity_edges) + + try: + if is_pilot_run: + # --- pilot run: layer 1 only --- + logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重") + from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( + deduplicate_entities_and_edges, + ) + + ( + dedup_entity_nodes, + dedup_stmt_edges, + dedup_ent_edges, + raw_details, + ) = await deduplicate_entities_and_edges( + entity_nodes, + statement_entity_edges, + entity_entity_edges, + report_stage="第一层去重消歧(试运行)", + report_append=False, + dedup_config=pipeline_config.deduplication, + llm_client=llm_client, + ) + + final_entities = dedup_entity_nodes + final_stmt_edges = dedup_stmt_edges + final_ent_edges = dedup_ent_edges + else: + # --- full mode: two-stage dedup --- + from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( + dedup_layers_and_merge_and_return, + ) + + ( + _dialogue_nodes, + _chunk_nodes, + _statement_nodes, + final_entities, + _statement_chunk_edges, + final_stmt_edges, + final_ent_edges, + raw_details, + ) = await dedup_layers_and_merge_and_return( + dialogue_nodes=[], + chunk_nodes=[], + statement_nodes=[], + entity_nodes=entity_nodes, + statement_chunk_edges=[], + statement_entity_edges=statement_entity_edges, + entity_entity_edges=entity_entity_edges, + dialog_data_list=dialog_data_list, + pipeline_config=pipeline_config, + connector=connector, + llm_client=llm_client, + ) + + # Parse raw details into structured records + merge_records, disamb_records, _id_redirect = save_dedup_details( + raw_details, entity_nodes, final_entities + ) + + logger.info( + "去重后: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边", + len(final_entities), + len(final_stmt_edges), + len(final_ent_edges), + ) + logger.info( + "去重效果: 实体减少 %d, 陈述句-实体边减少 %d, 实体-实体边减少 %d", + original_entity_count - len(final_entities), + original_stmt_edge_count - len(final_stmt_edges), + original_ent_edge_count - len(final_ent_edges), + ) + + # --- progress callbacks --- + if progress_callback: + merge_info = analyze_entity_merges(merge_records) + for i, detail in enumerate(merge_info[:5]): + dedup_result = { + "result_type": "entity_merge", + "merged_entity_name": detail["main_entity_name"], + "merged_count": detail["merged_count"], + "merge_progress": f"{i + 1}/{min(len(merge_info), 5)}", + "message": ( + f"{detail['main_entity_name']}合并{detail['merged_count']}个:相似实体已合并" + ), + } + await progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result) + + disamb_info = analyze_entity_disambiguation(disamb_records) + for i, detail in enumerate(disamb_info[:5]): + disamb_result = { + "result_type": "entity_disambiguation", + "disambiguated_entity_name": detail["entity_name"], + "disambiguation_type": detail["disamb_type"], + "confidence": detail.get("confidence", "unknown"), + "reason": detail.get("reason", ""), + "disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}", + "message": f"{detail['entity_name']}消歧完成:{detail['disamb_type']}", + } + await progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result) + + await send_dedup_progress_callback( + progress_callback, + merge_records, + disamb_records, + original_entity_count, + len(final_entities), + original_stmt_edge_count, + len(final_stmt_edges), + original_ent_edge_count, + len(final_ent_edges), + ) + + return DedupResult( + entity_nodes=final_entities, + statement_entity_edges=final_stmt_edges, + entity_entity_edges=final_ent_edges, + dedup_details=raw_details, + merge_records=merge_records, + disamb_records=disamb_records, + ) + + except Exception as e: + logger.error("两阶段去重失败: %s", e, exc_info=True) + raise diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/__init__.py b/api/app/core/memory/storage_services/extraction_engine/steps/__init__.py new file mode 100644 index 00000000..63a8ec77 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/__init__.py @@ -0,0 +1,16 @@ +"""Extraction pipeline steps — unified ExtractionStep paradigm. + +Importing this package triggers @register decorator self-registration +for all sidecar (non-critical) steps via SidecarStepFactory. +""" + +from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401 + +# Step implementations — importing triggers @register self-registration. +from .statement_step import StatementExtractionStep # noqa: F401 +from .triplet_step import TripletExtractionStep # noqa: F401 +from .emotion_step import EmotionExtractionStep # noqa: F401 +from .embedding_step import EmbeddingStep # noqa: F401 + +# Refactored orchestrator +from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401 diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/base.py b/api/app/core/memory/storage_services/extraction_engine/steps/base.py new file mode 100644 index 00000000..9d234a97 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/base.py @@ -0,0 +1,182 @@ +"""ExtractionStep abstract base class and StepContext. + +Provides the unified paradigm for all LLM extraction stages: + render_prompt → call_llm → parse_response → post_process + +Critical steps retry on failure with exponential backoff. +Sidecar (non-critical) steps return a default output on failure without retry. +""" + +import asyncio +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Optional, TypeVar + +logger = logging.getLogger(__name__) + +InputT = TypeVar("InputT") +OutputT = TypeVar("OutputT") + + +@dataclass +class StepContext: + """Shared context injected into every ExtractionStep by the orchestrator. + + Attributes: + llm_client: LLM client instance for generating completions. + language: Target language code (e.g. "en", "zh"). + config: Pipeline configuration object (ExtractionPipelineConfig). + is_pilot_run: When True, run in lightweight preview mode. + progress_callback: Optional callable for reporting progress. + """ + + llm_client: Any + language: str + config: Any + is_pilot_run: bool = False + progress_callback: Optional[Any] = None + + +class ExtractionStep(ABC, Generic[InputT, OutputT]): + """Abstract base class for all LLM extraction stages. + + Lifecycle: + 1. ``__init__(context)`` — receive shared context, bind config params + 2. ``should_skip()`` — check whether to skip (config-driven / pilot mode) + 3. ``run(input_data)`` — execute full flow (with retry for critical steps) + Internally: render_prompt → call_llm → parse_response → post_process + 4. ``on_failure(error)`` — critical steps raise; sidecar steps return default + + Type Parameters: + InputT: The Pydantic model type accepted by this step. + OutputT: The Pydantic model type produced by this step. + """ + + def __init__(self, context: StepContext) -> None: + self.context = context + self.llm_client = context.llm_client + self.language = context.language + self.config = context.config + + # ── Subclasses must implement ── + + @property + @abstractmethod + def name(self) -> str: + """Human-readable step name for logging.""" + ... + + @abstractmethod + async def render_prompt(self, input_data: InputT) -> Any: + """Build the prompt from *input_data* and bound config.""" + ... + + @abstractmethod + async def call_llm(self, prompt: Any) -> Any: + """Send *prompt* to the LLM and return the raw response.""" + ... + + @abstractmethod + async def parse_response(self, raw_response: Any, input_data: InputT) -> OutputT: + """Parse *raw_response* into a typed OutputT (Pydantic model).""" + ... + + @abstractmethod + def get_default_output(self) -> OutputT: + """Return a safe default when the step is skipped or fails gracefully.""" + ... + + # ── Overridable properties ── + + @property + def is_critical(self) -> bool: + """``True`` = critical step (failure aborts pipeline). + + ``False`` = sidecar step (failure degrades gracefully). + """ + return True + + @property + def max_retries(self) -> int: + """Maximum retry attempts (only effective for critical steps).""" + return 2 + + @property + def retry_backoff_base(self) -> float: + """Backoff base in seconds. Actual wait = base × 2^attempt.""" + return 1.0 + + # ── Overridable hooks ── + + def should_skip(self) -> bool: + """Config-driven skip check. Subclasses may override.""" + return False + + async def post_process(self, parsed_data: OutputT, input_data: InputT) -> OutputT: + """Post-processing hook. Default is identity (returns *parsed_data* unchanged).""" + return parsed_data + + # ── Core execution logic ── + + async def run(self, input_data: InputT) -> OutputT: + """Execute the full step lifecycle with retry logic. + + For critical steps (``is_critical=True``): + Attempt up to ``max_retries + 1`` times with exponential backoff. + If all attempts fail, delegate to ``on_failure`` which raises. + + For sidecar steps (``is_critical=False``): + Attempt exactly once. On failure, delegate to ``on_failure`` + which returns ``get_default_output()``. + """ + if self.should_skip(): + logger.info("Step '%s' skipped", self.name) + return self.get_default_output() + + last_error: Optional[Exception] = None + attempts = self.max_retries + 1 if self.is_critical else 1 + + for attempt in range(attempts): + try: + prompt = await self.render_prompt(input_data) + raw_response = await self.call_llm(prompt) + parsed = await self.parse_response(raw_response, input_data) + result = await self.post_process(parsed, input_data) + return result + except Exception as exc: + last_error = exc + logger.warning( + "Step '%s' attempt %d/%d failed: %s", + self.name, + attempt + 1, + attempts, + exc, + ) + if attempt < attempts - 1: + wait = self.retry_backoff_base * (2 ** attempt) + logger.info( + "Step '%s' retrying in %.1fs …", self.name, wait + ) + await asyncio.sleep(wait) + + # All attempts exhausted — delegate to failure handler + return self.on_failure(last_error) # type: ignore[arg-type] + + def on_failure(self, error: Exception) -> OutputT: + """Handle step failure. + + Critical steps: re-raise the exception to abort the pipeline. + Sidecar steps: return ``get_default_output()`` for graceful degradation. + """ + if self.is_critical: + logger.error( + "Critical step '%s' failed after retries: %s", self.name, error + ) + raise error + logger.warning( + "Sidecar step '%s' failed, returning default output: %s", + self.name, + error, + ) + return self.get_default_output() diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/embedding_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/embedding_step.py new file mode 100644 index 00000000..b33ff3f4 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/embedding_step.py @@ -0,0 +1,124 @@ +"""EmbeddingStep — generates vector embeddings for statements, chunks, dialogs, and entities. + +Unlike the LLM-based ExtractionSteps, EmbeddingStep calls an embedder client +rather than an LLM. It still follows the ``should_skip`` / ``run`` / +``get_default_output`` contract so the orchestrator can treat it uniformly. + +Supports **partial** embedding runs — the caller can populate only the fields +it needs (e.g. only ``statement_texts``) and leave the rest empty. +""" + +import asyncio +import logging +from typing import Any, Dict, List + +from .schema import EmbeddingStepInput, EmbeddingStepOutput + +logger = logging.getLogger(__name__) + + +class EmbeddingStep: + """Generate vector embeddings for text inputs. + + This step does **not** inherit from ``ExtractionStep`` because it does not + follow the render_prompt → call_llm → parse_response lifecycle. It does, + however, expose the same ``run`` / ``should_skip`` / ``get_default_output`` + interface so the orchestrator can use it interchangeably. + + Pilot-run mode skips execution entirely and returns empty dicts. + """ + + def __init__( + self, + embedder_client: Any, + is_pilot_run: bool = False, + batch_size: int = 100, + ) -> None: + self.embedder_client = embedder_client + self.is_pilot_run = is_pilot_run + self.batch_size = batch_size + + @property + def name(self) -> str: + return "embedding_generation" + + @property + def is_critical(self) -> bool: + return False + + @property + def max_retries(self) -> int: + return 1 + + @property + def retry_backoff_base(self) -> float: + return 1.0 + + def should_skip(self) -> bool: + return self.is_pilot_run + + def get_default_output(self) -> EmbeddingStepOutput: + return EmbeddingStepOutput() + + # ── Core execution ── + + async def run(self, input_data: EmbeddingStepInput) -> EmbeddingStepOutput: + """Generate embeddings for all non-empty text fields in *input_data*.""" + if self.should_skip(): + logger.info("EmbeddingStep skipped (pilot run)") + return self.get_default_output() + + try: + stmt_emb, chunk_emb, dialog_emb, entity_emb = await asyncio.gather( + self._embed_dict(input_data.statement_texts), + self._embed_dict(input_data.chunk_texts), + self._embed_list(input_data.dialog_texts), + self._embed_dict(input_data.entity_names), + ) + return EmbeddingStepOutput( + statement_embeddings=stmt_emb, + chunk_embeddings=chunk_emb, + dialog_embeddings=dialog_emb, + entity_embeddings=entity_emb, + ) + except Exception as exc: + logger.warning("EmbeddingStep failed, returning empty output: %s", exc) + return self.get_default_output() + + # ── Internal helpers ── + + async def _embed_dict( + self, texts: Dict[str, str] + ) -> Dict[str, List[float]]: + """Embed a dict of ``{id: text}`` and return ``{id: embedding}``.""" + if not texts: + return {} + + ids = list(texts.keys()) + text_list = list(texts.values()) + embeddings = await self._batch_embed(text_list) + + return dict(zip(ids, embeddings)) + + async def _embed_list(self, texts: List[str]) -> List[List[float]]: + """Embed a plain list of texts.""" + if not texts: + return [] + return await self._batch_embed(texts) + + async def _batch_embed(self, texts: List[str]) -> List[List[float]]: + """Call the embedder in batches of ``self.batch_size``.""" + if len(texts) <= self.batch_size: + return await self.embedder_client.response(texts) + + batches = [ + texts[i : i + self.batch_size] + for i in range(0, len(texts), self.batch_size) + ] + batch_results = await asyncio.gather( + *(self.embedder_client.response(b) for b in batches) + ) + embeddings: List[List[float]] = [] + for result in batch_results: + embeddings.extend(result) + return embeddings diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/emotion_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/emotion_step.py new file mode 100644 index 00000000..5dab791d --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/emotion_step.py @@ -0,0 +1,80 @@ +"""EmotionExtractionStep — sidecar step for extracting emotion from statements. + +Replaces the legacy ``EmotionExtractionService`` with the unified ExtractionStep +paradigm. Registered via ``@SidecarStepFactory.register`` so the orchestrator +picks it up automatically when ``emotion_enabled`` is ``True``. +""" + +import logging +from typing import Any + +from app.core.memory.models.emotion_models import EmotionExtraction +from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt + +from .base import ExtractionStep, StepContext +from .sidecar_factory import SidecarStepFactory, SidecarTiming +from .schema import EmotionStepInput, EmotionStepOutput + +logger = logging.getLogger(__name__) + + +@SidecarStepFactory.register("emotion_enabled", SidecarTiming.AFTER_STATEMENT) +class EmotionExtractionStep(ExtractionStep[EmotionStepInput, EmotionStepOutput]): + """Extract emotion type, intensity, and keywords from a statement. + + This is a **sidecar** (non-critical) step — failure returns a neutral + default without aborting the pipeline. + + The step self-registers with ``SidecarStepFactory`` under the config key + ``emotion_enabled`` and timing ``AFTER_STATEMENT``. + """ + + def __init__(self, context: StepContext) -> None: + super().__init__(context) + # Emotion-specific config flags (may live on a MemoryConfig object + # attached to context.config or as top-level attributes). + self.extract_keywords = getattr(self.config, "emotion_extract_keywords", True) + self.enable_subject = getattr(self.config, "emotion_enable_subject", False) + + # ── Identity ── + + @property + def name(self) -> str: + return "emotion_extraction" + + @property + def is_critical(self) -> bool: + return False + + # ── Config-driven skip ── + + def should_skip(self) -> bool: + return not getattr(self.config, "emotion_enabled", False) + + # ── Lifecycle ── + + async def render_prompt(self, input_data: EmotionStepInput) -> str: + return await render_emotion_extraction_prompt( + statement=input_data.statement_text, + extract_keywords=self.extract_keywords, + enable_subject=self.enable_subject, + language=self.language, + ) + + async def call_llm(self, prompt: Any) -> Any: + messages = [{"role": "user", "content": prompt}] + return await self.llm_client.response_structured( + messages, EmotionExtraction + ) + + async def parse_response( + self, raw_response: Any, input_data: EmotionStepInput + ) -> EmotionStepOutput: + return EmotionStepOutput( + emotion_type=getattr(raw_response, "emotion_type", "neutral"), + emotion_intensity=getattr(raw_response, "emotion_intensity", 0.0), + emotion_keywords=getattr(raw_response, "emotion_keywords", []), + ) + + def get_default_output(self) -> EmotionStepOutput: + return EmotionStepOutput() diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py new file mode 100644 index 00000000..4098312f --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/extraction_pipeline_orchestrator.py @@ -0,0 +1,906 @@ +"""Refactored ExtractionOrchestrator using the unified ExtractionStep paradigm. + +This module provides ``NewExtractionOrchestrator`` — a slimmed-down orchestrator +(~500 lines vs ~2500) that delegates extraction work to concrete ExtractionStep +instances and uses SidecarStepFactory for hot-pluggable sidecar modules. + +The new orchestrator coexists with the legacy ``ExtractionOrchestrator`` until +the team explicitly switches over. + +Execution phases: + 1. Statement extraction + concurrent chunk/dialog embedding + 2. Triplet extraction + concurrent after_statement sidecars + statement embedding + 3. Entity embedding + concurrent after_triplet sidecars + 4. Data assignment back to dialog_data_list +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple + +from app.core.memory.models.message_models import DialogData +from app.core.memory.models.variate_config import ExtractionPipelineConfig + +from .base import ExtractionStep, StepContext +from .embedding_step import EmbeddingStep +from .sidecar_factory import SidecarStepFactory, SidecarTiming +from .statement_step import StatementExtractionStep +from .triplet_step import TripletExtractionStep +from .schema import ( + EmbeddingStepInput, + EmbeddingStepOutput, + EmotionStepInput, + EmotionStepOutput, + MessageItem, + StatementStepInput, + StatementStepOutput, + SupportingContext, + TripletStepInput, + TripletStepOutput, +) + +logger = logging.getLogger(__name__) + + +class NewExtractionOrchestrator: + """Slimmed-down extraction orchestrator using the ExtractionStep paradigm. + + Responsibilities: + * Initialise all steps and sidecar groups via ``SidecarStepFactory`` + * Route data between stages (``_convert_to_*`` helpers) + * Orchestrate concurrent execution (``_run_with_sidecars``) + * Assign extracted results back to ``DialogData`` objects + + The orchestrator does **not** own dedup, node/edge creation, or Neo4j writes. + Those remain in ``WritePipeline`` / ``dedup_step``. + """ + + def __init__( + self, + llm_client: Any, + embedder_client: Any, + config: Optional[ExtractionPipelineConfig] = None, + embedding_id: Optional[str] = None, + ontology_types: Any = None, + language: str = "zh", + is_pilot_run: bool = False, + progress_callback: Optional[ + Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]] + ] = None, + ) -> None: + self.config = config or ExtractionPipelineConfig() + self.is_pilot_run = is_pilot_run + self.embedding_id = embedding_id + self.progress_callback = progress_callback + + # Build shared context for all LLM-based steps + self.context = StepContext( + llm_client=llm_client, + language=language, + config=self.config, + is_pilot_run=is_pilot_run, + progress_callback=progress_callback, + ) + + # ── Critical (main-line) steps ── + self.statement_step = StatementExtractionStep(self.context) + self.triplet_step = TripletExtractionStep( + self.context, ontology_types=ontology_types + ) + + # ── Embedding step (non-LLM, separate client) ── + self.embedding_step = EmbeddingStep( + embedder_client=embedder_client, + is_pilot_run=is_pilot_run, + ) + + # ── Sidecar steps (auto-discovered via @register decorator) ── + sidecar_groups = SidecarStepFactory.create_sidecars(self.config, self.context) + self.after_statement_sidecars: List[ExtractionStep] = sidecar_groups[ + SidecarTiming.AFTER_STATEMENT + ] + self.after_triplet_sidecars: List[ExtractionStep] = sidecar_groups[ + SidecarTiming.AFTER_TRIPLET + ] + + logger.info( + "NewExtractionOrchestrator initialised — " + "after_statement sidecars: %d, after_triplet sidecars: %d", + len(self.after_statement_sidecars), + len(self.after_triplet_sidecars), + ) + + # ────────────────────────────────────────────── + # 1. 并发执行引擎 + # 负责主线路 + 旁路的安全并发调度 + # ────────────────────────────────────────────── + + @staticmethod + async def _run_sidecar_safe( + step: ExtractionStep, input_data: Any + ) -> Any: + """Run a sidecar step, returning its default output on failure.""" + try: + return await step.run(input_data) + except Exception as exc: + logger.warning( + "Sidecar '%s' raised during gather — using default output: %s", + step.name, + exc, + ) + return step.get_default_output() + + async def _run_with_sidecars( + self, + critical_coro: Any, + sidecars: List[Tuple[ExtractionStep, Any]], + extra_coros: Optional[List[Any]] = None, + ) -> Tuple[Any, List[Any], List[Any]]: + """Run a critical coroutine concurrently with sidecar steps. + + Args: + critical_coro: The awaitable for the critical (main-line) step. + sidecars: List of ``(step, input_data)`` pairs for sidecar steps. + extra_coros: Additional non-sidecar coroutines to run concurrently + (e.g. embedding generation). + + Returns: + A 3-tuple of: + * The critical step result (exception propagated if it fails). + * A list of sidecar results (default outputs on failure). + * A list of extra coroutine results (empty list if none). + + Raises: + Exception: If the critical coroutine fails, the exception propagates. + """ + sidecar_coros = [ + self._run_sidecar_safe(step, inp) for step, inp in sidecars + ] + extra = extra_coros or [] + + # Gather everything concurrently + all_coros = [critical_coro] + sidecar_coros + extra + results = await asyncio.gather(*all_coros, return_exceptions=True) + + # Unpack: first result is critical, then sidecars, then extras + critical_result = results[0] + n_sidecars = len(sidecar_coros) + sidecar_results = list(results[1 : 1 + n_sidecars]) + extra_results = list(results[1 + n_sidecars :]) + + # Critical step failure → propagate + if isinstance(critical_result, BaseException): + raise critical_result + + # Sidecar failures should already be handled by _run_sidecar_safe, + # but guard against unexpected exceptions from gather + for i, res in enumerate(sidecar_results): + if isinstance(res, BaseException): + step = sidecars[i][0] + logger.warning( + "Sidecar '%s' unexpected exception in gather: %s", + step.name, + res, + ) + sidecar_results[i] = step.get_default_output() + + # Extra coroutine failures → log and replace with None + for i, res in enumerate(extra_results): + if isinstance(res, BaseException): + logger.warning("Extra coroutine %d failed: %s", i, res) + extra_results[i] = None + + return critical_result, sidecar_results, extra_results + + # ────────────────────────────────────────────── + # 2. 阶段间数据转换 + # 将上一阶段的 StepOutput 转换为下一阶段的 StepInput + # ────────────────────────────────────────────── + + @staticmethod + def _build_supporting_context( + dialog: DialogData, + ) -> SupportingContext: + """Build a SupportingContext from a dialog's content for pronoun resolution.""" + msgs: List[MessageItem] = [] + if hasattr(dialog, "content") and dialog.content: + # dialog.content is the raw conversation string; wrap as single msg + msgs.append(MessageItem(role="context", msg=dialog.content)) + return SupportingContext(msgs=msgs) + + @staticmethod + def _convert_to_triplet_input( + stmt_out: StatementStepOutput, + supporting_context: SupportingContext, + ) -> TripletStepInput: + """Convert a StatementStepOutput into a TripletStepInput.""" + return TripletStepInput( + statement_id=stmt_out.statement_id, + statement_text=stmt_out.statement_text, + statement_type=stmt_out.statement_type, + temporal_type=stmt_out.temporal_type, + supporting_context=supporting_context, + speaker=stmt_out.speaker, + valid_at=stmt_out.valid_at, + invalid_at=stmt_out.invalid_at, + ) + + @staticmethod + def _convert_to_emotion_input( + stmt_out: StatementStepOutput, + ) -> EmotionStepInput: + """Convert a StatementStepOutput into an EmotionStepInput.""" + return EmotionStepInput( + statement_id=stmt_out.statement_id, + statement_text=stmt_out.statement_text, + speaker=stmt_out.speaker, + ) + + # ────────────────────────────────────────────── + # 3. 流水线执行入口 + # 公开接口 run() → 分发到 pilot / full 模式 + # ────────────────────────────────────────────── + + async def run( + self, + dialog_data_list: List[DialogData], + ) -> List[DialogData]: + """Run the full extraction pipeline on *dialog_data_list*. + + Returns the mutated *dialog_data_list* with extracted data assigned + to each statement (triplets, temporal info, emotions, embeddings). + + The orchestrator does NOT create graph nodes/edges or run dedup — + those responsibilities remain in WritePipeline. + """ + mode = "pilot" if self.is_pilot_run else "full" + logger.info( + "Starting extraction pipeline (%s mode), %d dialogs", + mode, + len(dialog_data_list), + ) + + if self.is_pilot_run: + return await self._run_pilot(dialog_data_list) + return await self._run_full(dialog_data_list) + + # ── 3a. 试运行模式:仅 statement + triplet,不生成 embedding 和旁路 ── + + async def _run_pilot( + self, dialog_data_list: List[DialogData] + ) -> List[DialogData]: + """Pilot mode: statement + triplet extraction only, no sidecars or embeddings.""" + # Phase 1: Statement extraction (chunk-level parallel) + logger.info("Pilot phase 1/2: Statement extraction") + all_stmt_results = await self._extract_all_statements(dialog_data_list) + + # Phase 2: Triplet extraction (statement-level parallel) + logger.info("Pilot phase 2/2: Triplet extraction") + all_triplet_results = await self._extract_all_triplets( + dialog_data_list, all_stmt_results + ) + + # Assign results back to dialog_data_list + self._assign_results( + dialog_data_list, + all_stmt_results, + all_triplet_results, + emotion_results={}, + embedding_output=None, + ) + + # Store raw step outputs for snapshot/debugging + self._last_stage_outputs = { + "statement_results": all_stmt_results, + "triplet_results": all_triplet_results, + "emotion_results": {}, + "embedding_output": None, + } + + logger.info("Pilot extraction complete") + return dialog_data_list + + # ── 3b. 正式模式:四阶段并发执行 ── + + async def _run_full( + self, dialog_data_list: List[DialogData] + ) -> List[DialogData]: + """Full mode: all four phases with concurrent sidecars and embeddings.""" + + # ── Phase 1: Statement extraction + chunk/dialog embedding ── + logger.info("Phase 1/4: Statement extraction + chunk/dialog embedding") + chunk_dialog_emb_input = self._build_chunk_dialog_embedding_input( + dialog_data_list + ) + + stmt_coro = self._extract_all_statements(dialog_data_list) + emb_coro = self.embedding_step.run(chunk_dialog_emb_input) + + phase1_results = await asyncio.gather( + stmt_coro, emb_coro, return_exceptions=True + ) + + all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]] = ( + phase1_results[0] + if not isinstance(phase1_results[0], BaseException) + else {} + ) + if isinstance(phase1_results[0], BaseException): + raise phase1_results[0] + + chunk_dialog_emb: Optional[EmbeddingStepOutput] = ( + phase1_results[1] + if not isinstance(phase1_results[1], BaseException) + else None + ) + if isinstance(phase1_results[1], BaseException): + logger.warning("Chunk/dialog embedding failed: %s", phase1_results[1]) + + # ── Phase 2: Triplet extraction + after_statement sidecars + statement embedding ── + logger.info( + "Phase 2/4: Triplet extraction + sidecars + statement embedding" + ) + stmt_emb_input = self._build_statement_embedding_input( + dialog_data_list, all_stmt_results + ) + + # Build sidecar inputs for after_statement sidecars (e.g. emotion) + sidecar_pairs = self._build_after_statement_sidecar_inputs( + dialog_data_list, all_stmt_results + ) + + triplet_coro = self._extract_all_triplets( + dialog_data_list, all_stmt_results + ) + stmt_emb_coro = self.embedding_step.run(stmt_emb_input) + + triplet_results, sidecar_results, extra_results = ( + await self._run_with_sidecars( + triplet_coro, + sidecar_pairs, + extra_coros=[stmt_emb_coro], + ) + ) + all_triplet_results = triplet_results + stmt_emb: Optional[EmbeddingStepOutput] = ( + extra_results[0] if extra_results else None + ) + + # Collect sidecar outputs keyed by step name + sidecar_steps = [step for step, _inp in sidecar_pairs] + sidecar_output_map = self._collect_sidecar_outputs( + sidecar_steps, sidecar_results + ) + + # ── Phase 3: Entity embedding + after_triplet sidecars ── + logger.info("Phase 3/4: Entity embedding + after_triplet sidecars") + entity_emb_input = self._build_entity_embedding_input(all_triplet_results) + + after_triplet_pairs: List[Tuple[ExtractionStep, Any]] = [] + # Future after_triplet sidecars would be wired here + + entity_emb_coro = self.embedding_step.run(entity_emb_input) + + if after_triplet_pairs: + _, at_sidecar_results, at_extra = await self._run_with_sidecars( + entity_emb_coro, + after_triplet_pairs, + ) + entity_emb = at_extra[0] if at_extra else None + else: + # No after_triplet sidecars — just run embedding + entity_emb_result = await entity_emb_coro + entity_emb = ( + entity_emb_result + if not isinstance(entity_emb_result, BaseException) + else None + ) + + # Merge all embedding outputs + merged_emb = self._merge_embeddings(chunk_dialog_emb, stmt_emb, entity_emb) + + # ── Phase 4: Data assignment ── + logger.info("Phase 4/4: Data assignment") + emotion_results = sidecar_output_map.get("emotion_extraction", {}) + + self._assign_results( + dialog_data_list, + all_stmt_results, + all_triplet_results, + emotion_results=emotion_results, + embedding_output=merged_emb, + ) + + # Store raw step outputs for snapshot/debugging + self._last_stage_outputs = { + "statement_results": all_stmt_results, + "triplet_results": all_triplet_results, + "emotion_results": emotion_results, + "embedding_output": merged_emb, + } + + logger.info("Full extraction pipeline complete") + return dialog_data_list + + @property + def last_stage_outputs(self) -> Dict[str, Any]: + """Return the raw step outputs from the last run for snapshot/debugging.""" + return getattr(self, "_last_stage_outputs", {}) + + # ────────────────────────────────────────────── + # 4. 萃取执行器 + # chunk 级并行 statement 提取、statement 级并行 triplet 提取 + # ────────────────────────────────────────────── + + async def _extract_all_statements( + self, + dialog_data_list: List[DialogData], + ) -> Dict[str, Dict[str, List[StatementStepOutput]]]: + """Extract statements from all chunks across all dialogs (chunk-level parallel). + + Returns: + Nested dict: ``{dialog_id: {chunk_id: [StatementStepOutput, ...]}}`` + """ + # Collect all (chunk, metadata) pairs + tasks: List[Any] = [] + task_meta: List[Tuple[str, str, str, SupportingContext]] = [] + + for dialog in dialog_data_list: + ctx = self._build_supporting_context(dialog) + dialogue_content = ( + dialog.content + if getattr( + self.config, "statement_extraction", None + ) + and getattr( + self.config.statement_extraction, + "include_dialogue_context", + True, + ) + else None + ) + for chunk in dialog.chunks: + inp = StatementStepInput( + chunk_id=chunk.id, + end_user_id=dialog.end_user_id, + target_content=chunk.content, + target_message_date=str( + getattr(dialog, "created_at", "") or "" + ), + supporting_context=ctx, + ) + tasks.append(self.statement_step.run(inp)) + task_meta.append( + (dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx) + ) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Organise into nested dict + stmt_map: Dict[str, Dict[str, List[StatementStepOutput]]] = {} + for i, result in enumerate(results): + dialog_id, chunk_id, speaker, _ = task_meta[i] + if dialog_id not in stmt_map: + stmt_map[dialog_id] = {} + + if isinstance(result, BaseException): + logger.error("Statement extraction failed for chunk %s: %s", chunk_id, result) + stmt_map[dialog_id][chunk_id] = [] + else: + # Override speaker from chunk metadata + stmts: List[StatementStepOutput] = result if isinstance(result, list) else [] + for s in stmts: + s.speaker = speaker + stmt_map[dialog_id][chunk_id] = stmts + + return stmt_map + + async def _extract_all_triplets( + self, + dialog_data_list: List[DialogData], + all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]], + ) -> Dict[str, Dict[str, TripletStepOutput]]: + """Extract triplets for every statement (statement-level parallel). + + Returns: + Nested dict: ``{dialog_id: {statement_id: TripletStepOutput}}`` + """ + tasks: List[Any] = [] + task_meta: List[Tuple[str, str]] = [] # (dialog_id, statement_id) + + for dialog in dialog_data_list: + ctx = self._build_supporting_context(dialog) + chunk_stmts = all_stmt_results.get(dialog.id, {}) + for _chunk_id, stmts in chunk_stmts.items(): + for stmt in stmts: + inp = self._convert_to_triplet_input(stmt, ctx) + tasks.append(self.triplet_step.run(inp)) + task_meta.append((dialog.id, stmt.statement_id)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + triplet_map: Dict[str, Dict[str, TripletStepOutput]] = {} + for i, result in enumerate(results): + dialog_id, stmt_id = task_meta[i] + if dialog_id not in triplet_map: + triplet_map[dialog_id] = {} + + if isinstance(result, BaseException): + logger.error( + "Triplet extraction failed for statement %s: %s", + stmt_id, + result, + ) + triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output() + else: + triplet_map[dialog_id][stmt_id] = result + + return triplet_map + + # ────────────────────────────────────────────── + # 5. Embedding 输入构建器 + # 为不同阶段构建 EmbeddingStepInput(chunk/statement/entity) + # ────────────────────────────────────────────── + + @staticmethod + def _build_chunk_dialog_embedding_input( + dialog_data_list: List[DialogData], + ) -> EmbeddingStepInput: + """Build embedding input for chunks and dialogs (phase 1).""" + chunk_texts: Dict[str, str] = {} + dialog_texts: List[str] = [] + + for dialog in dialog_data_list: + if hasattr(dialog, "content") and dialog.content: + dialog_texts.append(dialog.content) + for chunk in dialog.chunks: + chunk_texts[chunk.id] = chunk.content + + return EmbeddingStepInput( + chunk_texts=chunk_texts, + dialog_texts=dialog_texts, + ) + + @staticmethod + def _build_statement_embedding_input( + dialog_data_list: List[DialogData], + all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]], + ) -> EmbeddingStepInput: + """Build embedding input for statements (phase 2).""" + stmt_texts: Dict[str, str] = {} + for _dialog_id, chunk_stmts in all_stmt_results.items(): + for _chunk_id, stmts in chunk_stmts.items(): + for s in stmts: + stmt_texts[s.statement_id] = s.statement_text + return EmbeddingStepInput(statement_texts=stmt_texts) + + @staticmethod + def _build_entity_embedding_input( + all_triplet_results: Dict[str, Dict[str, TripletStepOutput]], + ) -> EmbeddingStepInput: + """Build embedding input for entities (phase 3).""" + entity_names: Dict[str, str] = {} + entity_descs: Dict[str, str] = {} + seen: set = set() + + for _dialog_id, stmt_triplets in all_triplet_results.items(): + for _stmt_id, triplet_out in stmt_triplets.items(): + for ent in triplet_out.entities: + key = f"{ent.entity_idx}_{ent.name}" + if key not in seen: + seen.add(key) + entity_names[key] = ent.name + entity_descs[key] = ent.description + + return EmbeddingStepInput( + entity_names=entity_names, + entity_descriptions=entity_descs, + ) + + # ────────────────────────────────────────────── + # 6. 旁路输入构建与结果收集 + # 为 after_statement / after_triplet 旁路构建输入,合并 embedding 输出 + # ────────────────────────────────────────────── + + def _build_after_statement_sidecar_inputs( + self, + dialog_data_list: List[DialogData], + all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]], + ) -> List[Tuple[ExtractionStep, Any]]: + """Build (step, input) pairs for after_statement sidecars. + + For emotion extraction, we create a batch wrapper that runs the sidecar + on every user statement concurrently and returns a dict of results. + """ + if not self.after_statement_sidecars: + return [] + + # Collect all user statements for sidecar processing + all_user_stmts: List[StatementStepOutput] = [] + for _dialog_id, chunk_stmts in all_stmt_results.items(): + for _chunk_id, stmts in chunk_stmts.items(): + for s in stmts: + if s.speaker == "user": + all_user_stmts.append(s) + + pairs: List[Tuple[ExtractionStep, Any]] = [] + for sidecar in self.after_statement_sidecars: + if sidecar.name == "emotion_extraction": + # Emotion sidecar: wrap as batch coroutine via a sentinel input + # The actual per-statement calls happen inside _run_emotion_batch + pairs.append(( + _EmotionBatchWrapper(sidecar, all_user_stmts), + EmotionStepInput( + statement_id="__batch__", + statement_text="", + speaker="", + ), + )) + else: + # Generic sidecar: pass first statement as representative input + if all_user_stmts: + inp = self._convert_to_emotion_input(all_user_stmts[0]) + pairs.append((sidecar, inp)) + + return pairs + + @staticmethod + def _collect_sidecar_outputs( + sidecars: List[ExtractionStep], + results: List[Any], + ) -> Dict[str, Any]: + """Map sidecar results by step name.""" + output: Dict[str, Any] = {} + for i, sidecar in enumerate(sidecars): + if i < len(results): + output[sidecar.name] = results[i] + return output + + @staticmethod + def _merge_embeddings( + chunk_dialog: Optional[EmbeddingStepOutput], + statement: Optional[EmbeddingStepOutput], + entity: Optional[Any], + ) -> Optional[EmbeddingStepOutput]: + """Merge partial embedding outputs into a single EmbeddingStepOutput.""" + merged = EmbeddingStepOutput() + if chunk_dialog: + merged.chunk_embeddings = chunk_dialog.chunk_embeddings + merged.dialog_embeddings = chunk_dialog.dialog_embeddings + if statement: + merged.statement_embeddings = statement.statement_embeddings + if entity and isinstance(entity, EmbeddingStepOutput): + merged.entity_embeddings = entity.entity_embeddings + return merged + + # ────────────────────────────────────────────── + # 7. 数据赋值 + # 将各阶段 StepOutput 组装为 Statement 对象,替换 chunk.statements + # ────────────────────────────────────────────── + + def _assign_results( + self, + dialog_data_list: List[DialogData], + all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]], + all_triplet_results: Dict[str, Dict[str, TripletStepOutput]], + emotion_results: Dict[str, EmotionStepOutput], + embedding_output: Optional[EmbeddingStepOutput], + ) -> None: + """Assign extraction results back to dialog_data_list in-place. + + Replaces chunk.statements with new Statement objects built from step + outputs, because the new orchestrator generates its own statement IDs + that don't match the original chunk statement IDs. + """ + from app.core.memory.models.message_models import ( + Statement, + TemporalValidityRange, + ) + from app.core.memory.models.triplet_models import ( + TripletExtractionResponse, + Entity as TripletEntity, + Triplet as TripletRelation, + ) + from app.core.memory.utils.data.ontology import ( + RelevenceInfo, + StatementType, + TemporalInfo, + ) + + # Map string values to enums + _STMT_TYPE_MAP = { + "FACT": StatementType.FACT, + "OPINION": StatementType.OPINION, + "PREDICTION": StatementType.PREDICTION, + "SUGGESTION": StatementType.SUGGESTION, + } + _TEMPORAL_MAP = { + "STATIC": TemporalInfo.STATIC, + "DYNAMIC": TemporalInfo.DYNAMIC, + "ATEMPORAL": TemporalInfo.ATEMPORAL, + } + + total_stmts = 0 + assigned_triplets = 0 + assigned_emotions = 0 + assigned_stmt_emb = 0 + assigned_chunk_emb = 0 + assigned_dialog_emb = 0 + + for dialog in dialog_data_list: + dialog_stmts = all_stmt_results.get(dialog.id, {}) + dialog_triplets = all_triplet_results.get(dialog.id, {}) + + # Assign dialog embedding + if embedding_output and embedding_output.dialog_embeddings: + idx = dialog_data_list.index(dialog) + if idx < len(embedding_output.dialog_embeddings): + dialog.dialog_embedding = embedding_output.dialog_embeddings[idx] + assigned_dialog_emb += 1 + + for chunk in dialog.chunks: + # Assign chunk embedding + if embedding_output and chunk.id in embedding_output.chunk_embeddings: + chunk.chunk_embedding = embedding_output.chunk_embeddings[chunk.id] + assigned_chunk_emb += 1 + + # Build new Statement objects from step outputs + chunk_stmt_outputs = dialog_stmts.get(chunk.id, []) + new_statements = [] + + for stmt_out in chunk_stmt_outputs: + total_stmts += 1 + + # Temporal validity + valid_at = stmt_out.valid_at if stmt_out.valid_at != "NULL" else None + invalid_at = stmt_out.invalid_at if stmt_out.invalid_at != "NULL" else None + + # Triplet info + triplet_info = None + triplet_out = dialog_triplets.get(stmt_out.statement_id) + if triplet_out and (triplet_out.entities or triplet_out.triplets): + entities = [ + TripletEntity( + entity_idx=e.entity_idx, + name=e.name, + type=e.type, + description=e.description, + is_explicit_memory=e.is_explicit_memory, + ) + for e in triplet_out.entities + ] + triplets = [ + TripletRelation( + subject_name=t.subject_name, + subject_id=t.subject_id, + predicate=t.predicate, + object_name=t.object_name, + object_id=t.object_id, + ) + for t in triplet_out.triplets + ] + triplet_info = TripletExtractionResponse( + entities=entities, triplets=triplets, + ) + assigned_triplets += 1 + + # Emotion info + emo = emotion_results.get(stmt_out.statement_id) + emotion_kwargs = {} + if emo: + emotion_kwargs = { + "emotion_type": emo.emotion_type, + "emotion_intensity": emo.emotion_intensity, + "emotion_keywords": emo.emotion_keywords, + } + assigned_emotions += 1 + + # Statement embedding + stmt_embedding = None + if ( + embedding_output + and stmt_out.statement_id in embedding_output.statement_embeddings + ): + stmt_embedding = embedding_output.statement_embeddings[stmt_out.statement_id] + assigned_stmt_emb += 1 + + # Build the Statement object that _create_nodes_and_edges expects + stmt = Statement( + id=stmt_out.statement_id, + chunk_id=chunk.id, + end_user_id=dialog.end_user_id, + statement=stmt_out.statement_text, + speaker=stmt_out.speaker, + stmt_type=_STMT_TYPE_MAP.get(stmt_out.statement_type, StatementType.FACT), + temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL), + relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT, + temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at), + triplet_extraction_info=triplet_info, + statement_embedding=stmt_embedding, + **emotion_kwargs, + ) + new_statements.append(stmt) + + # Replace chunk.statements with newly built objects + chunk.statements = new_statements + + logger.info( + "Data assignment complete — statements: %d, triplets: %d, " + "emotions: %d, stmt_emb: %d, chunk_emb: %d, dialog_emb: %d", + total_stmts, + assigned_triplets, + assigned_emotions, + assigned_stmt_emb, + assigned_chunk_emb, + assigned_dialog_emb, + ) + + +class _EmotionBatchWrapper(ExtractionStep): + """情绪批量提取包装器。再考虑一下用法,这是子类? + + 将单条情绪旁路 Step 包装为批量并发执行,适配 ``_run_with_sidecars`` 接口。 + 编排器传入一个 sentinel input,``run()`` 忽略它,转而对预收集的 statement 列表 + 逐条并发调用内部 Step,返回 ``{statement_id: EmotionStepOutput}`` 字典。 + """ + + # ── 初始化 ── + + def __init__( + self, + inner_step: ExtractionStep, + statements: List[StatementStepOutput], + ) -> None: + # 不调用 super().__init__() — 本类是薄包装,不需要 StepContext + self._inner = inner_step + self._statements = statements + + # ── Step 身份属性(满足 ExtractionStep 抽象接口) ── + + @property + def name(self) -> str: + return self._inner.name + + @property + def is_critical(self) -> bool: + return False + + def get_default_output(self) -> Dict[str, EmotionStepOutput]: + return {} + + # ── 未使用的生命周期方法(批量包装器不走 render→call→parse 流程) ── + + async def render_prompt(self, input_data: Any) -> Any: + raise NotImplementedError + + async def call_llm(self, prompt: Any) -> Any: + raise NotImplementedError + + async def parse_response(self, raw_response: Any, input_data: Any) -> Any: + raise NotImplementedError + + # ── 批量执行入口 ── + + async def run(self, input_data: Any) -> Dict[str, EmotionStepOutput]: + """对所有预收集的 statement 并发执行情绪提取,单条失败返回默认值。""" + if not self._statements: + return {} + + async def _extract_one(stmt: StatementStepOutput) -> Tuple[str, EmotionStepOutput]: + inp = EmotionStepInput( + statement_id=stmt.statement_id, + statement_text=stmt.statement_text, + speaker=stmt.speaker, + ) + try: + result = await self._inner.run(inp) + return stmt.statement_id, result + except Exception: + return stmt.statement_id, self._inner.get_default_output() + + pairs = await asyncio.gather( + *[_extract_one(s) for s in self._statements] + ) + return dict(pairs) diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/graph_build_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/graph_build_step.py new file mode 100644 index 00000000..5da5f5ab --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/graph_build_step.py @@ -0,0 +1,366 @@ +""" +GraphBuildStep — 从 DialogData 构建 Neo4j 图节点和边。 + +职责: +- 遍历 DialogData 列表,构建 DialogueNode、ChunkNode、StatementNode、 + ExtractedEntityNode、PerceptualNode 及各类 Edge +- 不涉及 LLM 调用、去重、Neo4j 写入 + +依赖: +- embedder_client(可选):为 PerceptualNode 生成 summary embedding +- progress_callback(可选):流式输出关系创建进度 + +从 ExtractionOrchestrator._create_nodes_and_edges() 提取而来, +旧编排器保留原方法不变,新旧流水线完全隔离。 +""" +from __future__ import annotations + +import logging +from typing import Any, Awaitable, Callable, Dict, List, Optional + +from app.core.memory.models.graph_models import ( + ChunkNode, + DialogueNode, + EntityEntityEdge, + ExtractedEntityNode, + PerceptualEdge, + PerceptualNode, + StatementChunkEdge, + StatementEntityEdge, + StatementNode, +) +from app.core.memory.models.message_models import DialogData, TemporalInfo + +logger = logging.getLogger(__name__) + + +class GraphBuildResult: + """图构建步骤的输出。""" + + __slots__ = ( + "dialogue_nodes", + "chunk_nodes", + "statement_nodes", + "entity_nodes", + "perceptual_nodes", + "stmt_chunk_edges", + "stmt_entity_edges", + "entity_entity_edges", + "perceptual_edges", + ) + + def __init__( + self, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + perceptual_nodes: List[PerceptualNode], + stmt_chunk_edges: List[StatementChunkEdge], + stmt_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + perceptual_edges: List[PerceptualEdge], + ): + self.dialogue_nodes = dialogue_nodes + self.chunk_nodes = chunk_nodes + self.statement_nodes = statement_nodes + self.entity_nodes = entity_nodes + self.perceptual_nodes = perceptual_nodes + self.stmt_chunk_edges = stmt_chunk_edges + self.stmt_entity_edges = stmt_entity_edges + self.entity_entity_edges = entity_entity_edges + self.perceptual_edges = perceptual_edges + + +async def build_graph_nodes_and_edges( + dialog_data_list: List[DialogData], + embedder_client: Any = None, + progress_callback: Optional[ + Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]] + ] = None, +) -> GraphBuildResult: + """ + 从 DialogData 列表构建完整的图节点和边。 + + Args: + dialog_data_list: 经过萃取和数据赋值后的 DialogData 列表 + embedder_client: 可选的嵌入客户端,用于 PerceptualNode summary embedding + progress_callback: 可选的进度回调 + + Returns: + GraphBuildResult 包含所有节点和边 + """ + logger.info("开始创建节点和边") + + dialogue_nodes: List[DialogueNode] = [] + chunk_nodes: List[ChunkNode] = [] + statement_nodes: List[StatementNode] = [] + entity_nodes: List[ExtractedEntityNode] = [] + perceptual_nodes: List[PerceptualNode] = [] + stmt_chunk_edges: List[StatementChunkEdge] = [] + stmt_entity_edges: List[StatementEntityEdge] = [] + entity_entity_edges: List[EntityEntityEdge] = [] + perceptual_edges: List[PerceptualEdge] = [] + + entity_id_set: set = set() + total_dialogs = len(dialog_data_list) + processed_dialogs = 0 + + for dialog_data in dialog_data_list: + processed_dialogs += 1 +# region TODO 乐力齐 重构流水线切换生产环境稳定后修改 + # ── 对话节点 ── + dialogue_node = DialogueNode( + id=dialog_data.id, + name=f"Dialog_{dialog_data.id}", + ref_id=dialog_data.ref_id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + content=dialog_data.context.content if dialog_data.context else "", + dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None, + created_at=dialog_data.created_at, + expired_at=dialog_data.expired_at, + metadata=dialog_data.metadata, + config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None, + ) + dialogue_nodes.append(dialogue_node) + + # ── 分块节点 ── + for chunk_idx, chunk in enumerate(dialog_data.chunks): + chunk_node = ChunkNode( + id=chunk.id, + name=f"Chunk_{chunk.id}", + dialog_id=dialog_data.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + content=chunk.content, + speaker=getattr(chunk, "speaker", None), + chunk_embedding=chunk.chunk_embedding, + sequence_number=chunk_idx, + created_at=dialog_data.created_at, + expired_at=dialog_data.expired_at, + metadata=chunk.metadata, + ) + chunk_nodes.append(chunk_node) + + # ── 感知节点 ── + for p, file_type in chunk.files: + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + summary_embedding = None + if embedder_client and p.summary: + try: + summary_embedding = (await embedder_client.response([p.summary]))[0] + except Exception as emb_err: + logger.warning(f"Failed to embed perceptual summary: {emb_err}") + + perceptual = PerceptualNode( + name=f"Perceptual_{p.id}", + id=str(p.id), + end_user_id=str(p.end_user_id), + perceptual_type=p.perceptual_type, + file_path=p.file_path or "", + file_name=p.file_name or "", + file_ext=p.file_ext or "", + summary=p.summary or "", + keywords=content_meta.get("keywords", []), + topic=content_meta.get("topic", ""), + domain=content_meta.get("domain", ""), + created_at=p.created_time.isoformat() if p.created_time else None, + file_type=file_type, + summary_embedding=summary_embedding, + ) + perceptual_nodes.append(perceptual) + perceptual_edges.append( + PerceptualEdge( + source=perceptual.id, + target=chunk.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + ) + ) + + # ── 陈述句节点 + 边 ── + for statement in chunk.statements: + statement_node = StatementNode( + id=statement.id, + name=f"Statement_{statement.id}", + chunk_id=chunk.id, + stmt_type=getattr(statement, "stmt_type", "general"), + temporal_info=getattr(statement, "temporal_info", TemporalInfo.ATEMPORAL), + connect_strength=( + statement.connect_strength + if statement.connect_strength is not None + else "Strong" + ), + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + statement=statement.statement, + speaker=getattr(statement, "speaker", None), + statement_embedding=statement.statement_embedding, + valid_at=( + statement.temporal_validity.valid_at + if hasattr(statement, "temporal_validity") and statement.temporal_validity + else None + ), + invalid_at=( + statement.temporal_validity.invalid_at + if hasattr(statement, "temporal_validity") and statement.temporal_validity + else None + ), + created_at=dialog_data.created_at, + expired_at=dialog_data.expired_at, + config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None, + emotion_type=getattr(statement, "emotion_type", None), + emotion_intensity=getattr(statement, "emotion_intensity", None), + emotion_keywords=getattr(statement, "emotion_keywords", None), + emotion_subject=getattr(statement, "emotion_subject", None), + emotion_target=getattr(statement, "emotion_target", None), + ) + statement_nodes.append(statement_node) + + stmt_chunk_edges.append( + StatementChunkEdge( + source=statement.id, + target=chunk.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + ) + ) + + # ── 三元组 → 实体节点 + 边 ── + if not statement.triplet_extraction_info: + continue + + triplet_info = statement.triplet_extraction_info + entity_idx_to_id: Dict[int, str] = {} + + for entity_idx, entity in enumerate(triplet_info.entities): + entity_idx_to_id[entity.entity_idx] = entity.id + entity_idx_to_id[entity_idx] = entity.id + + if entity.id not in entity_id_set: + entity_connect_strength = getattr(entity, "connect_strength", "Strong") + entity_node = ExtractedEntityNode( + id=entity.id, + name=getattr(entity, "name", f"Entity_{entity.id}"), + entity_idx=entity.entity_idx, + statement_id=statement.id, + entity_type=getattr(entity, "type", "unknown"), + description=getattr(entity, "description", ""), + example=getattr(entity, "example", ""), + connect_strength=( + entity_connect_strength + if entity_connect_strength is not None + else "Strong" + ), + aliases=getattr(entity, "aliases", []) or [], + name_embedding=getattr(entity, "name_embedding", None), + is_explicit_memory=getattr(entity, "is_explicit_memory", False), + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + expired_at=dialog_data.expired_at, + config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None, + ) + entity_nodes.append(entity_node) + entity_id_set.add(entity.id) + + entity_connect_strength = getattr(entity, "connect_strength", "Strong") + stmt_entity_edges.append( + StatementEntityEdge( + source=statement.id, + target=entity.id, + connect_strength=( + entity_connect_strength + if entity_connect_strength is not None + else "Strong" + ), + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + ) + ) +# endregion + + for triplet in triplet_info.triplets: + subject_entity_id = entity_idx_to_id.get(triplet.subject_id) + object_entity_id = entity_idx_to_id.get(triplet.object_id) + + if subject_entity_id and object_entity_id: + entity_entity_edges.append( + EntityEntityEdge( + source=subject_entity_id, + target=object_entity_id, + relation_type=triplet.predicate, + statement=statement.statement, + source_statement_id=statement.id, + end_user_id=dialog_data.end_user_id, + run_id=dialog_data.run_id, + created_at=dialog_data.created_at, + expired_at=dialog_data.expired_at, + ) + ) + + if progress_callback and len(entity_entity_edges) <= 10: + relationship_result = { + "result_type": "relationship_creation", + "relationship_index": len(entity_entity_edges), + "source_entity": triplet.subject_name, + "relation_type": triplet.predicate, + "target_entity": triplet.object_name, + "relationship_text": f"{triplet.subject_name} -[{triplet.predicate}]-> {triplet.object_name}", + "dialog_progress": f"{processed_dialogs}/{total_dialogs}", + } + await progress_callback( + "creating_nodes_edges_result", + f"关系创建中 ({processed_dialogs}/{total_dialogs})", + relationship_result, + ) + else: + missing_subject = "subject" if not subject_entity_id else "" + missing_object = "object" if not object_entity_id else "" + missing_both = " and " if (not subject_entity_id and not object_entity_id) else "" + logger.debug( + f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: " + f"subject_id={triplet.subject_id} ({triplet.subject_name}), " + f"object_id={triplet.object_id} ({triplet.object_name}), " + f"predicate={triplet.predicate}, " + f"statement_id={statement.id}, " + f"available_indices={sorted(entity_idx_to_id.keys())}" + ) + + logger.info( + f"节点和边创建完成 - 对话节点: {len(dialogue_nodes)}, " + f"分块节点: {len(chunk_nodes)}, 陈述句节点: {len(statement_nodes)}, " + f"实体节点: {len(entity_nodes)}, 陈述句-分块边: {len(stmt_chunk_edges)}, " + f"陈述句-实体边: {len(stmt_entity_edges)}, " + f"实体-实体边: {len(entity_entity_edges)}" + ) + + if progress_callback: + nodes_edges_stats = { + "dialogue_nodes_count": len(dialogue_nodes), + "chunk_nodes_count": len(chunk_nodes), + "statement_nodes_count": len(statement_nodes), + "entity_nodes_count": len(entity_nodes), + "statement_chunk_edges_count": len(stmt_chunk_edges), + "statement_entity_edges_count": len(stmt_entity_edges), + "entity_entity_edges_count": len(entity_entity_edges), + } + await progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats) + + return GraphBuildResult( + dialogue_nodes=dialogue_nodes, + chunk_nodes=chunk_nodes, + statement_nodes=statement_nodes, + entity_nodes=entity_nodes, + perceptual_nodes=perceptual_nodes, + stmt_chunk_edges=stmt_chunk_edges, + stmt_entity_edges=stmt_entity_edges, + entity_entity_edges=entity_entity_edges, + perceptual_edges=perceptual_edges, + ) diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/schema/__init__.py b/api/app/core/memory/storage_services/extraction_engine/steps/schema/__init__.py new file mode 100644 index 00000000..0223b860 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/schema/__init__.py @@ -0,0 +1,42 @@ +"""Schema package for ExtractionStep inputs and outputs. + +Re-exports all models for convenient access: + from .schema import StatementStepInput, EmotionStepOutput, ... +""" + +from .extraction_step_schema import ( + EmbeddingStepInput, + EmbeddingStepOutput, + EntityItem, + MessageItem, + StatementStepInput, + StatementStepOutput, + SupportingContext, + TripletItem, + TripletStepInput, + TripletStepOutput, +) +from .sidecar_step_schema import ( + EmotionStepInput, + EmotionStepOutput, +) + +__all__ = [ + # Shared + "MessageItem", + "SupportingContext", + # Statement + "StatementStepInput", + "StatementStepOutput", + # Triplet + "TripletStepInput", + "TripletStepOutput", + "EntityItem", + "TripletItem", + # Embedding + "EmbeddingStepInput", + "EmbeddingStepOutput", + # Sidecar — Emotion + "EmotionStepInput", + "EmotionStepOutput", +] diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py b/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py new file mode 100644 index 00000000..a4dad6d5 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/schema/extraction_step_schema.py @@ -0,0 +1,115 @@ +"""Pydantic models for base extraction pipeline inputs and outputs. + +Covers the core (critical) stages: Statement extraction, Triplet extraction, +Embedding generation, and shared types used across stages. + +Malformed LLM JSON will raise ``ValidationError`` and trigger stage-level retry. +""" + +from typing import Dict, List +from pydantic import BaseModel, Field + + +# ── Shared types ── + + +class MessageItem(BaseModel): + """Single conversation message.""" + + role: str # "User" / "Assistant" + msg: str + + +class SupportingContext(BaseModel): + """Dialogue context window (used for pronoun resolution, etc.).""" + + msgs: List[MessageItem] = Field(default_factory=list) + + +# ── Statement extraction ── +class StatementStepInput(BaseModel): + """Input for StatementExtractionStep.""" + + chunk_id: str + end_user_id: str + target_content: str + target_message_date: str + supporting_context: SupportingContext + + +class StatementStepOutput(BaseModel): + """Single extracted statement (including temporal info).""" + + statement_id: str + statement_text: str + statement_type: str # FACT / OPINION / PREDICTION / SUGGESTION + temporal_type: str # STATIC / DYNAMIC / ATEMPORAL + relevance: str # RELEVANT / IRRELEVANT + speaker: str # "user" / "assistant" + valid_at: str # ISO 8601 or "NULL" + invalid_at: str # ISO 8601 or "NULL" + + +# ── Triplet extraction ── +class TripletStepInput(BaseModel): + """Input for TripletExtractionStep.""" + + statement_id: str + statement_text: str + statement_type: str + temporal_type: str + supporting_context: SupportingContext + speaker: str + valid_at: str + invalid_at: str + + +class EntityItem(BaseModel): + """Single entity extracted during triplet extraction.""" + + entity_idx: int + name: str + type: str + description: str + is_explicit_memory: bool = False + + +class TripletItem(BaseModel): + """Single triplet (subject-predicate-object) relationship.""" + + subject_name: str + subject_id: int + predicate: str + object_name: str + object_id: int + + +class TripletStepOutput(BaseModel): + """Output of TripletExtractionStep.""" + + entities: List[EntityItem] = Field(default_factory=list) + triplets: List[TripletItem] = Field(default_factory=list) + + +# ── Embedding generation ── +class EmbeddingStepInput(BaseModel): + """Input for EmbeddingStep. + + Each dict maps an ID to the text that should be embedded. + Fields can be left empty for partial embedding runs. + """ + + statement_texts: Dict[str, str] = Field(default_factory=dict) + chunk_texts: Dict[str, str] = Field(default_factory=dict) + dialog_texts: List[str] = Field(default_factory=list) + entity_names: Dict[str, str] = Field(default_factory=dict) + entity_descriptions: Dict[str, str] = Field(default_factory=dict) + + +class EmbeddingStepOutput(BaseModel): + """Output of EmbeddingStep.""" + + statement_embeddings: Dict[str, List[float]] = Field(default_factory=dict) + chunk_embeddings: Dict[str, List[float]] = Field(default_factory=dict) + dialog_embeddings: List[List[float]] = Field(default_factory=list) + entity_embeddings: Dict[str, List[float]] = Field(default_factory=dict) diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/schema/sidecar_step_schema.py b/api/app/core/memory/storage_services/extraction_engine/steps/schema/sidecar_step_schema.py new file mode 100644 index 00000000..78cb0982 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/schema/sidecar_step_schema.py @@ -0,0 +1,26 @@ +"""Pydantic models for hot-pluggable sidecar step inputs and outputs. + +Sidecar steps are non-critical (is_critical=False) modules registered via +``@SidecarStepFactory.register`` that run concurrently alongside the main +extraction pipeline. Failures degrade gracefully to default outputs. +""" + +from typing import List +from pydantic import BaseModel, Field + + +# ── Emotion extraction (sidecar) ── +class EmotionStepInput(BaseModel): + """Input for EmotionExtractionStep.""" + + statement_id: str + statement_text: str + speaker: str + + +class EmotionStepOutput(BaseModel): + """Output of EmotionExtractionStep.""" + + emotion_type: str = "neutral" + emotion_intensity: float = 0.0 + emotion_keywords: List[str] = Field(default_factory=list) diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/sidecar_factory.py b/api/app/core/memory/storage_services/extraction_engine/steps/sidecar_factory.py new file mode 100644 index 00000000..2f652ee6 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/sidecar_factory.py @@ -0,0 +1,97 @@ +"""SidecarStepFactory — decorator-based registry for sidecar (non-critical) steps. + +New sidecar modules self-register via ``@SidecarStepFactory.register`` and are +automatically discovered and instantiated by the orchestrator without any +changes to orchestrator code. +""" + +import logging +from enum import Enum +from typing import Any, Dict, List, Tuple, Type + +from .base import ExtractionStep, StepContext + +logger = logging.getLogger(__name__) + + +class SidecarTiming(str, Enum): + """Declares when a sidecar step runs relative to the main pipeline.""" + + AFTER_STATEMENT = "after_statement" + AFTER_TRIPLET = "after_triplet" + + +class SidecarStepFactory: + """Factory that manages sidecar step registration and creation. + + Registry maps ``config_key`` → ``(step_class, timing)``. + Adding a new sidecar only requires the ``@register`` decorator on the + step class — no orchestrator modifications needed. + """ + + _registry: Dict[str, Tuple[Type[ExtractionStep], SidecarTiming]] = {} + + @classmethod + def register(cls, config_key: str, timing: SidecarTiming): + """Class decorator that registers a sidecar step. + + Args: + config_key: Configuration flag name (e.g. ``"emotion_enabled"``). + The step is instantiated only when this flag is ``True``. + timing: When the sidecar runs relative to the main pipeline. + + Returns: + The original class, unmodified. + """ + + def decorator(step_class: Type[ExtractionStep]): + cls._registry[config_key] = (step_class, timing) + logger.debug( + "Registered sidecar '%s' (config_key=%s, timing=%s)", + step_class.__name__, + config_key, + timing.value, + ) + return step_class + + return decorator + + @classmethod + def create_sidecars( + cls, config: Any, context: StepContext + ) -> Dict[SidecarTiming, List[ExtractionStep]]: + """Instantiate enabled sidecar steps, grouped by timing. + + Args: + config: Pipeline configuration object. Each registered + ``config_key`` is looked up via ``getattr(config, key, False)``. + context: Shared :class:`StepContext` injected into every step. + + Returns: + A dict keyed by :class:`SidecarTiming`, each value a list of + instantiated sidecar steps whose config flag is ``True``. + """ + result: Dict[SidecarTiming, List[ExtractionStep]] = { + timing: [] for timing in SidecarTiming + } + for config_key, (step_class, timing) in cls._registry.items(): + if getattr(config, config_key, False): + step = step_class(context) + result[timing].append(step) + logger.debug( + "Created sidecar '%s' (timing=%s)", + step_class.__name__, + timing.value, + ) + else: + logger.debug( + "Skipped sidecar '%s' (config_key=%s is disabled)", + step_class.__name__, + config_key, + ) + return result + + @classmethod + def clear_registry(cls) -> None: + """Remove all registered sidecars. Useful for testing.""" + cls._registry.clear() diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py new file mode 100644 index 00000000..f0af11f6 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/statement_step.py @@ -0,0 +1,141 @@ +"""StatementExtractionStep — critical step for extracting statements from chunks. + +Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm. +Temporal extraction logic (valid_at / invalid_at) is merged into this step, +eliminating the need for a separate ``TemporalExtractor`` call. +""" + +import logging +import uuid +from typing import Any, List + +from pydantic import BaseModel, Field, field_validator + +from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS +from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt + +from .base import ExtractionStep, StepContext +from .schema import StatementStepInput, StatementStepOutput + +logger = logging.getLogger(__name__) + + +# ── LLM response schemas (internal) ── + + +class _ExtractedStatement(BaseModel): + """Raw statement returned by the LLM (before enrichment).""" + + statement: str = Field(..., description="The extracted statement text") + statement_type: str = Field(..., description="FACT / OPINION / SUGGESTION / PREDICTION") + temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL") + relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT") + valid_at: str = Field("NULL", description="ISO 8601 or NULL") + invalid_at: str = Field("NULL", description="ISO 8601 or NULL") + + +class _StatementExtractionResponse(BaseModel): + """Structured LLM response containing a list of extracted statements.""" + + statements: List[_ExtractedStatement] = Field(default_factory=list) + + @field_validator("statements", mode="before") + @classmethod + def filter_empty(cls, v: Any) -> Any: + """Drop empty / malformed dicts that the LLM occasionally produces.""" + if isinstance(v, list): + return [s for s in v if isinstance(s, dict) and s.get("statement")] + return v + + +class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]): + """Extract atomic statements (with temporal info) from a dialogue chunk. + + This is a **critical** step — failure aborts the pipeline after retries. + + Config params bound at init (from ``StepContext.config.statement_extraction``): + * ``definitions`` — label definitions for statement classification + * ``json_schema`` — JSON schema for the expected LLM output + * ``granularity`` — extraction granularity level (1-3) + * ``include_dialogue_context`` — whether to include full dialogue context + """ + + def __init__(self, context: StepContext) -> None: + super().__init__(context) + stmt_cfg = getattr(self.config, "statement_extraction", None) + self.definitions = LABEL_DEFINITIONS + self.json_schema = _ExtractedStatement.model_json_schema() + self.granularity = getattr(stmt_cfg, "statement_granularity", None) + self.include_dialogue_context = getattr(stmt_cfg, "include_dialogue_context", True) + self.max_dialogue_context_chars = getattr(stmt_cfg, "max_dialogue_context_chars", 2000) + + # ── Identity ── + + @property + def name(self) -> str: + return "statement_extraction" + + @property + def is_critical(self) -> bool: + return True + + # ── Lifecycle ── + + async def render_prompt(self, input_data: StatementStepInput) -> str: + # Build optional dialogue context from supporting_context messages + dialogue_content = None + if self.include_dialogue_context and input_data.supporting_context.msgs: + dialogue_content = "\n".join( + f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs + ) + + return await render_statement_extraction_prompt( + chunk_content=input_data.target_content, + definitions=self.definitions, + json_schema=self.json_schema, + granularity=self.granularity, + include_dialogue_context=self.include_dialogue_context, + dialogue_content=dialogue_content, + max_dialogue_chars=self.max_dialogue_context_chars, + language=self.language, + ) + + async def call_llm(self, prompt: Any) -> Any: + messages = [ + { + "role": "system", + "content": ( + "You are an expert at extracting and labeling atomic statements " + "from conversational text. Return valid JSON conforming to the schema." + ), + }, + {"role": "user", "content": prompt}, + ] + return await self.llm_client.response_structured( + messages, _StatementExtractionResponse + ) + + async def parse_response( + self, raw_response: Any, input_data: StatementStepInput + ) -> List[StatementStepOutput]: + if not hasattr(raw_response, "statements") or raw_response.statements is None: + return [] + + results: List[StatementStepOutput] = [] + for stmt in raw_response.statements: + results.append( + StatementStepOutput( + statement_id=uuid.uuid4().hex, + statement_text=stmt.statement, + statement_type=stmt.statement_type.strip().upper(), + temporal_type=stmt.temporal_type.strip().upper(), + relevance=stmt.relevance.strip().upper(), + speaker="user", # default; orchestrator overrides from chunk metadata + valid_at=stmt.valid_at or "NULL", + invalid_at=stmt.invalid_at or "NULL", + ) + ) + return results + + def get_default_output(self) -> List[StatementStepOutput]: + return [] diff --git a/api/app/core/memory/storage_services/extraction_engine/steps/triplet_step.py b/api/app/core/memory/storage_services/extraction_engine/steps/triplet_step.py new file mode 100644 index 00000000..f8319114 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/steps/triplet_step.py @@ -0,0 +1,118 @@ +"""TripletExtractionStep — critical step for extracting entities and triplets. + +Replaces the legacy ``TripletExtractor`` with the unified ExtractionStep paradigm. +Predicate filtering against the ontology whitelist is performed in ``parse_response``. +""" + +import logging +from typing import Any + +from app.core.memory.models.triplet_models import TripletExtractionResponse +from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate +from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt + +from .base import ExtractionStep, StepContext +from .schema import EntityItem, TripletItem, TripletStepInput, TripletStepOutput + +logger = logging.getLogger(__name__) + + +class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput]): + """Extract knowledge triplets and entities from a single statement. + + This is a **critical** step — failure aborts the pipeline after retries. + + Config params bound at init (from ``StepContext.config``): + * ``ontology_types`` — predefined ontology types for entity classification + * ``predicate_instructions`` — predicate definition guidance for the LLM + * ``json_schema`` — JSON schema for the expected LLM output + """ + + def __init__( + self, + context: StepContext, + ontology_types: Any = None, + ) -> None: + super().__init__(context) + self.ontology_types = ontology_types + self.predicate_instructions = PREDICATE_DEFINITIONS + self.json_schema = TripletExtractionResponse.model_json_schema() + self._allowed_predicates = {p.value for p in Predicate} + + # ── Identity ── + + @property + def name(self) -> str: + return "triplet_extraction" + + @property + def is_critical(self) -> bool: + return True + + # ── Lifecycle ── + + async def render_prompt(self, input_data: TripletStepInput) -> str: + # Build chunk_content from supporting_context for pronoun resolution + chunk_content = "\n".join( + f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs + ) if input_data.supporting_context.msgs else "" + + return await render_triplet_extraction_prompt( + statement=input_data.statement_text, + chunk_content=chunk_content, + json_schema=self.json_schema, + predicate_instructions=self.predicate_instructions, + language=self.language, + ontology_types=self.ontology_types, + speaker=input_data.speaker, + ) + + async def call_llm(self, prompt: Any) -> Any: + messages = [ + { + "role": "system", + "content": ( + "You are an expert at extracting knowledge triplets and entities " + "from text. Follow the provided instructions carefully and return valid JSON." + ), + }, + {"role": "user", "content": prompt}, + ] + return await self.llm_client.response_structured( + messages, TripletExtractionResponse + ) + + async def parse_response( + self, raw_response: Any, input_data: TripletStepInput + ) -> TripletStepOutput: + if not hasattr(raw_response, "triplets"): + return self.get_default_output() + + # Filter triplets to allowed predicates from ontology whitelist + filtered_triplets = [ + TripletItem( + subject_name=t.subject_name, + subject_id=t.subject_id, + predicate=t.predicate, + object_name=t.object_name, + object_id=t.object_id, + ) + for t in raw_response.triplets + if getattr(t, "predicate", "") in self._allowed_predicates + ] + + entities = [ + EntityItem( + entity_idx=e.entity_idx, + name=e.name, + type=e.type, + description=e.description, + is_explicit_memory=getattr(e, "is_explicit_memory", False), + ) + for e in (raw_response.entities or []) + ] + + return TripletStepOutput(entities=entities, triplets=filtered_triplets) + + def get_default_output(self) -> TripletStepOutput: + return TripletStepOutput(entities=[], triplets=[]) diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index e186e54b..8c7e0366 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -421,6 +421,9 @@ class MemoryConfig: pruning_scene: Optional[str] = "education" pruning_threshold: float = 0.5 + # Pipeline config: Emotion extraction + emotion_enabled: bool = False + # Ontology scene association scene_id: Optional[UUID] = None ontology_class_infos: list[dict] = field(default_factory=list) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index a4752ba9..9f4875ed 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -360,40 +360,64 @@ class MemoryAgentService: await write_rag(end_user_id, message_text, user_rag_memory_id) return "success" else: - await write_neo4j( - end_user_id=end_user_id, - messages=messages, - memory_config=memory_config, - ref_id='', - language=language - ) - - # ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ── + # TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码 import os - if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true": - try: - from app.core.memory.memory_service import MemoryService - import copy + use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true" - shadow_messages = copy.deepcopy(messages) - shadow_service = MemoryService( - memory_config=memory_config, - end_user_id=end_user_id, - ) - shadow_result = await shadow_service.write( - messages=shadow_messages, - language=language, - ref_id='', - is_pilot_run=True, # 试运行模式:只萃取不写入,避免重复写入 Neo4j - ) - logger.info( - f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, " - f"elapsed={shadow_result.elapsed_seconds:.2f}s, " - f"extraction={shadow_result.extraction}" - ) - except Exception as shadow_err: - logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}") - # ── 影子运行结束 ── + if use_new_pipeline: + # ── 新流水线:WritePipeline + NewExtractionOrchestrator ── + from app.core.memory.memory_service import MemoryService + + service = MemoryService( + memory_config=memory_config, + end_user_id=end_user_id, + ) + result = await service.write( + messages=messages, + language=language, + ref_id='', + is_pilot_run=False, + ) + logger.info( + f"[NewPipeline] 完成: status={result.status}, " + f"elapsed={result.elapsed_seconds:.2f}s, " + f"extraction={result.extraction}" + ) + else: + # ── 旧流水线:write_tools.write() + ExtractionOrchestrator ── + await write_neo4j( + end_user_id=end_user_id, + messages=messages, + memory_config=memory_config, + ref_id='', + language=language + ) + + # ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ── + if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true": + try: + from app.core.memory.memory_service import MemoryService + import copy + + shadow_messages = copy.deepcopy(messages) + shadow_service = MemoryService( + memory_config=memory_config, + end_user_id=end_user_id, + ) + shadow_result = await shadow_service.write( + messages=shadow_messages, + language=language, + ref_id='', + is_pilot_run=True, + ) + logger.info( + f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, " + f"elapsed={shadow_result.elapsed_seconds:.2f}s, " + f"extraction={shadow_result.extraction}" + ) + except Exception as shadow_err: + logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}") + # ── 影子运行结束 ── for lang in ["zh", "en"]: deleted = await InterestMemoryCache.delete_interest_distribution( end_user_id, lang diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index 4e80383c..921ee03f 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -418,6 +418,9 @@ class MemoryConfigService: pruning_scene=memory_config.pruning_scene or "education", pruning_threshold=float( memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, + # Pipeline config: Emotion extraction + emotion_enabled=bool( + memory_config.emotion_enabled) if memory_config.emotion_enabled is not None else False, # Ontology scene association scene_id=memory_config.scene_id, ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id), @@ -573,6 +576,7 @@ class MemoryConfigService: statement_extraction=stmt_config, deduplication=dedup_config, forgetting_engine=forget_config, + emotion_enabled=getattr(memory_config, "emotion_enabled", False), ) @staticmethod