feat(memory): implement step-based extraction pipeline architecture
Introduce ExtractionStep abstraction with modular pipeline stages: - Add base ExtractionStep class with render/call/parse lifecycle - Implement StatementExtractionStep, TripletExtractionStep, EmbeddingStep, EmotionStep, GraphBuildStep, and DedupStep - Add SidecarStepFactory for hot-pluggable non-critical steps - Define Pydantic I/O schemas for all pipeline stages - Refactor WritePipeline to orchestrate new step-based flow - Add NEW_PIPELINE_ENABLED env switch for old/new pipeline routing - Add emotion_enabled config flag to MemoryConfig - Fix workspace_id reference in get_end_user_connected_config
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
@@ -34,6 +31,7 @@ async def get_chunked_dialogs(
|
||||
|
||||
conversation_messages = []
|
||||
|
||||
# step1: 消息格式校验 role:user、assistant。content
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||
@@ -59,7 +57,7 @@ async def get_chunked_dialogs(
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 语义剪枝步骤(在分块之前)
|
||||
# step2: 语义剪枝步骤(在分块之前)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
@@ -116,6 +114,7 @@ async def get_chunked_dialogs(
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||
|
||||
# step3: 分块
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
@@ -147,7 +147,85 @@ async def write(
|
||||
all_perceptual_edges,
|
||||
all_dedup_details,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
# region TODO 乐力齐 重构流水线切换至生产环境稳定后,移除快照对比代码
|
||||
# ── Snapshot: 旧流水线萃取结果(按 phase2_step_io_schema_v1.md 格式) ──
|
||||
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
|
||||
snapshot = PipelineSnapshot("legacy")
|
||||
|
||||
# Statement 输出(从 dialog_data_list 中提取)
|
||||
stmt_snapshot = []
|
||||
for d in all_dedup_details:
|
||||
if not hasattr(d, "chunks"):
|
||||
continue
|
||||
for c in d.chunks:
|
||||
for s in c.statements:
|
||||
stmt_snapshot.append({
|
||||
"statement_id": s.id,
|
||||
"statement_text": s.statement,
|
||||
"statement_type": str(getattr(s, "stmt_type", "")),
|
||||
"temporal_type": str(getattr(s, "temporal_info", "")),
|
||||
"relevance": str(getattr(s, "relevence_info", "RELEVANT")),
|
||||
"speaker": getattr(s, "speaker", "user") or "user",
|
||||
"valid_at": s.temporal_validity.valid_at if s.temporal_validity else "NULL",
|
||||
"invalid_at": s.temporal_validity.invalid_at if s.temporal_validity else "NULL",
|
||||
})
|
||||
snapshot.save_stage("2_statement_outputs", stmt_snapshot)
|
||||
|
||||
# Triplet 输出(从 dialog_data_list 中提取)
|
||||
triplet_snapshot = {}
|
||||
for d in all_dedup_details:
|
||||
if not hasattr(d, "chunks"):
|
||||
continue
|
||||
for c in d.chunks:
|
||||
for s in c.statements:
|
||||
if s.triplet_extraction_info:
|
||||
triplet_snapshot[s.id] = {
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": e.entity_idx, "name": e.name,
|
||||
"type": e.type, "description": e.description,
|
||||
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
|
||||
}
|
||||
for e in s.triplet_extraction_info.entities
|
||||
],
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": t.subject_name, "subject_id": t.subject_id,
|
||||
"predicate": t.predicate,
|
||||
"object_name": t.object_name, "object_id": t.object_id,
|
||||
}
|
||||
for t in s.triplet_extraction_info.triplets
|
||||
],
|
||||
}
|
||||
snapshot.save_stage("3_triplet_outputs", triplet_snapshot)
|
||||
|
||||
# 图节点和边(去重后)
|
||||
snapshot.save_stage("6_nodes_edges_after_dedup", {
|
||||
"dialogue_nodes_count": len(all_dialogue_nodes),
|
||||
"chunk_nodes_count": len(all_chunk_nodes),
|
||||
"statement_nodes_count": len(all_statement_nodes),
|
||||
"entity_nodes": [
|
||||
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
|
||||
for e in all_entity_nodes
|
||||
],
|
||||
"entity_entity_edges": [
|
||||
{
|
||||
"source": e.source, "target": e.target,
|
||||
"relation_type": e.relation_type, "statement": e.statement,
|
||||
}
|
||||
for e in all_entity_entity_edges
|
||||
],
|
||||
})
|
||||
snapshot.save_summary({
|
||||
"dialogue_count": len(all_dialogue_nodes),
|
||||
"chunk_count": len(all_chunk_nodes),
|
||||
"statement_count": len(all_statement_nodes),
|
||||
"entity_count": len(all_entity_nodes),
|
||||
"relation_count": len(all_entity_entity_edges),
|
||||
})
|
||||
# endregion
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# Step 3: Save all data to Neo4j database
|
||||
|
||||
@@ -149,3 +149,5 @@ class ExtractionPipelineConfig(BaseModel):
|
||||
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
|
||||
deduplication: DedupConfig = Field(default_factory=DedupConfig)
|
||||
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)
|
||||
# 情绪引擎(旁路模块,SidecarStepFactory 通过此字段判断是否启用)
|
||||
emotion_enabled: bool = Field(default=False, description="是否启用情绪提取旁路")
|
||||
|
||||
@@ -12,20 +12,33 @@ WritePipeline — 记忆写入流水线
|
||||
|
||||
依赖方向:Facade → Pipeline → Engine → Repository(单向,不允许反向调用)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.core.memory.models.graph_models import ExtractedEntityNode
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
ChunkNode,
|
||||
DialogueNode,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
PerceptualEdge,
|
||||
PerceptualNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -34,36 +47,40 @@ logger = logging.getLogger(__name__)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""萃取步骤的结构化输出,替代 ExtractionOrchestrator.run() 返回的裸元组。
|
||||
class ExtractionResult(BaseModel):
|
||||
"""萃取 + 图构建 + 去重消歧后的结构化输出。
|
||||
|
||||
字段与 ExtractionOrchestrator.run() 的 10 元素返回值一一对应:
|
||||
[0] dialogue_nodes → self.dialogue_nodes
|
||||
[1] chunk_nodes → self.chunk_nodes
|
||||
[2] statement_nodes → self.statement_nodes
|
||||
[3] entity_nodes → self.entity_nodes
|
||||
[4] perceptual_nodes → self.perceptual_nodes
|
||||
[5] stmt_chunk_edges → self.stmt_chunk_edges
|
||||
[6] stmt_entity_edges → self.stmt_entity_edges
|
||||
[7] entity_entity_edges → self.entity_entity_edges
|
||||
[8] perceptual_edges → self.perceptual_edges
|
||||
[9] dialog_data_list → self.dialog_data_list
|
||||
作为 Pipeline 层的阶段间数据载体,确保下游步骤(_store、_cluster)
|
||||
接收到的图节点和边结构完整、类型正确。
|
||||
|
||||
注意:字段类型使用 List[Any] 而非具体的 graph_models 类型,
|
||||
避免在模块加载时触发循环依赖。Pipeline 只做数据传递,不检查具体类型。
|
||||
字段对应 ExtractionOrchestrator 产出的图节点/边:
|
||||
dialogue_nodes — 对话节点
|
||||
chunk_nodes — 分块节点
|
||||
statement_nodes — 陈述句节点
|
||||
entity_nodes — 实体节点(去重消歧后)
|
||||
perceptual_nodes — 感知节点
|
||||
stmt_chunk_edges — 陈述句 → 分块 边
|
||||
stmt_entity_edges — 陈述句 → 实体 边
|
||||
entity_entity_edges — 实体 → 实体 边(去重消歧后)
|
||||
perceptual_edges — 感知 → 分块 边
|
||||
dialog_data_list — 原始 DialogData(供摘要阶段使用)
|
||||
"""
|
||||
|
||||
dialogue_nodes: List[Any]
|
||||
chunk_nodes: List[Any]
|
||||
statement_nodes: List[Any]
|
||||
entity_nodes: List[Any]
|
||||
perceptual_nodes: List[Any]
|
||||
stmt_chunk_edges: List[Any]
|
||||
stmt_entity_edges: List[Any]
|
||||
entity_entity_edges: List[Any]
|
||||
perceptual_edges: List[Any]
|
||||
dialog_data_list: List[Any]
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
dialogue_nodes: List[DialogueNode]
|
||||
chunk_nodes: List[ChunkNode]
|
||||
statement_nodes: List[StatementNode]
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
perceptual_nodes: List[PerceptualNode]
|
||||
stmt_chunk_edges: List[StatementChunkEdge]
|
||||
stmt_entity_edges: List[StatementEntityEdge]
|
||||
entity_entity_edges: List[EntityEntityEdge]
|
||||
perceptual_edges: List[PerceptualEdge]
|
||||
dialog_data_list: List[Any] = Field(
|
||||
default_factory=list,
|
||||
description="原始 DialogData 列表,类型为 Any 以避免循环依赖",
|
||||
)
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, int]:
|
||||
@@ -78,8 +95,7 @@ class ExtractionResult:
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WriteResult:
|
||||
class WriteResult(BaseModel):
|
||||
"""写入流水线的最终输出,返回给 MemoryService / MemoryAgentService"""
|
||||
|
||||
status: str # "success" | "pilot_complete" | "failed"
|
||||
@@ -114,7 +130,7 @@ class WritePipeline:
|
||||
memory_config: 不可变的记忆配置对象(从数据库加载)
|
||||
end_user_id: 终端用户 ID
|
||||
language: 语言 ("zh" | "en")
|
||||
progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None]
|
||||
progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None] 供pilot run使用
|
||||
"""
|
||||
self.memory_config = memory_config
|
||||
self.end_user_id = end_user_id
|
||||
@@ -145,7 +161,7 @@ class WritePipeline:
|
||||
is_pilot_run: 试运行模式(只萃取不写入)
|
||||
|
||||
Returns:
|
||||
WriteResult 包含状态和统计信息
|
||||
WriteResult 包含状态和统计信息
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
@@ -164,7 +180,7 @@ class WritePipeline:
|
||||
self._init_clients()
|
||||
self._init_neo4j_connector()
|
||||
|
||||
# Step 1: 预处理 - 消息分块
|
||||
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝(暂无实现)
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await self._preprocess(messages, ref_id)
|
||||
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
|
||||
@@ -175,9 +191,7 @@ class WritePipeline:
|
||||
|
||||
# Step 2: 萃取 - 知识提取
|
||||
step_start = time.time()
|
||||
extraction_result = await self._extract(
|
||||
chunked_dialogs, is_pilot_run
|
||||
)
|
||||
extraction_result = await self._extract(chunked_dialogs, is_pilot_run)
|
||||
stats = extraction_result.stats
|
||||
logger.info(
|
||||
f"[WritePipeline] [2/5] 萃取:知识提取 "
|
||||
@@ -190,9 +204,7 @@ class WritePipeline:
|
||||
# 试运行模式到此结束
|
||||
if is_pilot_run:
|
||||
elapsed = time.time() - pipeline_start
|
||||
logger.info(
|
||||
f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s"
|
||||
)
|
||||
logger.info(f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s")
|
||||
return WriteResult(
|
||||
status="pilot_complete",
|
||||
extraction=extraction_result.stats,
|
||||
@@ -227,9 +239,7 @@ class WritePipeline:
|
||||
await self._update_stats_cache(extraction_result)
|
||||
|
||||
elapsed = time.time() - pipeline_start
|
||||
logger.info(
|
||||
f"[WritePipeline] 完成 ✔ {elapsed:.2f}s"
|
||||
)
|
||||
logger.info(f"[WritePipeline] 完成 ✔ {elapsed:.2f}s")
|
||||
return WriteResult(
|
||||
status="success",
|
||||
extraction=extraction_result.stats,
|
||||
@@ -251,16 +261,14 @@ class WritePipeline:
|
||||
# Step 1: 预处理
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
async def _preprocess(
|
||||
self, messages: List[dict], ref_id: str
|
||||
) -> List[DialogData]:
|
||||
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
|
||||
"""
|
||||
预处理:消息校验 → 语义剪枝 → 对话分块。
|
||||
预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。
|
||||
|
||||
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
|
||||
get_dialogs.py 内部已包含:
|
||||
- 消息格式校验(role/content 必填)
|
||||
- 语义剪枝(根据 config 中 pruning_enabled 决定)
|
||||
- AI消息语义剪枝(根据 config 中 pruning_enabled 决定)
|
||||
- DialogueChunker 分块
|
||||
"""
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
@@ -283,56 +291,187 @@ class WritePipeline:
|
||||
is_pilot_run: bool,
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
萃取:初始化引擎 → 执行知识提取 → 返回结构化结果。
|
||||
萃取:初始化引擎 → 执行知识提取 → 构建图节点/边 → 去重 → 返回结构化结果。
|
||||
|
||||
ExtractionOrchestrator 作为萃取引擎被调用,
|
||||
Pipeline 不关心引擎内部的并行策略和提取细节。
|
||||
使用 NewExtractionOrchestrator(ExtractionStep 范式)完成 LLM 萃取,
|
||||
然后通过独立的 graph_build_step 和 dedup_step 完成图构建和去重,
|
||||
不依赖旧编排器 ExtractionOrchestrator。
|
||||
|
||||
执行流程:
|
||||
1. NewExtractionOrchestrator.run() → 萃取并赋值到 DialogData
|
||||
2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边
|
||||
3. run_dedup() → 两阶段去重消歧
|
||||
"""
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
from app.core.memory.storage_services.extraction_engine.dedup_step import (
|
||||
run_dedup,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
||||
build_graph_nodes_and_edges,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
|
||||
NewExtractionOrchestrator,
|
||||
)
|
||||
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
|
||||
|
||||
pipeline_config = get_pipeline_config(self.memory_config)
|
||||
ontology_types = self._load_ontology_types()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
snapshot = PipelineSnapshot("new")
|
||||
|
||||
# ── 新编排器:LLM 萃取 + 数据赋值 ──
|
||||
new_orchestrator = NewExtractionOrchestrator(
|
||||
llm_client=self._llm_client,
|
||||
embedder_client=self._embedder_client,
|
||||
connector=self._neo4j_connector,
|
||||
config=pipeline_config,
|
||||
embedding_id=str(self.memory_config.embedding_model_id),
|
||||
language=self.language,
|
||||
ontology_types=ontology_types,
|
||||
language=self.language,
|
||||
is_pilot_run=is_pilot_run,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
# step1: 执行知识提取
|
||||
dialog_data_list = await new_orchestrator.run(chunked_dialogs)
|
||||
|
||||
# ── Snapshot: 各阶段萃取结果 ── TODO 乐力齐 重构流水线切换生产环境稳定后修改
|
||||
stage_outputs = new_orchestrator.last_stage_outputs
|
||||
if stage_outputs:
|
||||
stmt_results = stage_outputs.get("statement_results", {})
|
||||
stmt_snapshot = []
|
||||
for _did, chunk_stmts in stmt_results.items():
|
||||
for _cid, stmts in chunk_stmts.items():
|
||||
for s in stmts:
|
||||
stmt_snapshot.append(s.model_dump())
|
||||
snapshot.save_stage("2_statement_outputs", stmt_snapshot)
|
||||
|
||||
triplet_results = stage_outputs.get("triplet_results", {})
|
||||
triplet_snapshot = {}
|
||||
for _did, stmt_triplets in triplet_results.items():
|
||||
for stmt_id, t_out in stmt_triplets.items():
|
||||
triplet_snapshot[stmt_id] = t_out.model_dump()
|
||||
snapshot.save_stage("3_triplet_outputs", triplet_snapshot)
|
||||
|
||||
emotion_results = stage_outputs.get("emotion_results", {})
|
||||
emotion_snapshot = {}
|
||||
for stmt_id, emo in emotion_results.items():
|
||||
if hasattr(emo, "model_dump"):
|
||||
emotion_snapshot[stmt_id] = emo.model_dump()
|
||||
snapshot.save_stage("4_emotion_outputs", emotion_snapshot)
|
||||
|
||||
emb_output = stage_outputs.get("embedding_output")
|
||||
if emb_output and hasattr(emb_output, "model_dump"):
|
||||
emb_data = emb_output.model_dump()
|
||||
for key in (
|
||||
"statement_embeddings",
|
||||
"chunk_embeddings",
|
||||
"entity_embeddings",
|
||||
):
|
||||
if key in emb_data and isinstance(emb_data[key], dict):
|
||||
emb_data[key] = {
|
||||
k: v[:5] if isinstance(v, list) else v
|
||||
for k, v in emb_data[key].items()
|
||||
}
|
||||
if "dialog_embeddings" in emb_data and isinstance(
|
||||
emb_data["dialog_embeddings"], list
|
||||
):
|
||||
emb_data["dialog_embeddings"] = [
|
||||
v[:5] if isinstance(v, list) else v
|
||||
for v in emb_data["dialog_embeddings"]
|
||||
]
|
||||
snapshot.save_stage("5_embedding_outputs", emb_data)
|
||||
|
||||
# step2: 构建图节点和边
|
||||
graph = await build_graph_nodes_and_edges(
|
||||
dialog_data_list=dialog_data_list,
|
||||
embedder_client=self._embedder_client,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
perceptual_nodes,
|
||||
stmt_chunk_edges,
|
||||
stmt_entity_edges,
|
||||
entity_entity_edges,
|
||||
perceptual_edges,
|
||||
dialog_data_list,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=is_pilot_run)
|
||||
# region Snapshot: 图节点和边(去重前)Snapshot有关的内容在重构流水线切换生产环境之后修改
|
||||
snapshot.save_stage(
|
||||
"6_nodes_edges_before_dedup",
|
||||
{
|
||||
"dialogue_nodes_count": len(graph.dialogue_nodes),
|
||||
"chunk_nodes_count": len(graph.chunk_nodes),
|
||||
"statement_nodes_count": len(graph.statement_nodes),
|
||||
"entity_nodes": [
|
||||
{
|
||||
"id": e.id,
|
||||
"name": e.name,
|
||||
"entity_type": e.entity_type,
|
||||
"description": e.description,
|
||||
}
|
||||
for e in graph.entity_nodes
|
||||
],
|
||||
"entity_entity_edges": [
|
||||
{
|
||||
"source": e.source,
|
||||
"target": e.target,
|
||||
"relation_type": e.relation_type,
|
||||
"statement": e.statement,
|
||||
}
|
||||
for e in graph.entity_entity_edges
|
||||
],
|
||||
"stmt_entity_edges_count": len(graph.stmt_entity_edges),
|
||||
},
|
||||
)
|
||||
|
||||
return ExtractionResult(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
perceptual_nodes=perceptual_nodes,
|
||||
stmt_chunk_edges=stmt_chunk_edges,
|
||||
stmt_entity_edges=stmt_entity_edges,
|
||||
entity_entity_edges=entity_entity_edges,
|
||||
perceptual_edges=perceptual_edges,
|
||||
# step3: 两阶段去重消歧
|
||||
dedup_result = await run_dedup(
|
||||
entity_nodes=graph.entity_nodes,
|
||||
statement_entity_edges=graph.stmt_entity_edges,
|
||||
entity_entity_edges=graph.entity_entity_edges,
|
||||
dialog_data_list=dialog_data_list,
|
||||
pipeline_config=pipeline_config,
|
||||
connector=self._neo4j_connector,
|
||||
llm_client=self._llm_client,
|
||||
is_pilot_run=is_pilot_run,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
|
||||
# Snapshot: 去重后
|
||||
snapshot.save_stage(
|
||||
"7_after_dedup",
|
||||
{
|
||||
"entity_nodes": [
|
||||
{
|
||||
"id": e.id,
|
||||
"name": e.name,
|
||||
"entity_type": e.entity_type,
|
||||
"description": e.description,
|
||||
}
|
||||
for e in dedup_result.entity_nodes
|
||||
],
|
||||
"entity_entity_edges": [
|
||||
{
|
||||
"source": e.source,
|
||||
"target": e.target,
|
||||
"relation_type": e.relation_type,
|
||||
"statement": e.statement,
|
||||
}
|
||||
for e in dedup_result.entity_entity_edges
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# step4: 构造最终结果
|
||||
result = ExtractionResult(
|
||||
dialogue_nodes=graph.dialogue_nodes,
|
||||
chunk_nodes=graph.chunk_nodes,
|
||||
statement_nodes=graph.statement_nodes,
|
||||
entity_nodes=dedup_result.entity_nodes,
|
||||
perceptual_nodes=graph.perceptual_nodes,
|
||||
stmt_chunk_edges=graph.stmt_chunk_edges,
|
||||
stmt_entity_edges=dedup_result.statement_entity_edges,
|
||||
entity_entity_edges=dedup_result.entity_entity_edges,
|
||||
perceptual_edges=graph.perceptual_edges,
|
||||
dialog_data_list=dialog_data_list,
|
||||
)
|
||||
|
||||
snapshot.save_summary(result.stats) # TODO 乐力齐 snapshot需要改
|
||||
return result
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 3: 存储
|
||||
# ──────────────────────────────────────────────
|
||||
@@ -379,14 +518,10 @@ class WritePipeline:
|
||||
)
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
else:
|
||||
logger.error(
|
||||
f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败"
|
||||
)
|
||||
logger.error(f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败")
|
||||
except Exception as e:
|
||||
if self._is_deadlock(e) and attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})"
|
||||
)
|
||||
logger.warning(f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
@@ -401,6 +536,10 @@ class WritePipeline:
|
||||
|
||||
聚类不阻塞主写入流程,失败不影响写入结果。
|
||||
通过 Celery 异步执行,由 LabelPropagationEngine 完成实际计算。
|
||||
|
||||
注意:ExtractionResult.entity_nodes 已经是经过 _extract() 中
|
||||
两阶段去重消歧(_run_dedup_and_write_summary)后的结果,
|
||||
聚类直接基于去重后的实体 ID 执行。
|
||||
"""
|
||||
if not result.entity_nodes:
|
||||
return
|
||||
@@ -428,7 +567,9 @@ class WritePipeline:
|
||||
)
|
||||
logger.info(
|
||||
f"[Clustering] 增量聚类任务已提交 - "
|
||||
f"task_id={task.id}, entity_count={len(new_entity_ids)}"
|
||||
f"task_id={task.id}, "
|
||||
f"entity_count={len(new_entity_ids)}, "
|
||||
f"source=dedup"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -438,9 +579,9 @@ class WritePipeline:
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 5: 摘要
|
||||
# (+ entity_description)
|
||||
# (+ entity_description)+ meta_data部分在此提取
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
# TODO 乐力齐 需要做成异步celery任务
|
||||
async def _summarize(self, chunked_dialogs: List[DialogData]) -> None:
|
||||
"""
|
||||
摘要:生成情景记忆摘要 → 写入 Neo4j。
|
||||
@@ -467,9 +608,7 @@ class WritePipeline:
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(
|
||||
summaries, ms_connector
|
||||
)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
try:
|
||||
await ms_connector.close()
|
||||
@@ -494,9 +633,7 @@ class WritePipeline:
|
||||
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self._llm_client = factory.get_llm_client_from_config(
|
||||
self.memory_config
|
||||
)
|
||||
self._llm_client = factory.get_llm_client_from_config(self.memory_config)
|
||||
self._embedder_client = factory.get_embedder_client_from_config(
|
||||
self.memory_config
|
||||
)
|
||||
@@ -564,10 +701,8 @@ class WritePipeline:
|
||||
if entity_nodes:
|
||||
eu_id = entity_nodes[0].end_user_id
|
||||
if eu_id:
|
||||
neo4j_assistant_aliases = (
|
||||
await fetch_neo4j_assistant_aliases(
|
||||
self._neo4j_connector, eu_id
|
||||
)
|
||||
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(
|
||||
self._neo4j_connector, eu_id
|
||||
)
|
||||
clean_cross_role_aliases(
|
||||
entity_nodes,
|
||||
@@ -586,9 +721,7 @@ class WritePipeline:
|
||||
msg = str(e).lower()
|
||||
return "deadlockdetected" in msg or "deadlock" in msg
|
||||
|
||||
async def _update_stats_cache(
|
||||
self, result: ExtractionResult
|
||||
) -> None:
|
||||
async def _update_stats_cache(self, result: ExtractionResult) -> None:
|
||||
"""
|
||||
将提取统计写入 Redis 活动缓存,按 workspace_id 存储。
|
||||
失败不中断主流程。
|
||||
@@ -614,9 +747,7 @@ class WritePipeline:
|
||||
f"workspace_id={self.memory_config.workspace_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"写入活动统计缓存失败(不影响主流程): {e}"
|
||||
)
|
||||
logger.warning(f"写入活动统计缓存失败(不影响主流程): {e}")
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""
|
||||
@@ -634,16 +765,14 @@ class WritePipeline:
|
||||
# 防止 'RuntimeError: Event loop is closed' 在垃圾回收时触发
|
||||
for client_obj in (self._llm_client, self._embedder_client):
|
||||
try:
|
||||
underlying = getattr(
|
||||
client_obj, "client", None
|
||||
) or getattr(client_obj, "model", None)
|
||||
underlying = getattr(client_obj, "client", None) or getattr(
|
||||
client_obj, "model", None
|
||||
)
|
||||
if underlying is None:
|
||||
continue
|
||||
inner = getattr(underlying, "_model", underlying)
|
||||
http_client = getattr(inner, "async_client", None)
|
||||
if http_client is not None and hasattr(
|
||||
http_client, "aclose"
|
||||
):
|
||||
if http_client is not None and hasattr(http_client, "aclose"):
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Independent deduplication module for the extraction pipeline.
|
||||
|
||||
Extracts dedup logic from ExtractionOrchestrator into standalone functions
|
||||
so the orchestrator stays thin and dedup can be tested/evolved independently.
|
||||
|
||||
The module exposes:
|
||||
- ``DedupResult`` — structured output of the dedup process
|
||||
- ``run_dedup()`` — async entry point called by WritePipeline
|
||||
- Helper functions migrated from ExtractionOrchestrator:
|
||||
``save_dedup_details``, ``analyze_entity_merges``,
|
||||
``analyze_entity_disambiguation``, ``send_dedup_progress_callback``,
|
||||
``parse_dedup_report``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DedupResult dataclass (Requirement 10.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class DedupResult:
|
||||
"""Structured output of the two-stage entity deduplication process.
|
||||
|
||||
Attributes:
|
||||
entity_nodes: Deduplicated entity node list.
|
||||
statement_entity_edges: Deduplicated statement-entity edges.
|
||||
entity_entity_edges: Deduplicated entity-entity edges.
|
||||
dedup_details: Raw detail dict returned by the first-layer dedup.
|
||||
merge_records: Parsed merge records (exact / fuzzy / LLM).
|
||||
disamb_records: Parsed disambiguation records.
|
||||
"""
|
||||
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
statement_entity_edges: List[StatementEntityEdge]
|
||||
entity_entity_edges: List[EntityEntityEdge]
|
||||
dedup_details: Dict[str, Any] = field(default_factory=dict)
|
||||
merge_records: List[Dict[str, Any]] = field(default_factory=list)
|
||||
disamb_records: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, int]:
|
||||
"""Summary statistics for the dedup run."""
|
||||
return {
|
||||
"entity_count": len(self.entity_nodes),
|
||||
"merge_count": len(self.merge_records),
|
||||
"disamb_count": len(self.disamb_records),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migrated helpers (from ExtractionOrchestrator) — Requirement 10.4
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def save_dedup_details(
|
||||
dedup_details: Dict[str, Any],
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode],
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, str]]:
|
||||
"""Parse raw *dedup_details* into structured merge / disamb records.
|
||||
|
||||
Returns:
|
||||
(merge_records, disamb_records, id_redirect_map)
|
||||
"""
|
||||
merge_records: List[Dict[str, Any]] = []
|
||||
disamb_records: List[Dict[str, Any]] = []
|
||||
id_redirect_map: Dict[str, str] = {}
|
||||
|
||||
try:
|
||||
id_redirect_map = dedup_details.get("id_redirect", {})
|
||||
|
||||
# --- exact-match merges ---
|
||||
exact_merge_map = dedup_details.get("exact_merge_map", {})
|
||||
for _key, info in exact_merge_map.items():
|
||||
merged_ids = info.get("merged_ids", set())
|
||||
if merged_ids:
|
||||
merge_records.append({
|
||||
"type": "精确匹配",
|
||||
"canonical_id": info.get("canonical_id"),
|
||||
"entity_name": info.get("name"),
|
||||
"entity_type": info.get("entity_type"),
|
||||
"merged_count": len(merged_ids),
|
||||
"merged_ids": list(merged_ids),
|
||||
})
|
||||
|
||||
# --- fuzzy-match merges ---
|
||||
for record in dedup_details.get("fuzzy_merge_records", []):
|
||||
try:
|
||||
match = re.search(
|
||||
r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)",
|
||||
record,
|
||||
)
|
||||
if match:
|
||||
merge_records.append({
|
||||
"type": "模糊匹配",
|
||||
"canonical_id": match.group(1),
|
||||
"entity_name": match.group(3),
|
||||
"entity_type": match.group(4),
|
||||
"merged_count": 1,
|
||||
"merged_ids": [match.group(5)],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("解析模糊匹配记录失败: %s, 错误: %s", record, e)
|
||||
|
||||
# --- LLM-based merges ---
|
||||
for record in dedup_details.get("llm_decision_records", []):
|
||||
if "[LLM去重]" in str(record):
|
||||
try:
|
||||
match = re.search(
|
||||
r"同名类型相似 ([^(]+)(([^)]+))\|([^(]+)(([^)]+))",
|
||||
record,
|
||||
)
|
||||
if match:
|
||||
merge_records.append({
|
||||
"type": "LLM去重",
|
||||
"entity_name": match.group(1),
|
||||
"entity_type": f"{match.group(2)}|{match.group(4)}",
|
||||
"merged_count": 1,
|
||||
"merged_ids": [],
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("解析LLM去重记录失败: %s, 错误: %s", record, e)
|
||||
|
||||
# --- disambiguation records ---
|
||||
for record in dedup_details.get("disamb_records", []):
|
||||
if "[DISAMB阻断]" in str(record):
|
||||
try:
|
||||
content = str(record).replace("[DISAMB阻断]", "").strip()
|
||||
match = re.search(
|
||||
r"([^(]+)(([^)]+))\|([^(]+)(([^)]+))", content
|
||||
)
|
||||
if match:
|
||||
entity1_name = match.group(1).strip()
|
||||
entity1_type = match.group(2)
|
||||
entity2_type = match.group(4)
|
||||
|
||||
conf_match = re.search(r"conf=([0-9.]+)", str(record))
|
||||
confidence = conf_match.group(1) if conf_match else "unknown"
|
||||
|
||||
reason_match = re.search(r"reason=([^|]+)", str(record))
|
||||
reason = reason_match.group(1).strip() if reason_match else ""
|
||||
|
||||
disamb_records.append({
|
||||
"entity_name": entity1_name,
|
||||
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
|
||||
"confidence": confidence,
|
||||
"reason": (reason[:100] + "...") if len(reason) > 100 else reason,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("解析消歧记录失败: %s, 错误: %s", record, e)
|
||||
|
||||
logger.info(
|
||||
"保存去重消歧记录:%d 个合并记录,%d 个消歧记录",
|
||||
len(merge_records),
|
||||
len(disamb_records),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("保存去重消歧详情失败: %s", e, exc_info=True)
|
||||
|
||||
return merge_records, disamb_records, id_redirect_map
|
||||
|
||||
|
||||
def analyze_entity_merges(
|
||||
merge_records: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return merge info sorted by merged_count (descending)."""
|
||||
if not merge_records:
|
||||
return []
|
||||
sorted_records = sorted(
|
||||
merge_records, key=lambda x: x.get("merged_count", 0), reverse=True
|
||||
)
|
||||
return [
|
||||
{
|
||||
"main_entity_name": r.get("entity_name", "未知实体"),
|
||||
"merged_count": r.get("merged_count", 1),
|
||||
}
|
||||
for r in sorted_records
|
||||
]
|
||||
|
||||
|
||||
def analyze_entity_disambiguation(
|
||||
disamb_records: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return disambiguation records (pass-through)."""
|
||||
return disamb_records if disamb_records else []
|
||||
|
||||
|
||||
def parse_dedup_report(
|
||||
merge_records: List[Dict[str, Any]],
|
||||
disamb_records: List[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a summary report dict from parsed records."""
|
||||
try:
|
||||
dedup_examples: List[Dict[str, Any]] = []
|
||||
disamb_examples: List[Dict[str, Any]] = []
|
||||
total_merges = 0
|
||||
total_disambiguations = 0
|
||||
|
||||
for record in merge_records:
|
||||
merge_count = record.get("merged_count", 0)
|
||||
total_merges += merge_count
|
||||
dedup_examples.append({
|
||||
"type": record.get("type", "未知"),
|
||||
"entity_name": record.get("entity_name", "未知实体"),
|
||||
"entity_type": record.get("entity_type", "未知类型"),
|
||||
"merge_count": merge_count,
|
||||
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个",
|
||||
})
|
||||
|
||||
for record in disamb_records:
|
||||
total_disambiguations += 1
|
||||
disamb_type = record.get("disamb_type", "")
|
||||
entity_name = record.get("entity_name", "未知实体")
|
||||
disamb_examples.append({
|
||||
"entity1_name": entity_name,
|
||||
"entity1_type": (
|
||||
disamb_type.split("vs")[0].replace("消歧阻断:", "").strip()
|
||||
if "vs" in disamb_type
|
||||
else "未知"
|
||||
),
|
||||
"entity2_name": entity_name,
|
||||
"entity2_type": (
|
||||
disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知"
|
||||
),
|
||||
"description": f"{entity_name},消歧区分成功",
|
||||
})
|
||||
|
||||
return {
|
||||
"dedup_examples": dedup_examples[:5],
|
||||
"disamb_examples": disamb_examples[:5],
|
||||
"total_merges": total_merges,
|
||||
"total_disambiguations": total_disambiguations,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("获取去重报告失败: %s", e, exc_info=True)
|
||||
return {
|
||||
"dedup_examples": [],
|
||||
"disamb_examples": [],
|
||||
"total_merges": 0,
|
||||
"total_disambiguations": 0,
|
||||
}
|
||||
|
||||
|
||||
async def send_dedup_progress_callback(
|
||||
progress_callback: Callable,
|
||||
merge_records: List[Dict[str, Any]],
|
||||
disamb_records: List[Dict[str, Any]],
|
||||
original_entities: int,
|
||||
final_entities: int,
|
||||
original_stmt_edges: int,
|
||||
final_stmt_edges: int,
|
||||
original_ent_edges: int,
|
||||
final_ent_edges: int,
|
||||
) -> None:
|
||||
"""Send dedup completion progress via *progress_callback*."""
|
||||
try:
|
||||
dedup_details = parse_dedup_report(merge_records, disamb_records)
|
||||
|
||||
entities_reduced = original_entities - final_entities
|
||||
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
|
||||
ent_edges_reduced = original_ent_edges - final_ent_edges
|
||||
|
||||
dedup_stats = {
|
||||
"entities": {
|
||||
"original_count": original_entities,
|
||||
"final_count": final_entities,
|
||||
"reduced_count": entities_reduced,
|
||||
"reduction_rate": (
|
||||
round(entities_reduced / original_entities * 100, 1)
|
||||
if original_entities > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
"statement_entity_edges": {
|
||||
"original_count": original_stmt_edges,
|
||||
"final_count": final_stmt_edges,
|
||||
"reduced_count": stmt_edges_reduced,
|
||||
},
|
||||
"entity_entity_edges": {
|
||||
"original_count": original_ent_edges,
|
||||
"final_count": final_ent_edges,
|
||||
"reduced_count": ent_edges_reduced,
|
||||
},
|
||||
"dedup_examples": dedup_details.get("dedup_examples", []),
|
||||
"disamb_examples": dedup_details.get("disamb_examples", []),
|
||||
"summary": {
|
||||
"total_merges": dedup_details.get("total_merges", 0),
|
||||
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
|
||||
},
|
||||
}
|
||||
|
||||
await progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
|
||||
except Exception as e:
|
||||
logger.error("发送去重消歧进度回调失败: %s", e, exc_info=True)
|
||||
try:
|
||||
basic_stats = {
|
||||
"entities": {
|
||||
"original_count": original_entities,
|
||||
"final_count": final_entities,
|
||||
"reduced_count": original_entities - final_entities,
|
||||
},
|
||||
"summary": f"实体去重合并{original_entities - final_entities}个",
|
||||
}
|
||||
await progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
|
||||
except Exception as e2:
|
||||
logger.error("发送基本去重统计失败: %s", e2, exc_info=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_dedup — main entry point (Requirements 10.1, 10.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_dedup(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[Callable] = None,
|
||||
) -> DedupResult:
|
||||
"""Two-stage entity deduplication and disambiguation.
|
||||
|
||||
Full mode:
|
||||
Layer 1 — exact / fuzzy / LLM matching
|
||||
Layer 2 — Neo4j joint dedup + cross-role alias cleaning
|
||||
|
||||
Pilot-run mode:
|
||||
Layer 1 only (skip Neo4j layer 2 and alias cleaning).
|
||||
|
||||
Args:
|
||||
entity_nodes: Pre-dedup entity nodes.
|
||||
statement_entity_edges: Pre-dedup statement-entity edges.
|
||||
entity_entity_edges: Pre-dedup entity-entity edges.
|
||||
dialog_data_list: Source dialogue data (used to detect end_user_id).
|
||||
pipeline_config: Pipeline configuration (contains DedupConfig).
|
||||
connector: Optional Neo4j connector for layer-2 dedup.
|
||||
llm_client: Optional LLM client for LLM-based dedup decisions.
|
||||
is_pilot_run: When True, only execute layer-1 dedup.
|
||||
progress_callback: Optional async callable for progress reporting.
|
||||
|
||||
Returns:
|
||||
A ``DedupResult`` with deduplicated nodes, edges, and statistics.
|
||||
"""
|
||||
logger.info("开始两阶段实体去重和消歧")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("deduplication", "正在去重消歧...")
|
||||
|
||||
logger.info(
|
||||
"去重前: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边",
|
||||
len(entity_nodes),
|
||||
len(statement_entity_edges),
|
||||
len(entity_entity_edges),
|
||||
)
|
||||
|
||||
original_entity_count = len(entity_nodes)
|
||||
original_stmt_edge_count = len(statement_entity_edges)
|
||||
original_ent_edge_count = len(entity_entity_edges)
|
||||
|
||||
try:
|
||||
if is_pilot_run:
|
||||
# --- pilot run: layer 1 only ---
|
||||
logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重")
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
)
|
||||
|
||||
(
|
||||
dedup_entity_nodes,
|
||||
dedup_stmt_edges,
|
||||
dedup_ent_edges,
|
||||
raw_details,
|
||||
) = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
report_stage="第一层去重消歧(试运行)",
|
||||
report_append=False,
|
||||
dedup_config=pipeline_config.deduplication,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
final_entities = dedup_entity_nodes
|
||||
final_stmt_edges = dedup_stmt_edges
|
||||
final_ent_edges = dedup_ent_edges
|
||||
else:
|
||||
# --- full mode: two-stage dedup ---
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
)
|
||||
|
||||
(
|
||||
_dialogue_nodes,
|
||||
_chunk_nodes,
|
||||
_statement_nodes,
|
||||
final_entities,
|
||||
_statement_chunk_edges,
|
||||
final_stmt_edges,
|
||||
final_ent_edges,
|
||||
raw_details,
|
||||
) = await dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes=[],
|
||||
chunk_nodes=[],
|
||||
statement_nodes=[],
|
||||
entity_nodes=entity_nodes,
|
||||
statement_chunk_edges=[],
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
entity_entity_edges=entity_entity_edges,
|
||||
dialog_data_list=dialog_data_list,
|
||||
pipeline_config=pipeline_config,
|
||||
connector=connector,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
# Parse raw details into structured records
|
||||
merge_records, disamb_records, _id_redirect = save_dedup_details(
|
||||
raw_details, entity_nodes, final_entities
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"去重后: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边",
|
||||
len(final_entities),
|
||||
len(final_stmt_edges),
|
||||
len(final_ent_edges),
|
||||
)
|
||||
logger.info(
|
||||
"去重效果: 实体减少 %d, 陈述句-实体边减少 %d, 实体-实体边减少 %d",
|
||||
original_entity_count - len(final_entities),
|
||||
original_stmt_edge_count - len(final_stmt_edges),
|
||||
original_ent_edge_count - len(final_ent_edges),
|
||||
)
|
||||
|
||||
# --- progress callbacks ---
|
||||
if progress_callback:
|
||||
merge_info = analyze_entity_merges(merge_records)
|
||||
for i, detail in enumerate(merge_info[:5]):
|
||||
dedup_result = {
|
||||
"result_type": "entity_merge",
|
||||
"merged_entity_name": detail["main_entity_name"],
|
||||
"merged_count": detail["merged_count"],
|
||||
"merge_progress": f"{i + 1}/{min(len(merge_info), 5)}",
|
||||
"message": (
|
||||
f"{detail['main_entity_name']}合并{detail['merged_count']}个:相似实体已合并"
|
||||
),
|
||||
}
|
||||
await progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result)
|
||||
|
||||
disamb_info = analyze_entity_disambiguation(disamb_records)
|
||||
for i, detail in enumerate(disamb_info[:5]):
|
||||
disamb_result = {
|
||||
"result_type": "entity_disambiguation",
|
||||
"disambiguated_entity_name": detail["entity_name"],
|
||||
"disambiguation_type": detail["disamb_type"],
|
||||
"confidence": detail.get("confidence", "unknown"),
|
||||
"reason": detail.get("reason", ""),
|
||||
"disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}",
|
||||
"message": f"{detail['entity_name']}消歧完成:{detail['disamb_type']}",
|
||||
}
|
||||
await progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result)
|
||||
|
||||
await send_dedup_progress_callback(
|
||||
progress_callback,
|
||||
merge_records,
|
||||
disamb_records,
|
||||
original_entity_count,
|
||||
len(final_entities),
|
||||
original_stmt_edge_count,
|
||||
len(final_stmt_edges),
|
||||
original_ent_edge_count,
|
||||
len(final_ent_edges),
|
||||
)
|
||||
|
||||
return DedupResult(
|
||||
entity_nodes=final_entities,
|
||||
statement_entity_edges=final_stmt_edges,
|
||||
entity_entity_edges=final_ent_edges,
|
||||
dedup_details=raw_details,
|
||||
merge_records=merge_records,
|
||||
disamb_records=disamb_records,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("两阶段去重失败: %s", e, exc_info=True)
|
||||
raise
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Extraction pipeline steps — unified ExtractionStep paradigm.
|
||||
|
||||
Importing this package triggers @register decorator self-registration
|
||||
for all sidecar (non-critical) steps via SidecarStepFactory.
|
||||
"""
|
||||
|
||||
from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401
|
||||
|
||||
# Step implementations — importing triggers @register self-registration.
|
||||
from .statement_step import StatementExtractionStep # noqa: F401
|
||||
from .triplet_step import TripletExtractionStep # noqa: F401
|
||||
from .emotion_step import EmotionExtractionStep # noqa: F401
|
||||
from .embedding_step import EmbeddingStep # noqa: F401
|
||||
|
||||
# Refactored orchestrator
|
||||
from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401
|
||||
@@ -0,0 +1,182 @@
|
||||
"""ExtractionStep abstract base class and StepContext.
|
||||
|
||||
Provides the unified paradigm for all LLM extraction stages:
|
||||
render_prompt → call_llm → parse_response → post_process
|
||||
|
||||
Critical steps retry on failure with exponential backoff.
|
||||
Sidecar (non-critical) steps return a default output on failure without retry.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
InputT = TypeVar("InputT")
|
||||
OutputT = TypeVar("OutputT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepContext:
|
||||
"""Shared context injected into every ExtractionStep by the orchestrator.
|
||||
|
||||
Attributes:
|
||||
llm_client: LLM client instance for generating completions.
|
||||
language: Target language code (e.g. "en", "zh").
|
||||
config: Pipeline configuration object (ExtractionPipelineConfig).
|
||||
is_pilot_run: When True, run in lightweight preview mode.
|
||||
progress_callback: Optional callable for reporting progress.
|
||||
"""
|
||||
|
||||
llm_client: Any
|
||||
language: str
|
||||
config: Any
|
||||
is_pilot_run: bool = False
|
||||
progress_callback: Optional[Any] = None
|
||||
|
||||
|
||||
class ExtractionStep(ABC, Generic[InputT, OutputT]):
|
||||
"""Abstract base class for all LLM extraction stages.
|
||||
|
||||
Lifecycle:
|
||||
1. ``__init__(context)`` — receive shared context, bind config params
|
||||
2. ``should_skip()`` — check whether to skip (config-driven / pilot mode)
|
||||
3. ``run(input_data)`` — execute full flow (with retry for critical steps)
|
||||
Internally: render_prompt → call_llm → parse_response → post_process
|
||||
4. ``on_failure(error)`` — critical steps raise; sidecar steps return default
|
||||
|
||||
Type Parameters:
|
||||
InputT: The Pydantic model type accepted by this step.
|
||||
OutputT: The Pydantic model type produced by this step.
|
||||
"""
|
||||
|
||||
def __init__(self, context: StepContext) -> None:
|
||||
self.context = context
|
||||
self.llm_client = context.llm_client
|
||||
self.language = context.language
|
||||
self.config = context.config
|
||||
|
||||
# ── Subclasses must implement ──
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Human-readable step name for logging."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def render_prompt(self, input_data: InputT) -> Any:
|
||||
"""Build the prompt from *input_data* and bound config."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_llm(self, prompt: Any) -> Any:
|
||||
"""Send *prompt* to the LLM and return the raw response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def parse_response(self, raw_response: Any, input_data: InputT) -> OutputT:
|
||||
"""Parse *raw_response* into a typed OutputT (Pydantic model)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_default_output(self) -> OutputT:
|
||||
"""Return a safe default when the step is skipped or fails gracefully."""
|
||||
...
|
||||
|
||||
# ── Overridable properties ──
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
"""``True`` = critical step (failure aborts pipeline).
|
||||
|
||||
``False`` = sidecar step (failure degrades gracefully).
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def max_retries(self) -> int:
|
||||
"""Maximum retry attempts (only effective for critical steps)."""
|
||||
return 2
|
||||
|
||||
@property
|
||||
def retry_backoff_base(self) -> float:
|
||||
"""Backoff base in seconds. Actual wait = base × 2^attempt."""
|
||||
return 1.0
|
||||
|
||||
# ── Overridable hooks ──
|
||||
|
||||
def should_skip(self) -> bool:
|
||||
"""Config-driven skip check. Subclasses may override."""
|
||||
return False
|
||||
|
||||
async def post_process(self, parsed_data: OutputT, input_data: InputT) -> OutputT:
|
||||
"""Post-processing hook. Default is identity (returns *parsed_data* unchanged)."""
|
||||
return parsed_data
|
||||
|
||||
# ── Core execution logic ──
|
||||
|
||||
async def run(self, input_data: InputT) -> OutputT:
|
||||
"""Execute the full step lifecycle with retry logic.
|
||||
|
||||
For critical steps (``is_critical=True``):
|
||||
Attempt up to ``max_retries + 1`` times with exponential backoff.
|
||||
If all attempts fail, delegate to ``on_failure`` which raises.
|
||||
|
||||
For sidecar steps (``is_critical=False``):
|
||||
Attempt exactly once. On failure, delegate to ``on_failure``
|
||||
which returns ``get_default_output()``.
|
||||
"""
|
||||
if self.should_skip():
|
||||
logger.info("Step '%s' skipped", self.name)
|
||||
return self.get_default_output()
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
attempts = self.max_retries + 1 if self.is_critical else 1
|
||||
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
prompt = await self.render_prompt(input_data)
|
||||
raw_response = await self.call_llm(prompt)
|
||||
parsed = await self.parse_response(raw_response, input_data)
|
||||
result = await self.post_process(parsed, input_data)
|
||||
return result
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
logger.warning(
|
||||
"Step '%s' attempt %d/%d failed: %s",
|
||||
self.name,
|
||||
attempt + 1,
|
||||
attempts,
|
||||
exc,
|
||||
)
|
||||
if attempt < attempts - 1:
|
||||
wait = self.retry_backoff_base * (2 ** attempt)
|
||||
logger.info(
|
||||
"Step '%s' retrying in %.1fs …", self.name, wait
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
# All attempts exhausted — delegate to failure handler
|
||||
return self.on_failure(last_error) # type: ignore[arg-type]
|
||||
|
||||
def on_failure(self, error: Exception) -> OutputT:
|
||||
"""Handle step failure.
|
||||
|
||||
Critical steps: re-raise the exception to abort the pipeline.
|
||||
Sidecar steps: return ``get_default_output()`` for graceful degradation.
|
||||
"""
|
||||
if self.is_critical:
|
||||
logger.error(
|
||||
"Critical step '%s' failed after retries: %s", self.name, error
|
||||
)
|
||||
raise error
|
||||
logger.warning(
|
||||
"Sidecar step '%s' failed, returning default output: %s",
|
||||
self.name,
|
||||
error,
|
||||
)
|
||||
return self.get_default_output()
|
||||
@@ -0,0 +1,124 @@
|
||||
"""EmbeddingStep — generates vector embeddings for statements, chunks, dialogs, and entities.
|
||||
|
||||
Unlike the LLM-based ExtractionSteps, EmbeddingStep calls an embedder client
|
||||
rather than an LLM. It still follows the ``should_skip`` / ``run`` /
|
||||
``get_default_output`` contract so the orchestrator can treat it uniformly.
|
||||
|
||||
Supports **partial** embedding runs — the caller can populate only the fields
|
||||
it needs (e.g. only ``statement_texts``) and leave the rest empty.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .schema import EmbeddingStepInput, EmbeddingStepOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingStep:
|
||||
"""Generate vector embeddings for text inputs.
|
||||
|
||||
This step does **not** inherit from ``ExtractionStep`` because it does not
|
||||
follow the render_prompt → call_llm → parse_response lifecycle. It does,
|
||||
however, expose the same ``run`` / ``should_skip`` / ``get_default_output``
|
||||
interface so the orchestrator can use it interchangeably.
|
||||
|
||||
Pilot-run mode skips execution entirely and returns empty dicts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder_client: Any,
|
||||
is_pilot_run: bool = False,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
self.embedder_client = embedder_client
|
||||
self.is_pilot_run = is_pilot_run
|
||||
self.batch_size = batch_size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "embedding_generation"
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_retries(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def retry_backoff_base(self) -> float:
|
||||
return 1.0
|
||||
|
||||
def should_skip(self) -> bool:
|
||||
return self.is_pilot_run
|
||||
|
||||
def get_default_output(self) -> EmbeddingStepOutput:
|
||||
return EmbeddingStepOutput()
|
||||
|
||||
# ── Core execution ──
|
||||
|
||||
async def run(self, input_data: EmbeddingStepInput) -> EmbeddingStepOutput:
|
||||
"""Generate embeddings for all non-empty text fields in *input_data*."""
|
||||
if self.should_skip():
|
||||
logger.info("EmbeddingStep skipped (pilot run)")
|
||||
return self.get_default_output()
|
||||
|
||||
try:
|
||||
stmt_emb, chunk_emb, dialog_emb, entity_emb = await asyncio.gather(
|
||||
self._embed_dict(input_data.statement_texts),
|
||||
self._embed_dict(input_data.chunk_texts),
|
||||
self._embed_list(input_data.dialog_texts),
|
||||
self._embed_dict(input_data.entity_names),
|
||||
)
|
||||
return EmbeddingStepOutput(
|
||||
statement_embeddings=stmt_emb,
|
||||
chunk_embeddings=chunk_emb,
|
||||
dialog_embeddings=dialog_emb,
|
||||
entity_embeddings=entity_emb,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("EmbeddingStep failed, returning empty output: %s", exc)
|
||||
return self.get_default_output()
|
||||
|
||||
# ── Internal helpers ──
|
||||
|
||||
async def _embed_dict(
|
||||
self, texts: Dict[str, str]
|
||||
) -> Dict[str, List[float]]:
|
||||
"""Embed a dict of ``{id: text}`` and return ``{id: embedding}``."""
|
||||
if not texts:
|
||||
return {}
|
||||
|
||||
ids = list(texts.keys())
|
||||
text_list = list(texts.values())
|
||||
embeddings = await self._batch_embed(text_list)
|
||||
|
||||
return dict(zip(ids, embeddings))
|
||||
|
||||
async def _embed_list(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a plain list of texts."""
|
||||
if not texts:
|
||||
return []
|
||||
return await self._batch_embed(texts)
|
||||
|
||||
async def _batch_embed(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call the embedder in batches of ``self.batch_size``."""
|
||||
if len(texts) <= self.batch_size:
|
||||
return await self.embedder_client.response(texts)
|
||||
|
||||
batches = [
|
||||
texts[i : i + self.batch_size]
|
||||
for i in range(0, len(texts), self.batch_size)
|
||||
]
|
||||
batch_results = await asyncio.gather(
|
||||
*(self.embedder_client.response(b) for b in batches)
|
||||
)
|
||||
embeddings: List[List[float]] = []
|
||||
for result in batch_results:
|
||||
embeddings.extend(result)
|
||||
return embeddings
|
||||
@@ -0,0 +1,80 @@
|
||||
"""EmotionExtractionStep — sidecar step for extracting emotion from statements.
|
||||
|
||||
Replaces the legacy ``EmotionExtractionService`` with the unified ExtractionStep
|
||||
paradigm. Registered via ``@SidecarStepFactory.register`` so the orchestrator
|
||||
picks it up automatically when ``emotion_enabled`` is ``True``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
|
||||
|
||||
from .base import ExtractionStep, StepContext
|
||||
from .sidecar_factory import SidecarStepFactory, SidecarTiming
|
||||
from .schema import EmotionStepInput, EmotionStepOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@SidecarStepFactory.register("emotion_enabled", SidecarTiming.AFTER_STATEMENT)
|
||||
class EmotionExtractionStep(ExtractionStep[EmotionStepInput, EmotionStepOutput]):
|
||||
"""Extract emotion type, intensity, and keywords from a statement.
|
||||
|
||||
This is a **sidecar** (non-critical) step — failure returns a neutral
|
||||
default without aborting the pipeline.
|
||||
|
||||
The step self-registers with ``SidecarStepFactory`` under the config key
|
||||
``emotion_enabled`` and timing ``AFTER_STATEMENT``.
|
||||
"""
|
||||
|
||||
def __init__(self, context: StepContext) -> None:
|
||||
super().__init__(context)
|
||||
# Emotion-specific config flags (may live on a MemoryConfig object
|
||||
# attached to context.config or as top-level attributes).
|
||||
self.extract_keywords = getattr(self.config, "emotion_extract_keywords", True)
|
||||
self.enable_subject = getattr(self.config, "emotion_enable_subject", False)
|
||||
|
||||
# ── Identity ──
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "emotion_extraction"
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
return False
|
||||
|
||||
# ── Config-driven skip ──
|
||||
|
||||
def should_skip(self) -> bool:
|
||||
return not getattr(self.config, "emotion_enabled", False)
|
||||
|
||||
# ── Lifecycle ──
|
||||
|
||||
async def render_prompt(self, input_data: EmotionStepInput) -> str:
|
||||
return await render_emotion_extraction_prompt(
|
||||
statement=input_data.statement_text,
|
||||
extract_keywords=self.extract_keywords,
|
||||
enable_subject=self.enable_subject,
|
||||
language=self.language,
|
||||
)
|
||||
|
||||
async def call_llm(self, prompt: Any) -> Any:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return await self.llm_client.response_structured(
|
||||
messages, EmotionExtraction
|
||||
)
|
||||
|
||||
async def parse_response(
|
||||
self, raw_response: Any, input_data: EmotionStepInput
|
||||
) -> EmotionStepOutput:
|
||||
return EmotionStepOutput(
|
||||
emotion_type=getattr(raw_response, "emotion_type", "neutral"),
|
||||
emotion_intensity=getattr(raw_response, "emotion_intensity", 0.0),
|
||||
emotion_keywords=getattr(raw_response, "emotion_keywords", []),
|
||||
)
|
||||
|
||||
def get_default_output(self) -> EmotionStepOutput:
|
||||
return EmotionStepOutput()
|
||||
@@ -0,0 +1,906 @@
|
||||
"""Refactored ExtractionOrchestrator using the unified ExtractionStep paradigm.
|
||||
|
||||
This module provides ``NewExtractionOrchestrator`` — a slimmed-down orchestrator
|
||||
(~500 lines vs ~2500) that delegates extraction work to concrete ExtractionStep
|
||||
instances and uses SidecarStepFactory for hot-pluggable sidecar modules.
|
||||
|
||||
The new orchestrator coexists with the legacy ``ExtractionOrchestrator`` until
|
||||
the team explicitly switches over.
|
||||
|
||||
Execution phases:
|
||||
1. Statement extraction + concurrent chunk/dialog embedding
|
||||
2. Triplet extraction + concurrent after_statement sidecars + statement embedding
|
||||
3. Entity embedding + concurrent after_triplet sidecars
|
||||
4. Data assignment back to dialog_data_list
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
|
||||
from .base import ExtractionStep, StepContext
|
||||
from .embedding_step import EmbeddingStep
|
||||
from .sidecar_factory import SidecarStepFactory, SidecarTiming
|
||||
from .statement_step import StatementExtractionStep
|
||||
from .triplet_step import TripletExtractionStep
|
||||
from .schema import (
|
||||
EmbeddingStepInput,
|
||||
EmbeddingStepOutput,
|
||||
EmotionStepInput,
|
||||
EmotionStepOutput,
|
||||
MessageItem,
|
||||
StatementStepInput,
|
||||
StatementStepOutput,
|
||||
SupportingContext,
|
||||
TripletStepInput,
|
||||
TripletStepOutput,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NewExtractionOrchestrator:
|
||||
"""Slimmed-down extraction orchestrator using the ExtractionStep paradigm.
|
||||
|
||||
Responsibilities:
|
||||
* Initialise all steps and sidecar groups via ``SidecarStepFactory``
|
||||
* Route data between stages (``_convert_to_*`` helpers)
|
||||
* Orchestrate concurrent execution (``_run_with_sidecars``)
|
||||
* Assign extracted results back to ``DialogData`` objects
|
||||
|
||||
The orchestrator does **not** own dedup, node/edge creation, or Neo4j writes.
|
||||
Those remain in ``WritePipeline`` / ``dedup_step``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Any,
|
||||
embedder_client: Any,
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
embedding_id: Optional[str] = None,
|
||||
ontology_types: Any = None,
|
||||
language: str = "zh",
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[
|
||||
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||
] = None,
|
||||
) -> None:
|
||||
self.config = config or ExtractionPipelineConfig()
|
||||
self.is_pilot_run = is_pilot_run
|
||||
self.embedding_id = embedding_id
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
# Build shared context for all LLM-based steps
|
||||
self.context = StepContext(
|
||||
llm_client=llm_client,
|
||||
language=language,
|
||||
config=self.config,
|
||||
is_pilot_run=is_pilot_run,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# ── Critical (main-line) steps ──
|
||||
self.statement_step = StatementExtractionStep(self.context)
|
||||
self.triplet_step = TripletExtractionStep(
|
||||
self.context, ontology_types=ontology_types
|
||||
)
|
||||
|
||||
# ── Embedding step (non-LLM, separate client) ──
|
||||
self.embedding_step = EmbeddingStep(
|
||||
embedder_client=embedder_client,
|
||||
is_pilot_run=is_pilot_run,
|
||||
)
|
||||
|
||||
# ── Sidecar steps (auto-discovered via @register decorator) ──
|
||||
sidecar_groups = SidecarStepFactory.create_sidecars(self.config, self.context)
|
||||
self.after_statement_sidecars: List[ExtractionStep] = sidecar_groups[
|
||||
SidecarTiming.AFTER_STATEMENT
|
||||
]
|
||||
self.after_triplet_sidecars: List[ExtractionStep] = sidecar_groups[
|
||||
SidecarTiming.AFTER_TRIPLET
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"NewExtractionOrchestrator initialised — "
|
||||
"after_statement sidecars: %d, after_triplet sidecars: %d",
|
||||
len(self.after_statement_sidecars),
|
||||
len(self.after_triplet_sidecars),
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# 1. 并发执行引擎
|
||||
# 负责主线路 + 旁路的安全并发调度
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
async def _run_sidecar_safe(
|
||||
step: ExtractionStep, input_data: Any
|
||||
) -> Any:
|
||||
"""Run a sidecar step, returning its default output on failure."""
|
||||
try:
|
||||
return await step.run(input_data)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Sidecar '%s' raised during gather — using default output: %s",
|
||||
step.name,
|
||||
exc,
|
||||
)
|
||||
return step.get_default_output()
|
||||
|
||||
async def _run_with_sidecars(
|
||||
self,
|
||||
critical_coro: Any,
|
||||
sidecars: List[Tuple[ExtractionStep, Any]],
|
||||
extra_coros: Optional[List[Any]] = None,
|
||||
) -> Tuple[Any, List[Any], List[Any]]:
|
||||
"""Run a critical coroutine concurrently with sidecar steps.
|
||||
|
||||
Args:
|
||||
critical_coro: The awaitable for the critical (main-line) step.
|
||||
sidecars: List of ``(step, input_data)`` pairs for sidecar steps.
|
||||
extra_coros: Additional non-sidecar coroutines to run concurrently
|
||||
(e.g. embedding generation).
|
||||
|
||||
Returns:
|
||||
A 3-tuple of:
|
||||
* The critical step result (exception propagated if it fails).
|
||||
* A list of sidecar results (default outputs on failure).
|
||||
* A list of extra coroutine results (empty list if none).
|
||||
|
||||
Raises:
|
||||
Exception: If the critical coroutine fails, the exception propagates.
|
||||
"""
|
||||
sidecar_coros = [
|
||||
self._run_sidecar_safe(step, inp) for step, inp in sidecars
|
||||
]
|
||||
extra = extra_coros or []
|
||||
|
||||
# Gather everything concurrently
|
||||
all_coros = [critical_coro] + sidecar_coros + extra
|
||||
results = await asyncio.gather(*all_coros, return_exceptions=True)
|
||||
|
||||
# Unpack: first result is critical, then sidecars, then extras
|
||||
critical_result = results[0]
|
||||
n_sidecars = len(sidecar_coros)
|
||||
sidecar_results = list(results[1 : 1 + n_sidecars])
|
||||
extra_results = list(results[1 + n_sidecars :])
|
||||
|
||||
# Critical step failure → propagate
|
||||
if isinstance(critical_result, BaseException):
|
||||
raise critical_result
|
||||
|
||||
# Sidecar failures should already be handled by _run_sidecar_safe,
|
||||
# but guard against unexpected exceptions from gather
|
||||
for i, res in enumerate(sidecar_results):
|
||||
if isinstance(res, BaseException):
|
||||
step = sidecars[i][0]
|
||||
logger.warning(
|
||||
"Sidecar '%s' unexpected exception in gather: %s",
|
||||
step.name,
|
||||
res,
|
||||
)
|
||||
sidecar_results[i] = step.get_default_output()
|
||||
|
||||
# Extra coroutine failures → log and replace with None
|
||||
for i, res in enumerate(extra_results):
|
||||
if isinstance(res, BaseException):
|
||||
logger.warning("Extra coroutine %d failed: %s", i, res)
|
||||
extra_results[i] = None
|
||||
|
||||
return critical_result, sidecar_results, extra_results
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# 2. 阶段间数据转换
|
||||
# 将上一阶段的 StepOutput 转换为下一阶段的 StepInput
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _build_supporting_context(
|
||||
dialog: DialogData,
|
||||
) -> SupportingContext:
|
||||
"""Build a SupportingContext from a dialog's content for pronoun resolution."""
|
||||
msgs: List[MessageItem] = []
|
||||
if hasattr(dialog, "content") and dialog.content:
|
||||
# dialog.content is the raw conversation string; wrap as single msg
|
||||
msgs.append(MessageItem(role="context", msg=dialog.content))
|
||||
return SupportingContext(msgs=msgs)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_triplet_input(
|
||||
stmt_out: StatementStepOutput,
|
||||
supporting_context: SupportingContext,
|
||||
) -> TripletStepInput:
|
||||
"""Convert a StatementStepOutput into a TripletStepInput."""
|
||||
return TripletStepInput(
|
||||
statement_id=stmt_out.statement_id,
|
||||
statement_text=stmt_out.statement_text,
|
||||
statement_type=stmt_out.statement_type,
|
||||
temporal_type=stmt_out.temporal_type,
|
||||
supporting_context=supporting_context,
|
||||
speaker=stmt_out.speaker,
|
||||
valid_at=stmt_out.valid_at,
|
||||
invalid_at=stmt_out.invalid_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_emotion_input(
|
||||
stmt_out: StatementStepOutput,
|
||||
) -> EmotionStepInput:
|
||||
"""Convert a StatementStepOutput into an EmotionStepInput."""
|
||||
return EmotionStepInput(
|
||||
statement_id=stmt_out.statement_id,
|
||||
statement_text=stmt_out.statement_text,
|
||||
speaker=stmt_out.speaker,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# 3. 流水线执行入口
|
||||
# 公开接口 run() → 分发到 pilot / full 模式
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
async def run(
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> List[DialogData]:
|
||||
"""Run the full extraction pipeline on *dialog_data_list*.
|
||||
|
||||
Returns the mutated *dialog_data_list* with extracted data assigned
|
||||
to each statement (triplets, temporal info, emotions, embeddings).
|
||||
|
||||
The orchestrator does NOT create graph nodes/edges or run dedup —
|
||||
those responsibilities remain in WritePipeline.
|
||||
"""
|
||||
mode = "pilot" if self.is_pilot_run else "full"
|
||||
logger.info(
|
||||
"Starting extraction pipeline (%s mode), %d dialogs",
|
||||
mode,
|
||||
len(dialog_data_list),
|
||||
)
|
||||
|
||||
if self.is_pilot_run:
|
||||
return await self._run_pilot(dialog_data_list)
|
||||
return await self._run_full(dialog_data_list)
|
||||
|
||||
# ── 3a. 试运行模式:仅 statement + triplet,不生成 embedding 和旁路 ──
|
||||
|
||||
async def _run_pilot(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[DialogData]:
|
||||
"""Pilot mode: statement + triplet extraction only, no sidecars or embeddings."""
|
||||
# Phase 1: Statement extraction (chunk-level parallel)
|
||||
logger.info("Pilot phase 1/2: Statement extraction")
|
||||
all_stmt_results = await self._extract_all_statements(dialog_data_list)
|
||||
|
||||
# Phase 2: Triplet extraction (statement-level parallel)
|
||||
logger.info("Pilot phase 2/2: Triplet extraction")
|
||||
all_triplet_results = await self._extract_all_triplets(
|
||||
dialog_data_list, all_stmt_results
|
||||
)
|
||||
|
||||
# Assign results back to dialog_data_list
|
||||
self._assign_results(
|
||||
dialog_data_list,
|
||||
all_stmt_results,
|
||||
all_triplet_results,
|
||||
emotion_results={},
|
||||
embedding_output=None,
|
||||
)
|
||||
|
||||
# Store raw step outputs for snapshot/debugging
|
||||
self._last_stage_outputs = {
|
||||
"statement_results": all_stmt_results,
|
||||
"triplet_results": all_triplet_results,
|
||||
"emotion_results": {},
|
||||
"embedding_output": None,
|
||||
}
|
||||
|
||||
logger.info("Pilot extraction complete")
|
||||
return dialog_data_list
|
||||
|
||||
# ── 3b. 正式模式:四阶段并发执行 ──
|
||||
|
||||
async def _run_full(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[DialogData]:
|
||||
"""Full mode: all four phases with concurrent sidecars and embeddings."""
|
||||
|
||||
# ── Phase 1: Statement extraction + chunk/dialog embedding ──
|
||||
logger.info("Phase 1/4: Statement extraction + chunk/dialog embedding")
|
||||
chunk_dialog_emb_input = self._build_chunk_dialog_embedding_input(
|
||||
dialog_data_list
|
||||
)
|
||||
|
||||
stmt_coro = self._extract_all_statements(dialog_data_list)
|
||||
emb_coro = self.embedding_step.run(chunk_dialog_emb_input)
|
||||
|
||||
phase1_results = await asyncio.gather(
|
||||
stmt_coro, emb_coro, return_exceptions=True
|
||||
)
|
||||
|
||||
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]] = (
|
||||
phase1_results[0]
|
||||
if not isinstance(phase1_results[0], BaseException)
|
||||
else {}
|
||||
)
|
||||
if isinstance(phase1_results[0], BaseException):
|
||||
raise phase1_results[0]
|
||||
|
||||
chunk_dialog_emb: Optional[EmbeddingStepOutput] = (
|
||||
phase1_results[1]
|
||||
if not isinstance(phase1_results[1], BaseException)
|
||||
else None
|
||||
)
|
||||
if isinstance(phase1_results[1], BaseException):
|
||||
logger.warning("Chunk/dialog embedding failed: %s", phase1_results[1])
|
||||
|
||||
# ── Phase 2: Triplet extraction + after_statement sidecars + statement embedding ──
|
||||
logger.info(
|
||||
"Phase 2/4: Triplet extraction + sidecars + statement embedding"
|
||||
)
|
||||
stmt_emb_input = self._build_statement_embedding_input(
|
||||
dialog_data_list, all_stmt_results
|
||||
)
|
||||
|
||||
# Build sidecar inputs for after_statement sidecars (e.g. emotion)
|
||||
sidecar_pairs = self._build_after_statement_sidecar_inputs(
|
||||
dialog_data_list, all_stmt_results
|
||||
)
|
||||
|
||||
triplet_coro = self._extract_all_triplets(
|
||||
dialog_data_list, all_stmt_results
|
||||
)
|
||||
stmt_emb_coro = self.embedding_step.run(stmt_emb_input)
|
||||
|
||||
triplet_results, sidecar_results, extra_results = (
|
||||
await self._run_with_sidecars(
|
||||
triplet_coro,
|
||||
sidecar_pairs,
|
||||
extra_coros=[stmt_emb_coro],
|
||||
)
|
||||
)
|
||||
all_triplet_results = triplet_results
|
||||
stmt_emb: Optional[EmbeddingStepOutput] = (
|
||||
extra_results[0] if extra_results else None
|
||||
)
|
||||
|
||||
# Collect sidecar outputs keyed by step name
|
||||
sidecar_steps = [step for step, _inp in sidecar_pairs]
|
||||
sidecar_output_map = self._collect_sidecar_outputs(
|
||||
sidecar_steps, sidecar_results
|
||||
)
|
||||
|
||||
# ── Phase 3: Entity embedding + after_triplet sidecars ──
|
||||
logger.info("Phase 3/4: Entity embedding + after_triplet sidecars")
|
||||
entity_emb_input = self._build_entity_embedding_input(all_triplet_results)
|
||||
|
||||
after_triplet_pairs: List[Tuple[ExtractionStep, Any]] = []
|
||||
# Future after_triplet sidecars would be wired here
|
||||
|
||||
entity_emb_coro = self.embedding_step.run(entity_emb_input)
|
||||
|
||||
if after_triplet_pairs:
|
||||
_, at_sidecar_results, at_extra = await self._run_with_sidecars(
|
||||
entity_emb_coro,
|
||||
after_triplet_pairs,
|
||||
)
|
||||
entity_emb = at_extra[0] if at_extra else None
|
||||
else:
|
||||
# No after_triplet sidecars — just run embedding
|
||||
entity_emb_result = await entity_emb_coro
|
||||
entity_emb = (
|
||||
entity_emb_result
|
||||
if not isinstance(entity_emb_result, BaseException)
|
||||
else None
|
||||
)
|
||||
|
||||
# Merge all embedding outputs
|
||||
merged_emb = self._merge_embeddings(chunk_dialog_emb, stmt_emb, entity_emb)
|
||||
|
||||
# ── Phase 4: Data assignment ──
|
||||
logger.info("Phase 4/4: Data assignment")
|
||||
emotion_results = sidecar_output_map.get("emotion_extraction", {})
|
||||
|
||||
self._assign_results(
|
||||
dialog_data_list,
|
||||
all_stmt_results,
|
||||
all_triplet_results,
|
||||
emotion_results=emotion_results,
|
||||
embedding_output=merged_emb,
|
||||
)
|
||||
|
||||
# Store raw step outputs for snapshot/debugging
|
||||
self._last_stage_outputs = {
|
||||
"statement_results": all_stmt_results,
|
||||
"triplet_results": all_triplet_results,
|
||||
"emotion_results": emotion_results,
|
||||
"embedding_output": merged_emb,
|
||||
}
|
||||
|
||||
logger.info("Full extraction pipeline complete")
|
||||
return dialog_data_list
|
||||
|
||||
@property
|
||||
def last_stage_outputs(self) -> Dict[str, Any]:
|
||||
"""Return the raw step outputs from the last run for snapshot/debugging."""
|
||||
return getattr(self, "_last_stage_outputs", {})
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# 4. 萃取执行器
|
||||
# chunk 级并行 statement 提取、statement 级并行 triplet 提取
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
async def _extract_all_statements(
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> Dict[str, Dict[str, List[StatementStepOutput]]]:
|
||||
"""Extract statements from all chunks across all dialogs (chunk-level parallel).
|
||||
|
||||
Returns:
|
||||
Nested dict: ``{dialog_id: {chunk_id: [StatementStepOutput, ...]}}``
|
||||
"""
|
||||
# Collect all (chunk, metadata) pairs
|
||||
tasks: List[Any] = []
|
||||
task_meta: List[Tuple[str, str, str, SupportingContext]] = []
|
||||
|
||||
for dialog in dialog_data_list:
|
||||
ctx = self._build_supporting_context(dialog)
|
||||
dialogue_content = (
|
||||
dialog.content
|
||||
if getattr(
|
||||
self.config, "statement_extraction", None
|
||||
)
|
||||
and getattr(
|
||||
self.config.statement_extraction,
|
||||
"include_dialogue_context",
|
||||
True,
|
||||
)
|
||||
else None
|
||||
)
|
||||
for chunk in dialog.chunks:
|
||||
inp = StatementStepInput(
|
||||
chunk_id=chunk.id,
|
||||
end_user_id=dialog.end_user_id,
|
||||
target_content=chunk.content,
|
||||
target_message_date=str(
|
||||
getattr(dialog, "created_at", "") or ""
|
||||
),
|
||||
supporting_context=ctx,
|
||||
)
|
||||
tasks.append(self.statement_step.run(inp))
|
||||
task_meta.append(
|
||||
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Organise into nested dict
|
||||
stmt_map: Dict[str, Dict[str, List[StatementStepOutput]]] = {}
|
||||
for i, result in enumerate(results):
|
||||
dialog_id, chunk_id, speaker, _ = task_meta[i]
|
||||
if dialog_id not in stmt_map:
|
||||
stmt_map[dialog_id] = {}
|
||||
|
||||
if isinstance(result, BaseException):
|
||||
logger.error("Statement extraction failed for chunk %s: %s", chunk_id, result)
|
||||
stmt_map[dialog_id][chunk_id] = []
|
||||
else:
|
||||
# Override speaker from chunk metadata
|
||||
stmts: List[StatementStepOutput] = result if isinstance(result, list) else []
|
||||
for s in stmts:
|
||||
s.speaker = speaker
|
||||
stmt_map[dialog_id][chunk_id] = stmts
|
||||
|
||||
return stmt_map
|
||||
|
||||
async def _extract_all_triplets(
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||
) -> Dict[str, Dict[str, TripletStepOutput]]:
|
||||
"""Extract triplets for every statement (statement-level parallel).
|
||||
|
||||
Returns:
|
||||
Nested dict: ``{dialog_id: {statement_id: TripletStepOutput}}``
|
||||
"""
|
||||
tasks: List[Any] = []
|
||||
task_meta: List[Tuple[str, str]] = [] # (dialog_id, statement_id)
|
||||
|
||||
for dialog in dialog_data_list:
|
||||
ctx = self._build_supporting_context(dialog)
|
||||
chunk_stmts = all_stmt_results.get(dialog.id, {})
|
||||
for _chunk_id, stmts in chunk_stmts.items():
|
||||
for stmt in stmts:
|
||||
inp = self._convert_to_triplet_input(stmt, ctx)
|
||||
tasks.append(self.triplet_step.run(inp))
|
||||
task_meta.append((dialog.id, stmt.statement_id))
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
triplet_map: Dict[str, Dict[str, TripletStepOutput]] = {}
|
||||
for i, result in enumerate(results):
|
||||
dialog_id, stmt_id = task_meta[i]
|
||||
if dialog_id not in triplet_map:
|
||||
triplet_map[dialog_id] = {}
|
||||
|
||||
if isinstance(result, BaseException):
|
||||
logger.error(
|
||||
"Triplet extraction failed for statement %s: %s",
|
||||
stmt_id,
|
||||
result,
|
||||
)
|
||||
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
|
||||
else:
|
||||
triplet_map[dialog_id][stmt_id] = result
|
||||
|
||||
return triplet_map
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# 5. Embedding 输入构建器
|
||||
# 为不同阶段构建 EmbeddingStepInput(chunk/statement/entity)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _build_chunk_dialog_embedding_input(
|
||||
dialog_data_list: List[DialogData],
|
||||
) -> EmbeddingStepInput:
|
||||
"""Build embedding input for chunks and dialogs (phase 1)."""
|
||||
chunk_texts: Dict[str, str] = {}
|
||||
dialog_texts: List[str] = []
|
||||
|
||||
for dialog in dialog_data_list:
|
||||
if hasattr(dialog, "content") and dialog.content:
|
||||
dialog_texts.append(dialog.content)
|
||||
for chunk in dialog.chunks:
|
||||
chunk_texts[chunk.id] = chunk.content
|
||||
|
||||
return EmbeddingStepInput(
|
||||
chunk_texts=chunk_texts,
|
||||
dialog_texts=dialog_texts,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_statement_embedding_input(
|
||||
dialog_data_list: List[DialogData],
|
||||
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||
) -> EmbeddingStepInput:
|
||||
"""Build embedding input for statements (phase 2)."""
|
||||
stmt_texts: Dict[str, str] = {}
|
||||
for _dialog_id, chunk_stmts in all_stmt_results.items():
|
||||
for _chunk_id, stmts in chunk_stmts.items():
|
||||
for s in stmts:
|
||||
stmt_texts[s.statement_id] = s.statement_text
|
||||
return EmbeddingStepInput(statement_texts=stmt_texts)
|
||||
|
||||
@staticmethod
|
||||
def _build_entity_embedding_input(
|
||||
all_triplet_results: Dict[str, Dict[str, TripletStepOutput]],
|
||||
) -> EmbeddingStepInput:
|
||||
"""Build embedding input for entities (phase 3)."""
|
||||
entity_names: Dict[str, str] = {}
|
||||
entity_descs: Dict[str, str] = {}
|
||||
seen: set = set()
|
||||
|
||||
for _dialog_id, stmt_triplets in all_triplet_results.items():
|
||||
for _stmt_id, triplet_out in stmt_triplets.items():
|
||||
for ent in triplet_out.entities:
|
||||
key = f"{ent.entity_idx}_{ent.name}"
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
entity_names[key] = ent.name
|
||||
entity_descs[key] = ent.description
|
||||
|
||||
return EmbeddingStepInput(
|
||||
entity_names=entity_names,
|
||||
entity_descriptions=entity_descs,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# 6. 旁路输入构建与结果收集
|
||||
# 为 after_statement / after_triplet 旁路构建输入,合并 embedding 输出
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def _build_after_statement_sidecar_inputs(
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||
) -> List[Tuple[ExtractionStep, Any]]:
|
||||
"""Build (step, input) pairs for after_statement sidecars.
|
||||
|
||||
For emotion extraction, we create a batch wrapper that runs the sidecar
|
||||
on every user statement concurrently and returns a dict of results.
|
||||
"""
|
||||
if not self.after_statement_sidecars:
|
||||
return []
|
||||
|
||||
# Collect all user statements for sidecar processing
|
||||
all_user_stmts: List[StatementStepOutput] = []
|
||||
for _dialog_id, chunk_stmts in all_stmt_results.items():
|
||||
for _chunk_id, stmts in chunk_stmts.items():
|
||||
for s in stmts:
|
||||
if s.speaker == "user":
|
||||
all_user_stmts.append(s)
|
||||
|
||||
pairs: List[Tuple[ExtractionStep, Any]] = []
|
||||
for sidecar in self.after_statement_sidecars:
|
||||
if sidecar.name == "emotion_extraction":
|
||||
# Emotion sidecar: wrap as batch coroutine via a sentinel input
|
||||
# The actual per-statement calls happen inside _run_emotion_batch
|
||||
pairs.append((
|
||||
_EmotionBatchWrapper(sidecar, all_user_stmts),
|
||||
EmotionStepInput(
|
||||
statement_id="__batch__",
|
||||
statement_text="",
|
||||
speaker="",
|
||||
),
|
||||
))
|
||||
else:
|
||||
# Generic sidecar: pass first statement as representative input
|
||||
if all_user_stmts:
|
||||
inp = self._convert_to_emotion_input(all_user_stmts[0])
|
||||
pairs.append((sidecar, inp))
|
||||
|
||||
return pairs
|
||||
|
||||
@staticmethod
|
||||
def _collect_sidecar_outputs(
|
||||
sidecars: List[ExtractionStep],
|
||||
results: List[Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Map sidecar results by step name."""
|
||||
output: Dict[str, Any] = {}
|
||||
for i, sidecar in enumerate(sidecars):
|
||||
if i < len(results):
|
||||
output[sidecar.name] = results[i]
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _merge_embeddings(
|
||||
chunk_dialog: Optional[EmbeddingStepOutput],
|
||||
statement: Optional[EmbeddingStepOutput],
|
||||
entity: Optional[Any],
|
||||
) -> Optional[EmbeddingStepOutput]:
|
||||
"""Merge partial embedding outputs into a single EmbeddingStepOutput."""
|
||||
merged = EmbeddingStepOutput()
|
||||
if chunk_dialog:
|
||||
merged.chunk_embeddings = chunk_dialog.chunk_embeddings
|
||||
merged.dialog_embeddings = chunk_dialog.dialog_embeddings
|
||||
if statement:
|
||||
merged.statement_embeddings = statement.statement_embeddings
|
||||
if entity and isinstance(entity, EmbeddingStepOutput):
|
||||
merged.entity_embeddings = entity.entity_embeddings
|
||||
return merged
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# 7. 数据赋值
|
||||
# 将各阶段 StepOutput 组装为 Statement 对象,替换 chunk.statements
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def _assign_results(
|
||||
self,
|
||||
dialog_data_list: List[DialogData],
|
||||
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
|
||||
all_triplet_results: Dict[str, Dict[str, TripletStepOutput]],
|
||||
emotion_results: Dict[str, EmotionStepOutput],
|
||||
embedding_output: Optional[EmbeddingStepOutput],
|
||||
) -> None:
|
||||
"""Assign extraction results back to dialog_data_list in-place.
|
||||
|
||||
Replaces chunk.statements with new Statement objects built from step
|
||||
outputs, because the new orchestrator generates its own statement IDs
|
||||
that don't match the original chunk statement IDs.
|
||||
"""
|
||||
from app.core.memory.models.message_models import (
|
||||
Statement,
|
||||
TemporalValidityRange,
|
||||
)
|
||||
from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
Entity as TripletEntity,
|
||||
Triplet as TripletRelation,
|
||||
)
|
||||
from app.core.memory.utils.data.ontology import (
|
||||
RelevenceInfo,
|
||||
StatementType,
|
||||
TemporalInfo,
|
||||
)
|
||||
|
||||
# Map string values to enums
|
||||
_STMT_TYPE_MAP = {
|
||||
"FACT": StatementType.FACT,
|
||||
"OPINION": StatementType.OPINION,
|
||||
"PREDICTION": StatementType.PREDICTION,
|
||||
"SUGGESTION": StatementType.SUGGESTION,
|
||||
}
|
||||
_TEMPORAL_MAP = {
|
||||
"STATIC": TemporalInfo.STATIC,
|
||||
"DYNAMIC": TemporalInfo.DYNAMIC,
|
||||
"ATEMPORAL": TemporalInfo.ATEMPORAL,
|
||||
}
|
||||
|
||||
total_stmts = 0
|
||||
assigned_triplets = 0
|
||||
assigned_emotions = 0
|
||||
assigned_stmt_emb = 0
|
||||
assigned_chunk_emb = 0
|
||||
assigned_dialog_emb = 0
|
||||
|
||||
for dialog in dialog_data_list:
|
||||
dialog_stmts = all_stmt_results.get(dialog.id, {})
|
||||
dialog_triplets = all_triplet_results.get(dialog.id, {})
|
||||
|
||||
# Assign dialog embedding
|
||||
if embedding_output and embedding_output.dialog_embeddings:
|
||||
idx = dialog_data_list.index(dialog)
|
||||
if idx < len(embedding_output.dialog_embeddings):
|
||||
dialog.dialog_embedding = embedding_output.dialog_embeddings[idx]
|
||||
assigned_dialog_emb += 1
|
||||
|
||||
for chunk in dialog.chunks:
|
||||
# Assign chunk embedding
|
||||
if embedding_output and chunk.id in embedding_output.chunk_embeddings:
|
||||
chunk.chunk_embedding = embedding_output.chunk_embeddings[chunk.id]
|
||||
assigned_chunk_emb += 1
|
||||
|
||||
# Build new Statement objects from step outputs
|
||||
chunk_stmt_outputs = dialog_stmts.get(chunk.id, [])
|
||||
new_statements = []
|
||||
|
||||
for stmt_out in chunk_stmt_outputs:
|
||||
total_stmts += 1
|
||||
|
||||
# Temporal validity
|
||||
valid_at = stmt_out.valid_at if stmt_out.valid_at != "NULL" else None
|
||||
invalid_at = stmt_out.invalid_at if stmt_out.invalid_at != "NULL" else None
|
||||
|
||||
# Triplet info
|
||||
triplet_info = None
|
||||
triplet_out = dialog_triplets.get(stmt_out.statement_id)
|
||||
if triplet_out and (triplet_out.entities or triplet_out.triplets):
|
||||
entities = [
|
||||
TripletEntity(
|
||||
entity_idx=e.entity_idx,
|
||||
name=e.name,
|
||||
type=e.type,
|
||||
description=e.description,
|
||||
is_explicit_memory=e.is_explicit_memory,
|
||||
)
|
||||
for e in triplet_out.entities
|
||||
]
|
||||
triplets = [
|
||||
TripletRelation(
|
||||
subject_name=t.subject_name,
|
||||
subject_id=t.subject_id,
|
||||
predicate=t.predicate,
|
||||
object_name=t.object_name,
|
||||
object_id=t.object_id,
|
||||
)
|
||||
for t in triplet_out.triplets
|
||||
]
|
||||
triplet_info = TripletExtractionResponse(
|
||||
entities=entities, triplets=triplets,
|
||||
)
|
||||
assigned_triplets += 1
|
||||
|
||||
# Emotion info
|
||||
emo = emotion_results.get(stmt_out.statement_id)
|
||||
emotion_kwargs = {}
|
||||
if emo:
|
||||
emotion_kwargs = {
|
||||
"emotion_type": emo.emotion_type,
|
||||
"emotion_intensity": emo.emotion_intensity,
|
||||
"emotion_keywords": emo.emotion_keywords,
|
||||
}
|
||||
assigned_emotions += 1
|
||||
|
||||
# Statement embedding
|
||||
stmt_embedding = None
|
||||
if (
|
||||
embedding_output
|
||||
and stmt_out.statement_id in embedding_output.statement_embeddings
|
||||
):
|
||||
stmt_embedding = embedding_output.statement_embeddings[stmt_out.statement_id]
|
||||
assigned_stmt_emb += 1
|
||||
|
||||
# Build the Statement object that _create_nodes_and_edges expects
|
||||
stmt = Statement(
|
||||
id=stmt_out.statement_id,
|
||||
chunk_id=chunk.id,
|
||||
end_user_id=dialog.end_user_id,
|
||||
statement=stmt_out.statement_text,
|
||||
speaker=stmt_out.speaker,
|
||||
stmt_type=_STMT_TYPE_MAP.get(stmt_out.statement_type, StatementType.FACT),
|
||||
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
|
||||
relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
|
||||
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
|
||||
triplet_extraction_info=triplet_info,
|
||||
statement_embedding=stmt_embedding,
|
||||
**emotion_kwargs,
|
||||
)
|
||||
new_statements.append(stmt)
|
||||
|
||||
# Replace chunk.statements with newly built objects
|
||||
chunk.statements = new_statements
|
||||
|
||||
logger.info(
|
||||
"Data assignment complete — statements: %d, triplets: %d, "
|
||||
"emotions: %d, stmt_emb: %d, chunk_emb: %d, dialog_emb: %d",
|
||||
total_stmts,
|
||||
assigned_triplets,
|
||||
assigned_emotions,
|
||||
assigned_stmt_emb,
|
||||
assigned_chunk_emb,
|
||||
assigned_dialog_emb,
|
||||
)
|
||||
|
||||
|
||||
class _EmotionBatchWrapper(ExtractionStep):
|
||||
"""情绪批量提取包装器。再考虑一下用法,这是子类?
|
||||
|
||||
将单条情绪旁路 Step 包装为批量并发执行,适配 ``_run_with_sidecars`` 接口。
|
||||
编排器传入一个 sentinel input,``run()`` 忽略它,转而对预收集的 statement 列表
|
||||
逐条并发调用内部 Step,返回 ``{statement_id: EmotionStepOutput}`` 字典。
|
||||
"""
|
||||
|
||||
# ── 初始化 ──
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner_step: ExtractionStep,
|
||||
statements: List[StatementStepOutput],
|
||||
) -> None:
|
||||
# 不调用 super().__init__() — 本类是薄包装,不需要 StepContext
|
||||
self._inner = inner_step
|
||||
self._statements = statements
|
||||
|
||||
# ── Step 身份属性(满足 ExtractionStep 抽象接口) ──
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._inner.name
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_default_output(self) -> Dict[str, EmotionStepOutput]:
|
||||
return {}
|
||||
|
||||
# ── 未使用的生命周期方法(批量包装器不走 render→call→parse 流程) ──
|
||||
|
||||
async def render_prompt(self, input_data: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
async def call_llm(self, prompt: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
async def parse_response(self, raw_response: Any, input_data: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
# ── 批量执行入口 ──
|
||||
|
||||
async def run(self, input_data: Any) -> Dict[str, EmotionStepOutput]:
|
||||
"""对所有预收集的 statement 并发执行情绪提取,单条失败返回默认值。"""
|
||||
if not self._statements:
|
||||
return {}
|
||||
|
||||
async def _extract_one(stmt: StatementStepOutput) -> Tuple[str, EmotionStepOutput]:
|
||||
inp = EmotionStepInput(
|
||||
statement_id=stmt.statement_id,
|
||||
statement_text=stmt.statement_text,
|
||||
speaker=stmt.speaker,
|
||||
)
|
||||
try:
|
||||
result = await self._inner.run(inp)
|
||||
return stmt.statement_id, result
|
||||
except Exception:
|
||||
return stmt.statement_id, self._inner.get_default_output()
|
||||
|
||||
pairs = await asyncio.gather(
|
||||
*[_extract_one(s) for s in self._statements]
|
||||
)
|
||||
return dict(pairs)
|
||||
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
GraphBuildStep — 从 DialogData 构建 Neo4j 图节点和边。
|
||||
|
||||
职责:
|
||||
- 遍历 DialogData 列表,构建 DialogueNode、ChunkNode、StatementNode、
|
||||
ExtractedEntityNode、PerceptualNode 及各类 Edge
|
||||
- 不涉及 LLM 调用、去重、Neo4j 写入
|
||||
|
||||
依赖:
|
||||
- embedder_client(可选):为 PerceptualNode 生成 summary embedding
|
||||
- progress_callback(可选):流式输出关系创建进度
|
||||
|
||||
从 ExtractionOrchestrator._create_nodes_and_edges() 提取而来,
|
||||
旧编排器保留原方法不变,新旧流水线完全隔离。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
ChunkNode,
|
||||
DialogueNode,
|
||||
EntityEntityEdge,
|
||||
ExtractedEntityNode,
|
||||
PerceptualEdge,
|
||||
PerceptualNode,
|
||||
StatementChunkEdge,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData, TemporalInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphBuildResult:
|
||||
"""图构建步骤的输出。"""
|
||||
|
||||
__slots__ = (
|
||||
"dialogue_nodes",
|
||||
"chunk_nodes",
|
||||
"statement_nodes",
|
||||
"entity_nodes",
|
||||
"perceptual_nodes",
|
||||
"stmt_chunk_edges",
|
||||
"stmt_entity_edges",
|
||||
"entity_entity_edges",
|
||||
"perceptual_edges",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
perceptual_nodes: List[PerceptualNode],
|
||||
stmt_chunk_edges: List[StatementChunkEdge],
|
||||
stmt_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
perceptual_edges: List[PerceptualEdge],
|
||||
):
|
||||
self.dialogue_nodes = dialogue_nodes
|
||||
self.chunk_nodes = chunk_nodes
|
||||
self.statement_nodes = statement_nodes
|
||||
self.entity_nodes = entity_nodes
|
||||
self.perceptual_nodes = perceptual_nodes
|
||||
self.stmt_chunk_edges = stmt_chunk_edges
|
||||
self.stmt_entity_edges = stmt_entity_edges
|
||||
self.entity_entity_edges = entity_entity_edges
|
||||
self.perceptual_edges = perceptual_edges
|
||||
|
||||
|
||||
async def build_graph_nodes_and_edges(
|
||||
dialog_data_list: List[DialogData],
|
||||
embedder_client: Any = None,
|
||||
progress_callback: Optional[
|
||||
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||
] = None,
|
||||
) -> GraphBuildResult:
|
||||
"""
|
||||
从 DialogData 列表构建完整的图节点和边。
|
||||
|
||||
Args:
|
||||
dialog_data_list: 经过萃取和数据赋值后的 DialogData 列表
|
||||
embedder_client: 可选的嵌入客户端,用于 PerceptualNode summary embedding
|
||||
progress_callback: 可选的进度回调
|
||||
|
||||
Returns:
|
||||
GraphBuildResult 包含所有节点和边
|
||||
"""
|
||||
logger.info("开始创建节点和边")
|
||||
|
||||
dialogue_nodes: List[DialogueNode] = []
|
||||
chunk_nodes: List[ChunkNode] = []
|
||||
statement_nodes: List[StatementNode] = []
|
||||
entity_nodes: List[ExtractedEntityNode] = []
|
||||
perceptual_nodes: List[PerceptualNode] = []
|
||||
stmt_chunk_edges: List[StatementChunkEdge] = []
|
||||
stmt_entity_edges: List[StatementEntityEdge] = []
|
||||
entity_entity_edges: List[EntityEntityEdge] = []
|
||||
perceptual_edges: List[PerceptualEdge] = []
|
||||
|
||||
entity_id_set: set = set()
|
||||
total_dialogs = len(dialog_data_list)
|
||||
processed_dialogs = 0
|
||||
|
||||
for dialog_data in dialog_data_list:
|
||||
processed_dialogs += 1
|
||||
# region TODO 乐力齐 重构流水线切换生产环境稳定后修改
|
||||
# ── 对话节点 ──
|
||||
dialogue_node = DialogueNode(
|
||||
id=dialog_data.id,
|
||||
name=f"Dialog_{dialog_data.id}",
|
||||
ref_id=dialog_data.ref_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
content=dialog_data.context.content if dialog_data.context else "",
|
||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
metadata=dialog_data.metadata,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||
)
|
||||
dialogue_nodes.append(dialogue_node)
|
||||
|
||||
# ── 分块节点 ──
|
||||
for chunk_idx, chunk in enumerate(dialog_data.chunks):
|
||||
chunk_node = ChunkNode(
|
||||
id=chunk.id,
|
||||
name=f"Chunk_{chunk.id}",
|
||||
dialog_id=dialog_data.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
content=chunk.content,
|
||||
speaker=getattr(chunk, "speaker", None),
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
sequence_number=chunk_idx,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
chunk_nodes.append(chunk_node)
|
||||
|
||||
# ── 感知节点 ──
|
||||
for p, file_type in chunk.files:
|
||||
meta = p.meta_data or {}
|
||||
content_meta = meta.get("content", {})
|
||||
|
||||
summary_embedding = None
|
||||
if embedder_client and p.summary:
|
||||
try:
|
||||
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
||||
except Exception as emb_err:
|
||||
logger.warning(f"Failed to embed perceptual summary: {emb_err}")
|
||||
|
||||
perceptual = PerceptualNode(
|
||||
name=f"Perceptual_{p.id}",
|
||||
id=str(p.id),
|
||||
end_user_id=str(p.end_user_id),
|
||||
perceptual_type=p.perceptual_type,
|
||||
file_path=p.file_path or "",
|
||||
file_name=p.file_name or "",
|
||||
file_ext=p.file_ext or "",
|
||||
summary=p.summary or "",
|
||||
keywords=content_meta.get("keywords", []),
|
||||
topic=content_meta.get("topic", ""),
|
||||
domain=content_meta.get("domain", ""),
|
||||
created_at=p.created_time.isoformat() if p.created_time else None,
|
||||
file_type=file_type,
|
||||
summary_embedding=summary_embedding,
|
||||
)
|
||||
perceptual_nodes.append(perceptual)
|
||||
perceptual_edges.append(
|
||||
PerceptualEdge(
|
||||
source=perceptual.id,
|
||||
target=chunk.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
# ── 陈述句节点 + 边 ──
|
||||
for statement in chunk.statements:
|
||||
statement_node = StatementNode(
|
||||
id=statement.id,
|
||||
name=f"Statement_{statement.id}",
|
||||
chunk_id=chunk.id,
|
||||
stmt_type=getattr(statement, "stmt_type", "general"),
|
||||
temporal_info=getattr(statement, "temporal_info", TemporalInfo.ATEMPORAL),
|
||||
connect_strength=(
|
||||
statement.connect_strength
|
||||
if statement.connect_strength is not None
|
||||
else "Strong"
|
||||
),
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
statement=statement.statement,
|
||||
speaker=getattr(statement, "speaker", None),
|
||||
statement_embedding=statement.statement_embedding,
|
||||
valid_at=(
|
||||
statement.temporal_validity.valid_at
|
||||
if hasattr(statement, "temporal_validity") and statement.temporal_validity
|
||||
else None
|
||||
),
|
||||
invalid_at=(
|
||||
statement.temporal_validity.invalid_at
|
||||
if hasattr(statement, "temporal_validity") and statement.temporal_validity
|
||||
else None
|
||||
),
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||
emotion_type=getattr(statement, "emotion_type", None),
|
||||
emotion_intensity=getattr(statement, "emotion_intensity", None),
|
||||
emotion_keywords=getattr(statement, "emotion_keywords", None),
|
||||
emotion_subject=getattr(statement, "emotion_subject", None),
|
||||
emotion_target=getattr(statement, "emotion_target", None),
|
||||
)
|
||||
statement_nodes.append(statement_node)
|
||||
|
||||
stmt_chunk_edges.append(
|
||||
StatementChunkEdge(
|
||||
source=statement.id,
|
||||
target=chunk.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
# ── 三元组 → 实体节点 + 边 ──
|
||||
if not statement.triplet_extraction_info:
|
||||
continue
|
||||
|
||||
triplet_info = statement.triplet_extraction_info
|
||||
entity_idx_to_id: Dict[int, str] = {}
|
||||
|
||||
for entity_idx, entity in enumerate(triplet_info.entities):
|
||||
entity_idx_to_id[entity.entity_idx] = entity.id
|
||||
entity_idx_to_id[entity_idx] = entity.id
|
||||
|
||||
if entity.id not in entity_id_set:
|
||||
entity_connect_strength = getattr(entity, "connect_strength", "Strong")
|
||||
entity_node = ExtractedEntityNode(
|
||||
id=entity.id,
|
||||
name=getattr(entity, "name", f"Entity_{entity.id}"),
|
||||
entity_idx=entity.entity_idx,
|
||||
statement_id=statement.id,
|
||||
entity_type=getattr(entity, "type", "unknown"),
|
||||
description=getattr(entity, "description", ""),
|
||||
example=getattr(entity, "example", ""),
|
||||
connect_strength=(
|
||||
entity_connect_strength
|
||||
if entity_connect_strength is not None
|
||||
else "Strong"
|
||||
),
|
||||
aliases=getattr(entity, "aliases", []) or [],
|
||||
name_embedding=getattr(entity, "name_embedding", None),
|
||||
is_explicit_memory=getattr(entity, "is_explicit_memory", False),
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||
)
|
||||
entity_nodes.append(entity_node)
|
||||
entity_id_set.add(entity.id)
|
||||
|
||||
entity_connect_strength = getattr(entity, "connect_strength", "Strong")
|
||||
stmt_entity_edges.append(
|
||||
StatementEntityEdge(
|
||||
source=statement.id,
|
||||
target=entity.id,
|
||||
connect_strength=(
|
||||
entity_connect_strength
|
||||
if entity_connect_strength is not None
|
||||
else "Strong"
|
||||
),
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
)
|
||||
# endregion
|
||||
|
||||
for triplet in triplet_info.triplets:
|
||||
subject_entity_id = entity_idx_to_id.get(triplet.subject_id)
|
||||
object_entity_id = entity_idx_to_id.get(triplet.object_id)
|
||||
|
||||
if subject_entity_id and object_entity_id:
|
||||
entity_entity_edges.append(
|
||||
EntityEntityEdge(
|
||||
source=subject_entity_id,
|
||||
target=object_entity_id,
|
||||
relation_type=triplet.predicate,
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
)
|
||||
)
|
||||
|
||||
if progress_callback and len(entity_entity_edges) <= 10:
|
||||
relationship_result = {
|
||||
"result_type": "relationship_creation",
|
||||
"relationship_index": len(entity_entity_edges),
|
||||
"source_entity": triplet.subject_name,
|
||||
"relation_type": triplet.predicate,
|
||||
"target_entity": triplet.object_name,
|
||||
"relationship_text": f"{triplet.subject_name} -[{triplet.predicate}]-> {triplet.object_name}",
|
||||
"dialog_progress": f"{processed_dialogs}/{total_dialogs}",
|
||||
}
|
||||
await progress_callback(
|
||||
"creating_nodes_edges_result",
|
||||
f"关系创建中 ({processed_dialogs}/{total_dialogs})",
|
||||
relationship_result,
|
||||
)
|
||||
else:
|
||||
missing_subject = "subject" if not subject_entity_id else ""
|
||||
missing_object = "object" if not object_entity_id else ""
|
||||
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
|
||||
logger.debug(
|
||||
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
|
||||
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
|
||||
f"object_id={triplet.object_id} ({triplet.object_name}), "
|
||||
f"predicate={triplet.predicate}, "
|
||||
f"statement_id={statement.id}, "
|
||||
f"available_indices={sorted(entity_idx_to_id.keys())}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"节点和边创建完成 - 对话节点: {len(dialogue_nodes)}, "
|
||||
f"分块节点: {len(chunk_nodes)}, 陈述句节点: {len(statement_nodes)}, "
|
||||
f"实体节点: {len(entity_nodes)}, 陈述句-分块边: {len(stmt_chunk_edges)}, "
|
||||
f"陈述句-实体边: {len(stmt_entity_edges)}, "
|
||||
f"实体-实体边: {len(entity_entity_edges)}"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
nodes_edges_stats = {
|
||||
"dialogue_nodes_count": len(dialogue_nodes),
|
||||
"chunk_nodes_count": len(chunk_nodes),
|
||||
"statement_nodes_count": len(statement_nodes),
|
||||
"entity_nodes_count": len(entity_nodes),
|
||||
"statement_chunk_edges_count": len(stmt_chunk_edges),
|
||||
"statement_entity_edges_count": len(stmt_entity_edges),
|
||||
"entity_entity_edges_count": len(entity_entity_edges),
|
||||
}
|
||||
await progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
|
||||
|
||||
return GraphBuildResult(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
perceptual_nodes=perceptual_nodes,
|
||||
stmt_chunk_edges=stmt_chunk_edges,
|
||||
stmt_entity_edges=stmt_entity_edges,
|
||||
entity_entity_edges=entity_entity_edges,
|
||||
perceptual_edges=perceptual_edges,
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Schema package for ExtractionStep inputs and outputs.
|
||||
|
||||
Re-exports all models for convenient access:
|
||||
from .schema import StatementStepInput, EmotionStepOutput, ...
|
||||
"""
|
||||
|
||||
from .extraction_step_schema import (
|
||||
EmbeddingStepInput,
|
||||
EmbeddingStepOutput,
|
||||
EntityItem,
|
||||
MessageItem,
|
||||
StatementStepInput,
|
||||
StatementStepOutput,
|
||||
SupportingContext,
|
||||
TripletItem,
|
||||
TripletStepInput,
|
||||
TripletStepOutput,
|
||||
)
|
||||
from .sidecar_step_schema import (
|
||||
EmotionStepInput,
|
||||
EmotionStepOutput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Shared
|
||||
"MessageItem",
|
||||
"SupportingContext",
|
||||
# Statement
|
||||
"StatementStepInput",
|
||||
"StatementStepOutput",
|
||||
# Triplet
|
||||
"TripletStepInput",
|
||||
"TripletStepOutput",
|
||||
"EntityItem",
|
||||
"TripletItem",
|
||||
# Embedding
|
||||
"EmbeddingStepInput",
|
||||
"EmbeddingStepOutput",
|
||||
# Sidecar — Emotion
|
||||
"EmotionStepInput",
|
||||
"EmotionStepOutput",
|
||||
]
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Pydantic models for base extraction pipeline inputs and outputs.
|
||||
|
||||
Covers the core (critical) stages: Statement extraction, Triplet extraction,
|
||||
Embedding generation, and shared types used across stages.
|
||||
|
||||
Malformed LLM JSON will raise ``ValidationError`` and trigger stage-level retry.
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Shared types ──
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Single conversation message."""
|
||||
|
||||
role: str # "User" / "Assistant"
|
||||
msg: str
|
||||
|
||||
|
||||
class SupportingContext(BaseModel):
|
||||
"""Dialogue context window (used for pronoun resolution, etc.)."""
|
||||
|
||||
msgs: List[MessageItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Statement extraction ──
|
||||
class StatementStepInput(BaseModel):
|
||||
"""Input for StatementExtractionStep."""
|
||||
|
||||
chunk_id: str
|
||||
end_user_id: str
|
||||
target_content: str
|
||||
target_message_date: str
|
||||
supporting_context: SupportingContext
|
||||
|
||||
|
||||
class StatementStepOutput(BaseModel):
|
||||
"""Single extracted statement (including temporal info)."""
|
||||
|
||||
statement_id: str
|
||||
statement_text: str
|
||||
statement_type: str # FACT / OPINION / PREDICTION / SUGGESTION
|
||||
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
|
||||
relevance: str # RELEVANT / IRRELEVANT
|
||||
speaker: str # "user" / "assistant"
|
||||
valid_at: str # ISO 8601 or "NULL"
|
||||
invalid_at: str # ISO 8601 or "NULL"
|
||||
|
||||
|
||||
# ── Triplet extraction ──
|
||||
class TripletStepInput(BaseModel):
|
||||
"""Input for TripletExtractionStep."""
|
||||
|
||||
statement_id: str
|
||||
statement_text: str
|
||||
statement_type: str
|
||||
temporal_type: str
|
||||
supporting_context: SupportingContext
|
||||
speaker: str
|
||||
valid_at: str
|
||||
invalid_at: str
|
||||
|
||||
|
||||
class EntityItem(BaseModel):
|
||||
"""Single entity extracted during triplet extraction."""
|
||||
|
||||
entity_idx: int
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
is_explicit_memory: bool = False
|
||||
|
||||
|
||||
class TripletItem(BaseModel):
|
||||
"""Single triplet (subject-predicate-object) relationship."""
|
||||
|
||||
subject_name: str
|
||||
subject_id: int
|
||||
predicate: str
|
||||
object_name: str
|
||||
object_id: int
|
||||
|
||||
|
||||
class TripletStepOutput(BaseModel):
|
||||
"""Output of TripletExtractionStep."""
|
||||
|
||||
entities: List[EntityItem] = Field(default_factory=list)
|
||||
triplets: List[TripletItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Embedding generation ──
|
||||
class EmbeddingStepInput(BaseModel):
|
||||
"""Input for EmbeddingStep.
|
||||
|
||||
Each dict maps an ID to the text that should be embedded.
|
||||
Fields can be left empty for partial embedding runs.
|
||||
"""
|
||||
|
||||
statement_texts: Dict[str, str] = Field(default_factory=dict)
|
||||
chunk_texts: Dict[str, str] = Field(default_factory=dict)
|
||||
dialog_texts: List[str] = Field(default_factory=list)
|
||||
entity_names: Dict[str, str] = Field(default_factory=dict)
|
||||
entity_descriptions: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EmbeddingStepOutput(BaseModel):
|
||||
"""Output of EmbeddingStep."""
|
||||
|
||||
statement_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
|
||||
chunk_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
|
||||
dialog_embeddings: List[List[float]] = Field(default_factory=list)
|
||||
entity_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Pydantic models for hot-pluggable sidecar step inputs and outputs.
|
||||
|
||||
Sidecar steps are non-critical (is_critical=False) modules registered via
|
||||
``@SidecarStepFactory.register`` that run concurrently alongside the main
|
||||
extraction pipeline. Failures degrade gracefully to default outputs.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Emotion extraction (sidecar) ──
|
||||
class EmotionStepInput(BaseModel):
|
||||
"""Input for EmotionExtractionStep."""
|
||||
|
||||
statement_id: str
|
||||
statement_text: str
|
||||
speaker: str
|
||||
|
||||
|
||||
class EmotionStepOutput(BaseModel):
|
||||
"""Output of EmotionExtractionStep."""
|
||||
|
||||
emotion_type: str = "neutral"
|
||||
emotion_intensity: float = 0.0
|
||||
emotion_keywords: List[str] = Field(default_factory=list)
|
||||
@@ -0,0 +1,97 @@
|
||||
"""SidecarStepFactory — decorator-based registry for sidecar (non-critical) steps.
|
||||
|
||||
New sidecar modules self-register via ``@SidecarStepFactory.register`` and are
|
||||
automatically discovered and instantiated by the orchestrator without any
|
||||
changes to orchestrator code.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
from .base import ExtractionStep, StepContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SidecarTiming(str, Enum):
|
||||
"""Declares when a sidecar step runs relative to the main pipeline."""
|
||||
|
||||
AFTER_STATEMENT = "after_statement"
|
||||
AFTER_TRIPLET = "after_triplet"
|
||||
|
||||
|
||||
class SidecarStepFactory:
|
||||
"""Factory that manages sidecar step registration and creation.
|
||||
|
||||
Registry maps ``config_key`` → ``(step_class, timing)``.
|
||||
Adding a new sidecar only requires the ``@register`` decorator on the
|
||||
step class — no orchestrator modifications needed.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Tuple[Type[ExtractionStep], SidecarTiming]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, config_key: str, timing: SidecarTiming):
|
||||
"""Class decorator that registers a sidecar step.
|
||||
|
||||
Args:
|
||||
config_key: Configuration flag name (e.g. ``"emotion_enabled"``).
|
||||
The step is instantiated only when this flag is ``True``.
|
||||
timing: When the sidecar runs relative to the main pipeline.
|
||||
|
||||
Returns:
|
||||
The original class, unmodified.
|
||||
"""
|
||||
|
||||
def decorator(step_class: Type[ExtractionStep]):
|
||||
cls._registry[config_key] = (step_class, timing)
|
||||
logger.debug(
|
||||
"Registered sidecar '%s' (config_key=%s, timing=%s)",
|
||||
step_class.__name__,
|
||||
config_key,
|
||||
timing.value,
|
||||
)
|
||||
return step_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create_sidecars(
|
||||
cls, config: Any, context: StepContext
|
||||
) -> Dict[SidecarTiming, List[ExtractionStep]]:
|
||||
"""Instantiate enabled sidecar steps, grouped by timing.
|
||||
|
||||
Args:
|
||||
config: Pipeline configuration object. Each registered
|
||||
``config_key`` is looked up via ``getattr(config, key, False)``.
|
||||
context: Shared :class:`StepContext` injected into every step.
|
||||
|
||||
Returns:
|
||||
A dict keyed by :class:`SidecarTiming`, each value a list of
|
||||
instantiated sidecar steps whose config flag is ``True``.
|
||||
"""
|
||||
result: Dict[SidecarTiming, List[ExtractionStep]] = {
|
||||
timing: [] for timing in SidecarTiming
|
||||
}
|
||||
for config_key, (step_class, timing) in cls._registry.items():
|
||||
if getattr(config, config_key, False):
|
||||
step = step_class(context)
|
||||
result[timing].append(step)
|
||||
logger.debug(
|
||||
"Created sidecar '%s' (timing=%s)",
|
||||
step_class.__name__,
|
||||
timing.value,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Skipped sidecar '%s' (config_key=%s is disabled)",
|
||||
step_class.__name__,
|
||||
config_key,
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def clear_registry(cls) -> None:
|
||||
"""Remove all registered sidecars. Useful for testing."""
|
||||
cls._registry.clear()
|
||||
@@ -0,0 +1,141 @@
|
||||
"""StatementExtractionStep — critical step for extracting statements from chunks.
|
||||
|
||||
Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm.
|
||||
Temporal extraction logic (valid_at / invalid_at) is merged into this step,
|
||||
eliminating the need for a separate ``TemporalExtractor`` call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
|
||||
from .base import ExtractionStep, StepContext
|
||||
from .schema import StatementStepInput, StatementStepOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── LLM response schemas (internal) ──
|
||||
|
||||
|
||||
class _ExtractedStatement(BaseModel):
|
||||
"""Raw statement returned by the LLM (before enrichment)."""
|
||||
|
||||
statement: str = Field(..., description="The extracted statement text")
|
||||
statement_type: str = Field(..., description="FACT / OPINION / SUGGESTION / PREDICTION")
|
||||
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
|
||||
relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
|
||||
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||
|
||||
|
||||
class _StatementExtractionResponse(BaseModel):
|
||||
"""Structured LLM response containing a list of extracted statements."""
|
||||
|
||||
statements: List[_ExtractedStatement] = Field(default_factory=list)
|
||||
|
||||
@field_validator("statements", mode="before")
|
||||
@classmethod
|
||||
def filter_empty(cls, v: Any) -> Any:
|
||||
"""Drop empty / malformed dicts that the LLM occasionally produces."""
|
||||
if isinstance(v, list):
|
||||
return [s for s in v if isinstance(s, dict) and s.get("statement")]
|
||||
return v
|
||||
|
||||
|
||||
class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
|
||||
"""Extract atomic statements (with temporal info) from a dialogue chunk.
|
||||
|
||||
This is a **critical** step — failure aborts the pipeline after retries.
|
||||
|
||||
Config params bound at init (from ``StepContext.config.statement_extraction``):
|
||||
* ``definitions`` — label definitions for statement classification
|
||||
* ``json_schema`` — JSON schema for the expected LLM output
|
||||
* ``granularity`` — extraction granularity level (1-3)
|
||||
* ``include_dialogue_context`` — whether to include full dialogue context
|
||||
"""
|
||||
|
||||
def __init__(self, context: StepContext) -> None:
|
||||
super().__init__(context)
|
||||
stmt_cfg = getattr(self.config, "statement_extraction", None)
|
||||
self.definitions = LABEL_DEFINITIONS
|
||||
self.json_schema = _ExtractedStatement.model_json_schema()
|
||||
self.granularity = getattr(stmt_cfg, "statement_granularity", None)
|
||||
self.include_dialogue_context = getattr(stmt_cfg, "include_dialogue_context", True)
|
||||
self.max_dialogue_context_chars = getattr(stmt_cfg, "max_dialogue_context_chars", 2000)
|
||||
|
||||
# ── Identity ──
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "statement_extraction"
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
return True
|
||||
|
||||
# ── Lifecycle ──
|
||||
|
||||
async def render_prompt(self, input_data: StatementStepInput) -> str:
|
||||
# Build optional dialogue context from supporting_context messages
|
||||
dialogue_content = None
|
||||
if self.include_dialogue_context and input_data.supporting_context.msgs:
|
||||
dialogue_content = "\n".join(
|
||||
f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs
|
||||
)
|
||||
|
||||
return await render_statement_extraction_prompt(
|
||||
chunk_content=input_data.target_content,
|
||||
definitions=self.definitions,
|
||||
json_schema=self.json_schema,
|
||||
granularity=self.granularity,
|
||||
include_dialogue_context=self.include_dialogue_context,
|
||||
dialogue_content=dialogue_content,
|
||||
max_dialogue_chars=self.max_dialogue_context_chars,
|
||||
language=self.language,
|
||||
)
|
||||
|
||||
async def call_llm(self, prompt: Any) -> Any:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert at extracting and labeling atomic statements "
|
||||
"from conversational text. Return valid JSON conforming to the schema."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
return await self.llm_client.response_structured(
|
||||
messages, _StatementExtractionResponse
|
||||
)
|
||||
|
||||
async def parse_response(
|
||||
self, raw_response: Any, input_data: StatementStepInput
|
||||
) -> List[StatementStepOutput]:
|
||||
if not hasattr(raw_response, "statements") or raw_response.statements is None:
|
||||
return []
|
||||
|
||||
results: List[StatementStepOutput] = []
|
||||
for stmt in raw_response.statements:
|
||||
results.append(
|
||||
StatementStepOutput(
|
||||
statement_id=uuid.uuid4().hex,
|
||||
statement_text=stmt.statement,
|
||||
statement_type=stmt.statement_type.strip().upper(),
|
||||
temporal_type=stmt.temporal_type.strip().upper(),
|
||||
relevance=stmt.relevance.strip().upper(),
|
||||
speaker="user", # default; orchestrator overrides from chunk metadata
|
||||
valid_at=stmt.valid_at or "NULL",
|
||||
invalid_at=stmt.invalid_at or "NULL",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
def get_default_output(self) -> List[StatementStepOutput]:
|
||||
return []
|
||||
@@ -0,0 +1,118 @@
|
||||
"""TripletExtractionStep — critical step for extracting entities and triplets.
|
||||
|
||||
Replaces the legacy ``TripletExtractor`` with the unified ExtractionStep paradigm.
|
||||
Predicate filtering against the ontology whitelist is performed in ``parse_response``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
|
||||
from .base import ExtractionStep, StepContext
|
||||
from .schema import EntityItem, TripletItem, TripletStepInput, TripletStepOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput]):
|
||||
"""Extract knowledge triplets and entities from a single statement.
|
||||
|
||||
This is a **critical** step — failure aborts the pipeline after retries.
|
||||
|
||||
Config params bound at init (from ``StepContext.config``):
|
||||
* ``ontology_types`` — predefined ontology types for entity classification
|
||||
* ``predicate_instructions`` — predicate definition guidance for the LLM
|
||||
* ``json_schema`` — JSON schema for the expected LLM output
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: StepContext,
|
||||
ontology_types: Any = None,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.ontology_types = ontology_types
|
||||
self.predicate_instructions = PREDICATE_DEFINITIONS
|
||||
self.json_schema = TripletExtractionResponse.model_json_schema()
|
||||
self._allowed_predicates = {p.value for p in Predicate}
|
||||
|
||||
# ── Identity ──
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "triplet_extraction"
|
||||
|
||||
@property
|
||||
def is_critical(self) -> bool:
|
||||
return True
|
||||
|
||||
# ── Lifecycle ──
|
||||
|
||||
async def render_prompt(self, input_data: TripletStepInput) -> str:
|
||||
# Build chunk_content from supporting_context for pronoun resolution
|
||||
chunk_content = "\n".join(
|
||||
f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs
|
||||
) if input_data.supporting_context.msgs else ""
|
||||
|
||||
return await render_triplet_extraction_prompt(
|
||||
statement=input_data.statement_text,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=self.json_schema,
|
||||
predicate_instructions=self.predicate_instructions,
|
||||
language=self.language,
|
||||
ontology_types=self.ontology_types,
|
||||
speaker=input_data.speaker,
|
||||
)
|
||||
|
||||
async def call_llm(self, prompt: Any) -> Any:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert at extracting knowledge triplets and entities "
|
||||
"from text. Follow the provided instructions carefully and return valid JSON."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
return await self.llm_client.response_structured(
|
||||
messages, TripletExtractionResponse
|
||||
)
|
||||
|
||||
async def parse_response(
|
||||
self, raw_response: Any, input_data: TripletStepInput
|
||||
) -> TripletStepOutput:
|
||||
if not hasattr(raw_response, "triplets"):
|
||||
return self.get_default_output()
|
||||
|
||||
# Filter triplets to allowed predicates from ontology whitelist
|
||||
filtered_triplets = [
|
||||
TripletItem(
|
||||
subject_name=t.subject_name,
|
||||
subject_id=t.subject_id,
|
||||
predicate=t.predicate,
|
||||
object_name=t.object_name,
|
||||
object_id=t.object_id,
|
||||
)
|
||||
for t in raw_response.triplets
|
||||
if getattr(t, "predicate", "") in self._allowed_predicates
|
||||
]
|
||||
|
||||
entities = [
|
||||
EntityItem(
|
||||
entity_idx=e.entity_idx,
|
||||
name=e.name,
|
||||
type=e.type,
|
||||
description=e.description,
|
||||
is_explicit_memory=getattr(e, "is_explicit_memory", False),
|
||||
)
|
||||
for e in (raw_response.entities or [])
|
||||
]
|
||||
|
||||
return TripletStepOutput(entities=entities, triplets=filtered_triplets)
|
||||
|
||||
def get_default_output(self) -> TripletStepOutput:
|
||||
return TripletStepOutput(entities=[], triplets=[])
|
||||
@@ -421,6 +421,9 @@ class MemoryConfig:
|
||||
pruning_scene: Optional[str] = "education"
|
||||
pruning_threshold: float = 0.5
|
||||
|
||||
# Pipeline config: Emotion extraction
|
||||
emotion_enabled: bool = False
|
||||
|
||||
# Ontology scene association
|
||||
scene_id: Optional[UUID] = None
|
||||
ontology_class_infos: list[dict] = field(default_factory=list)
|
||||
|
||||
@@ -360,40 +360,64 @@ class MemoryAgentService:
|
||||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return "success"
|
||||
else:
|
||||
await write_neo4j(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
memory_config=memory_config,
|
||||
ref_id='',
|
||||
language=language
|
||||
)
|
||||
|
||||
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
|
||||
# TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码
|
||||
import os
|
||||
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
|
||||
try:
|
||||
from app.core.memory.memory_service import MemoryService
|
||||
import copy
|
||||
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
|
||||
|
||||
shadow_messages = copy.deepcopy(messages)
|
||||
shadow_service = MemoryService(
|
||||
memory_config=memory_config,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
shadow_result = await shadow_service.write(
|
||||
messages=shadow_messages,
|
||||
language=language,
|
||||
ref_id='',
|
||||
is_pilot_run=True, # 试运行模式:只萃取不写入,避免重复写入 Neo4j
|
||||
)
|
||||
logger.info(
|
||||
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
|
||||
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
|
||||
f"extraction={shadow_result.extraction}"
|
||||
)
|
||||
except Exception as shadow_err:
|
||||
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
|
||||
# ── 影子运行结束 ──
|
||||
if use_new_pipeline:
|
||||
# ── 新流水线:WritePipeline + NewExtractionOrchestrator ──
|
||||
from app.core.memory.memory_service import MemoryService
|
||||
|
||||
service = MemoryService(
|
||||
memory_config=memory_config,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result = await service.write(
|
||||
messages=messages,
|
||||
language=language,
|
||||
ref_id='',
|
||||
is_pilot_run=False,
|
||||
)
|
||||
logger.info(
|
||||
f"[NewPipeline] 完成: status={result.status}, "
|
||||
f"elapsed={result.elapsed_seconds:.2f}s, "
|
||||
f"extraction={result.extraction}"
|
||||
)
|
||||
else:
|
||||
# ── 旧流水线:write_tools.write() + ExtractionOrchestrator ──
|
||||
await write_neo4j(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
memory_config=memory_config,
|
||||
ref_id='',
|
||||
language=language
|
||||
)
|
||||
|
||||
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
|
||||
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
|
||||
try:
|
||||
from app.core.memory.memory_service import MemoryService
|
||||
import copy
|
||||
|
||||
shadow_messages = copy.deepcopy(messages)
|
||||
shadow_service = MemoryService(
|
||||
memory_config=memory_config,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
shadow_result = await shadow_service.write(
|
||||
messages=shadow_messages,
|
||||
language=language,
|
||||
ref_id='',
|
||||
is_pilot_run=True,
|
||||
)
|
||||
logger.info(
|
||||
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
|
||||
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
|
||||
f"extraction={shadow_result.extraction}"
|
||||
)
|
||||
except Exception as shadow_err:
|
||||
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
|
||||
# ── 影子运行结束 ──
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id, lang
|
||||
|
||||
@@ -418,6 +418,9 @@ class MemoryConfigService:
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(
|
||||
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Pipeline config: Emotion extraction
|
||||
emotion_enabled=bool(
|
||||
memory_config.emotion_enabled) if memory_config.emotion_enabled is not None else False,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||
@@ -573,6 +576,7 @@ class MemoryConfigService:
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
emotion_enabled=getattr(memory_config, "emotion_enabled", False),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user