feat(memory): implement step-based extraction pipeline architecture
Introduce ExtractionStep abstraction with modular pipeline stages: - Add base ExtractionStep class with render/call/parse lifecycle - Implement StatementExtractionStep, TripletExtractionStep, EmbeddingStep, EmotionStep, GraphBuildStep, and DedupStep - Add SidecarStepFactory for hot-pluggable non-critical steps - Define Pydantic I/O schemas for all pipeline stages - Refactor WritePipeline to orchestrate new step-based flow - Add NEW_PIPELINE_ENABLED env switch for old/new pipeline routing - Add emotion_enabled config flag to MemoryConfig - Fix workspace_id reference in get_end_user_connected_config
This commit is contained in:
@@ -1,7 +1,4 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||||
@@ -34,6 +31,7 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
conversation_messages = []
|
conversation_messages = []
|
||||||
|
|
||||||
|
# step1: 消息格式校验 role:user、assistant。content
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||||
@@ -59,7 +57,7 @@ async def get_chunked_dialogs(
|
|||||||
config_id=config_id
|
config_id=config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 语义剪枝步骤(在分块之前)
|
# step2: 语义剪枝步骤(在分块之前)
|
||||||
try:
|
try:
|
||||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||||
from app.core.memory.models.config_models import PruningConfig
|
from app.core.memory.models.config_models import PruningConfig
|
||||||
@@ -116,6 +114,7 @@ async def get_chunked_dialogs(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# step3: 分块
|
||||||
chunker = DialogueChunker(chunker_strategy)
|
chunker = DialogueChunker(chunker_strategy)
|
||||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||||
dialog_data.chunks = extracted_chunks
|
dialog_data.chunks = extracted_chunks
|
||||||
|
|||||||
@@ -148,6 +148,84 @@ async def write(
|
|||||||
all_dedup_details,
|
all_dedup_details,
|
||||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||||
|
|
||||||
|
# region TODO 乐力齐 重构流水线切换至生产环境稳定后,移除快照对比代码
|
||||||
|
# ── Snapshot: 旧流水线萃取结果(按 phase2_step_io_schema_v1.md 格式) ──
|
||||||
|
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
|
||||||
|
snapshot = PipelineSnapshot("legacy")
|
||||||
|
|
||||||
|
# Statement 输出(从 dialog_data_list 中提取)
|
||||||
|
stmt_snapshot = []
|
||||||
|
for d in all_dedup_details:
|
||||||
|
if not hasattr(d, "chunks"):
|
||||||
|
continue
|
||||||
|
for c in d.chunks:
|
||||||
|
for s in c.statements:
|
||||||
|
stmt_snapshot.append({
|
||||||
|
"statement_id": s.id,
|
||||||
|
"statement_text": s.statement,
|
||||||
|
"statement_type": str(getattr(s, "stmt_type", "")),
|
||||||
|
"temporal_type": str(getattr(s, "temporal_info", "")),
|
||||||
|
"relevance": str(getattr(s, "relevence_info", "RELEVANT")),
|
||||||
|
"speaker": getattr(s, "speaker", "user") or "user",
|
||||||
|
"valid_at": s.temporal_validity.valid_at if s.temporal_validity else "NULL",
|
||||||
|
"invalid_at": s.temporal_validity.invalid_at if s.temporal_validity else "NULL",
|
||||||
|
})
|
||||||
|
snapshot.save_stage("2_statement_outputs", stmt_snapshot)
|
||||||
|
|
||||||
|
# Triplet 输出(从 dialog_data_list 中提取)
|
||||||
|
triplet_snapshot = {}
|
||||||
|
for d in all_dedup_details:
|
||||||
|
if not hasattr(d, "chunks"):
|
||||||
|
continue
|
||||||
|
for c in d.chunks:
|
||||||
|
for s in c.statements:
|
||||||
|
if s.triplet_extraction_info:
|
||||||
|
triplet_snapshot[s.id] = {
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"entity_idx": e.entity_idx, "name": e.name,
|
||||||
|
"type": e.type, "description": e.description,
|
||||||
|
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
|
||||||
|
}
|
||||||
|
for e in s.triplet_extraction_info.entities
|
||||||
|
],
|
||||||
|
"triplets": [
|
||||||
|
{
|
||||||
|
"subject_name": t.subject_name, "subject_id": t.subject_id,
|
||||||
|
"predicate": t.predicate,
|
||||||
|
"object_name": t.object_name, "object_id": t.object_id,
|
||||||
|
}
|
||||||
|
for t in s.triplet_extraction_info.triplets
|
||||||
|
],
|
||||||
|
}
|
||||||
|
snapshot.save_stage("3_triplet_outputs", triplet_snapshot)
|
||||||
|
|
||||||
|
# 图节点和边(去重后)
|
||||||
|
snapshot.save_stage("6_nodes_edges_after_dedup", {
|
||||||
|
"dialogue_nodes_count": len(all_dialogue_nodes),
|
||||||
|
"chunk_nodes_count": len(all_chunk_nodes),
|
||||||
|
"statement_nodes_count": len(all_statement_nodes),
|
||||||
|
"entity_nodes": [
|
||||||
|
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
|
||||||
|
for e in all_entity_nodes
|
||||||
|
],
|
||||||
|
"entity_entity_edges": [
|
||||||
|
{
|
||||||
|
"source": e.source, "target": e.target,
|
||||||
|
"relation_type": e.relation_type, "statement": e.statement,
|
||||||
|
}
|
||||||
|
for e in all_entity_entity_edges
|
||||||
|
],
|
||||||
|
})
|
||||||
|
snapshot.save_summary({
|
||||||
|
"dialogue_count": len(all_dialogue_nodes),
|
||||||
|
"chunk_count": len(all_chunk_nodes),
|
||||||
|
"statement_count": len(all_statement_nodes),
|
||||||
|
"entity_count": len(all_entity_nodes),
|
||||||
|
"relation_count": len(all_entity_entity_edges),
|
||||||
|
})
|
||||||
|
# endregion
|
||||||
|
|
||||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|
||||||
# Step 3: Save all data to Neo4j database
|
# Step 3: Save all data to Neo4j database
|
||||||
|
|||||||
@@ -149,3 +149,5 @@ class ExtractionPipelineConfig(BaseModel):
|
|||||||
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
|
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
|
||||||
deduplication: DedupConfig = Field(default_factory=DedupConfig)
|
deduplication: DedupConfig = Field(default_factory=DedupConfig)
|
||||||
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)
|
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)
|
||||||
|
# 情绪引擎(旁路模块,SidecarStepFactory 通过此字段判断是否启用)
|
||||||
|
emotion_enabled: bool = Field(default=False, description="是否启用情绪提取旁路")
|
||||||
|
|||||||
@@ -12,20 +12,33 @@ WritePipeline — 记忆写入流水线
|
|||||||
|
|
||||||
依赖方向:Facade → Pipeline → Engine → Repository(单向,不允许反向调用)
|
依赖方向:Facade → Pipeline → Engine → Repository(单向,不允许反向调用)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.core.memory.models.graph_models import ExtractedEntityNode
|
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
from app.core.memory.models.graph_models import (
|
||||||
|
ChunkNode,
|
||||||
|
DialogueNode,
|
||||||
|
EntityEntityEdge,
|
||||||
|
ExtractedEntityNode,
|
||||||
|
PerceptualEdge,
|
||||||
|
PerceptualNode,
|
||||||
|
StatementChunkEdge,
|
||||||
|
StatementEntityEdge,
|
||||||
|
StatementNode,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -34,36 +47,40 @@ logger = logging.getLogger(__name__)
|
|||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class ExtractionResult(BaseModel):
|
||||||
class ExtractionResult:
|
"""萃取 + 图构建 + 去重消歧后的结构化输出。
|
||||||
"""萃取步骤的结构化输出,替代 ExtractionOrchestrator.run() 返回的裸元组。
|
|
||||||
|
|
||||||
字段与 ExtractionOrchestrator.run() 的 10 元素返回值一一对应:
|
作为 Pipeline 层的阶段间数据载体,确保下游步骤(_store、_cluster)
|
||||||
[0] dialogue_nodes → self.dialogue_nodes
|
接收到的图节点和边结构完整、类型正确。
|
||||||
[1] chunk_nodes → self.chunk_nodes
|
|
||||||
[2] statement_nodes → self.statement_nodes
|
|
||||||
[3] entity_nodes → self.entity_nodes
|
|
||||||
[4] perceptual_nodes → self.perceptual_nodes
|
|
||||||
[5] stmt_chunk_edges → self.stmt_chunk_edges
|
|
||||||
[6] stmt_entity_edges → self.stmt_entity_edges
|
|
||||||
[7] entity_entity_edges → self.entity_entity_edges
|
|
||||||
[8] perceptual_edges → self.perceptual_edges
|
|
||||||
[9] dialog_data_list → self.dialog_data_list
|
|
||||||
|
|
||||||
注意:字段类型使用 List[Any] 而非具体的 graph_models 类型,
|
字段对应 ExtractionOrchestrator 产出的图节点/边:
|
||||||
避免在模块加载时触发循环依赖。Pipeline 只做数据传递,不检查具体类型。
|
dialogue_nodes — 对话节点
|
||||||
|
chunk_nodes — 分块节点
|
||||||
|
statement_nodes — 陈述句节点
|
||||||
|
entity_nodes — 实体节点(去重消歧后)
|
||||||
|
perceptual_nodes — 感知节点
|
||||||
|
stmt_chunk_edges — 陈述句 → 分块 边
|
||||||
|
stmt_entity_edges — 陈述句 → 实体 边
|
||||||
|
entity_entity_edges — 实体 → 实体 边(去重消歧后)
|
||||||
|
perceptual_edges — 感知 → 分块 边
|
||||||
|
dialog_data_list — 原始 DialogData(供摘要阶段使用)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dialogue_nodes: List[Any]
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
chunk_nodes: List[Any]
|
|
||||||
statement_nodes: List[Any]
|
dialogue_nodes: List[DialogueNode]
|
||||||
entity_nodes: List[Any]
|
chunk_nodes: List[ChunkNode]
|
||||||
perceptual_nodes: List[Any]
|
statement_nodes: List[StatementNode]
|
||||||
stmt_chunk_edges: List[Any]
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
stmt_entity_edges: List[Any]
|
perceptual_nodes: List[PerceptualNode]
|
||||||
entity_entity_edges: List[Any]
|
stmt_chunk_edges: List[StatementChunkEdge]
|
||||||
perceptual_edges: List[Any]
|
stmt_entity_edges: List[StatementEntityEdge]
|
||||||
dialog_data_list: List[Any]
|
entity_entity_edges: List[EntityEntityEdge]
|
||||||
|
perceptual_edges: List[PerceptualEdge]
|
||||||
|
dialog_data_list: List[Any] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="原始 DialogData 列表,类型为 Any 以避免循环依赖",
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stats(self) -> Dict[str, int]:
|
def stats(self) -> Dict[str, int]:
|
||||||
@@ -78,8 +95,7 @@ class ExtractionResult:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class WriteResult(BaseModel):
|
||||||
class WriteResult:
|
|
||||||
"""写入流水线的最终输出,返回给 MemoryService / MemoryAgentService"""
|
"""写入流水线的最终输出,返回给 MemoryService / MemoryAgentService"""
|
||||||
|
|
||||||
status: str # "success" | "pilot_complete" | "failed"
|
status: str # "success" | "pilot_complete" | "failed"
|
||||||
@@ -114,7 +130,7 @@ class WritePipeline:
|
|||||||
memory_config: 不可变的记忆配置对象(从数据库加载)
|
memory_config: 不可变的记忆配置对象(从数据库加载)
|
||||||
end_user_id: 终端用户 ID
|
end_user_id: 终端用户 ID
|
||||||
language: 语言 ("zh" | "en")
|
language: 语言 ("zh" | "en")
|
||||||
progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None]
|
progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None] 供pilot run使用
|
||||||
"""
|
"""
|
||||||
self.memory_config = memory_config
|
self.memory_config = memory_config
|
||||||
self.end_user_id = end_user_id
|
self.end_user_id = end_user_id
|
||||||
@@ -164,7 +180,7 @@ class WritePipeline:
|
|||||||
self._init_clients()
|
self._init_clients()
|
||||||
self._init_neo4j_connector()
|
self._init_neo4j_connector()
|
||||||
|
|
||||||
# Step 1: 预处理 - 消息分块
|
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝(暂无实现)
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
chunked_dialogs = await self._preprocess(messages, ref_id)
|
chunked_dialogs = await self._preprocess(messages, ref_id)
|
||||||
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
|
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
|
||||||
@@ -175,9 +191,7 @@ class WritePipeline:
|
|||||||
|
|
||||||
# Step 2: 萃取 - 知识提取
|
# Step 2: 萃取 - 知识提取
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
extraction_result = await self._extract(
|
extraction_result = await self._extract(chunked_dialogs, is_pilot_run)
|
||||||
chunked_dialogs, is_pilot_run
|
|
||||||
)
|
|
||||||
stats = extraction_result.stats
|
stats = extraction_result.stats
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[WritePipeline] [2/5] 萃取:知识提取 "
|
f"[WritePipeline] [2/5] 萃取:知识提取 "
|
||||||
@@ -190,9 +204,7 @@ class WritePipeline:
|
|||||||
# 试运行模式到此结束
|
# 试运行模式到此结束
|
||||||
if is_pilot_run:
|
if is_pilot_run:
|
||||||
elapsed = time.time() - pipeline_start
|
elapsed = time.time() - pipeline_start
|
||||||
logger.info(
|
logger.info(f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s")
|
||||||
f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s"
|
|
||||||
)
|
|
||||||
return WriteResult(
|
return WriteResult(
|
||||||
status="pilot_complete",
|
status="pilot_complete",
|
||||||
extraction=extraction_result.stats,
|
extraction=extraction_result.stats,
|
||||||
@@ -227,9 +239,7 @@ class WritePipeline:
|
|||||||
await self._update_stats_cache(extraction_result)
|
await self._update_stats_cache(extraction_result)
|
||||||
|
|
||||||
elapsed = time.time() - pipeline_start
|
elapsed = time.time() - pipeline_start
|
||||||
logger.info(
|
logger.info(f"[WritePipeline] 完成 ✔ {elapsed:.2f}s")
|
||||||
f"[WritePipeline] 完成 ✔ {elapsed:.2f}s"
|
|
||||||
)
|
|
||||||
return WriteResult(
|
return WriteResult(
|
||||||
status="success",
|
status="success",
|
||||||
extraction=extraction_result.stats,
|
extraction=extraction_result.stats,
|
||||||
@@ -251,16 +261,14 @@ class WritePipeline:
|
|||||||
# Step 1: 预处理
|
# Step 1: 预处理
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
async def _preprocess(
|
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
|
||||||
self, messages: List[dict], ref_id: str
|
|
||||||
) -> List[DialogData]:
|
|
||||||
"""
|
"""
|
||||||
预处理:消息校验 → 语义剪枝 → 对话分块。
|
预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。
|
||||||
|
|
||||||
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
|
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
|
||||||
get_dialogs.py 内部已包含:
|
get_dialogs.py 内部已包含:
|
||||||
- 消息格式校验(role/content 必填)
|
- 消息格式校验(role/content 必填)
|
||||||
- 语义剪枝(根据 config 中 pruning_enabled 决定)
|
- AI消息语义剪枝(根据 config 中 pruning_enabled 决定)
|
||||||
- DialogueChunker 分块
|
- DialogueChunker 分块
|
||||||
"""
|
"""
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
@@ -283,56 +291,187 @@ class WritePipeline:
|
|||||||
is_pilot_run: bool,
|
is_pilot_run: bool,
|
||||||
) -> ExtractionResult:
|
) -> ExtractionResult:
|
||||||
"""
|
"""
|
||||||
萃取:初始化引擎 → 执行知识提取 → 返回结构化结果。
|
萃取:初始化引擎 → 执行知识提取 → 构建图节点/边 → 去重 → 返回结构化结果。
|
||||||
|
|
||||||
ExtractionOrchestrator 作为萃取引擎被调用,
|
使用 NewExtractionOrchestrator(ExtractionStep 范式)完成 LLM 萃取,
|
||||||
Pipeline 不关心引擎内部的并行策略和提取细节。
|
然后通过独立的 graph_build_step 和 dedup_step 完成图构建和去重,
|
||||||
|
不依赖旧编排器 ExtractionOrchestrator。
|
||||||
|
|
||||||
|
执行流程:
|
||||||
|
1. NewExtractionOrchestrator.run() → 萃取并赋值到 DialogData
|
||||||
|
2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边
|
||||||
|
3. run_dedup() → 两阶段去重消歧
|
||||||
"""
|
"""
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
from app.core.memory.storage_services.extraction_engine.dedup_step import (
|
||||||
ExtractionOrchestrator,
|
run_dedup,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
||||||
|
build_graph_nodes_and_edges,
|
||||||
|
)
|
||||||
|
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
|
||||||
|
NewExtractionOrchestrator,
|
||||||
|
)
|
||||||
|
|
||||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||||
|
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
|
||||||
|
|
||||||
pipeline_config = get_pipeline_config(self.memory_config)
|
pipeline_config = get_pipeline_config(self.memory_config)
|
||||||
ontology_types = self._load_ontology_types()
|
ontology_types = self._load_ontology_types()
|
||||||
|
|
||||||
orchestrator = ExtractionOrchestrator(
|
snapshot = PipelineSnapshot("new")
|
||||||
|
|
||||||
|
# ── 新编排器:LLM 萃取 + 数据赋值 ──
|
||||||
|
new_orchestrator = NewExtractionOrchestrator(
|
||||||
llm_client=self._llm_client,
|
llm_client=self._llm_client,
|
||||||
embedder_client=self._embedder_client,
|
embedder_client=self._embedder_client,
|
||||||
connector=self._neo4j_connector,
|
|
||||||
config=pipeline_config,
|
config=pipeline_config,
|
||||||
embedding_id=str(self.memory_config.embedding_model_id),
|
embedding_id=str(self.memory_config.embedding_model_id),
|
||||||
language=self.language,
|
|
||||||
ontology_types=ontology_types,
|
ontology_types=ontology_types,
|
||||||
|
language=self.language,
|
||||||
|
is_pilot_run=is_pilot_run,
|
||||||
|
progress_callback=self.progress_callback,
|
||||||
|
)
|
||||||
|
# step1: 执行知识提取
|
||||||
|
dialog_data_list = await new_orchestrator.run(chunked_dialogs)
|
||||||
|
|
||||||
|
# ── Snapshot: 各阶段萃取结果 ── TODO 乐力齐 重构流水线切换生产环境稳定后修改
|
||||||
|
stage_outputs = new_orchestrator.last_stage_outputs
|
||||||
|
if stage_outputs:
|
||||||
|
stmt_results = stage_outputs.get("statement_results", {})
|
||||||
|
stmt_snapshot = []
|
||||||
|
for _did, chunk_stmts in stmt_results.items():
|
||||||
|
for _cid, stmts in chunk_stmts.items():
|
||||||
|
for s in stmts:
|
||||||
|
stmt_snapshot.append(s.model_dump())
|
||||||
|
snapshot.save_stage("2_statement_outputs", stmt_snapshot)
|
||||||
|
|
||||||
|
triplet_results = stage_outputs.get("triplet_results", {})
|
||||||
|
triplet_snapshot = {}
|
||||||
|
for _did, stmt_triplets in triplet_results.items():
|
||||||
|
for stmt_id, t_out in stmt_triplets.items():
|
||||||
|
triplet_snapshot[stmt_id] = t_out.model_dump()
|
||||||
|
snapshot.save_stage("3_triplet_outputs", triplet_snapshot)
|
||||||
|
|
||||||
|
emotion_results = stage_outputs.get("emotion_results", {})
|
||||||
|
emotion_snapshot = {}
|
||||||
|
for stmt_id, emo in emotion_results.items():
|
||||||
|
if hasattr(emo, "model_dump"):
|
||||||
|
emotion_snapshot[stmt_id] = emo.model_dump()
|
||||||
|
snapshot.save_stage("4_emotion_outputs", emotion_snapshot)
|
||||||
|
|
||||||
|
emb_output = stage_outputs.get("embedding_output")
|
||||||
|
if emb_output and hasattr(emb_output, "model_dump"):
|
||||||
|
emb_data = emb_output.model_dump()
|
||||||
|
for key in (
|
||||||
|
"statement_embeddings",
|
||||||
|
"chunk_embeddings",
|
||||||
|
"entity_embeddings",
|
||||||
|
):
|
||||||
|
if key in emb_data and isinstance(emb_data[key], dict):
|
||||||
|
emb_data[key] = {
|
||||||
|
k: v[:5] if isinstance(v, list) else v
|
||||||
|
for k, v in emb_data[key].items()
|
||||||
|
}
|
||||||
|
if "dialog_embeddings" in emb_data and isinstance(
|
||||||
|
emb_data["dialog_embeddings"], list
|
||||||
|
):
|
||||||
|
emb_data["dialog_embeddings"] = [
|
||||||
|
v[:5] if isinstance(v, list) else v
|
||||||
|
for v in emb_data["dialog_embeddings"]
|
||||||
|
]
|
||||||
|
snapshot.save_stage("5_embedding_outputs", emb_data)
|
||||||
|
|
||||||
|
# step2: 构建图节点和边
|
||||||
|
graph = await build_graph_nodes_and_edges(
|
||||||
|
dialog_data_list=dialog_data_list,
|
||||||
|
embedder_client=self._embedder_client,
|
||||||
progress_callback=self.progress_callback,
|
progress_callback=self.progress_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
# region Snapshot: 图节点和边(去重前)Snapshot有关的内容在重构流水线切换生产环境之后修改
|
||||||
dialogue_nodes,
|
snapshot.save_stage(
|
||||||
chunk_nodes,
|
"6_nodes_edges_before_dedup",
|
||||||
statement_nodes,
|
{
|
||||||
entity_nodes,
|
"dialogue_nodes_count": len(graph.dialogue_nodes),
|
||||||
perceptual_nodes,
|
"chunk_nodes_count": len(graph.chunk_nodes),
|
||||||
stmt_chunk_edges,
|
"statement_nodes_count": len(graph.statement_nodes),
|
||||||
stmt_entity_edges,
|
"entity_nodes": [
|
||||||
entity_entity_edges,
|
{
|
||||||
perceptual_edges,
|
"id": e.id,
|
||||||
dialog_data_list,
|
"name": e.name,
|
||||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=is_pilot_run)
|
"entity_type": e.entity_type,
|
||||||
|
"description": e.description,
|
||||||
|
}
|
||||||
|
for e in graph.entity_nodes
|
||||||
|
],
|
||||||
|
"entity_entity_edges": [
|
||||||
|
{
|
||||||
|
"source": e.source,
|
||||||
|
"target": e.target,
|
||||||
|
"relation_type": e.relation_type,
|
||||||
|
"statement": e.statement,
|
||||||
|
}
|
||||||
|
for e in graph.entity_entity_edges
|
||||||
|
],
|
||||||
|
"stmt_entity_edges_count": len(graph.stmt_entity_edges),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return ExtractionResult(
|
# step3: 两阶段去重消歧
|
||||||
dialogue_nodes=dialogue_nodes,
|
dedup_result = await run_dedup(
|
||||||
chunk_nodes=chunk_nodes,
|
entity_nodes=graph.entity_nodes,
|
||||||
statement_nodes=statement_nodes,
|
statement_entity_edges=graph.stmt_entity_edges,
|
||||||
entity_nodes=entity_nodes,
|
entity_entity_edges=graph.entity_entity_edges,
|
||||||
perceptual_nodes=perceptual_nodes,
|
dialog_data_list=dialog_data_list,
|
||||||
stmt_chunk_edges=stmt_chunk_edges,
|
pipeline_config=pipeline_config,
|
||||||
stmt_entity_edges=stmt_entity_edges,
|
connector=self._neo4j_connector,
|
||||||
entity_entity_edges=entity_entity_edges,
|
llm_client=self._llm_client,
|
||||||
perceptual_edges=perceptual_edges,
|
is_pilot_run=is_pilot_run,
|
||||||
|
progress_callback=self.progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Snapshot: 去重后
|
||||||
|
snapshot.save_stage(
|
||||||
|
"7_after_dedup",
|
||||||
|
{
|
||||||
|
"entity_nodes": [
|
||||||
|
{
|
||||||
|
"id": e.id,
|
||||||
|
"name": e.name,
|
||||||
|
"entity_type": e.entity_type,
|
||||||
|
"description": e.description,
|
||||||
|
}
|
||||||
|
for e in dedup_result.entity_nodes
|
||||||
|
],
|
||||||
|
"entity_entity_edges": [
|
||||||
|
{
|
||||||
|
"source": e.source,
|
||||||
|
"target": e.target,
|
||||||
|
"relation_type": e.relation_type,
|
||||||
|
"statement": e.statement,
|
||||||
|
}
|
||||||
|
for e in dedup_result.entity_entity_edges
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# step4: 构造最终结果
|
||||||
|
result = ExtractionResult(
|
||||||
|
dialogue_nodes=graph.dialogue_nodes,
|
||||||
|
chunk_nodes=graph.chunk_nodes,
|
||||||
|
statement_nodes=graph.statement_nodes,
|
||||||
|
entity_nodes=dedup_result.entity_nodes,
|
||||||
|
perceptual_nodes=graph.perceptual_nodes,
|
||||||
|
stmt_chunk_edges=graph.stmt_chunk_edges,
|
||||||
|
stmt_entity_edges=dedup_result.statement_entity_edges,
|
||||||
|
entity_entity_edges=dedup_result.entity_entity_edges,
|
||||||
|
perceptual_edges=graph.perceptual_edges,
|
||||||
dialog_data_list=dialog_data_list,
|
dialog_data_list=dialog_data_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
snapshot.save_summary(result.stats) # TODO 乐力齐 snapshot需要改
|
||||||
|
return result
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
# Step 3: 存储
|
# Step 3: 存储
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
@@ -379,14 +518,10 @@ class WritePipeline:
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(1 * (attempt + 1))
|
await asyncio.sleep(1 * (attempt + 1))
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败")
|
||||||
f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self._is_deadlock(e) and attempt < max_retries - 1:
|
if self._is_deadlock(e) and attempt < max_retries - 1:
|
||||||
logger.warning(
|
logger.warning(f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})")
|
||||||
f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(1 * (attempt + 1))
|
await asyncio.sleep(1 * (attempt + 1))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@@ -401,6 +536,10 @@ class WritePipeline:
|
|||||||
|
|
||||||
聚类不阻塞主写入流程,失败不影响写入结果。
|
聚类不阻塞主写入流程,失败不影响写入结果。
|
||||||
通过 Celery 异步执行,由 LabelPropagationEngine 完成实际计算。
|
通过 Celery 异步执行,由 LabelPropagationEngine 完成实际计算。
|
||||||
|
|
||||||
|
注意:ExtractionResult.entity_nodes 已经是经过 _extract() 中
|
||||||
|
两阶段去重消歧(_run_dedup_and_write_summary)后的结果,
|
||||||
|
聚类直接基于去重后的实体 ID 执行。
|
||||||
"""
|
"""
|
||||||
if not result.entity_nodes:
|
if not result.entity_nodes:
|
||||||
return
|
return
|
||||||
@@ -428,7 +567,9 @@ class WritePipeline:
|
|||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[Clustering] 增量聚类任务已提交 - "
|
f"[Clustering] 增量聚类任务已提交 - "
|
||||||
f"task_id={task.id}, entity_count={len(new_entity_ids)}"
|
f"task_id={task.id}, "
|
||||||
|
f"entity_count={len(new_entity_ids)}, "
|
||||||
|
f"source=dedup"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -438,9 +579,9 @@ class WritePipeline:
|
|||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
# Step 5: 摘要
|
# Step 5: 摘要
|
||||||
# (+ entity_description)
|
# (+ entity_description)+ meta_data部分在此提取
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
|
# TODO 乐力齐 需要做成异步celery任务
|
||||||
async def _summarize(self, chunked_dialogs: List[DialogData]) -> None:
|
async def _summarize(self, chunked_dialogs: List[DialogData]) -> None:
|
||||||
"""
|
"""
|
||||||
摘要:生成情景记忆摘要 → 写入 Neo4j。
|
摘要:生成情景记忆摘要 → 写入 Neo4j。
|
||||||
@@ -467,9 +608,7 @@ class WritePipeline:
|
|||||||
ms_connector = Neo4jConnector()
|
ms_connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
await add_memory_summary_nodes(summaries, ms_connector)
|
await add_memory_summary_nodes(summaries, ms_connector)
|
||||||
await add_memory_summary_statement_edges(
|
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||||
summaries, ms_connector
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await ms_connector.close()
|
await ms_connector.close()
|
||||||
@@ -494,9 +633,7 @@ class WritePipeline:
|
|||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
self._llm_client = factory.get_llm_client_from_config(
|
self._llm_client = factory.get_llm_client_from_config(self.memory_config)
|
||||||
self.memory_config
|
|
||||||
)
|
|
||||||
self._embedder_client = factory.get_embedder_client_from_config(
|
self._embedder_client = factory.get_embedder_client_from_config(
|
||||||
self.memory_config
|
self.memory_config
|
||||||
)
|
)
|
||||||
@@ -564,10 +701,8 @@ class WritePipeline:
|
|||||||
if entity_nodes:
|
if entity_nodes:
|
||||||
eu_id = entity_nodes[0].end_user_id
|
eu_id = entity_nodes[0].end_user_id
|
||||||
if eu_id:
|
if eu_id:
|
||||||
neo4j_assistant_aliases = (
|
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(
|
||||||
await fetch_neo4j_assistant_aliases(
|
self._neo4j_connector, eu_id
|
||||||
self._neo4j_connector, eu_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
clean_cross_role_aliases(
|
clean_cross_role_aliases(
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
@@ -586,9 +721,7 @@ class WritePipeline:
|
|||||||
msg = str(e).lower()
|
msg = str(e).lower()
|
||||||
return "deadlockdetected" in msg or "deadlock" in msg
|
return "deadlockdetected" in msg or "deadlock" in msg
|
||||||
|
|
||||||
async def _update_stats_cache(
|
async def _update_stats_cache(self, result: ExtractionResult) -> None:
|
||||||
self, result: ExtractionResult
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
将提取统计写入 Redis 活动缓存,按 workspace_id 存储。
|
将提取统计写入 Redis 活动缓存,按 workspace_id 存储。
|
||||||
失败不中断主流程。
|
失败不中断主流程。
|
||||||
@@ -614,9 +747,7 @@ class WritePipeline:
|
|||||||
f"workspace_id={self.memory_config.workspace_id}"
|
f"workspace_id={self.memory_config.workspace_id}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"写入活动统计缓存失败(不影响主流程): {e}")
|
||||||
f"写入活动统计缓存失败(不影响主流程): {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _cleanup(self) -> None:
|
async def _cleanup(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -634,16 +765,14 @@ class WritePipeline:
|
|||||||
# 防止 'RuntimeError: Event loop is closed' 在垃圾回收时触发
|
# 防止 'RuntimeError: Event loop is closed' 在垃圾回收时触发
|
||||||
for client_obj in (self._llm_client, self._embedder_client):
|
for client_obj in (self._llm_client, self._embedder_client):
|
||||||
try:
|
try:
|
||||||
underlying = getattr(
|
underlying = getattr(client_obj, "client", None) or getattr(
|
||||||
client_obj, "client", None
|
client_obj, "model", None
|
||||||
) or getattr(client_obj, "model", None)
|
)
|
||||||
if underlying is None:
|
if underlying is None:
|
||||||
continue
|
continue
|
||||||
inner = getattr(underlying, "_model", underlying)
|
inner = getattr(underlying, "_model", underlying)
|
||||||
http_client = getattr(inner, "async_client", None)
|
http_client = getattr(inner, "async_client", None)
|
||||||
if http_client is not None and hasattr(
|
if http_client is not None and hasattr(http_client, "aclose"):
|
||||||
http_client, "aclose"
|
|
||||||
):
|
|
||||||
await http_client.aclose()
|
await http_client.aclose()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -0,0 +1,506 @@
|
|||||||
|
"""Independent deduplication module for the extraction pipeline.
|
||||||
|
|
||||||
|
Extracts dedup logic from ExtractionOrchestrator into standalone functions
|
||||||
|
so the orchestrator stays thin and dedup can be tested/evolved independently.
|
||||||
|
|
||||||
|
The module exposes:
|
||||||
|
- ``DedupResult`` — structured output of the dedup process
|
||||||
|
- ``run_dedup()`` — async entry point called by WritePipeline
|
||||||
|
- Helper functions migrated from ExtractionOrchestrator:
|
||||||
|
``save_dedup_details``, ``analyze_entity_merges``,
|
||||||
|
``analyze_entity_disambiguation``, ``send_dedup_progress_callback``,
|
||||||
|
``parse_dedup_report``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from app.core.memory.models.graph_models import (
|
||||||
|
EntityEntityEdge,
|
||||||
|
ExtractedEntityNode,
|
||||||
|
StatementEntityEdge,
|
||||||
|
)
|
||||||
|
from app.core.memory.models.message_models import DialogData
|
||||||
|
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DedupResult dataclass (Requirement 10.2)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DedupResult:
|
||||||
|
"""Structured output of the two-stage entity deduplication process.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
entity_nodes: Deduplicated entity node list.
|
||||||
|
statement_entity_edges: Deduplicated statement-entity edges.
|
||||||
|
entity_entity_edges: Deduplicated entity-entity edges.
|
||||||
|
dedup_details: Raw detail dict returned by the first-layer dedup.
|
||||||
|
merge_records: Parsed merge records (exact / fuzzy / LLM).
|
||||||
|
disamb_records: Parsed disambiguation records.
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
|
statement_entity_edges: List[StatementEntityEdge]
|
||||||
|
entity_entity_edges: List[EntityEntityEdge]
|
||||||
|
dedup_details: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
merge_records: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
disamb_records: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> Dict[str, int]:
|
||||||
|
"""Summary statistics for the dedup run."""
|
||||||
|
return {
|
||||||
|
"entity_count": len(self.entity_nodes),
|
||||||
|
"merge_count": len(self.merge_records),
|
||||||
|
"disamb_count": len(self.disamb_records),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Migrated helpers (from ExtractionOrchestrator) — Requirement 10.4
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def save_dedup_details(
|
||||||
|
dedup_details: Dict[str, Any],
|
||||||
|
original_entities: List[ExtractedEntityNode],
|
||||||
|
final_entities: List[ExtractedEntityNode],
|
||||||
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, str]]:
|
||||||
|
"""Parse raw *dedup_details* into structured merge / disamb records.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(merge_records, disamb_records, id_redirect_map)
|
||||||
|
"""
|
||||||
|
merge_records: List[Dict[str, Any]] = []
|
||||||
|
disamb_records: List[Dict[str, Any]] = []
|
||||||
|
id_redirect_map: Dict[str, str] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
id_redirect_map = dedup_details.get("id_redirect", {})
|
||||||
|
|
||||||
|
# --- exact-match merges ---
|
||||||
|
exact_merge_map = dedup_details.get("exact_merge_map", {})
|
||||||
|
for _key, info in exact_merge_map.items():
|
||||||
|
merged_ids = info.get("merged_ids", set())
|
||||||
|
if merged_ids:
|
||||||
|
merge_records.append({
|
||||||
|
"type": "精确匹配",
|
||||||
|
"canonical_id": info.get("canonical_id"),
|
||||||
|
"entity_name": info.get("name"),
|
||||||
|
"entity_type": info.get("entity_type"),
|
||||||
|
"merged_count": len(merged_ids),
|
||||||
|
"merged_ids": list(merged_ids),
|
||||||
|
})
|
||||||
|
|
||||||
|
# --- fuzzy-match merges ---
|
||||||
|
for record in dedup_details.get("fuzzy_merge_records", []):
|
||||||
|
try:
|
||||||
|
match = re.search(
|
||||||
|
r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)",
|
||||||
|
record,
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
merge_records.append({
|
||||||
|
"type": "模糊匹配",
|
||||||
|
"canonical_id": match.group(1),
|
||||||
|
"entity_name": match.group(3),
|
||||||
|
"entity_type": match.group(4),
|
||||||
|
"merged_count": 1,
|
||||||
|
"merged_ids": [match.group(5)],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("解析模糊匹配记录失败: %s, 错误: %s", record, e)
|
||||||
|
|
||||||
|
# --- LLM-based merges ---
|
||||||
|
for record in dedup_details.get("llm_decision_records", []):
|
||||||
|
if "[LLM去重]" in str(record):
|
||||||
|
try:
|
||||||
|
match = re.search(
|
||||||
|
r"同名类型相似 ([^(]+)(([^)]+))\|([^(]+)(([^)]+))",
|
||||||
|
record,
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
merge_records.append({
|
||||||
|
"type": "LLM去重",
|
||||||
|
"entity_name": match.group(1),
|
||||||
|
"entity_type": f"{match.group(2)}|{match.group(4)}",
|
||||||
|
"merged_count": 1,
|
||||||
|
"merged_ids": [],
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("解析LLM去重记录失败: %s, 错误: %s", record, e)
|
||||||
|
|
||||||
|
# --- disambiguation records ---
|
||||||
|
for record in dedup_details.get("disamb_records", []):
|
||||||
|
if "[DISAMB阻断]" in str(record):
|
||||||
|
try:
|
||||||
|
content = str(record).replace("[DISAMB阻断]", "").strip()
|
||||||
|
match = re.search(
|
||||||
|
r"([^(]+)(([^)]+))\|([^(]+)(([^)]+))", content
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
entity1_name = match.group(1).strip()
|
||||||
|
entity1_type = match.group(2)
|
||||||
|
entity2_type = match.group(4)
|
||||||
|
|
||||||
|
conf_match = re.search(r"conf=([0-9.]+)", str(record))
|
||||||
|
confidence = conf_match.group(1) if conf_match else "unknown"
|
||||||
|
|
||||||
|
reason_match = re.search(r"reason=([^|]+)", str(record))
|
||||||
|
reason = reason_match.group(1).strip() if reason_match else ""
|
||||||
|
|
||||||
|
disamb_records.append({
|
||||||
|
"entity_name": entity1_name,
|
||||||
|
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
|
||||||
|
"confidence": confidence,
|
||||||
|
"reason": (reason[:100] + "...") if len(reason) > 100 else reason,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("解析消歧记录失败: %s, 错误: %s", record, e)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"保存去重消歧记录:%d 个合并记录,%d 个消歧记录",
|
||||||
|
len(merge_records),
|
||||||
|
len(disamb_records),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("保存去重消歧详情失败: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
return merge_records, disamb_records, id_redirect_map
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_entity_merges(
|
||||||
|
merge_records: List[Dict[str, Any]],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Return merge info sorted by merged_count (descending)."""
|
||||||
|
if not merge_records:
|
||||||
|
return []
|
||||||
|
sorted_records = sorted(
|
||||||
|
merge_records, key=lambda x: x.get("merged_count", 0), reverse=True
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"main_entity_name": r.get("entity_name", "未知实体"),
|
||||||
|
"merged_count": r.get("merged_count", 1),
|
||||||
|
}
|
||||||
|
for r in sorted_records
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_entity_disambiguation(
|
||||||
|
disamb_records: List[Dict[str, Any]],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Return disambiguation records (pass-through)."""
|
||||||
|
return disamb_records if disamb_records else []
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dedup_report(
|
||||||
|
merge_records: List[Dict[str, Any]],
|
||||||
|
disamb_records: List[Dict[str, Any]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build a summary report dict from parsed records."""
|
||||||
|
try:
|
||||||
|
dedup_examples: List[Dict[str, Any]] = []
|
||||||
|
disamb_examples: List[Dict[str, Any]] = []
|
||||||
|
total_merges = 0
|
||||||
|
total_disambiguations = 0
|
||||||
|
|
||||||
|
for record in merge_records:
|
||||||
|
merge_count = record.get("merged_count", 0)
|
||||||
|
total_merges += merge_count
|
||||||
|
dedup_examples.append({
|
||||||
|
"type": record.get("type", "未知"),
|
||||||
|
"entity_name": record.get("entity_name", "未知实体"),
|
||||||
|
"entity_type": record.get("entity_type", "未知类型"),
|
||||||
|
"merge_count": merge_count,
|
||||||
|
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个",
|
||||||
|
})
|
||||||
|
|
||||||
|
for record in disamb_records:
|
||||||
|
total_disambiguations += 1
|
||||||
|
disamb_type = record.get("disamb_type", "")
|
||||||
|
entity_name = record.get("entity_name", "未知实体")
|
||||||
|
disamb_examples.append({
|
||||||
|
"entity1_name": entity_name,
|
||||||
|
"entity1_type": (
|
||||||
|
disamb_type.split("vs")[0].replace("消歧阻断:", "").strip()
|
||||||
|
if "vs" in disamb_type
|
||||||
|
else "未知"
|
||||||
|
),
|
||||||
|
"entity2_name": entity_name,
|
||||||
|
"entity2_type": (
|
||||||
|
disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知"
|
||||||
|
),
|
||||||
|
"description": f"{entity_name},消歧区分成功",
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dedup_examples": dedup_examples[:5],
|
||||||
|
"disamb_examples": disamb_examples[:5],
|
||||||
|
"total_merges": total_merges,
|
||||||
|
"total_disambiguations": total_disambiguations,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("获取去重报告失败: %s", e, exc_info=True)
|
||||||
|
return {
|
||||||
|
"dedup_examples": [],
|
||||||
|
"disamb_examples": [],
|
||||||
|
"total_merges": 0,
|
||||||
|
"total_disambiguations": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def send_dedup_progress_callback(
|
||||||
|
progress_callback: Callable,
|
||||||
|
merge_records: List[Dict[str, Any]],
|
||||||
|
disamb_records: List[Dict[str, Any]],
|
||||||
|
original_entities: int,
|
||||||
|
final_entities: int,
|
||||||
|
original_stmt_edges: int,
|
||||||
|
final_stmt_edges: int,
|
||||||
|
original_ent_edges: int,
|
||||||
|
final_ent_edges: int,
|
||||||
|
) -> None:
|
||||||
|
"""Send dedup completion progress via *progress_callback*."""
|
||||||
|
try:
|
||||||
|
dedup_details = parse_dedup_report(merge_records, disamb_records)
|
||||||
|
|
||||||
|
entities_reduced = original_entities - final_entities
|
||||||
|
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
|
||||||
|
ent_edges_reduced = original_ent_edges - final_ent_edges
|
||||||
|
|
||||||
|
dedup_stats = {
|
||||||
|
"entities": {
|
||||||
|
"original_count": original_entities,
|
||||||
|
"final_count": final_entities,
|
||||||
|
"reduced_count": entities_reduced,
|
||||||
|
"reduction_rate": (
|
||||||
|
round(entities_reduced / original_entities * 100, 1)
|
||||||
|
if original_entities > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"statement_entity_edges": {
|
||||||
|
"original_count": original_stmt_edges,
|
||||||
|
"final_count": final_stmt_edges,
|
||||||
|
"reduced_count": stmt_edges_reduced,
|
||||||
|
},
|
||||||
|
"entity_entity_edges": {
|
||||||
|
"original_count": original_ent_edges,
|
||||||
|
"final_count": final_ent_edges,
|
||||||
|
"reduced_count": ent_edges_reduced,
|
||||||
|
},
|
||||||
|
"dedup_examples": dedup_details.get("dedup_examples", []),
|
||||||
|
"disamb_examples": dedup_details.get("disamb_examples", []),
|
||||||
|
"summary": {
|
||||||
|
"total_merges": dedup_details.get("total_merges", 0),
|
||||||
|
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("发送去重消歧进度回调失败: %s", e, exc_info=True)
|
||||||
|
try:
|
||||||
|
basic_stats = {
|
||||||
|
"entities": {
|
||||||
|
"original_count": original_entities,
|
||||||
|
"final_count": final_entities,
|
||||||
|
"reduced_count": original_entities - final_entities,
|
||||||
|
},
|
||||||
|
"summary": f"实体去重合并{original_entities - final_entities}个",
|
||||||
|
}
|
||||||
|
await progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
|
||||||
|
except Exception as e2:
|
||||||
|
logger.error("发送基本去重统计失败: %s", e2, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_dedup — main entry point (Requirements 10.1, 10.3)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def run_dedup(
|
||||||
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
pipeline_config: ExtractionPipelineConfig,
|
||||||
|
connector: Optional[Neo4jConnector] = None,
|
||||||
|
llm_client: Optional[Any] = None,
|
||||||
|
is_pilot_run: bool = False,
|
||||||
|
progress_callback: Optional[Callable] = None,
|
||||||
|
) -> DedupResult:
|
||||||
|
"""Two-stage entity deduplication and disambiguation.
|
||||||
|
|
||||||
|
Full mode:
|
||||||
|
Layer 1 — exact / fuzzy / LLM matching
|
||||||
|
Layer 2 — Neo4j joint dedup + cross-role alias cleaning
|
||||||
|
|
||||||
|
Pilot-run mode:
|
||||||
|
Layer 1 only (skip Neo4j layer 2 and alias cleaning).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_nodes: Pre-dedup entity nodes.
|
||||||
|
statement_entity_edges: Pre-dedup statement-entity edges.
|
||||||
|
entity_entity_edges: Pre-dedup entity-entity edges.
|
||||||
|
dialog_data_list: Source dialogue data (used to detect end_user_id).
|
||||||
|
pipeline_config: Pipeline configuration (contains DedupConfig).
|
||||||
|
connector: Optional Neo4j connector for layer-2 dedup.
|
||||||
|
llm_client: Optional LLM client for LLM-based dedup decisions.
|
||||||
|
is_pilot_run: When True, only execute layer-1 dedup.
|
||||||
|
progress_callback: Optional async callable for progress reporting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``DedupResult`` with deduplicated nodes, edges, and statistics.
|
||||||
|
"""
|
||||||
|
logger.info("开始两阶段实体去重和消歧")
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("deduplication", "正在去重消歧...")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"去重前: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边",
|
||||||
|
len(entity_nodes),
|
||||||
|
len(statement_entity_edges),
|
||||||
|
len(entity_entity_edges),
|
||||||
|
)
|
||||||
|
|
||||||
|
original_entity_count = len(entity_nodes)
|
||||||
|
original_stmt_edge_count = len(statement_entity_edges)
|
||||||
|
original_ent_edge_count = len(entity_entity_edges)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_pilot_run:
|
||||||
|
# --- pilot run: layer 1 only ---
|
||||||
|
logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重")
|
||||||
|
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||||
|
deduplicate_entities_and_edges,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
dedup_entity_nodes,
|
||||||
|
dedup_stmt_edges,
|
||||||
|
dedup_ent_edges,
|
||||||
|
raw_details,
|
||||||
|
) = await deduplicate_entities_and_edges(
|
||||||
|
entity_nodes,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_entity_edges,
|
||||||
|
report_stage="第一层去重消歧(试运行)",
|
||||||
|
report_append=False,
|
||||||
|
dedup_config=pipeline_config.deduplication,
|
||||||
|
llm_client=llm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_entities = dedup_entity_nodes
|
||||||
|
final_stmt_edges = dedup_stmt_edges
|
||||||
|
final_ent_edges = dedup_ent_edges
|
||||||
|
else:
|
||||||
|
# --- full mode: two-stage dedup ---
|
||||||
|
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||||
|
dedup_layers_and_merge_and_return,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
_dialogue_nodes,
|
||||||
|
_chunk_nodes,
|
||||||
|
_statement_nodes,
|
||||||
|
final_entities,
|
||||||
|
_statement_chunk_edges,
|
||||||
|
final_stmt_edges,
|
||||||
|
final_ent_edges,
|
||||||
|
raw_details,
|
||||||
|
) = await dedup_layers_and_merge_and_return(
|
||||||
|
dialogue_nodes=[],
|
||||||
|
chunk_nodes=[],
|
||||||
|
statement_nodes=[],
|
||||||
|
entity_nodes=entity_nodes,
|
||||||
|
statement_chunk_edges=[],
|
||||||
|
statement_entity_edges=statement_entity_edges,
|
||||||
|
entity_entity_edges=entity_entity_edges,
|
||||||
|
dialog_data_list=dialog_data_list,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
connector=connector,
|
||||||
|
llm_client=llm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse raw details into structured records
|
||||||
|
merge_records, disamb_records, _id_redirect = save_dedup_details(
|
||||||
|
raw_details, entity_nodes, final_entities
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"去重后: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边",
|
||||||
|
len(final_entities),
|
||||||
|
len(final_stmt_edges),
|
||||||
|
len(final_ent_edges),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"去重效果: 实体减少 %d, 陈述句-实体边减少 %d, 实体-实体边减少 %d",
|
||||||
|
original_entity_count - len(final_entities),
|
||||||
|
original_stmt_edge_count - len(final_stmt_edges),
|
||||||
|
original_ent_edge_count - len(final_ent_edges),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- progress callbacks ---
|
||||||
|
if progress_callback:
|
||||||
|
merge_info = analyze_entity_merges(merge_records)
|
||||||
|
for i, detail in enumerate(merge_info[:5]):
|
||||||
|
dedup_result = {
|
||||||
|
"result_type": "entity_merge",
|
||||||
|
"merged_entity_name": detail["main_entity_name"],
|
||||||
|
"merged_count": detail["merged_count"],
|
||||||
|
"merge_progress": f"{i + 1}/{min(len(merge_info), 5)}",
|
||||||
|
"message": (
|
||||||
|
f"{detail['main_entity_name']}合并{detail['merged_count']}个:相似实体已合并"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
await progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result)
|
||||||
|
|
||||||
|
disamb_info = analyze_entity_disambiguation(disamb_records)
|
||||||
|
for i, detail in enumerate(disamb_info[:5]):
|
||||||
|
disamb_result = {
|
||||||
|
"result_type": "entity_disambiguation",
|
||||||
|
"disambiguated_entity_name": detail["entity_name"],
|
||||||
|
"disambiguation_type": detail["disamb_type"],
|
||||||
|
"confidence": detail.get("confidence", "unknown"),
|
||||||
|
"reason": detail.get("reason", ""),
|
||||||
|
"disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}",
|
||||||
|
"message": f"{detail['entity_name']}消歧完成:{detail['disamb_type']}",
|
||||||
|
}
|
||||||
|
await progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result)
|
||||||
|
|
||||||
|
await send_dedup_progress_callback(
|
||||||
|
progress_callback,
|
||||||
|
merge_records,
|
||||||
|
disamb_records,
|
||||||
|
original_entity_count,
|
||||||
|
len(final_entities),
|
||||||
|
original_stmt_edge_count,
|
||||||
|
len(final_stmt_edges),
|
||||||
|
original_ent_edge_count,
|
||||||
|
len(final_ent_edges),
|
||||||
|
)
|
||||||
|
|
||||||
|
return DedupResult(
|
||||||
|
entity_nodes=final_entities,
|
||||||
|
statement_entity_edges=final_stmt_edges,
|
||||||
|
entity_entity_edges=final_ent_edges,
|
||||||
|
dedup_details=raw_details,
|
||||||
|
merge_records=merge_records,
|
||||||
|
disamb_records=disamb_records,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("两阶段去重失败: %s", e, exc_info=True)
|
||||||
|
raise
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
"""Extraction pipeline steps — unified ExtractionStep paradigm.
|
||||||
|
|
||||||
|
Importing this package triggers @register decorator self-registration
|
||||||
|
for all sidecar (non-critical) steps via SidecarStepFactory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401
|
||||||
|
|
||||||
|
# Step implementations — importing triggers @register self-registration.
|
||||||
|
from .statement_step import StatementExtractionStep # noqa: F401
|
||||||
|
from .triplet_step import TripletExtractionStep # noqa: F401
|
||||||
|
from .emotion_step import EmotionExtractionStep # noqa: F401
|
||||||
|
from .embedding_step import EmbeddingStep # noqa: F401
|
||||||
|
|
||||||
|
# Refactored orchestrator
|
||||||
|
from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401
|
||||||
@@ -0,0 +1,182 @@
|
|||||||
|
"""ExtractionStep abstract base class and StepContext.
|
||||||
|
|
||||||
|
Provides the unified paradigm for all LLM extraction stages:
|
||||||
|
render_prompt → call_llm → parse_response → post_process
|
||||||
|
|
||||||
|
Critical steps retry on failure with exponential backoff.
|
||||||
|
Sidecar (non-critical) steps return a default output on failure without retry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
InputT = TypeVar("InputT")
|
||||||
|
OutputT = TypeVar("OutputT")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StepContext:
|
||||||
|
"""Shared context injected into every ExtractionStep by the orchestrator.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
llm_client: LLM client instance for generating completions.
|
||||||
|
language: Target language code (e.g. "en", "zh").
|
||||||
|
config: Pipeline configuration object (ExtractionPipelineConfig).
|
||||||
|
is_pilot_run: When True, run in lightweight preview mode.
|
||||||
|
progress_callback: Optional callable for reporting progress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_client: Any
|
||||||
|
language: str
|
||||||
|
config: Any
|
||||||
|
is_pilot_run: bool = False
|
||||||
|
progress_callback: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionStep(ABC, Generic[InputT, OutputT]):
|
||||||
|
"""Abstract base class for all LLM extraction stages.
|
||||||
|
|
||||||
|
Lifecycle:
|
||||||
|
1. ``__init__(context)`` — receive shared context, bind config params
|
||||||
|
2. ``should_skip()`` — check whether to skip (config-driven / pilot mode)
|
||||||
|
3. ``run(input_data)`` — execute full flow (with retry for critical steps)
|
||||||
|
Internally: render_prompt → call_llm → parse_response → post_process
|
||||||
|
4. ``on_failure(error)`` — critical steps raise; sidecar steps return default
|
||||||
|
|
||||||
|
Type Parameters:
|
||||||
|
InputT: The Pydantic model type accepted by this step.
|
||||||
|
OutputT: The Pydantic model type produced by this step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context: StepContext) -> None:
|
||||||
|
self.context = context
|
||||||
|
self.llm_client = context.llm_client
|
||||||
|
self.language = context.language
|
||||||
|
self.config = context.config
|
||||||
|
|
||||||
|
# ── Subclasses must implement ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Human-readable step name for logging."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def render_prompt(self, input_data: InputT) -> Any:
|
||||||
|
"""Build the prompt from *input_data* and bound config."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def call_llm(self, prompt: Any) -> Any:
|
||||||
|
"""Send *prompt* to the LLM and return the raw response."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def parse_response(self, raw_response: Any, input_data: InputT) -> OutputT:
|
||||||
|
"""Parse *raw_response* into a typed OutputT (Pydantic model)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_default_output(self) -> OutputT:
|
||||||
|
"""Return a safe default when the step is skipped or fails gracefully."""
|
||||||
|
...
|
||||||
|
|
||||||
|
# ── Overridable properties ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical(self) -> bool:
|
||||||
|
"""``True`` = critical step (failure aborts pipeline).
|
||||||
|
|
||||||
|
``False`` = sidecar step (failure degrades gracefully).
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_retries(self) -> int:
|
||||||
|
"""Maximum retry attempts (only effective for critical steps)."""
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retry_backoff_base(self) -> float:
|
||||||
|
"""Backoff base in seconds. Actual wait = base × 2^attempt."""
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
# ── Overridable hooks ──
|
||||||
|
|
||||||
|
def should_skip(self) -> bool:
|
||||||
|
"""Config-driven skip check. Subclasses may override."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def post_process(self, parsed_data: OutputT, input_data: InputT) -> OutputT:
|
||||||
|
"""Post-processing hook. Default is identity (returns *parsed_data* unchanged)."""
|
||||||
|
return parsed_data
|
||||||
|
|
||||||
|
# ── Core execution logic ──
|
||||||
|
|
||||||
|
async def run(self, input_data: InputT) -> OutputT:
|
||||||
|
"""Execute the full step lifecycle with retry logic.
|
||||||
|
|
||||||
|
For critical steps (``is_critical=True``):
|
||||||
|
Attempt up to ``max_retries + 1`` times with exponential backoff.
|
||||||
|
If all attempts fail, delegate to ``on_failure`` which raises.
|
||||||
|
|
||||||
|
For sidecar steps (``is_critical=False``):
|
||||||
|
Attempt exactly once. On failure, delegate to ``on_failure``
|
||||||
|
which returns ``get_default_output()``.
|
||||||
|
"""
|
||||||
|
if self.should_skip():
|
||||||
|
logger.info("Step '%s' skipped", self.name)
|
||||||
|
return self.get_default_output()
|
||||||
|
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
attempts = self.max_retries + 1 if self.is_critical else 1
|
||||||
|
|
||||||
|
for attempt in range(attempts):
|
||||||
|
try:
|
||||||
|
prompt = await self.render_prompt(input_data)
|
||||||
|
raw_response = await self.call_llm(prompt)
|
||||||
|
parsed = await self.parse_response(raw_response, input_data)
|
||||||
|
result = await self.post_process(parsed, input_data)
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
last_error = exc
|
||||||
|
logger.warning(
|
||||||
|
"Step '%s' attempt %d/%d failed: %s",
|
||||||
|
self.name,
|
||||||
|
attempt + 1,
|
||||||
|
attempts,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
if attempt < attempts - 1:
|
||||||
|
wait = self.retry_backoff_base * (2 ** attempt)
|
||||||
|
logger.info(
|
||||||
|
"Step '%s' retrying in %.1fs …", self.name, wait
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
|
# All attempts exhausted — delegate to failure handler
|
||||||
|
return self.on_failure(last_error) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def on_failure(self, error: Exception) -> OutputT:
|
||||||
|
"""Handle step failure.
|
||||||
|
|
||||||
|
Critical steps: re-raise the exception to abort the pipeline.
|
||||||
|
Sidecar steps: return ``get_default_output()`` for graceful degradation.
|
||||||
|
"""
|
||||||
|
if self.is_critical:
|
||||||
|
logger.error(
|
||||||
|
"Critical step '%s' failed after retries: %s", self.name, error
|
||||||
|
)
|
||||||
|
raise error
|
||||||
|
logger.warning(
|
||||||
|
"Sidecar step '%s' failed, returning default output: %s",
|
||||||
|
self.name,
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
return self.get_default_output()
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
"""EmbeddingStep — generates vector embeddings for statements, chunks, dialogs, and entities.
|
||||||
|
|
||||||
|
Unlike the LLM-based ExtractionSteps, EmbeddingStep calls an embedder client
|
||||||
|
rather than an LLM. It still follows the ``should_skip`` / ``run`` /
|
||||||
|
``get_default_output`` contract so the orchestrator can treat it uniformly.
|
||||||
|
|
||||||
|
Supports **partial** embedding runs — the caller can populate only the fields
|
||||||
|
it needs (e.g. only ``statement_texts``) and leave the rest empty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .schema import EmbeddingStepInput, EmbeddingStepOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingStep:
|
||||||
|
"""Generate vector embeddings for text inputs.
|
||||||
|
|
||||||
|
This step does **not** inherit from ``ExtractionStep`` because it does not
|
||||||
|
follow the render_prompt → call_llm → parse_response lifecycle. It does,
|
||||||
|
however, expose the same ``run`` / ``should_skip`` / ``get_default_output``
|
||||||
|
interface so the orchestrator can use it interchangeably.
|
||||||
|
|
||||||
|
Pilot-run mode skips execution entirely and returns empty dicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedder_client: Any,
|
||||||
|
is_pilot_run: bool = False,
|
||||||
|
batch_size: int = 100,
|
||||||
|
) -> None:
|
||||||
|
self.embedder_client = embedder_client
|
||||||
|
self.is_pilot_run = is_pilot_run
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "embedding_generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_retries(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retry_backoff_base(self) -> float:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def should_skip(self) -> bool:
|
||||||
|
return self.is_pilot_run
|
||||||
|
|
||||||
|
def get_default_output(self) -> EmbeddingStepOutput:
|
||||||
|
return EmbeddingStepOutput()
|
||||||
|
|
||||||
|
# ── Core execution ──
|
||||||
|
|
||||||
|
async def run(self, input_data: EmbeddingStepInput) -> EmbeddingStepOutput:
|
||||||
|
"""Generate embeddings for all non-empty text fields in *input_data*."""
|
||||||
|
if self.should_skip():
|
||||||
|
logger.info("EmbeddingStep skipped (pilot run)")
|
||||||
|
return self.get_default_output()
|
||||||
|
|
||||||
|
try:
|
||||||
|
stmt_emb, chunk_emb, dialog_emb, entity_emb = await asyncio.gather(
|
||||||
|
self._embed_dict(input_data.statement_texts),
|
||||||
|
self._embed_dict(input_data.chunk_texts),
|
||||||
|
self._embed_list(input_data.dialog_texts),
|
||||||
|
self._embed_dict(input_data.entity_names),
|
||||||
|
)
|
||||||
|
return EmbeddingStepOutput(
|
||||||
|
statement_embeddings=stmt_emb,
|
||||||
|
chunk_embeddings=chunk_emb,
|
||||||
|
dialog_embeddings=dialog_emb,
|
||||||
|
entity_embeddings=entity_emb,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("EmbeddingStep failed, returning empty output: %s", exc)
|
||||||
|
return self.get_default_output()
|
||||||
|
|
||||||
|
# ── Internal helpers ──
|
||||||
|
|
||||||
|
async def _embed_dict(
|
||||||
|
self, texts: Dict[str, str]
|
||||||
|
) -> Dict[str, List[float]]:
|
||||||
|
"""Embed a dict of ``{id: text}`` and return ``{id: embedding}``."""
|
||||||
|
if not texts:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
ids = list(texts.keys())
|
||||||
|
text_list = list(texts.values())
|
||||||
|
embeddings = await self._batch_embed(text_list)
|
||||||
|
|
||||||
|
return dict(zip(ids, embeddings))
|
||||||
|
|
||||||
|
async def _embed_list(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Embed a plain list of texts."""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
return await self._batch_embed(texts)
|
||||||
|
|
||||||
|
async def _batch_embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Call the embedder in batches of ``self.batch_size``."""
|
||||||
|
if len(texts) <= self.batch_size:
|
||||||
|
return await self.embedder_client.response(texts)
|
||||||
|
|
||||||
|
batches = [
|
||||||
|
texts[i : i + self.batch_size]
|
||||||
|
for i in range(0, len(texts), self.batch_size)
|
||||||
|
]
|
||||||
|
batch_results = await asyncio.gather(
|
||||||
|
*(self.embedder_client.response(b) for b in batches)
|
||||||
|
)
|
||||||
|
embeddings: List[List[float]] = []
|
||||||
|
for result in batch_results:
|
||||||
|
embeddings.extend(result)
|
||||||
|
return embeddings
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
"""EmotionExtractionStep — sidecar step for extracting emotion from statements.
|
||||||
|
|
||||||
|
Replaces the legacy ``EmotionExtractionService`` with the unified ExtractionStep
|
||||||
|
paradigm. Registered via ``@SidecarStepFactory.register`` so the orchestrator
|
||||||
|
picks it up automatically when ``emotion_enabled`` is ``True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||||
|
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
|
||||||
|
|
||||||
|
from .base import ExtractionStep, StepContext
|
||||||
|
from .sidecar_factory import SidecarStepFactory, SidecarTiming
|
||||||
|
from .schema import EmotionStepInput, EmotionStepOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@SidecarStepFactory.register("emotion_enabled", SidecarTiming.AFTER_STATEMENT)
|
||||||
|
class EmotionExtractionStep(ExtractionStep[EmotionStepInput, EmotionStepOutput]):
|
||||||
|
"""Extract emotion type, intensity, and keywords from a statement.
|
||||||
|
|
||||||
|
This is a **sidecar** (non-critical) step — failure returns a neutral
|
||||||
|
default without aborting the pipeline.
|
||||||
|
|
||||||
|
The step self-registers with ``SidecarStepFactory`` under the config key
|
||||||
|
``emotion_enabled`` and timing ``AFTER_STATEMENT``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context: StepContext) -> None:
|
||||||
|
super().__init__(context)
|
||||||
|
# Emotion-specific config flags (may live on a MemoryConfig object
|
||||||
|
# attached to context.config or as top-level attributes).
|
||||||
|
self.extract_keywords = getattr(self.config, "emotion_extract_keywords", True)
|
||||||
|
self.enable_subject = getattr(self.config, "emotion_enable_subject", False)
|
||||||
|
|
||||||
|
# ── Identity ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "emotion_extraction"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ── Config-driven skip ──
|
||||||
|
|
||||||
|
def should_skip(self) -> bool:
|
||||||
|
return not getattr(self.config, "emotion_enabled", False)
|
||||||
|
|
||||||
|
# ── Lifecycle ──
|
||||||
|
|
||||||
|
async def render_prompt(self, input_data: EmotionStepInput) -> str:
|
||||||
|
return await render_emotion_extraction_prompt(
|
||||||
|
statement=input_data.statement_text,
|
||||||
|
extract_keywords=self.extract_keywords,
|
||||||
|
enable_subject=self.enable_subject,
|
||||||
|
language=self.language,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call_llm(self, prompt: Any) -> Any:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return await self.llm_client.response_structured(
|
||||||
|
messages, EmotionExtraction
|
||||||
|
)
|
||||||
|
|
||||||
|
async def parse_response(
|
||||||
|
self, raw_response: Any, input_data: EmotionStepInput
|
||||||
|
) -> EmotionStepOutput:
|
||||||
|
return EmotionStepOutput(
|
||||||
|
emotion_type=getattr(raw_response, "emotion_type", "neutral"),
|
||||||
|
emotion_intensity=getattr(raw_response, "emotion_intensity", 0.0),
|
||||||
|
emotion_keywords=getattr(raw_response, "emotion_keywords", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_output(self) -> EmotionStepOutput:
|
||||||
|
return EmotionStepOutput()
|
||||||
@@ -0,0 +1,906 @@
|
|||||||
|
"""Refactored ExtractionOrchestrator using the unified ExtractionStep paradigm.
|
||||||
|
|
||||||
|
This module provides ``NewExtractionOrchestrator`` — a slimmed-down orchestrator
|
||||||
|
(~500 lines vs ~2500) that delegates extraction work to concrete ExtractionStep
|
||||||
|
instances and uses SidecarStepFactory for hot-pluggable sidecar modules.
|
||||||
|
|
||||||
|
The new orchestrator coexists with the legacy ``ExtractionOrchestrator`` until
|
||||||
|
the team explicitly switches over.
|
||||||
|
|
||||||
|
Execution phases:
|
||||||
|
1. Statement extraction + concurrent chunk/dialog embedding
|
||||||
|
2. Triplet extraction + concurrent after_statement sidecars + statement embedding
|
||||||
|
3. Entity embedding + concurrent after_triplet sidecars
|
||||||
|
4. Data assignment back to dialog_data_list
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from app.core.memory.models.message_models import DialogData
|
||||||
|
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||||
|
|
||||||
|
from .base import ExtractionStep, StepContext
|
||||||
|
from .embedding_step import EmbeddingStep
|
||||||
|
from .sidecar_factory import SidecarStepFactory, SidecarTiming
|
||||||
|
from .statement_step import StatementExtractionStep
|
||||||
|
from .triplet_step import TripletExtractionStep
|
||||||
|
from .schema import (
|
||||||
|
EmbeddingStepInput,
|
||||||
|
EmbeddingStepOutput,
|
||||||
|
EmotionStepInput,
|
||||||
|
EmotionStepOutput,
|
||||||
|
MessageItem,
|
||||||
|
StatementStepInput,
|
||||||
|
StatementStepOutput,
|
||||||
|
SupportingContext,
|
||||||
|
TripletStepInput,
|
||||||
|
TripletStepOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NewExtractionOrchestrator:
|
||||||
|
"""Slimmed-down extraction orchestrator using the ExtractionStep paradigm.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
* Initialise all steps and sidecar groups via ``SidecarStepFactory``
|
||||||
|
* Route data between stages (``_convert_to_*`` helpers)
|
||||||
|
* Orchestrate concurrent execution (``_run_with_sidecars``)
|
||||||
|
* Assign extracted results back to ``DialogData`` objects
|
||||||
|
|
||||||
|
The orchestrator does **not** own dedup, node/edge creation, or Neo4j writes.
|
||||||
|
Those remain in ``WritePipeline`` / ``dedup_step``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_client: Any,
|
||||||
|
embedder_client: Any,
|
||||||
|
config: Optional[ExtractionPipelineConfig] = None,
|
||||||
|
embedding_id: Optional[str] = None,
|
||||||
|
ontology_types: Any = None,
|
||||||
|
language: str = "zh",
|
||||||
|
is_pilot_run: bool = False,
|
||||||
|
progress_callback: Optional[
|
||||||
|
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
self.config = config or ExtractionPipelineConfig()
|
||||||
|
self.is_pilot_run = is_pilot_run
|
||||||
|
self.embedding_id = embedding_id
|
||||||
|
self.progress_callback = progress_callback
|
||||||
|
|
||||||
|
# Build shared context for all LLM-based steps
|
||||||
|
self.context = StepContext(
|
||||||
|
llm_client=llm_client,
|
||||||
|
language=language,
|
||||||
|
config=self.config,
|
||||||
|
is_pilot_run=is_pilot_run,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Critical (main-line) steps ──
|
||||||
|
self.statement_step = StatementExtractionStep(self.context)
|
||||||
|
self.triplet_step = TripletExtractionStep(
|
||||||
|
self.context, ontology_types=ontology_types
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Embedding step (non-LLM, separate client) ──
|
||||||
|
self.embedding_step = EmbeddingStep(
|
||||||
|
embedder_client=embedder_client,
|
||||||
|
is_pilot_run=is_pilot_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Sidecar steps (auto-discovered via @register decorator) ──
|
||||||
|
sidecar_groups = SidecarStepFactory.create_sidecars(self.config, self.context)
|
||||||
|
self.after_statement_sidecars: List[ExtractionStep] = sidecar_groups[
|
||||||
|
SidecarTiming.AFTER_STATEMENT
|
||||||
|
]
|
||||||
|
self.after_triplet_sidecars: List[ExtractionStep] = sidecar_groups[
|
||||||
|
SidecarTiming.AFTER_TRIPLET
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"NewExtractionOrchestrator initialised — "
|
||||||
|
"after_statement sidecars: %d, after_triplet sidecars: %d",
|
||||||
|
len(self.after_statement_sidecars),
|
||||||
|
len(self.after_triplet_sidecars),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# 1. 并发执行引擎
|
||||||
|
# 负责主线路 + 旁路的安全并发调度
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _run_sidecar_safe(
|
||||||
|
step: ExtractionStep, input_data: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Run a sidecar step, returning its default output on failure."""
|
||||||
|
try:
|
||||||
|
return await step.run(input_data)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Sidecar '%s' raised during gather — using default output: %s",
|
||||||
|
step.name,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return step.get_default_output()
|
||||||
|
|
||||||
|
async def _run_with_sidecars(
|
||||||
|
self,
|
||||||
|
critical_coro: Any,
|
||||||
|
sidecars: List[Tuple[ExtractionStep, Any]],
|
||||||
|
extra_coros: Optional[List[Any]] = None,
|
||||||
|
) -> Tuple[Any, List[Any], List[Any]]:
|
||||||
|
"""Run a critical coroutine concurrently with sidecar steps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
critical_coro: The awaitable for the critical (main-line) step.
|
||||||
|
sidecars: List of ``(step, input_data)`` pairs for sidecar steps.
|
||||||
|
extra_coros: Additional non-sidecar coroutines to run concurrently
|
||||||
|
(e.g. embedding generation).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 3-tuple of:
|
||||||
|
* The critical step result (exception propagated if it fails).
|
||||||
|
* A list of sidecar results (default outputs on failure).
|
||||||
|
* A list of extra coroutine results (empty list if none).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the critical coroutine fails, the exception propagates.
|
||||||
|
"""
|
||||||
|
sidecar_coros = [
|
||||||
|
self._run_sidecar_safe(step, inp) for step, inp in sidecars
|
||||||
|
]
|
||||||
|
extra = extra_coros or []
|
||||||
|
|
||||||
|
# Gather everything concurrently
|
||||||
|
all_coros = [critical_coro] + sidecar_coros + extra
|
||||||
|
results = await asyncio.gather(*all_coros, return_exceptions=True)
|
||||||
|
|
||||||
|
# Unpack: first result is critical, then sidecars, then extras
|
||||||
|
critical_result = results[0]
|
||||||
|
n_sidecars = len(sidecar_coros)
|
||||||
|
sidecar_results = list(results[1 : 1 + n_sidecars])
|
||||||
|
extra_results = list(results[1 + n_sidecars :])
|
||||||
|
|
||||||
|
# Critical step failure → propagate
|
||||||
|
if isinstance(critical_result, BaseException):
|
||||||
|
raise critical_result
|
||||||
|
|
||||||
|
# Sidecar failures should already be handled by _run_sidecar_safe,
|
||||||
|
# but guard against unexpected exceptions from gather
|
||||||
|
for i, res in enumerate(sidecar_results):
|
||||||
|
if isinstance(res, BaseException):
|
||||||
|
step = sidecars[i][0]
|
||||||
|
logger.warning(
|
||||||
|
"Sidecar '%s' unexpected exception in gather: %s",
|
||||||
|
step.name,
|
||||||
|
res,
|
||||||
|
)
|
||||||
|
sidecar_results[i] = step.get_default_output()
|
||||||
|
|
||||||
|
# Extra coroutine failures → log and replace with None
|
||||||
|
for i, res in enumerate(extra_results):
|
||||||
|
if isinstance(res, BaseException):
|
||||||
|
logger.warning("Extra coroutine %d failed: %s", i, res)
|
||||||
|
extra_results[i] = None
|
||||||
|
|
||||||
|
return critical_result, sidecar_results, extra_results
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# 2. 阶段间数据转换
|
||||||
|
# 将上一阶段的 StepOutput 转换为下一阶段的 StepInput
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_supporting_context(
|
||||||
|
dialog: DialogData,
|
||||||
|
) -> SupportingContext:
|
||||||
|
"""Build a SupportingContext from a dialog's content for pronoun resolution."""
|
||||||
|
msgs: List[MessageItem] = []
|
||||||
|
if hasattr(dialog, "content") and dialog.content:
|
||||||
|
# dialog.content is the raw conversation string; wrap as single msg
|
||||||
|
msgs.append(MessageItem(role="context", msg=dialog.content))
|
||||||
|
return SupportingContext(msgs=msgs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_triplet_input(
|
||||||
|
stmt_out: StatementStepOutput,
|
||||||
|
supporting_context: SupportingContext,
|
||||||
|
) -> TripletStepInput:
|
||||||
|
"""Convert a StatementStepOutput into a TripletStepInput."""
|
||||||
|
return TripletStepInput(
|
||||||
|
statement_id=stmt_out.statement_id,
|
||||||
|
statement_text=stmt_out.statement_text,
|
||||||
|
statement_type=stmt_out.statement_type,
|
||||||
|
temporal_type=stmt_out.temporal_type,
|
||||||
|
supporting_context=supporting_context,
|
||||||
|
speaker=stmt_out.speaker,
|
||||||
|
valid_at=stmt_out.valid_at,
|
||||||
|
invalid_at=stmt_out.invalid_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_emotion_input(
|
||||||
|
stmt_out: StatementStepOutput,
|
||||||
|
) -> EmotionStepInput:
|
||||||
|
"""Convert a StatementStepOutput into an EmotionStepInput."""
|
||||||
|
return EmotionStepInput(
|
||||||
|
statement_id=stmt_out.statement_id,
|
||||||
|
statement_text=stmt_out.statement_text,
|
||||||
|
speaker=stmt_out.speaker,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# 3. 流水线执行入口
|
||||||
|
# 公开接口 run() → 分发到 pilot / full 模式
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
) -> List[DialogData]:
|
||||||
|
"""Run the full extraction pipeline on *dialog_data_list*.
|
||||||
|
|
||||||
|
Returns the mutated *dialog_data_list* with extracted data assigned
|
||||||
|
to each statement (triplets, temporal info, emotions, embeddings).
|
||||||
|
|
||||||
|
The orchestrator does NOT create graph nodes/edges or run dedup —
|
||||||
|
those responsibilities remain in WritePipeline.
|
||||||
|
"""
|
||||||
|
mode = "pilot" if self.is_pilot_run else "full"
|
||||||
|
logger.info(
|
||||||
|
"Starting extraction pipeline (%s mode), %d dialogs",
|
||||||
|
mode,
|
||||||
|
len(dialog_data_list),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.is_pilot_run:
|
||||||
|
return await self._run_pilot(dialog_data_list)
|
||||||
|
return await self._run_full(dialog_data_list)
|
||||||
|
|
||||||
|
# ── 3a. 试运行模式:仅 statement + triplet,不生成 embedding 和旁路 ──
|
||||||
|
|
||||||
|
async def _run_pilot(
|
||||||
|
self, dialog_data_list: List[DialogData]
|
||||||
|
) -> List[DialogData]:
|
||||||
|
"""Pilot mode: statement + triplet extraction only, no sidecars or embeddings."""
|
||||||
|
# Phase 1: Statement extraction (chunk-level parallel)
|
||||||
|
logger.info("Pilot phase 1/2: Statement extraction")
|
||||||
|
all_stmt_results = await self._extract_all_statements(dialog_data_list)
|
||||||
|
|
||||||
|
# Phase 2: Triplet extraction (statement-level parallel)
|
||||||
|
logger.info("Pilot phase 2/2: Triplet extraction")
|
||||||
|
all_triplet_results = await self._extract_all_triplets(
|
||||||
|
dialog_data_list, all_stmt_results
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign results back to dialog_data_list
|
||||||
|
self._assign_results(
|
||||||
|
dialog_data_list,
|
||||||
|
all_stmt_results,
|
||||||
|
all_triplet_results,
|
||||||
|
emotion_results={},
|
||||||
|
embedding_output=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store raw step outputs for snapshot/debugging
|
||||||
|
self._last_stage_outputs = {
|
||||||
|
"statement_results": all_stmt_results,
|
||||||
|
"triplet_results": all_triplet_results,
|
||||||
|
"emotion_results": {},
|
||||||
|
"embedding_output": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Pilot extraction complete")
|
||||||
|
return dialog_data_list
|
||||||
|
|
||||||
|
# ── 3b. 正式模式:四阶段并发执行 ──
|
||||||
|
|
||||||
|
async def _run_full(
|
||||||
|
self, dialog_data_list: List[DialogData]
|
||||||
|
) -> List[DialogData]:
|
||||||
|
"""Full mode: all four phases with concurrent sidecars and embeddings."""
|
||||||
|
|
||||||
|
# ── Phase 1: Statement extraction + chunk/dialog embedding ──
|
||||||
|
logger.info("Phase 1/4: Statement extraction + chunk/dialog embedding")
|
||||||
|
chunk_dialog_emb_input = self._build_chunk_dialog_embedding_input(
|
||||||
|
dialog_data_list
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt_coro = self._extract_all_statements(dialog_data_list)
|
||||||
|
emb_coro = self.embedding_step.run(chunk_dialog_emb_input)
|
||||||
|
|
||||||
|
phase1_results = await asyncio.gather(
|
||||||
|
stmt_coro, emb_coro, return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]] = (
|
||||||
|
phase1_results[0]
|
||||||
|
if not isinstance(phase1_results[0], BaseException)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
if isinstance(phase1_results[0], BaseException):
|
||||||
|
raise phase1_results[0]
|
||||||
|
|
||||||
|
chunk_dialog_emb: Optional[EmbeddingStepOutput] = (
|
||||||
|
phase1_results[1]
|
||||||
|
if not isinstance(phase1_results[1], BaseException)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if isinstance(phase1_results[1], BaseException):
|
||||||
|
logger.warning("Chunk/dialog embedding failed: %s", phase1_results[1])
|
||||||
|
|
||||||
|
# ── Phase 2: Triplet extraction + after_statement sidecars + statement embedding ──
|
||||||
|
logger.info(
|
||||||
|
"Phase 2/4: Triplet extraction + sidecars + statement embedding"
|
||||||
|
)
|
||||||
|
stmt_emb_input = self._build_statement_embedding_input(
|
||||||
|
dialog_data_list, all_stmt_results
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build sidecar inputs for after_statement sidecars (e.g. emotion)
|
||||||
|
sidecar_pairs = self._build_after_statement_sidecar_inputs(
|
||||||
|
dialog_data_list, all_stmt_results
|
||||||
|
)
|
||||||
|
|
||||||
|
triplet_coro = self._extract_all_triplets(
|
||||||
|
dialog_data_list, all_stmt_results
|
||||||
|
)
|
||||||
|
stmt_emb_coro = self.embedding_step.run(stmt_emb_input)
|
||||||
|
|
||||||
|
triplet_results, sidecar_results, extra_results = (
|
||||||
|
await self._run_with_sidecars(
|
||||||
|
triplet_coro,
|
||||||
|
sidecar_pairs,
|
||||||
|
extra_coros=[stmt_emb_coro],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_triplet_results = triplet_results
|
||||||
|
stmt_emb: Optional[EmbeddingStepOutput] = (
|
||||||
|
extra_results[0] if extra_results else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect sidecar outputs keyed by step name
|
||||||
|
sidecar_steps = [step for step, _inp in sidecar_pairs]
|
||||||
|
sidecar_output_map = self._collect_sidecar_outputs(
|
||||||
|
sidecar_steps, sidecar_results
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Phase 3: Entity embedding + after_triplet sidecars ──
|
||||||
|
logger.info("Phase 3/4: Entity embedding + after_triplet sidecars")
|
||||||
|
entity_emb_input = self._build_entity_embedding_input(all_triplet_results)
|
||||||
|
|
||||||
|
after_triplet_pairs: List[Tuple[ExtractionStep, Any]] = []
|
||||||
|
# Future after_triplet sidecars would be wired here
|
||||||
|
|
||||||
|
entity_emb_coro = self.embedding_step.run(entity_emb_input)
|
||||||
|
|
||||||
|
if after_triplet_pairs:
|
||||||
|
_, at_sidecar_results, at_extra = await self._run_with_sidecars(
|
||||||
|
entity_emb_coro,
|
||||||
|
after_triplet_pairs,
|
||||||
|
)
|
||||||
|
entity_emb = at_extra[0] if at_extra else None
|
||||||
|
else:
|
||||||
|
# No after_triplet sidecars — just run embedding
|
||||||
|
entity_emb_result = await entity_emb_coro
|
||||||
|
entity_emb = (
|
||||||
|
entity_emb_result
|
||||||
|
if not isinstance(entity_emb_result, BaseException)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge all embedding outputs
|
||||||
|
merged_emb = self._merge_embeddings(chunk_dialog_emb, stmt_emb, entity_emb)
|
||||||
|
|
||||||
|
# ── Phase 4: Data assignment ──
|
||||||
|
logger.info("Phase 4/4: Data assignment")
|
||||||
|
emotion_results = sidecar_output_map.get("emotion_extraction", {})
|
||||||
|
|
||||||
|
self._assign_results(
|
||||||
|
dialog_data_list,
|
||||||
|
all_stmt_results,
|
||||||
|
all_triplet_results,
|
||||||
|
emotion_results=emotion_results,
|
||||||
|
embedding_output=merged_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store raw step outputs for snapshot/debugging
|
||||||
|
self._last_stage_outputs = {
|
||||||
|
"statement_results": all_stmt_results,
|
||||||
|
"triplet_results": all_triplet_results,
|
||||||
|
"emotion_results": emotion_results,
|
||||||
|
"embedding_output": merged_emb,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Full extraction pipeline complete")
|
||||||
|
return dialog_data_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_stage_outputs(self) -> Dict[str, Any]:
|
||||||
|
"""Return the raw step outputs from the last run for snapshot/debugging."""
|
||||||
|
return getattr(self, "_last_stage_outputs", {})
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# 4. 萃取执行器
|
||||||
|
# chunk 级并行 statement 提取、statement 级并行 triplet 提取
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _extract_all_statements(
|
||||||
|
self,
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
) -> Dict[str, Dict[str, List[StatementStepOutput]]]:
|
||||||
|
"""Extract statements from all chunks across all dialogs (chunk-level parallel).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Nested dict: ``{dialog_id: {chunk_id: [StatementStepOutput, ...]}}``
|
||||||
|
"""
|
||||||
|
# Collect all (chunk, metadata) pairs
|
||||||
|
tasks: List[Any] = []
|
||||||
|
task_meta: List[Tuple[str, str, str, SupportingContext]] = []
|
||||||
|
|
||||||
|
for dialog in dialog_data_list:
|
||||||
|
ctx = self._build_supporting_context(dialog)
|
||||||
|
dialogue_content = (
|
||||||
|
dialog.content
|
||||||
|
if getattr(
|
||||||
|
self.config, "statement_extraction", None
|
||||||
|
)
|
||||||
|
and getattr(
|
||||||
|
self.config.statement_extraction,
|
||||||
|
"include_dialogue_context",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
for chunk in dialog.chunks:
|
||||||
|
inp = StatementStepInput(
|
||||||
|
chunk_id=chunk.id,
|
||||||
|
end_user_id=dialog.end_user_id,
|
||||||
|
target_content=chunk.content,
|
||||||
|
target_message_date=str(
|
||||||
|
getattr(dialog, "created_at", "") or ""
|
||||||
|
),
|
||||||
|
supporting_context=ctx,
|
||||||
|
)
|
||||||
|
tasks.append(self.statement_step.run(inp))
|
||||||
|
task_meta.append(
|
||||||
|
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Organise into nested dict
|
||||||
|
stmt_map: Dict[str, Dict[str, List[StatementStepOutput]]] = {}
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
dialog_id, chunk_id, speaker, _ = task_meta[i]
|
||||||
|
if dialog_id not in stmt_map:
|
||||||
|
stmt_map[dialog_id] = {}
|
||||||
|
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
logger.error("Statement extraction failed for chunk %s: %s", chunk_id, result)
|
||||||
|
stmt_map[dialog_id][chunk_id] = []
|
||||||
|
else:
|
||||||
|
# Override speaker from chunk metadata
|
||||||
|
stmts: List[StatementStepOutput] = result if isinstance(result, list) else []
|
||||||
|
for s in stmts:
|
||||||
|
s.speaker = speaker
|
||||||
|
stmt_map[dialog_id][chunk_id] = stmts
|
||||||
|
|
||||||
|
return stmt_map
|
||||||
|
|
||||||
|
async def _extract_all_triplets(
|
||||||
|
self,
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||||
|
) -> Dict[str, Dict[str, TripletStepOutput]]:
|
||||||
|
"""Extract triplets for every statement (statement-level parallel).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Nested dict: ``{dialog_id: {statement_id: TripletStepOutput}}``
|
||||||
|
"""
|
||||||
|
tasks: List[Any] = []
|
||||||
|
task_meta: List[Tuple[str, str]] = [] # (dialog_id, statement_id)
|
||||||
|
|
||||||
|
for dialog in dialog_data_list:
|
||||||
|
ctx = self._build_supporting_context(dialog)
|
||||||
|
chunk_stmts = all_stmt_results.get(dialog.id, {})
|
||||||
|
for _chunk_id, stmts in chunk_stmts.items():
|
||||||
|
for stmt in stmts:
|
||||||
|
inp = self._convert_to_triplet_input(stmt, ctx)
|
||||||
|
tasks.append(self.triplet_step.run(inp))
|
||||||
|
task_meta.append((dialog.id, stmt.statement_id))
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
triplet_map: Dict[str, Dict[str, TripletStepOutput]] = {}
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
dialog_id, stmt_id = task_meta[i]
|
||||||
|
if dialog_id not in triplet_map:
|
||||||
|
triplet_map[dialog_id] = {}
|
||||||
|
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
logger.error(
|
||||||
|
"Triplet extraction failed for statement %s: %s",
|
||||||
|
stmt_id,
|
||||||
|
result,
|
||||||
|
)
|
||||||
|
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
|
||||||
|
else:
|
||||||
|
triplet_map[dialog_id][stmt_id] = result
|
||||||
|
|
||||||
|
return triplet_map
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# 5. Embedding 输入构建器
|
||||||
|
# 为不同阶段构建 EmbeddingStepInput(chunk/statement/entity)
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_chunk_dialog_embedding_input(
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
) -> EmbeddingStepInput:
|
||||||
|
"""Build embedding input for chunks and dialogs (phase 1)."""
|
||||||
|
chunk_texts: Dict[str, str] = {}
|
||||||
|
dialog_texts: List[str] = []
|
||||||
|
|
||||||
|
for dialog in dialog_data_list:
|
||||||
|
if hasattr(dialog, "content") and dialog.content:
|
||||||
|
dialog_texts.append(dialog.content)
|
||||||
|
for chunk in dialog.chunks:
|
||||||
|
chunk_texts[chunk.id] = chunk.content
|
||||||
|
|
||||||
|
return EmbeddingStepInput(
|
||||||
|
chunk_texts=chunk_texts,
|
||||||
|
dialog_texts=dialog_texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_statement_embedding_input(
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||||
|
) -> EmbeddingStepInput:
|
||||||
|
"""Build embedding input for statements (phase 2)."""
|
||||||
|
stmt_texts: Dict[str, str] = {}
|
||||||
|
for _dialog_id, chunk_stmts in all_stmt_results.items():
|
||||||
|
for _chunk_id, stmts in chunk_stmts.items():
|
||||||
|
for s in stmts:
|
||||||
|
stmt_texts[s.statement_id] = s.statement_text
|
||||||
|
return EmbeddingStepInput(statement_texts=stmt_texts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_entity_embedding_input(
|
||||||
|
all_triplet_results: Dict[str, Dict[str, TripletStepOutput]],
|
||||||
|
) -> EmbeddingStepInput:
|
||||||
|
"""Build embedding input for entities (phase 3)."""
|
||||||
|
entity_names: Dict[str, str] = {}
|
||||||
|
entity_descs: Dict[str, str] = {}
|
||||||
|
seen: set = set()
|
||||||
|
|
||||||
|
for _dialog_id, stmt_triplets in all_triplet_results.items():
|
||||||
|
for _stmt_id, triplet_out in stmt_triplets.items():
|
||||||
|
for ent in triplet_out.entities:
|
||||||
|
key = f"{ent.entity_idx}_{ent.name}"
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
entity_names[key] = ent.name
|
||||||
|
entity_descs[key] = ent.description
|
||||||
|
|
||||||
|
return EmbeddingStepInput(
|
||||||
|
entity_names=entity_names,
|
||||||
|
entity_descriptions=entity_descs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# 6. 旁路输入构建与结果收集
|
||||||
|
# 为 after_statement / after_triplet 旁路构建输入,合并 embedding 输出
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _build_after_statement_sidecar_inputs(
|
||||||
|
self,
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||||
|
) -> List[Tuple[ExtractionStep, Any]]:
|
||||||
|
"""Build (step, input) pairs for after_statement sidecars.
|
||||||
|
|
||||||
|
For emotion extraction, we create a batch wrapper that runs the sidecar
|
||||||
|
on every user statement concurrently and returns a dict of results.
|
||||||
|
"""
|
||||||
|
if not self.after_statement_sidecars:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Collect all user statements for sidecar processing
|
||||||
|
all_user_stmts: List[StatementStepOutput] = []
|
||||||
|
for _dialog_id, chunk_stmts in all_stmt_results.items():
|
||||||
|
for _chunk_id, stmts in chunk_stmts.items():
|
||||||
|
for s in stmts:
|
||||||
|
if s.speaker == "user":
|
||||||
|
all_user_stmts.append(s)
|
||||||
|
|
||||||
|
pairs: List[Tuple[ExtractionStep, Any]] = []
|
||||||
|
for sidecar in self.after_statement_sidecars:
|
||||||
|
if sidecar.name == "emotion_extraction":
|
||||||
|
# Emotion sidecar: wrap as batch coroutine via a sentinel input
|
||||||
|
# The actual per-statement calls happen inside _run_emotion_batch
|
||||||
|
pairs.append((
|
||||||
|
_EmotionBatchWrapper(sidecar, all_user_stmts),
|
||||||
|
EmotionStepInput(
|
||||||
|
statement_id="__batch__",
|
||||||
|
statement_text="",
|
||||||
|
speaker="",
|
||||||
|
),
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
# Generic sidecar: pass first statement as representative input
|
||||||
|
if all_user_stmts:
|
||||||
|
inp = self._convert_to_emotion_input(all_user_stmts[0])
|
||||||
|
pairs.append((sidecar, inp))
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collect_sidecar_outputs(
|
||||||
|
sidecars: List[ExtractionStep],
|
||||||
|
results: List[Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Map sidecar results by step name."""
|
||||||
|
output: Dict[str, Any] = {}
|
||||||
|
for i, sidecar in enumerate(sidecars):
|
||||||
|
if i < len(results):
|
||||||
|
output[sidecar.name] = results[i]
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_embeddings(
|
||||||
|
chunk_dialog: Optional[EmbeddingStepOutput],
|
||||||
|
statement: Optional[EmbeddingStepOutput],
|
||||||
|
entity: Optional[Any],
|
||||||
|
) -> Optional[EmbeddingStepOutput]:
|
||||||
|
"""Merge partial embedding outputs into a single EmbeddingStepOutput."""
|
||||||
|
merged = EmbeddingStepOutput()
|
||||||
|
if chunk_dialog:
|
||||||
|
merged.chunk_embeddings = chunk_dialog.chunk_embeddings
|
||||||
|
merged.dialog_embeddings = chunk_dialog.dialog_embeddings
|
||||||
|
if statement:
|
||||||
|
merged.statement_embeddings = statement.statement_embeddings
|
||||||
|
if entity and isinstance(entity, EmbeddingStepOutput):
|
||||||
|
merged.entity_embeddings = entity.entity_embeddings
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# 7. 数据赋值
|
||||||
|
# 将各阶段 StepOutput 组装为 Statement 对象,替换 chunk.statements
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _assign_results(
|
||||||
|
self,
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||||
|
all_triplet_results: Dict[str, Dict[str, TripletStepOutput]],
|
||||||
|
emotion_results: Dict[str, EmotionStepOutput],
|
||||||
|
embedding_output: Optional[EmbeddingStepOutput],
|
||||||
|
) -> None:
|
||||||
|
"""Assign extraction results back to dialog_data_list in-place.
|
||||||
|
|
||||||
|
Replaces chunk.statements with new Statement objects built from step
|
||||||
|
outputs, because the new orchestrator generates its own statement IDs
|
||||||
|
that don't match the original chunk statement IDs.
|
||||||
|
"""
|
||||||
|
from app.core.memory.models.message_models import (
|
||||||
|
Statement,
|
||||||
|
TemporalValidityRange,
|
||||||
|
)
|
||||||
|
from app.core.memory.models.triplet_models import (
|
||||||
|
TripletExtractionResponse,
|
||||||
|
Entity as TripletEntity,
|
||||||
|
Triplet as TripletRelation,
|
||||||
|
)
|
||||||
|
from app.core.memory.utils.data.ontology import (
|
||||||
|
RelevenceInfo,
|
||||||
|
StatementType,
|
||||||
|
TemporalInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map string values to enums
|
||||||
|
_STMT_TYPE_MAP = {
|
||||||
|
"FACT": StatementType.FACT,
|
||||||
|
"OPINION": StatementType.OPINION,
|
||||||
|
"PREDICTION": StatementType.PREDICTION,
|
||||||
|
"SUGGESTION": StatementType.SUGGESTION,
|
||||||
|
}
|
||||||
|
_TEMPORAL_MAP = {
|
||||||
|
"STATIC": TemporalInfo.STATIC,
|
||||||
|
"DYNAMIC": TemporalInfo.DYNAMIC,
|
||||||
|
"ATEMPORAL": TemporalInfo.ATEMPORAL,
|
||||||
|
}
|
||||||
|
|
||||||
|
total_stmts = 0
|
||||||
|
assigned_triplets = 0
|
||||||
|
assigned_emotions = 0
|
||||||
|
assigned_stmt_emb = 0
|
||||||
|
assigned_chunk_emb = 0
|
||||||
|
assigned_dialog_emb = 0
|
||||||
|
|
||||||
|
for dialog in dialog_data_list:
|
||||||
|
dialog_stmts = all_stmt_results.get(dialog.id, {})
|
||||||
|
dialog_triplets = all_triplet_results.get(dialog.id, {})
|
||||||
|
|
||||||
|
# Assign dialog embedding
|
||||||
|
if embedding_output and embedding_output.dialog_embeddings:
|
||||||
|
idx = dialog_data_list.index(dialog)
|
||||||
|
if idx < len(embedding_output.dialog_embeddings):
|
||||||
|
dialog.dialog_embedding = embedding_output.dialog_embeddings[idx]
|
||||||
|
assigned_dialog_emb += 1
|
||||||
|
|
||||||
|
for chunk in dialog.chunks:
|
||||||
|
# Assign chunk embedding
|
||||||
|
if embedding_output and chunk.id in embedding_output.chunk_embeddings:
|
||||||
|
chunk.chunk_embedding = embedding_output.chunk_embeddings[chunk.id]
|
||||||
|
assigned_chunk_emb += 1
|
||||||
|
|
||||||
|
# Build new Statement objects from step outputs
|
||||||
|
chunk_stmt_outputs = dialog_stmts.get(chunk.id, [])
|
||||||
|
new_statements = []
|
||||||
|
|
||||||
|
for stmt_out in chunk_stmt_outputs:
|
||||||
|
total_stmts += 1
|
||||||
|
|
||||||
|
# Temporal validity
|
||||||
|
valid_at = stmt_out.valid_at if stmt_out.valid_at != "NULL" else None
|
||||||
|
invalid_at = stmt_out.invalid_at if stmt_out.invalid_at != "NULL" else None
|
||||||
|
|
||||||
|
# Triplet info
|
||||||
|
triplet_info = None
|
||||||
|
triplet_out = dialog_triplets.get(stmt_out.statement_id)
|
||||||
|
if triplet_out and (triplet_out.entities or triplet_out.triplets):
|
||||||
|
entities = [
|
||||||
|
TripletEntity(
|
||||||
|
entity_idx=e.entity_idx,
|
||||||
|
name=e.name,
|
||||||
|
type=e.type,
|
||||||
|
description=e.description,
|
||||||
|
is_explicit_memory=e.is_explicit_memory,
|
||||||
|
)
|
||||||
|
for e in triplet_out.entities
|
||||||
|
]
|
||||||
|
triplets = [
|
||||||
|
TripletRelation(
|
||||||
|
subject_name=t.subject_name,
|
||||||
|
subject_id=t.subject_id,
|
||||||
|
predicate=t.predicate,
|
||||||
|
object_name=t.object_name,
|
||||||
|
object_id=t.object_id,
|
||||||
|
)
|
||||||
|
for t in triplet_out.triplets
|
||||||
|
]
|
||||||
|
triplet_info = TripletExtractionResponse(
|
||||||
|
entities=entities, triplets=triplets,
|
||||||
|
)
|
||||||
|
assigned_triplets += 1
|
||||||
|
|
||||||
|
# Emotion info
|
||||||
|
emo = emotion_results.get(stmt_out.statement_id)
|
||||||
|
emotion_kwargs = {}
|
||||||
|
if emo:
|
||||||
|
emotion_kwargs = {
|
||||||
|
"emotion_type": emo.emotion_type,
|
||||||
|
"emotion_intensity": emo.emotion_intensity,
|
||||||
|
"emotion_keywords": emo.emotion_keywords,
|
||||||
|
}
|
||||||
|
assigned_emotions += 1
|
||||||
|
|
||||||
|
# Statement embedding
|
||||||
|
stmt_embedding = None
|
||||||
|
if (
|
||||||
|
embedding_output
|
||||||
|
and stmt_out.statement_id in embedding_output.statement_embeddings
|
||||||
|
):
|
||||||
|
stmt_embedding = embedding_output.statement_embeddings[stmt_out.statement_id]
|
||||||
|
assigned_stmt_emb += 1
|
||||||
|
|
||||||
|
# Build the Statement object that _create_nodes_and_edges expects
|
||||||
|
stmt = Statement(
|
||||||
|
id=stmt_out.statement_id,
|
||||||
|
chunk_id=chunk.id,
|
||||||
|
end_user_id=dialog.end_user_id,
|
||||||
|
statement=stmt_out.statement_text,
|
||||||
|
speaker=stmt_out.speaker,
|
||||||
|
stmt_type=_STMT_TYPE_MAP.get(stmt_out.statement_type, StatementType.FACT),
|
||||||
|
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
|
||||||
|
relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
|
||||||
|
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
|
||||||
|
triplet_extraction_info=triplet_info,
|
||||||
|
statement_embedding=stmt_embedding,
|
||||||
|
**emotion_kwargs,
|
||||||
|
)
|
||||||
|
new_statements.append(stmt)
|
||||||
|
|
||||||
|
# Replace chunk.statements with newly built objects
|
||||||
|
chunk.statements = new_statements
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Data assignment complete — statements: %d, triplets: %d, "
|
||||||
|
"emotions: %d, stmt_emb: %d, chunk_emb: %d, dialog_emb: %d",
|
||||||
|
total_stmts,
|
||||||
|
assigned_triplets,
|
||||||
|
assigned_emotions,
|
||||||
|
assigned_stmt_emb,
|
||||||
|
assigned_chunk_emb,
|
||||||
|
assigned_dialog_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _EmotionBatchWrapper(ExtractionStep):
|
||||||
|
"""情绪批量提取包装器。再考虑一下用法,这是子类?
|
||||||
|
|
||||||
|
将单条情绪旁路 Step 包装为批量并发执行,适配 ``_run_with_sidecars`` 接口。
|
||||||
|
编排器传入一个 sentinel input,``run()`` 忽略它,转而对预收集的 statement 列表
|
||||||
|
逐条并发调用内部 Step,返回 ``{statement_id: EmotionStepOutput}`` 字典。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── 初始化 ──
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_step: ExtractionStep,
|
||||||
|
statements: List[StatementStepOutput],
|
||||||
|
) -> None:
|
||||||
|
# 不调用 super().__init__() — 本类是薄包装,不需要 StepContext
|
||||||
|
self._inner = inner_step
|
||||||
|
self._statements = statements
|
||||||
|
|
||||||
|
# ── Step 身份属性(满足 ExtractionStep 抽象接口) ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._inner.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_default_output(self) -> Dict[str, EmotionStepOutput]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# ── 未使用的生命周期方法(批量包装器不走 render→call→parse 流程) ──
|
||||||
|
|
||||||
|
async def render_prompt(self, input_data: Any) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def call_llm(self, prompt: Any) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def parse_response(self, raw_response: Any, input_data: Any) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# ── 批量执行入口 ──
|
||||||
|
|
||||||
|
async def run(self, input_data: Any) -> Dict[str, EmotionStepOutput]:
|
||||||
|
"""对所有预收集的 statement 并发执行情绪提取,单条失败返回默认值。"""
|
||||||
|
if not self._statements:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _extract_one(stmt: StatementStepOutput) -> Tuple[str, EmotionStepOutput]:
|
||||||
|
inp = EmotionStepInput(
|
||||||
|
statement_id=stmt.statement_id,
|
||||||
|
statement_text=stmt.statement_text,
|
||||||
|
speaker=stmt.speaker,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = await self._inner.run(inp)
|
||||||
|
return stmt.statement_id, result
|
||||||
|
except Exception:
|
||||||
|
return stmt.statement_id, self._inner.get_default_output()
|
||||||
|
|
||||||
|
pairs = await asyncio.gather(
|
||||||
|
*[_extract_one(s) for s in self._statements]
|
||||||
|
)
|
||||||
|
return dict(pairs)
|
||||||
@@ -0,0 +1,366 @@
|
|||||||
|
"""
|
||||||
|
GraphBuildStep — 从 DialogData 构建 Neo4j 图节点和边。
|
||||||
|
|
||||||
|
职责:
|
||||||
|
- 遍历 DialogData 列表,构建 DialogueNode、ChunkNode、StatementNode、
|
||||||
|
ExtractedEntityNode、PerceptualNode 及各类 Edge
|
||||||
|
- 不涉及 LLM 调用、去重、Neo4j 写入
|
||||||
|
|
||||||
|
依赖:
|
||||||
|
- embedder_client(可选):为 PerceptualNode 生成 summary embedding
|
||||||
|
- progress_callback(可选):流式输出关系创建进度
|
||||||
|
|
||||||
|
从 ExtractionOrchestrator._create_nodes_and_edges() 提取而来,
|
||||||
|
旧编排器保留原方法不变,新旧流水线完全隔离。
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from app.core.memory.models.graph_models import (
|
||||||
|
ChunkNode,
|
||||||
|
DialogueNode,
|
||||||
|
EntityEntityEdge,
|
||||||
|
ExtractedEntityNode,
|
||||||
|
PerceptualEdge,
|
||||||
|
PerceptualNode,
|
||||||
|
StatementChunkEdge,
|
||||||
|
StatementEntityEdge,
|
||||||
|
StatementNode,
|
||||||
|
)
|
||||||
|
from app.core.memory.models.message_models import DialogData, TemporalInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphBuildResult:
|
||||||
|
"""图构建步骤的输出。"""
|
||||||
|
|
||||||
|
__slots__ = (
|
||||||
|
"dialogue_nodes",
|
||||||
|
"chunk_nodes",
|
||||||
|
"statement_nodes",
|
||||||
|
"entity_nodes",
|
||||||
|
"perceptual_nodes",
|
||||||
|
"stmt_chunk_edges",
|
||||||
|
"stmt_entity_edges",
|
||||||
|
"entity_entity_edges",
|
||||||
|
"perceptual_edges",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dialogue_nodes: List[DialogueNode],
|
||||||
|
chunk_nodes: List[ChunkNode],
|
||||||
|
statement_nodes: List[StatementNode],
|
||||||
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
|
perceptual_nodes: List[PerceptualNode],
|
||||||
|
stmt_chunk_edges: List[StatementChunkEdge],
|
||||||
|
stmt_entity_edges: List[StatementEntityEdge],
|
||||||
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
|
perceptual_edges: List[PerceptualEdge],
|
||||||
|
):
|
||||||
|
self.dialogue_nodes = dialogue_nodes
|
||||||
|
self.chunk_nodes = chunk_nodes
|
||||||
|
self.statement_nodes = statement_nodes
|
||||||
|
self.entity_nodes = entity_nodes
|
||||||
|
self.perceptual_nodes = perceptual_nodes
|
||||||
|
self.stmt_chunk_edges = stmt_chunk_edges
|
||||||
|
self.stmt_entity_edges = stmt_entity_edges
|
||||||
|
self.entity_entity_edges = entity_entity_edges
|
||||||
|
self.perceptual_edges = perceptual_edges
|
||||||
|
|
||||||
|
|
||||||
|
async def build_graph_nodes_and_edges(
|
||||||
|
dialog_data_list: List[DialogData],
|
||||||
|
embedder_client: Any = None,
|
||||||
|
progress_callback: Optional[
|
||||||
|
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||||
|
] = None,
|
||||||
|
) -> GraphBuildResult:
|
||||||
|
"""
|
||||||
|
从 DialogData 列表构建完整的图节点和边。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialog_data_list: 经过萃取和数据赋值后的 DialogData 列表
|
||||||
|
embedder_client: 可选的嵌入客户端,用于 PerceptualNode summary embedding
|
||||||
|
progress_callback: 可选的进度回调
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GraphBuildResult 包含所有节点和边
|
||||||
|
"""
|
||||||
|
logger.info("开始创建节点和边")
|
||||||
|
|
||||||
|
dialogue_nodes: List[DialogueNode] = []
|
||||||
|
chunk_nodes: List[ChunkNode] = []
|
||||||
|
statement_nodes: List[StatementNode] = []
|
||||||
|
entity_nodes: List[ExtractedEntityNode] = []
|
||||||
|
perceptual_nodes: List[PerceptualNode] = []
|
||||||
|
stmt_chunk_edges: List[StatementChunkEdge] = []
|
||||||
|
stmt_entity_edges: List[StatementEntityEdge] = []
|
||||||
|
entity_entity_edges: List[EntityEntityEdge] = []
|
||||||
|
perceptual_edges: List[PerceptualEdge] = []
|
||||||
|
|
||||||
|
entity_id_set: set = set()
|
||||||
|
total_dialogs = len(dialog_data_list)
|
||||||
|
processed_dialogs = 0
|
||||||
|
|
||||||
|
for dialog_data in dialog_data_list:
|
||||||
|
processed_dialogs += 1
|
||||||
|
# region TODO 乐力齐 重构流水线切换生产环境稳定后修改
|
||||||
|
# ── 对话节点 ──
|
||||||
|
dialogue_node = DialogueNode(
|
||||||
|
id=dialog_data.id,
|
||||||
|
name=f"Dialog_{dialog_data.id}",
|
||||||
|
ref_id=dialog_data.ref_id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
content=dialog_data.context.content if dialog_data.context else "",
|
||||||
|
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
expired_at=dialog_data.expired_at,
|
||||||
|
metadata=dialog_data.metadata,
|
||||||
|
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||||
|
)
|
||||||
|
dialogue_nodes.append(dialogue_node)
|
||||||
|
|
||||||
|
# ── 分块节点 ──
|
||||||
|
for chunk_idx, chunk in enumerate(dialog_data.chunks):
|
||||||
|
chunk_node = ChunkNode(
|
||||||
|
id=chunk.id,
|
||||||
|
name=f"Chunk_{chunk.id}",
|
||||||
|
dialog_id=dialog_data.id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
content=chunk.content,
|
||||||
|
speaker=getattr(chunk, "speaker", None),
|
||||||
|
chunk_embedding=chunk.chunk_embedding,
|
||||||
|
sequence_number=chunk_idx,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
expired_at=dialog_data.expired_at,
|
||||||
|
metadata=chunk.metadata,
|
||||||
|
)
|
||||||
|
chunk_nodes.append(chunk_node)
|
||||||
|
|
||||||
|
# ── 感知节点 ──
|
||||||
|
for p, file_type in chunk.files:
|
||||||
|
meta = p.meta_data or {}
|
||||||
|
content_meta = meta.get("content", {})
|
||||||
|
|
||||||
|
summary_embedding = None
|
||||||
|
if embedder_client and p.summary:
|
||||||
|
try:
|
||||||
|
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
||||||
|
except Exception as emb_err:
|
||||||
|
logger.warning(f"Failed to embed perceptual summary: {emb_err}")
|
||||||
|
|
||||||
|
perceptual = PerceptualNode(
|
||||||
|
name=f"Perceptual_{p.id}",
|
||||||
|
id=str(p.id),
|
||||||
|
end_user_id=str(p.end_user_id),
|
||||||
|
perceptual_type=p.perceptual_type,
|
||||||
|
file_path=p.file_path or "",
|
||||||
|
file_name=p.file_name or "",
|
||||||
|
file_ext=p.file_ext or "",
|
||||||
|
summary=p.summary or "",
|
||||||
|
keywords=content_meta.get("keywords", []),
|
||||||
|
topic=content_meta.get("topic", ""),
|
||||||
|
domain=content_meta.get("domain", ""),
|
||||||
|
created_at=p.created_time.isoformat() if p.created_time else None,
|
||||||
|
file_type=file_type,
|
||||||
|
summary_embedding=summary_embedding,
|
||||||
|
)
|
||||||
|
perceptual_nodes.append(perceptual)
|
||||||
|
perceptual_edges.append(
|
||||||
|
PerceptualEdge(
|
||||||
|
source=perceptual.id,
|
||||||
|
target=chunk.id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 陈述句节点 + 边 ──
|
||||||
|
for statement in chunk.statements:
|
||||||
|
statement_node = StatementNode(
|
||||||
|
id=statement.id,
|
||||||
|
name=f"Statement_{statement.id}",
|
||||||
|
chunk_id=chunk.id,
|
||||||
|
stmt_type=getattr(statement, "stmt_type", "general"),
|
||||||
|
temporal_info=getattr(statement, "temporal_info", TemporalInfo.ATEMPORAL),
|
||||||
|
connect_strength=(
|
||||||
|
statement.connect_strength
|
||||||
|
if statement.connect_strength is not None
|
||||||
|
else "Strong"
|
||||||
|
),
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
statement=statement.statement,
|
||||||
|
speaker=getattr(statement, "speaker", None),
|
||||||
|
statement_embedding=statement.statement_embedding,
|
||||||
|
valid_at=(
|
||||||
|
statement.temporal_validity.valid_at
|
||||||
|
if hasattr(statement, "temporal_validity") and statement.temporal_validity
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
invalid_at=(
|
||||||
|
statement.temporal_validity.invalid_at
|
||||||
|
if hasattr(statement, "temporal_validity") and statement.temporal_validity
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
expired_at=dialog_data.expired_at,
|
||||||
|
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||||
|
emotion_type=getattr(statement, "emotion_type", None),
|
||||||
|
emotion_intensity=getattr(statement, "emotion_intensity", None),
|
||||||
|
emotion_keywords=getattr(statement, "emotion_keywords", None),
|
||||||
|
emotion_subject=getattr(statement, "emotion_subject", None),
|
||||||
|
emotion_target=getattr(statement, "emotion_target", None),
|
||||||
|
)
|
||||||
|
statement_nodes.append(statement_node)
|
||||||
|
|
||||||
|
stmt_chunk_edges.append(
|
||||||
|
StatementChunkEdge(
|
||||||
|
source=statement.id,
|
||||||
|
target=chunk.id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 三元组 → 实体节点 + 边 ──
|
||||||
|
if not statement.triplet_extraction_info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
triplet_info = statement.triplet_extraction_info
|
||||||
|
entity_idx_to_id: Dict[int, str] = {}
|
||||||
|
|
||||||
|
for entity_idx, entity in enumerate(triplet_info.entities):
|
||||||
|
entity_idx_to_id[entity.entity_idx] = entity.id
|
||||||
|
entity_idx_to_id[entity_idx] = entity.id
|
||||||
|
|
||||||
|
if entity.id not in entity_id_set:
|
||||||
|
entity_connect_strength = getattr(entity, "connect_strength", "Strong")
|
||||||
|
entity_node = ExtractedEntityNode(
|
||||||
|
id=entity.id,
|
||||||
|
name=getattr(entity, "name", f"Entity_{entity.id}"),
|
||||||
|
entity_idx=entity.entity_idx,
|
||||||
|
statement_id=statement.id,
|
||||||
|
entity_type=getattr(entity, "type", "unknown"),
|
||||||
|
description=getattr(entity, "description", ""),
|
||||||
|
example=getattr(entity, "example", ""),
|
||||||
|
connect_strength=(
|
||||||
|
entity_connect_strength
|
||||||
|
if entity_connect_strength is not None
|
||||||
|
else "Strong"
|
||||||
|
),
|
||||||
|
aliases=getattr(entity, "aliases", []) or [],
|
||||||
|
name_embedding=getattr(entity, "name_embedding", None),
|
||||||
|
is_explicit_memory=getattr(entity, "is_explicit_memory", False),
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
expired_at=dialog_data.expired_at,
|
||||||
|
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||||
|
)
|
||||||
|
entity_nodes.append(entity_node)
|
||||||
|
entity_id_set.add(entity.id)
|
||||||
|
|
||||||
|
entity_connect_strength = getattr(entity, "connect_strength", "Strong")
|
||||||
|
stmt_entity_edges.append(
|
||||||
|
StatementEntityEdge(
|
||||||
|
source=statement.id,
|
||||||
|
target=entity.id,
|
||||||
|
connect_strength=(
|
||||||
|
entity_connect_strength
|
||||||
|
if entity_connect_strength is not None
|
||||||
|
else "Strong"
|
||||||
|
),
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
for triplet in triplet_info.triplets:
|
||||||
|
subject_entity_id = entity_idx_to_id.get(triplet.subject_id)
|
||||||
|
object_entity_id = entity_idx_to_id.get(triplet.object_id)
|
||||||
|
|
||||||
|
if subject_entity_id and object_entity_id:
|
||||||
|
entity_entity_edges.append(
|
||||||
|
EntityEntityEdge(
|
||||||
|
source=subject_entity_id,
|
||||||
|
target=object_entity_id,
|
||||||
|
relation_type=triplet.predicate,
|
||||||
|
statement=statement.statement,
|
||||||
|
source_statement_id=statement.id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
expired_at=dialog_data.expired_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback and len(entity_entity_edges) <= 10:
|
||||||
|
relationship_result = {
|
||||||
|
"result_type": "relationship_creation",
|
||||||
|
"relationship_index": len(entity_entity_edges),
|
||||||
|
"source_entity": triplet.subject_name,
|
||||||
|
"relation_type": triplet.predicate,
|
||||||
|
"target_entity": triplet.object_name,
|
||||||
|
"relationship_text": f"{triplet.subject_name} -[{triplet.predicate}]-> {triplet.object_name}",
|
||||||
|
"dialog_progress": f"{processed_dialogs}/{total_dialogs}",
|
||||||
|
}
|
||||||
|
await progress_callback(
|
||||||
|
"creating_nodes_edges_result",
|
||||||
|
f"关系创建中 ({processed_dialogs}/{total_dialogs})",
|
||||||
|
relationship_result,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
missing_subject = "subject" if not subject_entity_id else ""
|
||||||
|
missing_object = "object" if not object_entity_id else ""
|
||||||
|
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
|
||||||
|
logger.debug(
|
||||||
|
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
|
||||||
|
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
|
||||||
|
f"object_id={triplet.object_id} ({triplet.object_name}), "
|
||||||
|
f"predicate={triplet.predicate}, "
|
||||||
|
f"statement_id={statement.id}, "
|
||||||
|
f"available_indices={sorted(entity_idx_to_id.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"节点和边创建完成 - 对话节点: {len(dialogue_nodes)}, "
|
||||||
|
f"分块节点: {len(chunk_nodes)}, 陈述句节点: {len(statement_nodes)}, "
|
||||||
|
f"实体节点: {len(entity_nodes)}, 陈述句-分块边: {len(stmt_chunk_edges)}, "
|
||||||
|
f"陈述句-实体边: {len(stmt_entity_edges)}, "
|
||||||
|
f"实体-实体边: {len(entity_entity_edges)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
nodes_edges_stats = {
|
||||||
|
"dialogue_nodes_count": len(dialogue_nodes),
|
||||||
|
"chunk_nodes_count": len(chunk_nodes),
|
||||||
|
"statement_nodes_count": len(statement_nodes),
|
||||||
|
"entity_nodes_count": len(entity_nodes),
|
||||||
|
"statement_chunk_edges_count": len(stmt_chunk_edges),
|
||||||
|
"statement_entity_edges_count": len(stmt_entity_edges),
|
||||||
|
"entity_entity_edges_count": len(entity_entity_edges),
|
||||||
|
}
|
||||||
|
await progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
|
||||||
|
|
||||||
|
return GraphBuildResult(
|
||||||
|
dialogue_nodes=dialogue_nodes,
|
||||||
|
chunk_nodes=chunk_nodes,
|
||||||
|
statement_nodes=statement_nodes,
|
||||||
|
entity_nodes=entity_nodes,
|
||||||
|
perceptual_nodes=perceptual_nodes,
|
||||||
|
stmt_chunk_edges=stmt_chunk_edges,
|
||||||
|
stmt_entity_edges=stmt_entity_edges,
|
||||||
|
entity_entity_edges=entity_entity_edges,
|
||||||
|
perceptual_edges=perceptual_edges,
|
||||||
|
)
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
"""Schema package for ExtractionStep inputs and outputs.
|
||||||
|
|
||||||
|
Re-exports all models for convenient access:
|
||||||
|
from .schema import StatementStepInput, EmotionStepOutput, ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .extraction_step_schema import (
|
||||||
|
EmbeddingStepInput,
|
||||||
|
EmbeddingStepOutput,
|
||||||
|
EntityItem,
|
||||||
|
MessageItem,
|
||||||
|
StatementStepInput,
|
||||||
|
StatementStepOutput,
|
||||||
|
SupportingContext,
|
||||||
|
TripletItem,
|
||||||
|
TripletStepInput,
|
||||||
|
TripletStepOutput,
|
||||||
|
)
|
||||||
|
from .sidecar_step_schema import (
|
||||||
|
EmotionStepInput,
|
||||||
|
EmotionStepOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Shared
|
||||||
|
"MessageItem",
|
||||||
|
"SupportingContext",
|
||||||
|
# Statement
|
||||||
|
"StatementStepInput",
|
||||||
|
"StatementStepOutput",
|
||||||
|
# Triplet
|
||||||
|
"TripletStepInput",
|
||||||
|
"TripletStepOutput",
|
||||||
|
"EntityItem",
|
||||||
|
"TripletItem",
|
||||||
|
# Embedding
|
||||||
|
"EmbeddingStepInput",
|
||||||
|
"EmbeddingStepOutput",
|
||||||
|
# Sidecar — Emotion
|
||||||
|
"EmotionStepInput",
|
||||||
|
"EmotionStepOutput",
|
||||||
|
]
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
"""Pydantic models for base extraction pipeline inputs and outputs.
|
||||||
|
|
||||||
|
Covers the core (critical) stages: Statement extraction, Triplet extraction,
|
||||||
|
Embedding generation, and shared types used across stages.
|
||||||
|
|
||||||
|
Malformed LLM JSON will raise ``ValidationError`` and trigger stage-level retry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ── Shared types ──
|
||||||
|
|
||||||
|
|
||||||
|
class MessageItem(BaseModel):
|
||||||
|
"""Single conversation message."""
|
||||||
|
|
||||||
|
role: str # "User" / "Assistant"
|
||||||
|
msg: str
|
||||||
|
|
||||||
|
|
||||||
|
class SupportingContext(BaseModel):
|
||||||
|
"""Dialogue context window (used for pronoun resolution, etc.)."""
|
||||||
|
|
||||||
|
msgs: List[MessageItem] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Statement extraction ──
|
||||||
|
class StatementStepInput(BaseModel):
|
||||||
|
"""Input for StatementExtractionStep."""
|
||||||
|
|
||||||
|
chunk_id: str
|
||||||
|
end_user_id: str
|
||||||
|
target_content: str
|
||||||
|
target_message_date: str
|
||||||
|
supporting_context: SupportingContext
|
||||||
|
|
||||||
|
|
||||||
|
class StatementStepOutput(BaseModel):
|
||||||
|
"""Single extracted statement (including temporal info)."""
|
||||||
|
|
||||||
|
statement_id: str
|
||||||
|
statement_text: str
|
||||||
|
statement_type: str # FACT / OPINION / PREDICTION / SUGGESTION
|
||||||
|
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
|
||||||
|
relevance: str # RELEVANT / IRRELEVANT
|
||||||
|
speaker: str # "user" / "assistant"
|
||||||
|
valid_at: str # ISO 8601 or "NULL"
|
||||||
|
invalid_at: str # ISO 8601 or "NULL"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Triplet extraction ──
|
||||||
|
class TripletStepInput(BaseModel):
|
||||||
|
"""Input for TripletExtractionStep."""
|
||||||
|
|
||||||
|
statement_id: str
|
||||||
|
statement_text: str
|
||||||
|
statement_type: str
|
||||||
|
temporal_type: str
|
||||||
|
supporting_context: SupportingContext
|
||||||
|
speaker: str
|
||||||
|
valid_at: str
|
||||||
|
invalid_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class EntityItem(BaseModel):
|
||||||
|
"""Single entity extracted during triplet extraction."""
|
||||||
|
|
||||||
|
entity_idx: int
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
description: str
|
||||||
|
is_explicit_memory: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class TripletItem(BaseModel):
|
||||||
|
"""Single triplet (subject-predicate-object) relationship."""
|
||||||
|
|
||||||
|
subject_name: str
|
||||||
|
subject_id: int
|
||||||
|
predicate: str
|
||||||
|
object_name: str
|
||||||
|
object_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class TripletStepOutput(BaseModel):
|
||||||
|
"""Output of TripletExtractionStep."""
|
||||||
|
|
||||||
|
entities: List[EntityItem] = Field(default_factory=list)
|
||||||
|
triplets: List[TripletItem] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Embedding generation ──
|
||||||
|
class EmbeddingStepInput(BaseModel):
|
||||||
|
"""Input for EmbeddingStep.
|
||||||
|
|
||||||
|
Each dict maps an ID to the text that should be embedded.
|
||||||
|
Fields can be left empty for partial embedding runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
statement_texts: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
chunk_texts: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
dialog_texts: List[str] = Field(default_factory=list)
|
||||||
|
entity_names: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
entity_descriptions: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingStepOutput(BaseModel):
|
||||||
|
"""Output of EmbeddingStep."""
|
||||||
|
|
||||||
|
statement_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
|
||||||
|
chunk_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
|
||||||
|
dialog_embeddings: List[List[float]] = Field(default_factory=list)
|
||||||
|
entity_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""Pydantic models for hot-pluggable sidecar step inputs and outputs.
|
||||||
|
|
||||||
|
Sidecar steps are non-critical (is_critical=False) modules registered via
|
||||||
|
``@SidecarStepFactory.register`` that run concurrently alongside the main
|
||||||
|
extraction pipeline. Failures degrade gracefully to default outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ── Emotion extraction (sidecar) ──
|
||||||
|
class EmotionStepInput(BaseModel):
|
||||||
|
"""Input for EmotionExtractionStep."""
|
||||||
|
|
||||||
|
statement_id: str
|
||||||
|
statement_text: str
|
||||||
|
speaker: str
|
||||||
|
|
||||||
|
|
||||||
|
class EmotionStepOutput(BaseModel):
|
||||||
|
"""Output of EmotionExtractionStep."""
|
||||||
|
|
||||||
|
emotion_type: str = "neutral"
|
||||||
|
emotion_intensity: float = 0.0
|
||||||
|
emotion_keywords: List[str] = Field(default_factory=list)
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
"""SidecarStepFactory — decorator-based registry for sidecar (non-critical) steps.
|
||||||
|
|
||||||
|
New sidecar modules self-register via ``@SidecarStepFactory.register`` and are
|
||||||
|
automatically discovered and instantiated by the orchestrator without any
|
||||||
|
changes to orchestrator code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Tuple, Type
|
||||||
|
|
||||||
|
from .base import ExtractionStep, StepContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SidecarTiming(str, Enum):
|
||||||
|
"""Declares when a sidecar step runs relative to the main pipeline."""
|
||||||
|
|
||||||
|
AFTER_STATEMENT = "after_statement"
|
||||||
|
AFTER_TRIPLET = "after_triplet"
|
||||||
|
|
||||||
|
|
||||||
|
class SidecarStepFactory:
|
||||||
|
"""Factory that manages sidecar step registration and creation.
|
||||||
|
|
||||||
|
Registry maps ``config_key`` → ``(step_class, timing)``.
|
||||||
|
Adding a new sidecar only requires the ``@register`` decorator on the
|
||||||
|
step class — no orchestrator modifications needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Dict[str, Tuple[Type[ExtractionStep], SidecarTiming]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, config_key: str, timing: SidecarTiming):
|
||||||
|
"""Class decorator that registers a sidecar step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_key: Configuration flag name (e.g. ``"emotion_enabled"``).
|
||||||
|
The step is instantiated only when this flag is ``True``.
|
||||||
|
timing: When the sidecar runs relative to the main pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The original class, unmodified.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(step_class: Type[ExtractionStep]):
|
||||||
|
cls._registry[config_key] = (step_class, timing)
|
||||||
|
logger.debug(
|
||||||
|
"Registered sidecar '%s' (config_key=%s, timing=%s)",
|
||||||
|
step_class.__name__,
|
||||||
|
config_key,
|
||||||
|
timing.value,
|
||||||
|
)
|
||||||
|
return step_class
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_sidecars(
|
||||||
|
cls, config: Any, context: StepContext
|
||||||
|
) -> Dict[SidecarTiming, List[ExtractionStep]]:
|
||||||
|
"""Instantiate enabled sidecar steps, grouped by timing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Pipeline configuration object. Each registered
|
||||||
|
``config_key`` is looked up via ``getattr(config, key, False)``.
|
||||||
|
context: Shared :class:`StepContext` injected into every step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict keyed by :class:`SidecarTiming`, each value a list of
|
||||||
|
instantiated sidecar steps whose config flag is ``True``.
|
||||||
|
"""
|
||||||
|
result: Dict[SidecarTiming, List[ExtractionStep]] = {
|
||||||
|
timing: [] for timing in SidecarTiming
|
||||||
|
}
|
||||||
|
for config_key, (step_class, timing) in cls._registry.items():
|
||||||
|
if getattr(config, config_key, False):
|
||||||
|
step = step_class(context)
|
||||||
|
result[timing].append(step)
|
||||||
|
logger.debug(
|
||||||
|
"Created sidecar '%s' (timing=%s)",
|
||||||
|
step_class.__name__,
|
||||||
|
timing.value,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Skipped sidecar '%s' (config_key=%s is disabled)",
|
||||||
|
step_class.__name__,
|
||||||
|
config_key,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_registry(cls) -> None:
|
||||||
|
"""Remove all registered sidecars. Useful for testing."""
|
||||||
|
cls._registry.clear()
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
"""StatementExtractionStep — critical step for extracting statements from chunks.
|
||||||
|
|
||||||
|
Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm.
|
||||||
|
Temporal extraction logic (valid_at / invalid_at) is merged into this step,
|
||||||
|
eliminating the need for a separate ``TemporalExtractor`` call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
|
||||||
|
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||||
|
|
||||||
|
from .base import ExtractionStep, StepContext
|
||||||
|
from .schema import StatementStepInput, StatementStepOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM response schemas (internal) ──
|
||||||
|
|
||||||
|
|
||||||
|
class _ExtractedStatement(BaseModel):
|
||||||
|
"""Raw statement returned by the LLM (before enrichment)."""
|
||||||
|
|
||||||
|
statement: str = Field(..., description="The extracted statement text")
|
||||||
|
statement_type: str = Field(..., description="FACT / OPINION / SUGGESTION / PREDICTION")
|
||||||
|
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
|
||||||
|
relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
|
||||||
|
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||||
|
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||||
|
|
||||||
|
|
||||||
|
class _StatementExtractionResponse(BaseModel):
|
||||||
|
"""Structured LLM response containing a list of extracted statements."""
|
||||||
|
|
||||||
|
statements: List[_ExtractedStatement] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("statements", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def filter_empty(cls, v: Any) -> Any:
|
||||||
|
"""Drop empty / malformed dicts that the LLM occasionally produces."""
|
||||||
|
if isinstance(v, list):
|
||||||
|
return [s for s in v if isinstance(s, dict) and s.get("statement")]
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
|
||||||
|
"""Extract atomic statements (with temporal info) from a dialogue chunk.
|
||||||
|
|
||||||
|
This is a **critical** step — failure aborts the pipeline after retries.
|
||||||
|
|
||||||
|
Config params bound at init (from ``StepContext.config.statement_extraction``):
|
||||||
|
* ``definitions`` — label definitions for statement classification
|
||||||
|
* ``json_schema`` — JSON schema for the expected LLM output
|
||||||
|
* ``granularity`` — extraction granularity level (1-3)
|
||||||
|
* ``include_dialogue_context`` — whether to include full dialogue context
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context: StepContext) -> None:
|
||||||
|
super().__init__(context)
|
||||||
|
stmt_cfg = getattr(self.config, "statement_extraction", None)
|
||||||
|
self.definitions = LABEL_DEFINITIONS
|
||||||
|
self.json_schema = _ExtractedStatement.model_json_schema()
|
||||||
|
self.granularity = getattr(stmt_cfg, "statement_granularity", None)
|
||||||
|
self.include_dialogue_context = getattr(stmt_cfg, "include_dialogue_context", True)
|
||||||
|
self.max_dialogue_context_chars = getattr(stmt_cfg, "max_dialogue_context_chars", 2000)
|
||||||
|
|
||||||
|
# ── Identity ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "statement_extraction"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Lifecycle ──
|
||||||
|
|
||||||
|
async def render_prompt(self, input_data: StatementStepInput) -> str:
|
||||||
|
# Build optional dialogue context from supporting_context messages
|
||||||
|
dialogue_content = None
|
||||||
|
if self.include_dialogue_context and input_data.supporting_context.msgs:
|
||||||
|
dialogue_content = "\n".join(
|
||||||
|
f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
return await render_statement_extraction_prompt(
|
||||||
|
chunk_content=input_data.target_content,
|
||||||
|
definitions=self.definitions,
|
||||||
|
json_schema=self.json_schema,
|
||||||
|
granularity=self.granularity,
|
||||||
|
include_dialogue_context=self.include_dialogue_context,
|
||||||
|
dialogue_content=dialogue_content,
|
||||||
|
max_dialogue_chars=self.max_dialogue_context_chars,
|
||||||
|
language=self.language,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call_llm(self, prompt: Any) -> Any:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You are an expert at extracting and labeling atomic statements "
|
||||||
|
"from conversational text. Return valid JSON conforming to the schema."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
return await self.llm_client.response_structured(
|
||||||
|
messages, _StatementExtractionResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
async def parse_response(
|
||||||
|
self, raw_response: Any, input_data: StatementStepInput
|
||||||
|
) -> List[StatementStepOutput]:
|
||||||
|
if not hasattr(raw_response, "statements") or raw_response.statements is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results: List[StatementStepOutput] = []
|
||||||
|
for stmt in raw_response.statements:
|
||||||
|
results.append(
|
||||||
|
StatementStepOutput(
|
||||||
|
statement_id=uuid.uuid4().hex,
|
||||||
|
statement_text=stmt.statement,
|
||||||
|
statement_type=stmt.statement_type.strip().upper(),
|
||||||
|
temporal_type=stmt.temporal_type.strip().upper(),
|
||||||
|
relevance=stmt.relevance.strip().upper(),
|
||||||
|
speaker="user", # default; orchestrator overrides from chunk metadata
|
||||||
|
valid_at=stmt.valid_at or "NULL",
|
||||||
|
invalid_at=stmt.invalid_at or "NULL",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_default_output(self) -> List[StatementStepOutput]:
|
||||||
|
return []
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
"""TripletExtractionStep — critical step for extracting entities and triplets.
|
||||||
|
|
||||||
|
Replaces the legacy ``TripletExtractor`` with the unified ExtractionStep paradigm.
|
||||||
|
Predicate filtering against the ontology whitelist is performed in ``parse_response``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||||
|
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate
|
||||||
|
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||||
|
|
||||||
|
from .base import ExtractionStep, StepContext
|
||||||
|
from .schema import EntityItem, TripletItem, TripletStepInput, TripletStepOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput]):
|
||||||
|
"""Extract knowledge triplets and entities from a single statement.
|
||||||
|
|
||||||
|
This is a **critical** step — failure aborts the pipeline after retries.
|
||||||
|
|
||||||
|
Config params bound at init (from ``StepContext.config``):
|
||||||
|
* ``ontology_types`` — predefined ontology types for entity classification
|
||||||
|
* ``predicate_instructions`` — predicate definition guidance for the LLM
|
||||||
|
* ``json_schema`` — JSON schema for the expected LLM output
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: StepContext,
|
||||||
|
ontology_types: Any = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(context)
|
||||||
|
self.ontology_types = ontology_types
|
||||||
|
self.predicate_instructions = PREDICATE_DEFINITIONS
|
||||||
|
self.json_schema = TripletExtractionResponse.model_json_schema()
|
||||||
|
self._allowed_predicates = {p.value for p in Predicate}
|
||||||
|
|
||||||
|
# ── Identity ──
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "triplet_extraction"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_critical(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Lifecycle ──
|
||||||
|
|
||||||
|
async def render_prompt(self, input_data: TripletStepInput) -> str:
|
||||||
|
# Build chunk_content from supporting_context for pronoun resolution
|
||||||
|
chunk_content = "\n".join(
|
||||||
|
f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs
|
||||||
|
) if input_data.supporting_context.msgs else ""
|
||||||
|
|
||||||
|
return await render_triplet_extraction_prompt(
|
||||||
|
statement=input_data.statement_text,
|
||||||
|
chunk_content=chunk_content,
|
||||||
|
json_schema=self.json_schema,
|
||||||
|
predicate_instructions=self.predicate_instructions,
|
||||||
|
language=self.language,
|
||||||
|
ontology_types=self.ontology_types,
|
||||||
|
speaker=input_data.speaker,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call_llm(self, prompt: Any) -> Any:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You are an expert at extracting knowledge triplets and entities "
|
||||||
|
"from text. Follow the provided instructions carefully and return valid JSON."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
return await self.llm_client.response_structured(
|
||||||
|
messages, TripletExtractionResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
async def parse_response(
|
||||||
|
self, raw_response: Any, input_data: TripletStepInput
|
||||||
|
) -> TripletStepOutput:
|
||||||
|
if not hasattr(raw_response, "triplets"):
|
||||||
|
return self.get_default_output()
|
||||||
|
|
||||||
|
# Filter triplets to allowed predicates from ontology whitelist
|
||||||
|
filtered_triplets = [
|
||||||
|
TripletItem(
|
||||||
|
subject_name=t.subject_name,
|
||||||
|
subject_id=t.subject_id,
|
||||||
|
predicate=t.predicate,
|
||||||
|
object_name=t.object_name,
|
||||||
|
object_id=t.object_id,
|
||||||
|
)
|
||||||
|
for t in raw_response.triplets
|
||||||
|
if getattr(t, "predicate", "") in self._allowed_predicates
|
||||||
|
]
|
||||||
|
|
||||||
|
entities = [
|
||||||
|
EntityItem(
|
||||||
|
entity_idx=e.entity_idx,
|
||||||
|
name=e.name,
|
||||||
|
type=e.type,
|
||||||
|
description=e.description,
|
||||||
|
is_explicit_memory=getattr(e, "is_explicit_memory", False),
|
||||||
|
)
|
||||||
|
for e in (raw_response.entities or [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return TripletStepOutput(entities=entities, triplets=filtered_triplets)
|
||||||
|
|
||||||
|
def get_default_output(self) -> TripletStepOutput:
|
||||||
|
return TripletStepOutput(entities=[], triplets=[])
|
||||||
@@ -421,6 +421,9 @@ class MemoryConfig:
|
|||||||
pruning_scene: Optional[str] = "education"
|
pruning_scene: Optional[str] = "education"
|
||||||
pruning_threshold: float = 0.5
|
pruning_threshold: float = 0.5
|
||||||
|
|
||||||
|
# Pipeline config: Emotion extraction
|
||||||
|
emotion_enabled: bool = False
|
||||||
|
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id: Optional[UUID] = None
|
scene_id: Optional[UUID] = None
|
||||||
ontology_class_infos: list[dict] = field(default_factory=list)
|
ontology_class_infos: list[dict] = field(default_factory=list)
|
||||||
|
|||||||
@@ -360,40 +360,64 @@ class MemoryAgentService:
|
|||||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
return "success"
|
return "success"
|
||||||
else:
|
else:
|
||||||
await write_neo4j(
|
# TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码
|
||||||
end_user_id=end_user_id,
|
|
||||||
messages=messages,
|
|
||||||
memory_config=memory_config,
|
|
||||||
ref_id='',
|
|
||||||
language=language
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
|
|
||||||
import os
|
import os
|
||||||
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
|
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
|
||||||
try:
|
|
||||||
from app.core.memory.memory_service import MemoryService
|
|
||||||
import copy
|
|
||||||
|
|
||||||
shadow_messages = copy.deepcopy(messages)
|
if use_new_pipeline:
|
||||||
shadow_service = MemoryService(
|
# ── 新流水线:WritePipeline + NewExtractionOrchestrator ──
|
||||||
memory_config=memory_config,
|
from app.core.memory.memory_service import MemoryService
|
||||||
end_user_id=end_user_id,
|
|
||||||
)
|
service = MemoryService(
|
||||||
shadow_result = await shadow_service.write(
|
memory_config=memory_config,
|
||||||
messages=shadow_messages,
|
end_user_id=end_user_id,
|
||||||
language=language,
|
)
|
||||||
ref_id='',
|
result = await service.write(
|
||||||
is_pilot_run=True, # 试运行模式:只萃取不写入,避免重复写入 Neo4j
|
messages=messages,
|
||||||
)
|
language=language,
|
||||||
logger.info(
|
ref_id='',
|
||||||
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
|
is_pilot_run=False,
|
||||||
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
|
)
|
||||||
f"extraction={shadow_result.extraction}"
|
logger.info(
|
||||||
)
|
f"[NewPipeline] 完成: status={result.status}, "
|
||||||
except Exception as shadow_err:
|
f"elapsed={result.elapsed_seconds:.2f}s, "
|
||||||
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
|
f"extraction={result.extraction}"
|
||||||
# ── 影子运行结束 ──
|
)
|
||||||
|
else:
|
||||||
|
# ── 旧流水线:write_tools.write() + ExtractionOrchestrator ──
|
||||||
|
await write_neo4j(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
messages=messages,
|
||||||
|
memory_config=memory_config,
|
||||||
|
ref_id='',
|
||||||
|
language=language
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
|
||||||
|
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
|
||||||
|
try:
|
||||||
|
from app.core.memory.memory_service import MemoryService
|
||||||
|
import copy
|
||||||
|
|
||||||
|
shadow_messages = copy.deepcopy(messages)
|
||||||
|
shadow_service = MemoryService(
|
||||||
|
memory_config=memory_config,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
shadow_result = await shadow_service.write(
|
||||||
|
messages=shadow_messages,
|
||||||
|
language=language,
|
||||||
|
ref_id='',
|
||||||
|
is_pilot_run=True,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
|
||||||
|
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
|
||||||
|
f"extraction={shadow_result.extraction}"
|
||||||
|
)
|
||||||
|
except Exception as shadow_err:
|
||||||
|
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
|
||||||
|
# ── 影子运行结束 ──
|
||||||
for lang in ["zh", "en"]:
|
for lang in ["zh", "en"]:
|
||||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||||
end_user_id, lang
|
end_user_id, lang
|
||||||
|
|||||||
@@ -418,6 +418,9 @@ class MemoryConfigService:
|
|||||||
pruning_scene=memory_config.pruning_scene or "education",
|
pruning_scene=memory_config.pruning_scene or "education",
|
||||||
pruning_threshold=float(
|
pruning_threshold=float(
|
||||||
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||||
|
# Pipeline config: Emotion extraction
|
||||||
|
emotion_enabled=bool(
|
||||||
|
memory_config.emotion_enabled) if memory_config.emotion_enabled is not None else False,
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id=memory_config.scene_id,
|
scene_id=memory_config.scene_id,
|
||||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||||
@@ -573,6 +576,7 @@ class MemoryConfigService:
|
|||||||
statement_extraction=stmt_config,
|
statement_extraction=stmt_config,
|
||||||
deduplication=dedup_config,
|
deduplication=dedup_config,
|
||||||
forgetting_engine=forget_config,
|
forgetting_engine=forget_config,
|
||||||
|
emotion_enabled=getattr(memory_config, "emotion_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user