feat(memory): implement step-based extraction pipeline architecture

Introduce ExtractionStep abstraction with modular pipeline stages:
- Add base ExtractionStep class with render/call/parse lifecycle
- Implement StatementExtractionStep, TripletExtractionStep,
  EmbeddingStep, EmotionStep, GraphBuildStep, and DedupStep
- Add SidecarStepFactory for hot-pluggable non-critical steps
- Define Pydantic I/O schemas for all pipeline stages
- Refactor WritePipeline to orchestrate new step-based flow
- Add NEW_PIPELINE_ENABLED env switch for old/new pipeline routing
- Add emotion_enabled config flag to MemoryConfig
- Fix workspace_id reference in get_end_user_connected_config
This commit is contained in:
lanceyq
2026-04-23 15:47:46 +08:00
parent 41535c34e6
commit a98011fc8a
20 changed files with 3102 additions and 144 deletions

View File

@@ -1,7 +1,4 @@
import os
import json
from typing import List
from datetime import datetime
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
@@ -34,6 +31,7 @@ async def get_chunked_dialogs(
conversation_messages = []
# step1: 消息格式校验 roleuser、assistant。content
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
@@ -59,7 +57,7 @@ async def get_chunked_dialogs(
config_id=config_id
)
# 语义剪枝步骤(在分块之前)
# step2: 语义剪枝步骤(在分块之前)
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
from app.core.memory.models.config_models import PruningConfig
@@ -116,6 +114,7 @@ async def get_chunked_dialogs(
except Exception as e:
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
# step3 分块
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks

View File

@@ -147,7 +147,85 @@ async def write(
all_perceptual_edges,
all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
# region TODO 乐力齐 重构流水线切换至生产环境稳定后,移除快照对比代码
# ── Snapshot: 旧流水线萃取结果(按 phase2_step_io_schema_v1.md 格式) ──
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
snapshot = PipelineSnapshot("legacy")
# Statement 输出(从 dialog_data_list 中提取)
stmt_snapshot = []
for d in all_dedup_details:
if not hasattr(d, "chunks"):
continue
for c in d.chunks:
for s in c.statements:
stmt_snapshot.append({
"statement_id": s.id,
"statement_text": s.statement,
"statement_type": str(getattr(s, "stmt_type", "")),
"temporal_type": str(getattr(s, "temporal_info", "")),
"relevance": str(getattr(s, "relevence_info", "RELEVANT")),
"speaker": getattr(s, "speaker", "user") or "user",
"valid_at": s.temporal_validity.valid_at if s.temporal_validity else "NULL",
"invalid_at": s.temporal_validity.invalid_at if s.temporal_validity else "NULL",
})
snapshot.save_stage("2_statement_outputs", stmt_snapshot)
# Triplet 输出(从 dialog_data_list 中提取)
triplet_snapshot = {}
for d in all_dedup_details:
if not hasattr(d, "chunks"):
continue
for c in d.chunks:
for s in c.statements:
if s.triplet_extraction_info:
triplet_snapshot[s.id] = {
"entities": [
{
"entity_idx": e.entity_idx, "name": e.name,
"type": e.type, "description": e.description,
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
}
for e in s.triplet_extraction_info.entities
],
"triplets": [
{
"subject_name": t.subject_name, "subject_id": t.subject_id,
"predicate": t.predicate,
"object_name": t.object_name, "object_id": t.object_id,
}
for t in s.triplet_extraction_info.triplets
],
}
snapshot.save_stage("3_triplet_outputs", triplet_snapshot)
# 图节点和边(去重后)
snapshot.save_stage("6_nodes_edges_after_dedup", {
"dialogue_nodes_count": len(all_dialogue_nodes),
"chunk_nodes_count": len(all_chunk_nodes),
"statement_nodes_count": len(all_statement_nodes),
"entity_nodes": [
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
for e in all_entity_nodes
],
"entity_entity_edges": [
{
"source": e.source, "target": e.target,
"relation_type": e.relation_type, "statement": e.statement,
}
for e in all_entity_entity_edges
],
})
snapshot.save_summary({
"dialogue_count": len(all_dialogue_nodes),
"chunk_count": len(all_chunk_nodes),
"statement_count": len(all_statement_nodes),
"entity_count": len(all_entity_nodes),
"relation_count": len(all_entity_entity_edges),
})
# endregion
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# Step 3: Save all data to Neo4j database

View File

@@ -149,3 +149,5 @@ class ExtractionPipelineConfig(BaseModel):
temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig)
deduplication: DedupConfig = Field(default_factory=DedupConfig)
forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig)
# 情绪引擎旁路模块SidecarStepFactory 通过此字段判断是否启用)
emotion_enabled: bool = Field(default=False, description="是否启用情绪提取旁路")

View File

@@ -12,20 +12,33 @@ WritePipeline — 记忆写入流水线
依赖方向Facade → Pipeline → Engine → Repository单向不允许反向调用
"""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
from pydantic import BaseModel, Field, ConfigDict
if TYPE_CHECKING:
from app.core.memory.models.graph_models import ExtractedEntityNode
from app.core.memory.models.message_models import DialogData
from app.schemas.memory_config_schema import MemoryConfig
from app.core.memory.models.graph_models import (
ChunkNode,
DialogueNode,
EntityEntityEdge,
ExtractedEntityNode,
PerceptualEdge,
PerceptualNode,
StatementChunkEdge,
StatementEntityEdge,
StatementNode,
)
logger = logging.getLogger(__name__)
@@ -34,36 +47,40 @@ logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
@dataclass
class ExtractionResult:
"""萃取步骤的结构化输出,替代 ExtractionOrchestrator.run() 返回的裸元组。
class ExtractionResult(BaseModel):
"""萃取 + 图构建 + 去重消歧后的结构化输出。
字段与 ExtractionOrchestrator.run() 的 10 元素返回值一一对应:
[0] dialogue_nodes → self.dialogue_nodes
[1] chunk_nodes → self.chunk_nodes
[2] statement_nodes → self.statement_nodes
[3] entity_nodes → self.entity_nodes
[4] perceptual_nodes → self.perceptual_nodes
[5] stmt_chunk_edges → self.stmt_chunk_edges
[6] stmt_entity_edges → self.stmt_entity_edges
[7] entity_entity_edges → self.entity_entity_edges
[8] perceptual_edges → self.perceptual_edges
[9] dialog_data_list → self.dialog_data_list
作为 Pipeline 层的阶段间数据载体确保下游步骤_store、_cluster
接收到的图节点和边结构完整、类型正确。
注意:字段类型使用 List[Any] 而非具体的 graph_models 类型,
避免在模块加载时触发循环依赖。Pipeline 只做数据传递,不检查具体类型。
字段对应 ExtractionOrchestrator 产出的图节点/边:
dialogue_nodes — 对话节点
chunk_nodes — 分块节点
statement_nodes — 陈述句节点
entity_nodes — 实体节点(去重消歧后)
perceptual_nodes — 感知节点
stmt_chunk_edges — 陈述句 → 分块 边
stmt_entity_edges — 陈述句 → 实体 边
entity_entity_edges — 实体 → 实体 边(去重消歧后)
perceptual_edges — 感知 → 分块 边
dialog_data_list — 原始 DialogData供摘要阶段使用
"""
dialogue_nodes: List[Any]
chunk_nodes: List[Any]
statement_nodes: List[Any]
entity_nodes: List[Any]
perceptual_nodes: List[Any]
stmt_chunk_edges: List[Any]
stmt_entity_edges: List[Any]
entity_entity_edges: List[Any]
perceptual_edges: List[Any]
dialog_data_list: List[Any]
model_config = ConfigDict(arbitrary_types_allowed=True)
dialogue_nodes: List[DialogueNode]
chunk_nodes: List[ChunkNode]
statement_nodes: List[StatementNode]
entity_nodes: List[ExtractedEntityNode]
perceptual_nodes: List[PerceptualNode]
stmt_chunk_edges: List[StatementChunkEdge]
stmt_entity_edges: List[StatementEntityEdge]
entity_entity_edges: List[EntityEntityEdge]
perceptual_edges: List[PerceptualEdge]
dialog_data_list: List[Any] = Field(
default_factory=list,
description="原始 DialogData 列表,类型为 Any 以避免循环依赖",
)
@property
def stats(self) -> Dict[str, int]:
@@ -78,8 +95,7 @@ class ExtractionResult:
}
@dataclass
class WriteResult:
class WriteResult(BaseModel):
"""写入流水线的最终输出,返回给 MemoryService / MemoryAgentService"""
status: str # "success" | "pilot_complete" | "failed"
@@ -114,7 +130,7 @@ class WritePipeline:
memory_config: 不可变的记忆配置对象(从数据库加载)
end_user_id: 终端用户 ID
language: 语言 ("zh" | "en")
progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None]
progress_callback: 可选的进度回调,签名 (stage, message, data?) -> Awaitable[None] 供pilot run使用
"""
self.memory_config = memory_config
self.end_user_id = end_user_id
@@ -145,7 +161,7 @@ class WritePipeline:
is_pilot_run: 试运行模式(只萃取不写入)
Returns:
WriteResult 包含状态和统计信息
WriteResult 包含状态和统计信息
"""
if not ref_id:
ref_id = uuid.uuid4().hex
@@ -164,7 +180,7 @@ class WritePipeline:
self._init_clients()
self._init_neo4j_connector()
# Step 1: 预处理 - 消息分块
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝暂无实现
step_start = time.time()
chunked_dialogs = await self._preprocess(messages, ref_id)
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
@@ -175,9 +191,7 @@ class WritePipeline:
# Step 2: 萃取 - 知识提取
step_start = time.time()
extraction_result = await self._extract(
chunked_dialogs, is_pilot_run
)
extraction_result = await self._extract(chunked_dialogs, is_pilot_run)
stats = extraction_result.stats
logger.info(
f"[WritePipeline] [2/5] 萃取:知识提取 "
@@ -190,9 +204,7 @@ class WritePipeline:
# 试运行模式到此结束
if is_pilot_run:
elapsed = time.time() - pipeline_start
logger.info(
f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s"
)
logger.info(f"[WritePipeline] 完成(试运行) ✔ {elapsed:.2f}s")
return WriteResult(
status="pilot_complete",
extraction=extraction_result.stats,
@@ -227,9 +239,7 @@ class WritePipeline:
await self._update_stats_cache(extraction_result)
elapsed = time.time() - pipeline_start
logger.info(
f"[WritePipeline] 完成 ✔ {elapsed:.2f}s"
)
logger.info(f"[WritePipeline] 完成 ✔ {elapsed:.2f}s")
return WriteResult(
status="success",
extraction=extraction_result.stats,
@@ -251,16 +261,14 @@ class WritePipeline:
# Step 1: 预处理
# ──────────────────────────────────────────────
async def _preprocess(
self, messages: List[dict], ref_id: str
) -> List[DialogData]:
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
"""
预处理:消息校验 → 语义剪枝 → 对话分块。
预处理:消息校验 → AI消息语义剪枝暂未实现 → 对话分块。
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
get_dialogs.py 内部已包含:
- 消息格式校验role/content 必填)
- 语义剪枝(根据 config 中 pruning_enabled 决定)
- AI消息语义剪枝(根据 config 中 pruning_enabled 决定)
- DialogueChunker 分块
"""
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
@@ -283,56 +291,187 @@ class WritePipeline:
is_pilot_run: bool,
) -> ExtractionResult:
"""
萃取:初始化引擎 → 执行知识提取 → 返回结构化结果。
萃取:初始化引擎 → 执行知识提取 → 构建图节点/边 → 去重 → 返回结构化结果。
ExtractionOrchestrator 作为萃取引擎被调用
Pipeline 不关心引擎内部的并行策略和提取细节。
使用 NewExtractionOrchestratorExtractionStep 范式)完成 LLM 萃取
然后通过独立的 graph_build_step 和 dedup_step 完成图构建和去重,
不依赖旧编排器 ExtractionOrchestrator。
执行流程:
1. NewExtractionOrchestrator.run() → 萃取并赋值到 DialogData
2. build_graph_nodes_and_edges() → 从 DialogData 构建图节点和边
3. run_dedup() → 两阶段去重消歧
"""
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
from app.core.memory.storage_services.extraction_engine.dedup_step import (
run_dedup,
)
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
build_graph_nodes_and_edges,
)
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
NewExtractionOrchestrator,
)
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
pipeline_config = get_pipeline_config(self.memory_config)
ontology_types = self._load_ontology_types()
orchestrator = ExtractionOrchestrator(
snapshot = PipelineSnapshot("new")
# ── 新编排器LLM 萃取 + 数据赋值 ──
new_orchestrator = NewExtractionOrchestrator(
llm_client=self._llm_client,
embedder_client=self._embedder_client,
connector=self._neo4j_connector,
config=pipeline_config,
embedding_id=str(self.memory_config.embedding_model_id),
language=self.language,
ontology_types=ontology_types,
language=self.language,
is_pilot_run=is_pilot_run,
progress_callback=self.progress_callback,
)
# step1: 执行知识提取
dialog_data_list = await new_orchestrator.run(chunked_dialogs)
# ── Snapshot: 各阶段萃取结果 ── TODO 乐力齐 重构流水线切换生产环境稳定后修改
stage_outputs = new_orchestrator.last_stage_outputs
if stage_outputs:
stmt_results = stage_outputs.get("statement_results", {})
stmt_snapshot = []
for _did, chunk_stmts in stmt_results.items():
for _cid, stmts in chunk_stmts.items():
for s in stmts:
stmt_snapshot.append(s.model_dump())
snapshot.save_stage("2_statement_outputs", stmt_snapshot)
triplet_results = stage_outputs.get("triplet_results", {})
triplet_snapshot = {}
for _did, stmt_triplets in triplet_results.items():
for stmt_id, t_out in stmt_triplets.items():
triplet_snapshot[stmt_id] = t_out.model_dump()
snapshot.save_stage("3_triplet_outputs", triplet_snapshot)
emotion_results = stage_outputs.get("emotion_results", {})
emotion_snapshot = {}
for stmt_id, emo in emotion_results.items():
if hasattr(emo, "model_dump"):
emotion_snapshot[stmt_id] = emo.model_dump()
snapshot.save_stage("4_emotion_outputs", emotion_snapshot)
emb_output = stage_outputs.get("embedding_output")
if emb_output and hasattr(emb_output, "model_dump"):
emb_data = emb_output.model_dump()
for key in (
"statement_embeddings",
"chunk_embeddings",
"entity_embeddings",
):
if key in emb_data and isinstance(emb_data[key], dict):
emb_data[key] = {
k: v[:5] if isinstance(v, list) else v
for k, v in emb_data[key].items()
}
if "dialog_embeddings" in emb_data and isinstance(
emb_data["dialog_embeddings"], list
):
emb_data["dialog_embeddings"] = [
v[:5] if isinstance(v, list) else v
for v in emb_data["dialog_embeddings"]
]
snapshot.save_stage("5_embedding_outputs", emb_data)
# step2: 构建图节点和边
graph = await build_graph_nodes_and_edges(
dialog_data_list=dialog_data_list,
embedder_client=self._embedder_client,
progress_callback=self.progress_callback,
)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
perceptual_nodes,
stmt_chunk_edges,
stmt_entity_edges,
entity_entity_edges,
perceptual_edges,
dialog_data_list,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=is_pilot_run)
# region Snapshot: 图节点和边去重前Snapshot有关的内容在重构流水线切换生产环境之后修改
snapshot.save_stage(
"6_nodes_edges_before_dedup",
{
"dialogue_nodes_count": len(graph.dialogue_nodes),
"chunk_nodes_count": len(graph.chunk_nodes),
"statement_nodes_count": len(graph.statement_nodes),
"entity_nodes": [
{
"id": e.id,
"name": e.name,
"entity_type": e.entity_type,
"description": e.description,
}
for e in graph.entity_nodes
],
"entity_entity_edges": [
{
"source": e.source,
"target": e.target,
"relation_type": e.relation_type,
"statement": e.statement,
}
for e in graph.entity_entity_edges
],
"stmt_entity_edges_count": len(graph.stmt_entity_edges),
},
)
return ExtractionResult(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
perceptual_nodes=perceptual_nodes,
stmt_chunk_edges=stmt_chunk_edges,
stmt_entity_edges=stmt_entity_edges,
entity_entity_edges=entity_entity_edges,
perceptual_edges=perceptual_edges,
# step3: 两阶段去重消歧
dedup_result = await run_dedup(
entity_nodes=graph.entity_nodes,
statement_entity_edges=graph.stmt_entity_edges,
entity_entity_edges=graph.entity_entity_edges,
dialog_data_list=dialog_data_list,
pipeline_config=pipeline_config,
connector=self._neo4j_connector,
llm_client=self._llm_client,
is_pilot_run=is_pilot_run,
progress_callback=self.progress_callback,
)
# Snapshot: 去重后
snapshot.save_stage(
"7_after_dedup",
{
"entity_nodes": [
{
"id": e.id,
"name": e.name,
"entity_type": e.entity_type,
"description": e.description,
}
for e in dedup_result.entity_nodes
],
"entity_entity_edges": [
{
"source": e.source,
"target": e.target,
"relation_type": e.relation_type,
"statement": e.statement,
}
for e in dedup_result.entity_entity_edges
],
},
)
# step4: 构造最终结果
result = ExtractionResult(
dialogue_nodes=graph.dialogue_nodes,
chunk_nodes=graph.chunk_nodes,
statement_nodes=graph.statement_nodes,
entity_nodes=dedup_result.entity_nodes,
perceptual_nodes=graph.perceptual_nodes,
stmt_chunk_edges=graph.stmt_chunk_edges,
stmt_entity_edges=dedup_result.statement_entity_edges,
entity_entity_edges=dedup_result.entity_entity_edges,
perceptual_edges=graph.perceptual_edges,
dialog_data_list=dialog_data_list,
)
snapshot.save_summary(result.stats) # TODO 乐力齐 snapshot需要改
return result
# ──────────────────────────────────────────────
# Step 3: 存储
# ──────────────────────────────────────────────
@@ -379,14 +518,10 @@ class WritePipeline:
)
await asyncio.sleep(1 * (attempt + 1))
else:
logger.error(
f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败"
)
logger.error(f"Neo4j 写入在 {max_retries} 次尝试后仍部分失败")
except Exception as e:
if self._is_deadlock(e) and attempt < max_retries - 1:
logger.warning(
f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})"
)
logger.warning(f"Neo4j 死锁,重试 ({attempt + 2}/{max_retries})")
await asyncio.sleep(1 * (attempt + 1))
else:
raise
@@ -401,6 +536,10 @@ class WritePipeline:
聚类不阻塞主写入流程,失败不影响写入结果。
通过 Celery 异步执行,由 LabelPropagationEngine 完成实际计算。
注意ExtractionResult.entity_nodes 已经是经过 _extract() 中
两阶段去重消歧_run_dedup_and_write_summary后的结果
聚类直接基于去重后的实体 ID 执行。
"""
if not result.entity_nodes:
return
@@ -428,7 +567,9 @@ class WritePipeline:
)
logger.info(
f"[Clustering] 增量聚类任务已提交 - "
f"task_id={task.id}, entity_count={len(new_entity_ids)}"
f"task_id={task.id}, "
f"entity_count={len(new_entity_ids)}, "
f"source=dedup"
)
except Exception as e:
logger.error(
@@ -438,9 +579,9 @@ class WritePipeline:
# ──────────────────────────────────────────────
# Step 5: 摘要
# + entity_description
# + entity_description+ meta_data部分在此提取
# ──────────────────────────────────────────────
# TODO 乐力齐 需要做成异步celery任务
async def _summarize(self, chunked_dialogs: List[DialogData]) -> None:
"""
摘要:生成情景记忆摘要 → 写入 Neo4j。
@@ -467,9 +608,7 @@ class WritePipeline:
ms_connector = Neo4jConnector()
try:
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(
summaries, ms_connector
)
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
try:
await ms_connector.close()
@@ -494,9 +633,7 @@ class WritePipeline:
with get_db_context() as db:
factory = MemoryClientFactory(db)
self._llm_client = factory.get_llm_client_from_config(
self.memory_config
)
self._llm_client = factory.get_llm_client_from_config(self.memory_config)
self._embedder_client = factory.get_embedder_client_from_config(
self.memory_config
)
@@ -564,10 +701,8 @@ class WritePipeline:
if entity_nodes:
eu_id = entity_nodes[0].end_user_id
if eu_id:
neo4j_assistant_aliases = (
await fetch_neo4j_assistant_aliases(
self._neo4j_connector, eu_id
)
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(
self._neo4j_connector, eu_id
)
clean_cross_role_aliases(
entity_nodes,
@@ -586,9 +721,7 @@ class WritePipeline:
msg = str(e).lower()
return "deadlockdetected" in msg or "deadlock" in msg
async def _update_stats_cache(
self, result: ExtractionResult
) -> None:
async def _update_stats_cache(self, result: ExtractionResult) -> None:
"""
将提取统计写入 Redis 活动缓存,按 workspace_id 存储。
失败不中断主流程。
@@ -614,9 +747,7 @@ class WritePipeline:
f"workspace_id={self.memory_config.workspace_id}"
)
except Exception as e:
logger.warning(
f"写入活动统计缓存失败(不影响主流程): {e}"
)
logger.warning(f"写入活动统计缓存失败(不影响主流程): {e}")
async def _cleanup(self) -> None:
"""
@@ -634,16 +765,14 @@ class WritePipeline:
# 防止 'RuntimeError: Event loop is closed' 在垃圾回收时触发
for client_obj in (self._llm_client, self._embedder_client):
try:
underlying = getattr(
client_obj, "client", None
) or getattr(client_obj, "model", None)
underlying = getattr(client_obj, "client", None) or getattr(
client_obj, "model", None
)
if underlying is None:
continue
inner = getattr(underlying, "_model", underlying)
http_client = getattr(inner, "async_client", None)
if http_client is not None and hasattr(
http_client, "aclose"
):
if http_client is not None and hasattr(http_client, "aclose"):
await http_client.aclose()
except Exception:
pass

View File

@@ -0,0 +1,506 @@
"""Independent deduplication module for the extraction pipeline.
Extracts dedup logic from ExtractionOrchestrator into standalone functions
so the orchestrator stays thin and dedup can be tested/evolved independently.
The module exposes:
- ``DedupResult`` — structured output of the dedup process
- ``run_dedup()`` — async entry point called by WritePipeline
- Helper functions migrated from ExtractionOrchestrator:
``save_dedup_details``, ``analyze_entity_merges``,
``analyze_entity_disambiguation``, ``send_dedup_progress_callback``,
``parse_dedup_report``
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
from app.core.memory.models.graph_models import (
EntityEntityEdge,
ExtractedEntityNode,
StatementEntityEdge,
)
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DedupResult dataclass (Requirement 10.2)
# ---------------------------------------------------------------------------
@dataclass
class DedupResult:
"""Structured output of the two-stage entity deduplication process.
Attributes:
entity_nodes: Deduplicated entity node list.
statement_entity_edges: Deduplicated statement-entity edges.
entity_entity_edges: Deduplicated entity-entity edges.
dedup_details: Raw detail dict returned by the first-layer dedup.
merge_records: Parsed merge records (exact / fuzzy / LLM).
disamb_records: Parsed disambiguation records.
"""
entity_nodes: List[ExtractedEntityNode]
statement_entity_edges: List[StatementEntityEdge]
entity_entity_edges: List[EntityEntityEdge]
dedup_details: Dict[str, Any] = field(default_factory=dict)
merge_records: List[Dict[str, Any]] = field(default_factory=list)
disamb_records: List[Dict[str, Any]] = field(default_factory=list)
@property
def stats(self) -> Dict[str, int]:
"""Summary statistics for the dedup run."""
return {
"entity_count": len(self.entity_nodes),
"merge_count": len(self.merge_records),
"disamb_count": len(self.disamb_records),
}
# ---------------------------------------------------------------------------
# Migrated helpers (from ExtractionOrchestrator) — Requirement 10.4
# ---------------------------------------------------------------------------
def save_dedup_details(
dedup_details: Dict[str, Any],
original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode],
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, str]]:
"""Parse raw *dedup_details* into structured merge / disamb records.
Returns:
(merge_records, disamb_records, id_redirect_map)
"""
merge_records: List[Dict[str, Any]] = []
disamb_records: List[Dict[str, Any]] = []
id_redirect_map: Dict[str, str] = {}
try:
id_redirect_map = dedup_details.get("id_redirect", {})
# --- exact-match merges ---
exact_merge_map = dedup_details.get("exact_merge_map", {})
for _key, info in exact_merge_map.items():
merged_ids = info.get("merged_ids", set())
if merged_ids:
merge_records.append({
"type": "精确匹配",
"canonical_id": info.get("canonical_id"),
"entity_name": info.get("name"),
"entity_type": info.get("entity_type"),
"merged_count": len(merged_ids),
"merged_ids": list(merged_ids),
})
# --- fuzzy-match merges ---
for record in dedup_details.get("fuzzy_merge_records", []):
try:
match = re.search(
r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)",
record,
)
if match:
merge_records.append({
"type": "模糊匹配",
"canonical_id": match.group(1),
"entity_name": match.group(3),
"entity_type": match.group(4),
"merged_count": 1,
"merged_ids": [match.group(5)],
})
except Exception as e:
logger.debug("解析模糊匹配记录失败: %s, 错误: %s", record, e)
# --- LLM-based merges ---
for record in dedup_details.get("llm_decision_records", []):
if "[LLM去重]" in str(record):
try:
match = re.search(
r"同名类型相似 ([^]+)([^]+)\|([^]+)([^]+)",
record,
)
if match:
merge_records.append({
"type": "LLM去重",
"entity_name": match.group(1),
"entity_type": f"{match.group(2)}|{match.group(4)}",
"merged_count": 1,
"merged_ids": [],
})
except Exception as e:
logger.debug("解析LLM去重记录失败: %s, 错误: %s", record, e)
# --- disambiguation records ---
for record in dedup_details.get("disamb_records", []):
if "[DISAMB阻断]" in str(record):
try:
content = str(record).replace("[DISAMB阻断]", "").strip()
match = re.search(
r"([^]+)([^]+)\|([^]+)([^]+)", content
)
if match:
entity1_name = match.group(1).strip()
entity1_type = match.group(2)
entity2_type = match.group(4)
conf_match = re.search(r"conf=([0-9.]+)", str(record))
confidence = conf_match.group(1) if conf_match else "unknown"
reason_match = re.search(r"reason=([^|]+)", str(record))
reason = reason_match.group(1).strip() if reason_match else ""
disamb_records.append({
"entity_name": entity1_name,
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
"confidence": confidence,
"reason": (reason[:100] + "...") if len(reason) > 100 else reason,
})
except Exception as e:
logger.debug("解析消歧记录失败: %s, 错误: %s", record, e)
logger.info(
"保存去重消歧记录:%d 个合并记录,%d 个消歧记录",
len(merge_records),
len(disamb_records),
)
except Exception as e:
logger.error("保存去重消歧详情失败: %s", e, exc_info=True)
return merge_records, disamb_records, id_redirect_map
def analyze_entity_merges(
merge_records: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Return merge info sorted by merged_count (descending)."""
if not merge_records:
return []
sorted_records = sorted(
merge_records, key=lambda x: x.get("merged_count", 0), reverse=True
)
return [
{
"main_entity_name": r.get("entity_name", "未知实体"),
"merged_count": r.get("merged_count", 1),
}
for r in sorted_records
]
def analyze_entity_disambiguation(
disamb_records: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Return disambiguation records (pass-through)."""
return disamb_records if disamb_records else []
def parse_dedup_report(
merge_records: List[Dict[str, Any]],
disamb_records: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Build a summary report dict from parsed records."""
try:
dedup_examples: List[Dict[str, Any]] = []
disamb_examples: List[Dict[str, Any]] = []
total_merges = 0
total_disambiguations = 0
for record in merge_records:
merge_count = record.get("merged_count", 0)
total_merges += merge_count
dedup_examples.append({
"type": record.get("type", "未知"),
"entity_name": record.get("entity_name", "未知实体"),
"entity_type": record.get("entity_type", "未知类型"),
"merge_count": merge_count,
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}",
})
for record in disamb_records:
total_disambiguations += 1
disamb_type = record.get("disamb_type", "")
entity_name = record.get("entity_name", "未知实体")
disamb_examples.append({
"entity1_name": entity_name,
"entity1_type": (
disamb_type.split("vs")[0].replace("消歧阻断:", "").strip()
if "vs" in disamb_type
else "未知"
),
"entity2_name": entity_name,
"entity2_type": (
disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知"
),
"description": f"{entity_name},消歧区分成功",
})
return {
"dedup_examples": dedup_examples[:5],
"disamb_examples": disamb_examples[:5],
"total_merges": total_merges,
"total_disambiguations": total_disambiguations,
}
except Exception as e:
logger.error("获取去重报告失败: %s", e, exc_info=True)
return {
"dedup_examples": [],
"disamb_examples": [],
"total_merges": 0,
"total_disambiguations": 0,
}
async def send_dedup_progress_callback(
progress_callback: Callable,
merge_records: List[Dict[str, Any]],
disamb_records: List[Dict[str, Any]],
original_entities: int,
final_entities: int,
original_stmt_edges: int,
final_stmt_edges: int,
original_ent_edges: int,
final_ent_edges: int,
) -> None:
"""Send dedup completion progress via *progress_callback*."""
try:
dedup_details = parse_dedup_report(merge_records, disamb_records)
entities_reduced = original_entities - final_entities
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
ent_edges_reduced = original_ent_edges - final_ent_edges
dedup_stats = {
"entities": {
"original_count": original_entities,
"final_count": final_entities,
"reduced_count": entities_reduced,
"reduction_rate": (
round(entities_reduced / original_entities * 100, 1)
if original_entities > 0
else 0
),
},
"statement_entity_edges": {
"original_count": original_stmt_edges,
"final_count": final_stmt_edges,
"reduced_count": stmt_edges_reduced,
},
"entity_entity_edges": {
"original_count": original_ent_edges,
"final_count": final_ent_edges,
"reduced_count": ent_edges_reduced,
},
"dedup_examples": dedup_details.get("dedup_examples", []),
"disamb_examples": dedup_details.get("disamb_examples", []),
"summary": {
"total_merges": dedup_details.get("total_merges", 0),
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
},
}
await progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
except Exception as e:
logger.error("发送去重消歧进度回调失败: %s", e, exc_info=True)
try:
basic_stats = {
"entities": {
"original_count": original_entities,
"final_count": final_entities,
"reduced_count": original_entities - final_entities,
},
"summary": f"实体去重合并{original_entities - final_entities}",
}
await progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
except Exception as e2:
logger.error("发送基本去重统计失败: %s", e2, exc_info=True)
# ---------------------------------------------------------------------------
# run_dedup — main entry point (Requirements 10.1, 10.3)
# ---------------------------------------------------------------------------
async def run_dedup(
entity_nodes: List[ExtractedEntityNode],
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData],
pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None,
llm_client: Optional[Any] = None,
is_pilot_run: bool = False,
progress_callback: Optional[Callable] = None,
) -> DedupResult:
"""Two-stage entity deduplication and disambiguation.
Full mode:
Layer 1 — exact / fuzzy / LLM matching
Layer 2 — Neo4j joint dedup + cross-role alias cleaning
Pilot-run mode:
Layer 1 only (skip Neo4j layer 2 and alias cleaning).
Args:
entity_nodes: Pre-dedup entity nodes.
statement_entity_edges: Pre-dedup statement-entity edges.
entity_entity_edges: Pre-dedup entity-entity edges.
dialog_data_list: Source dialogue data (used to detect end_user_id).
pipeline_config: Pipeline configuration (contains DedupConfig).
connector: Optional Neo4j connector for layer-2 dedup.
llm_client: Optional LLM client for LLM-based dedup decisions.
is_pilot_run: When True, only execute layer-1 dedup.
progress_callback: Optional async callable for progress reporting.
Returns:
A ``DedupResult`` with deduplicated nodes, edges, and statistics.
"""
logger.info("开始两阶段实体去重和消歧")
if progress_callback:
await progress_callback("deduplication", "正在去重消歧...")
logger.info(
"去重前: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边",
len(entity_nodes),
len(statement_entity_edges),
len(entity_entity_edges),
)
original_entity_count = len(entity_nodes)
original_stmt_edge_count = len(statement_entity_edges)
original_ent_edge_count = len(entity_entity_edges)
try:
if is_pilot_run:
# --- pilot run: layer 1 only ---
logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重")
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
)
(
dedup_entity_nodes,
dedup_stmt_edges,
dedup_ent_edges,
raw_details,
) = await deduplicate_entities_and_edges(
entity_nodes,
statement_entity_edges,
entity_entity_edges,
report_stage="第一层去重消歧(试运行)",
report_append=False,
dedup_config=pipeline_config.deduplication,
llm_client=llm_client,
)
final_entities = dedup_entity_nodes
final_stmt_edges = dedup_stmt_edges
final_ent_edges = dedup_ent_edges
else:
# --- full mode: two-stage dedup ---
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
)
(
_dialogue_nodes,
_chunk_nodes,
_statement_nodes,
final_entities,
_statement_chunk_edges,
final_stmt_edges,
final_ent_edges,
raw_details,
) = await dedup_layers_and_merge_and_return(
dialogue_nodes=[],
chunk_nodes=[],
statement_nodes=[],
entity_nodes=entity_nodes,
statement_chunk_edges=[],
statement_entity_edges=statement_entity_edges,
entity_entity_edges=entity_entity_edges,
dialog_data_list=dialog_data_list,
pipeline_config=pipeline_config,
connector=connector,
llm_client=llm_client,
)
# Parse raw details into structured records
merge_records, disamb_records, _id_redirect = save_dedup_details(
raw_details, entity_nodes, final_entities
)
logger.info(
"去重后: %d 个实体节点, %d 条陈述句-实体边, %d 条实体-实体边",
len(final_entities),
len(final_stmt_edges),
len(final_ent_edges),
)
logger.info(
"去重效果: 实体减少 %d, 陈述句-实体边减少 %d, 实体-实体边减少 %d",
original_entity_count - len(final_entities),
original_stmt_edge_count - len(final_stmt_edges),
original_ent_edge_count - len(final_ent_edges),
)
# --- progress callbacks ---
if progress_callback:
merge_info = analyze_entity_merges(merge_records)
for i, detail in enumerate(merge_info[:5]):
dedup_result = {
"result_type": "entity_merge",
"merged_entity_name": detail["main_entity_name"],
"merged_count": detail["merged_count"],
"merge_progress": f"{i + 1}/{min(len(merge_info), 5)}",
"message": (
f"{detail['main_entity_name']}合并{detail['merged_count']}个:相似实体已合并"
),
}
await progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result)
disamb_info = analyze_entity_disambiguation(disamb_records)
for i, detail in enumerate(disamb_info[:5]):
disamb_result = {
"result_type": "entity_disambiguation",
"disambiguated_entity_name": detail["entity_name"],
"disambiguation_type": detail["disamb_type"],
"confidence": detail.get("confidence", "unknown"),
"reason": detail.get("reason", ""),
"disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}",
"message": f"{detail['entity_name']}消歧完成:{detail['disamb_type']}",
}
await progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result)
await send_dedup_progress_callback(
progress_callback,
merge_records,
disamb_records,
original_entity_count,
len(final_entities),
original_stmt_edge_count,
len(final_stmt_edges),
original_ent_edge_count,
len(final_ent_edges),
)
return DedupResult(
entity_nodes=final_entities,
statement_entity_edges=final_stmt_edges,
entity_entity_edges=final_ent_edges,
dedup_details=raw_details,
merge_records=merge_records,
disamb_records=disamb_records,
)
except Exception as e:
logger.error("两阶段去重失败: %s", e, exc_info=True)
raise

View File

@@ -0,0 +1,16 @@
"""Extraction pipeline steps — unified ExtractionStep paradigm.
Importing this package triggers @register decorator self-registration
for all sidecar (non-critical) steps via SidecarStepFactory.
"""
from .sidecar_factory import SidecarStepFactory, SidecarTiming # noqa: F401
# Step implementations — importing triggers @register self-registration.
from .statement_step import StatementExtractionStep # noqa: F401
from .triplet_step import TripletExtractionStep # noqa: F401
from .emotion_step import EmotionExtractionStep # noqa: F401
from .embedding_step import EmbeddingStep # noqa: F401
# Refactored orchestrator
from .extraction_pipeline_orchestrator import NewExtractionOrchestrator # noqa: F401

View File

@@ -0,0 +1,182 @@
"""ExtractionStep abstract base class and StepContext.
Provides the unified paradigm for all LLM extraction stages:
render_prompt → call_llm → parse_response → post_process
Critical steps retry on failure with exponential backoff.
Sidecar (non-critical) steps return a default output on failure without retry.
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Generic, Optional, TypeVar
logger = logging.getLogger(__name__)
InputT = TypeVar("InputT")
OutputT = TypeVar("OutputT")
@dataclass
class StepContext:
"""Shared context injected into every ExtractionStep by the orchestrator.
Attributes:
llm_client: LLM client instance for generating completions.
language: Target language code (e.g. "en", "zh").
config: Pipeline configuration object (ExtractionPipelineConfig).
is_pilot_run: When True, run in lightweight preview mode.
progress_callback: Optional callable for reporting progress.
"""
llm_client: Any
language: str
config: Any
is_pilot_run: bool = False
progress_callback: Optional[Any] = None
class ExtractionStep(ABC, Generic[InputT, OutputT]):
"""Abstract base class for all LLM extraction stages.
Lifecycle:
1. ``__init__(context)`` — receive shared context, bind config params
2. ``should_skip()`` — check whether to skip (config-driven / pilot mode)
3. ``run(input_data)`` — execute full flow (with retry for critical steps)
Internally: render_prompt → call_llm → parse_response → post_process
4. ``on_failure(error)`` — critical steps raise; sidecar steps return default
Type Parameters:
InputT: The Pydantic model type accepted by this step.
OutputT: The Pydantic model type produced by this step.
"""
def __init__(self, context: StepContext) -> None:
self.context = context
self.llm_client = context.llm_client
self.language = context.language
self.config = context.config
# ── Subclasses must implement ──
@property
@abstractmethod
def name(self) -> str:
"""Human-readable step name for logging."""
...
@abstractmethod
async def render_prompt(self, input_data: InputT) -> Any:
"""Build the prompt from *input_data* and bound config."""
...
@abstractmethod
async def call_llm(self, prompt: Any) -> Any:
"""Send *prompt* to the LLM and return the raw response."""
...
@abstractmethod
async def parse_response(self, raw_response: Any, input_data: InputT) -> OutputT:
"""Parse *raw_response* into a typed OutputT (Pydantic model)."""
...
@abstractmethod
def get_default_output(self) -> OutputT:
"""Return a safe default when the step is skipped or fails gracefully."""
...
# ── Overridable properties ──
@property
def is_critical(self) -> bool:
"""``True`` = critical step (failure aborts pipeline).
``False`` = sidecar step (failure degrades gracefully).
"""
return True
@property
def max_retries(self) -> int:
"""Maximum retry attempts (only effective for critical steps)."""
return 2
@property
def retry_backoff_base(self) -> float:
"""Backoff base in seconds. Actual wait = base × 2^attempt."""
return 1.0
# ── Overridable hooks ──
def should_skip(self) -> bool:
"""Config-driven skip check. Subclasses may override."""
return False
async def post_process(self, parsed_data: OutputT, input_data: InputT) -> OutputT:
"""Post-processing hook. Default is identity (returns *parsed_data* unchanged)."""
return parsed_data
# ── Core execution logic ──
async def run(self, input_data: InputT) -> OutputT:
"""Execute the full step lifecycle with retry logic.
For critical steps (``is_critical=True``):
Attempt up to ``max_retries + 1`` times with exponential backoff.
If all attempts fail, delegate to ``on_failure`` which raises.
For sidecar steps (``is_critical=False``):
Attempt exactly once. On failure, delegate to ``on_failure``
which returns ``get_default_output()``.
"""
if self.should_skip():
logger.info("Step '%s' skipped", self.name)
return self.get_default_output()
last_error: Optional[Exception] = None
attempts = self.max_retries + 1 if self.is_critical else 1
for attempt in range(attempts):
try:
prompt = await self.render_prompt(input_data)
raw_response = await self.call_llm(prompt)
parsed = await self.parse_response(raw_response, input_data)
result = await self.post_process(parsed, input_data)
return result
except Exception as exc:
last_error = exc
logger.warning(
"Step '%s' attempt %d/%d failed: %s",
self.name,
attempt + 1,
attempts,
exc,
)
if attempt < attempts - 1:
wait = self.retry_backoff_base * (2 ** attempt)
logger.info(
"Step '%s' retrying in %.1fs …", self.name, wait
)
await asyncio.sleep(wait)
# All attempts exhausted — delegate to failure handler
return self.on_failure(last_error) # type: ignore[arg-type]
def on_failure(self, error: Exception) -> OutputT:
"""Handle step failure.
Critical steps: re-raise the exception to abort the pipeline.
Sidecar steps: return ``get_default_output()`` for graceful degradation.
"""
if self.is_critical:
logger.error(
"Critical step '%s' failed after retries: %s", self.name, error
)
raise error
logger.warning(
"Sidecar step '%s' failed, returning default output: %s",
self.name,
error,
)
return self.get_default_output()

View File

@@ -0,0 +1,124 @@
"""EmbeddingStep — generates vector embeddings for statements, chunks, dialogs, and entities.
Unlike the LLM-based ExtractionSteps, EmbeddingStep calls an embedder client
rather than an LLM. It still follows the ``should_skip`` / ``run`` /
``get_default_output`` contract so the orchestrator can treat it uniformly.
Supports **partial** embedding runs — the caller can populate only the fields
it needs (e.g. only ``statement_texts``) and leave the rest empty.
"""
import asyncio
import logging
from typing import Any, Dict, List
from .schema import EmbeddingStepInput, EmbeddingStepOutput
logger = logging.getLogger(__name__)
class EmbeddingStep:
"""Generate vector embeddings for text inputs.
This step does **not** inherit from ``ExtractionStep`` because it does not
follow the render_prompt → call_llm → parse_response lifecycle. It does,
however, expose the same ``run`` / ``should_skip`` / ``get_default_output``
interface so the orchestrator can use it interchangeably.
Pilot-run mode skips execution entirely and returns empty dicts.
"""
def __init__(
self,
embedder_client: Any,
is_pilot_run: bool = False,
batch_size: int = 100,
) -> None:
self.embedder_client = embedder_client
self.is_pilot_run = is_pilot_run
self.batch_size = batch_size
@property
def name(self) -> str:
return "embedding_generation"
@property
def is_critical(self) -> bool:
return False
@property
def max_retries(self) -> int:
return 1
@property
def retry_backoff_base(self) -> float:
return 1.0
def should_skip(self) -> bool:
return self.is_pilot_run
def get_default_output(self) -> EmbeddingStepOutput:
return EmbeddingStepOutput()
# ── Core execution ──
async def run(self, input_data: EmbeddingStepInput) -> EmbeddingStepOutput:
"""Generate embeddings for all non-empty text fields in *input_data*."""
if self.should_skip():
logger.info("EmbeddingStep skipped (pilot run)")
return self.get_default_output()
try:
stmt_emb, chunk_emb, dialog_emb, entity_emb = await asyncio.gather(
self._embed_dict(input_data.statement_texts),
self._embed_dict(input_data.chunk_texts),
self._embed_list(input_data.dialog_texts),
self._embed_dict(input_data.entity_names),
)
return EmbeddingStepOutput(
statement_embeddings=stmt_emb,
chunk_embeddings=chunk_emb,
dialog_embeddings=dialog_emb,
entity_embeddings=entity_emb,
)
except Exception as exc:
logger.warning("EmbeddingStep failed, returning empty output: %s", exc)
return self.get_default_output()
# ── Internal helpers ──
async def _embed_dict(
self, texts: Dict[str, str]
) -> Dict[str, List[float]]:
"""Embed a dict of ``{id: text}`` and return ``{id: embedding}``."""
if not texts:
return {}
ids = list(texts.keys())
text_list = list(texts.values())
embeddings = await self._batch_embed(text_list)
return dict(zip(ids, embeddings))
async def _embed_list(self, texts: List[str]) -> List[List[float]]:
"""Embed a plain list of texts."""
if not texts:
return []
return await self._batch_embed(texts)
async def _batch_embed(self, texts: List[str]) -> List[List[float]]:
"""Call the embedder in batches of ``self.batch_size``."""
if len(texts) <= self.batch_size:
return await self.embedder_client.response(texts)
batches = [
texts[i : i + self.batch_size]
for i in range(0, len(texts), self.batch_size)
]
batch_results = await asyncio.gather(
*(self.embedder_client.response(b) for b in batches)
)
embeddings: List[List[float]] = []
for result in batch_results:
embeddings.extend(result)
return embeddings

View File

@@ -0,0 +1,80 @@
"""EmotionExtractionStep — sidecar step for extracting emotion from statements.
Replaces the legacy ``EmotionExtractionService`` with the unified ExtractionStep
paradigm. Registered via ``@SidecarStepFactory.register`` so the orchestrator
picks it up automatically when ``emotion_enabled`` is ``True``.
"""
import logging
from typing import Any
from app.core.memory.models.emotion_models import EmotionExtraction
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
from .base import ExtractionStep, StepContext
from .sidecar_factory import SidecarStepFactory, SidecarTiming
from .schema import EmotionStepInput, EmotionStepOutput
logger = logging.getLogger(__name__)
@SidecarStepFactory.register("emotion_enabled", SidecarTiming.AFTER_STATEMENT)
class EmotionExtractionStep(ExtractionStep[EmotionStepInput, EmotionStepOutput]):
"""Extract emotion type, intensity, and keywords from a statement.
This is a **sidecar** (non-critical) step — failure returns a neutral
default without aborting the pipeline.
The step self-registers with ``SidecarStepFactory`` under the config key
``emotion_enabled`` and timing ``AFTER_STATEMENT``.
"""
def __init__(self, context: StepContext) -> None:
super().__init__(context)
# Emotion-specific config flags (may live on a MemoryConfig object
# attached to context.config or as top-level attributes).
self.extract_keywords = getattr(self.config, "emotion_extract_keywords", True)
self.enable_subject = getattr(self.config, "emotion_enable_subject", False)
# ── Identity ──
@property
def name(self) -> str:
return "emotion_extraction"
@property
def is_critical(self) -> bool:
return False
# ── Config-driven skip ──
def should_skip(self) -> bool:
return not getattr(self.config, "emotion_enabled", False)
# ── Lifecycle ──
async def render_prompt(self, input_data: EmotionStepInput) -> str:
return await render_emotion_extraction_prompt(
statement=input_data.statement_text,
extract_keywords=self.extract_keywords,
enable_subject=self.enable_subject,
language=self.language,
)
async def call_llm(self, prompt: Any) -> Any:
messages = [{"role": "user", "content": prompt}]
return await self.llm_client.response_structured(
messages, EmotionExtraction
)
async def parse_response(
self, raw_response: Any, input_data: EmotionStepInput
) -> EmotionStepOutput:
return EmotionStepOutput(
emotion_type=getattr(raw_response, "emotion_type", "neutral"),
emotion_intensity=getattr(raw_response, "emotion_intensity", 0.0),
emotion_keywords=getattr(raw_response, "emotion_keywords", []),
)
def get_default_output(self) -> EmotionStepOutput:
return EmotionStepOutput()

View File

@@ -0,0 +1,906 @@
"""Refactored ExtractionOrchestrator using the unified ExtractionStep paradigm.
This module provides ``NewExtractionOrchestrator`` — a slimmed-down orchestrator
(~500 lines vs ~2500) that delegates extraction work to concrete ExtractionStep
instances and uses SidecarStepFactory for hot-pluggable sidecar modules.
The new orchestrator coexists with the legacy ``ExtractionOrchestrator`` until
the team explicitly switches over.
Execution phases:
1. Statement extraction + concurrent chunk/dialog embedding
2. Triplet extraction + concurrent after_statement sidecars + statement embedding
3. Entity embedding + concurrent after_triplet sidecars
4. Data assignment back to dialog_data_list
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from .base import ExtractionStep, StepContext
from .embedding_step import EmbeddingStep
from .sidecar_factory import SidecarStepFactory, SidecarTiming
from .statement_step import StatementExtractionStep
from .triplet_step import TripletExtractionStep
from .schema import (
EmbeddingStepInput,
EmbeddingStepOutput,
EmotionStepInput,
EmotionStepOutput,
MessageItem,
StatementStepInput,
StatementStepOutput,
SupportingContext,
TripletStepInput,
TripletStepOutput,
)
logger = logging.getLogger(__name__)
class NewExtractionOrchestrator:
"""Slimmed-down extraction orchestrator using the ExtractionStep paradigm.
Responsibilities:
* Initialise all steps and sidecar groups via ``SidecarStepFactory``
* Route data between stages (``_convert_to_*`` helpers)
* Orchestrate concurrent execution (``_run_with_sidecars``)
* Assign extracted results back to ``DialogData`` objects
The orchestrator does **not** own dedup, node/edge creation, or Neo4j writes.
Those remain in ``WritePipeline`` / ``dedup_step``.
"""
def __init__(
self,
llm_client: Any,
embedder_client: Any,
config: Optional[ExtractionPipelineConfig] = None,
embedding_id: Optional[str] = None,
ontology_types: Any = None,
language: str = "zh",
is_pilot_run: bool = False,
progress_callback: Optional[
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
] = None,
) -> None:
self.config = config or ExtractionPipelineConfig()
self.is_pilot_run = is_pilot_run
self.embedding_id = embedding_id
self.progress_callback = progress_callback
# Build shared context for all LLM-based steps
self.context = StepContext(
llm_client=llm_client,
language=language,
config=self.config,
is_pilot_run=is_pilot_run,
progress_callback=progress_callback,
)
# ── Critical (main-line) steps ──
self.statement_step = StatementExtractionStep(self.context)
self.triplet_step = TripletExtractionStep(
self.context, ontology_types=ontology_types
)
# ── Embedding step (non-LLM, separate client) ──
self.embedding_step = EmbeddingStep(
embedder_client=embedder_client,
is_pilot_run=is_pilot_run,
)
# ── Sidecar steps (auto-discovered via @register decorator) ──
sidecar_groups = SidecarStepFactory.create_sidecars(self.config, self.context)
self.after_statement_sidecars: List[ExtractionStep] = sidecar_groups[
SidecarTiming.AFTER_STATEMENT
]
self.after_triplet_sidecars: List[ExtractionStep] = sidecar_groups[
SidecarTiming.AFTER_TRIPLET
]
logger.info(
"NewExtractionOrchestrator initialised — "
"after_statement sidecars: %d, after_triplet sidecars: %d",
len(self.after_statement_sidecars),
len(self.after_triplet_sidecars),
)
# ──────────────────────────────────────────────
# 1. 并发执行引擎
# 负责主线路 + 旁路的安全并发调度
# ──────────────────────────────────────────────
@staticmethod
async def _run_sidecar_safe(
step: ExtractionStep, input_data: Any
) -> Any:
"""Run a sidecar step, returning its default output on failure."""
try:
return await step.run(input_data)
except Exception as exc:
logger.warning(
"Sidecar '%s' raised during gather — using default output: %s",
step.name,
exc,
)
return step.get_default_output()
async def _run_with_sidecars(
self,
critical_coro: Any,
sidecars: List[Tuple[ExtractionStep, Any]],
extra_coros: Optional[List[Any]] = None,
) -> Tuple[Any, List[Any], List[Any]]:
"""Run a critical coroutine concurrently with sidecar steps.
Args:
critical_coro: The awaitable for the critical (main-line) step.
sidecars: List of ``(step, input_data)`` pairs for sidecar steps.
extra_coros: Additional non-sidecar coroutines to run concurrently
(e.g. embedding generation).
Returns:
A 3-tuple of:
* The critical step result (exception propagated if it fails).
* A list of sidecar results (default outputs on failure).
* A list of extra coroutine results (empty list if none).
Raises:
Exception: If the critical coroutine fails, the exception propagates.
"""
sidecar_coros = [
self._run_sidecar_safe(step, inp) for step, inp in sidecars
]
extra = extra_coros or []
# Gather everything concurrently
all_coros = [critical_coro] + sidecar_coros + extra
results = await asyncio.gather(*all_coros, return_exceptions=True)
# Unpack: first result is critical, then sidecars, then extras
critical_result = results[0]
n_sidecars = len(sidecar_coros)
sidecar_results = list(results[1 : 1 + n_sidecars])
extra_results = list(results[1 + n_sidecars :])
# Critical step failure → propagate
if isinstance(critical_result, BaseException):
raise critical_result
# Sidecar failures should already be handled by _run_sidecar_safe,
# but guard against unexpected exceptions from gather
for i, res in enumerate(sidecar_results):
if isinstance(res, BaseException):
step = sidecars[i][0]
logger.warning(
"Sidecar '%s' unexpected exception in gather: %s",
step.name,
res,
)
sidecar_results[i] = step.get_default_output()
# Extra coroutine failures → log and replace with None
for i, res in enumerate(extra_results):
if isinstance(res, BaseException):
logger.warning("Extra coroutine %d failed: %s", i, res)
extra_results[i] = None
return critical_result, sidecar_results, extra_results
# ──────────────────────────────────────────────
# 2. 阶段间数据转换
# 将上一阶段的 StepOutput 转换为下一阶段的 StepInput
# ──────────────────────────────────────────────
@staticmethod
def _build_supporting_context(
dialog: DialogData,
) -> SupportingContext:
"""Build a SupportingContext from a dialog's content for pronoun resolution."""
msgs: List[MessageItem] = []
if hasattr(dialog, "content") and dialog.content:
# dialog.content is the raw conversation string; wrap as single msg
msgs.append(MessageItem(role="context", msg=dialog.content))
return SupportingContext(msgs=msgs)
@staticmethod
def _convert_to_triplet_input(
stmt_out: StatementStepOutput,
supporting_context: SupportingContext,
) -> TripletStepInput:
"""Convert a StatementStepOutput into a TripletStepInput."""
return TripletStepInput(
statement_id=stmt_out.statement_id,
statement_text=stmt_out.statement_text,
statement_type=stmt_out.statement_type,
temporal_type=stmt_out.temporal_type,
supporting_context=supporting_context,
speaker=stmt_out.speaker,
valid_at=stmt_out.valid_at,
invalid_at=stmt_out.invalid_at,
)
@staticmethod
def _convert_to_emotion_input(
stmt_out: StatementStepOutput,
) -> EmotionStepInput:
"""Convert a StatementStepOutput into an EmotionStepInput."""
return EmotionStepInput(
statement_id=stmt_out.statement_id,
statement_text=stmt_out.statement_text,
speaker=stmt_out.speaker,
)
# ──────────────────────────────────────────────
# 3. 流水线执行入口
# 公开接口 run() → 分发到 pilot / full 模式
# ──────────────────────────────────────────────
async def run(
self,
dialog_data_list: List[DialogData],
) -> List[DialogData]:
"""Run the full extraction pipeline on *dialog_data_list*.
Returns the mutated *dialog_data_list* with extracted data assigned
to each statement (triplets, temporal info, emotions, embeddings).
The orchestrator does NOT create graph nodes/edges or run dedup —
those responsibilities remain in WritePipeline.
"""
mode = "pilot" if self.is_pilot_run else "full"
logger.info(
"Starting extraction pipeline (%s mode), %d dialogs",
mode,
len(dialog_data_list),
)
if self.is_pilot_run:
return await self._run_pilot(dialog_data_list)
return await self._run_full(dialog_data_list)
# ── 3a. 试运行模式:仅 statement + triplet不生成 embedding 和旁路 ──
async def _run_pilot(
self, dialog_data_list: List[DialogData]
) -> List[DialogData]:
"""Pilot mode: statement + triplet extraction only, no sidecars or embeddings."""
# Phase 1: Statement extraction (chunk-level parallel)
logger.info("Pilot phase 1/2: Statement extraction")
all_stmt_results = await self._extract_all_statements(dialog_data_list)
# Phase 2: Triplet extraction (statement-level parallel)
logger.info("Pilot phase 2/2: Triplet extraction")
all_triplet_results = await self._extract_all_triplets(
dialog_data_list, all_stmt_results
)
# Assign results back to dialog_data_list
self._assign_results(
dialog_data_list,
all_stmt_results,
all_triplet_results,
emotion_results={},
embedding_output=None,
)
# Store raw step outputs for snapshot/debugging
self._last_stage_outputs = {
"statement_results": all_stmt_results,
"triplet_results": all_triplet_results,
"emotion_results": {},
"embedding_output": None,
}
logger.info("Pilot extraction complete")
return dialog_data_list
# ── 3b. 正式模式:四阶段并发执行 ──
async def _run_full(
self, dialog_data_list: List[DialogData]
) -> List[DialogData]:
"""Full mode: all four phases with concurrent sidecars and embeddings."""
# ── Phase 1: Statement extraction + chunk/dialog embedding ──
logger.info("Phase 1/4: Statement extraction + chunk/dialog embedding")
chunk_dialog_emb_input = self._build_chunk_dialog_embedding_input(
dialog_data_list
)
stmt_coro = self._extract_all_statements(dialog_data_list)
emb_coro = self.embedding_step.run(chunk_dialog_emb_input)
phase1_results = await asyncio.gather(
stmt_coro, emb_coro, return_exceptions=True
)
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]] = (
phase1_results[0]
if not isinstance(phase1_results[0], BaseException)
else {}
)
if isinstance(phase1_results[0], BaseException):
raise phase1_results[0]
chunk_dialog_emb: Optional[EmbeddingStepOutput] = (
phase1_results[1]
if not isinstance(phase1_results[1], BaseException)
else None
)
if isinstance(phase1_results[1], BaseException):
logger.warning("Chunk/dialog embedding failed: %s", phase1_results[1])
# ── Phase 2: Triplet extraction + after_statement sidecars + statement embedding ──
logger.info(
"Phase 2/4: Triplet extraction + sidecars + statement embedding"
)
stmt_emb_input = self._build_statement_embedding_input(
dialog_data_list, all_stmt_results
)
# Build sidecar inputs for after_statement sidecars (e.g. emotion)
sidecar_pairs = self._build_after_statement_sidecar_inputs(
dialog_data_list, all_stmt_results
)
triplet_coro = self._extract_all_triplets(
dialog_data_list, all_stmt_results
)
stmt_emb_coro = self.embedding_step.run(stmt_emb_input)
triplet_results, sidecar_results, extra_results = (
await self._run_with_sidecars(
triplet_coro,
sidecar_pairs,
extra_coros=[stmt_emb_coro],
)
)
all_triplet_results = triplet_results
stmt_emb: Optional[EmbeddingStepOutput] = (
extra_results[0] if extra_results else None
)
# Collect sidecar outputs keyed by step name
sidecar_steps = [step for step, _inp in sidecar_pairs]
sidecar_output_map = self._collect_sidecar_outputs(
sidecar_steps, sidecar_results
)
# ── Phase 3: Entity embedding + after_triplet sidecars ──
logger.info("Phase 3/4: Entity embedding + after_triplet sidecars")
entity_emb_input = self._build_entity_embedding_input(all_triplet_results)
after_triplet_pairs: List[Tuple[ExtractionStep, Any]] = []
# Future after_triplet sidecars would be wired here
entity_emb_coro = self.embedding_step.run(entity_emb_input)
if after_triplet_pairs:
_, at_sidecar_results, at_extra = await self._run_with_sidecars(
entity_emb_coro,
after_triplet_pairs,
)
entity_emb = at_extra[0] if at_extra else None
else:
# No after_triplet sidecars — just run embedding
entity_emb_result = await entity_emb_coro
entity_emb = (
entity_emb_result
if not isinstance(entity_emb_result, BaseException)
else None
)
# Merge all embedding outputs
merged_emb = self._merge_embeddings(chunk_dialog_emb, stmt_emb, entity_emb)
# ── Phase 4: Data assignment ──
logger.info("Phase 4/4: Data assignment")
emotion_results = sidecar_output_map.get("emotion_extraction", {})
self._assign_results(
dialog_data_list,
all_stmt_results,
all_triplet_results,
emotion_results=emotion_results,
embedding_output=merged_emb,
)
# Store raw step outputs for snapshot/debugging
self._last_stage_outputs = {
"statement_results": all_stmt_results,
"triplet_results": all_triplet_results,
"emotion_results": emotion_results,
"embedding_output": merged_emb,
}
logger.info("Full extraction pipeline complete")
return dialog_data_list
@property
def last_stage_outputs(self) -> Dict[str, Any]:
"""Return the raw step outputs from the last run for snapshot/debugging."""
return getattr(self, "_last_stage_outputs", {})
# ──────────────────────────────────────────────
# 4. 萃取执行器
# chunk 级并行 statement 提取、statement 级并行 triplet 提取
# ──────────────────────────────────────────────
async def _extract_all_statements(
self,
dialog_data_list: List[DialogData],
) -> Dict[str, Dict[str, List[StatementStepOutput]]]:
"""Extract statements from all chunks across all dialogs (chunk-level parallel).
Returns:
Nested dict: ``{dialog_id: {chunk_id: [StatementStepOutput, ...]}}``
"""
# Collect all (chunk, metadata) pairs
tasks: List[Any] = []
task_meta: List[Tuple[str, str, str, SupportingContext]] = []
for dialog in dialog_data_list:
ctx = self._build_supporting_context(dialog)
dialogue_content = (
dialog.content
if getattr(
self.config, "statement_extraction", None
)
and getattr(
self.config.statement_extraction,
"include_dialogue_context",
True,
)
else None
)
for chunk in dialog.chunks:
inp = StatementStepInput(
chunk_id=chunk.id,
end_user_id=dialog.end_user_id,
target_content=chunk.content,
target_message_date=str(
getattr(dialog, "created_at", "") or ""
),
supporting_context=ctx,
)
tasks.append(self.statement_step.run(inp))
task_meta.append(
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Organise into nested dict
stmt_map: Dict[str, Dict[str, List[StatementStepOutput]]] = {}
for i, result in enumerate(results):
dialog_id, chunk_id, speaker, _ = task_meta[i]
if dialog_id not in stmt_map:
stmt_map[dialog_id] = {}
if isinstance(result, BaseException):
logger.error("Statement extraction failed for chunk %s: %s", chunk_id, result)
stmt_map[dialog_id][chunk_id] = []
else:
# Override speaker from chunk metadata
stmts: List[StatementStepOutput] = result if isinstance(result, list) else []
for s in stmts:
s.speaker = speaker
stmt_map[dialog_id][chunk_id] = stmts
return stmt_map
async def _extract_all_triplets(
self,
dialog_data_list: List[DialogData],
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
) -> Dict[str, Dict[str, TripletStepOutput]]:
"""Extract triplets for every statement (statement-level parallel).
Returns:
Nested dict: ``{dialog_id: {statement_id: TripletStepOutput}}``
"""
tasks: List[Any] = []
task_meta: List[Tuple[str, str]] = [] # (dialog_id, statement_id)
for dialog in dialog_data_list:
ctx = self._build_supporting_context(dialog)
chunk_stmts = all_stmt_results.get(dialog.id, {})
for _chunk_id, stmts in chunk_stmts.items():
for stmt in stmts:
inp = self._convert_to_triplet_input(stmt, ctx)
tasks.append(self.triplet_step.run(inp))
task_meta.append((dialog.id, stmt.statement_id))
results = await asyncio.gather(*tasks, return_exceptions=True)
triplet_map: Dict[str, Dict[str, TripletStepOutput]] = {}
for i, result in enumerate(results):
dialog_id, stmt_id = task_meta[i]
if dialog_id not in triplet_map:
triplet_map[dialog_id] = {}
if isinstance(result, BaseException):
logger.error(
"Triplet extraction failed for statement %s: %s",
stmt_id,
result,
)
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
else:
triplet_map[dialog_id][stmt_id] = result
return triplet_map
# ──────────────────────────────────────────────
# 5. Embedding 输入构建器
# 为不同阶段构建 EmbeddingStepInputchunk/statement/entity
# ──────────────────────────────────────────────
@staticmethod
def _build_chunk_dialog_embedding_input(
dialog_data_list: List[DialogData],
) -> EmbeddingStepInput:
"""Build embedding input for chunks and dialogs (phase 1)."""
chunk_texts: Dict[str, str] = {}
dialog_texts: List[str] = []
for dialog in dialog_data_list:
if hasattr(dialog, "content") and dialog.content:
dialog_texts.append(dialog.content)
for chunk in dialog.chunks:
chunk_texts[chunk.id] = chunk.content
return EmbeddingStepInput(
chunk_texts=chunk_texts,
dialog_texts=dialog_texts,
)
@staticmethod
def _build_statement_embedding_input(
dialog_data_list: List[DialogData],
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
) -> EmbeddingStepInput:
"""Build embedding input for statements (phase 2)."""
stmt_texts: Dict[str, str] = {}
for _dialog_id, chunk_stmts in all_stmt_results.items():
for _chunk_id, stmts in chunk_stmts.items():
for s in stmts:
stmt_texts[s.statement_id] = s.statement_text
return EmbeddingStepInput(statement_texts=stmt_texts)
@staticmethod
def _build_entity_embedding_input(
all_triplet_results: Dict[str, Dict[str, TripletStepOutput]],
) -> EmbeddingStepInput:
"""Build embedding input for entities (phase 3)."""
entity_names: Dict[str, str] = {}
entity_descs: Dict[str, str] = {}
seen: set = set()
for _dialog_id, stmt_triplets in all_triplet_results.items():
for _stmt_id, triplet_out in stmt_triplets.items():
for ent in triplet_out.entities:
key = f"{ent.entity_idx}_{ent.name}"
if key not in seen:
seen.add(key)
entity_names[key] = ent.name
entity_descs[key] = ent.description
return EmbeddingStepInput(
entity_names=entity_names,
entity_descriptions=entity_descs,
)
# ──────────────────────────────────────────────
# 6. 旁路输入构建与结果收集
# 为 after_statement / after_triplet 旁路构建输入,合并 embedding 输出
# ──────────────────────────────────────────────
def _build_after_statement_sidecar_inputs(
self,
dialog_data_list: List[DialogData],
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
) -> List[Tuple[ExtractionStep, Any]]:
"""Build (step, input) pairs for after_statement sidecars.
For emotion extraction, we create a batch wrapper that runs the sidecar
on every user statement concurrently and returns a dict of results.
"""
if not self.after_statement_sidecars:
return []
# Collect all user statements for sidecar processing
all_user_stmts: List[StatementStepOutput] = []
for _dialog_id, chunk_stmts in all_stmt_results.items():
for _chunk_id, stmts in chunk_stmts.items():
for s in stmts:
if s.speaker == "user":
all_user_stmts.append(s)
pairs: List[Tuple[ExtractionStep, Any]] = []
for sidecar in self.after_statement_sidecars:
if sidecar.name == "emotion_extraction":
# Emotion sidecar: wrap as batch coroutine via a sentinel input
# The actual per-statement calls happen inside _run_emotion_batch
pairs.append((
_EmotionBatchWrapper(sidecar, all_user_stmts),
EmotionStepInput(
statement_id="__batch__",
statement_text="",
speaker="",
),
))
else:
# Generic sidecar: pass first statement as representative input
if all_user_stmts:
inp = self._convert_to_emotion_input(all_user_stmts[0])
pairs.append((sidecar, inp))
return pairs
@staticmethod
def _collect_sidecar_outputs(
sidecars: List[ExtractionStep],
results: List[Any],
) -> Dict[str, Any]:
"""Map sidecar results by step name."""
output: Dict[str, Any] = {}
for i, sidecar in enumerate(sidecars):
if i < len(results):
output[sidecar.name] = results[i]
return output
@staticmethod
def _merge_embeddings(
chunk_dialog: Optional[EmbeddingStepOutput],
statement: Optional[EmbeddingStepOutput],
entity: Optional[Any],
) -> Optional[EmbeddingStepOutput]:
"""Merge partial embedding outputs into a single EmbeddingStepOutput."""
merged = EmbeddingStepOutput()
if chunk_dialog:
merged.chunk_embeddings = chunk_dialog.chunk_embeddings
merged.dialog_embeddings = chunk_dialog.dialog_embeddings
if statement:
merged.statement_embeddings = statement.statement_embeddings
if entity and isinstance(entity, EmbeddingStepOutput):
merged.entity_embeddings = entity.entity_embeddings
return merged
# ──────────────────────────────────────────────
# 7. 数据赋值
# 将各阶段 StepOutput 组装为 Statement 对象,替换 chunk.statements
# ──────────────────────────────────────────────
def _assign_results(
self,
dialog_data_list: List[DialogData],
all_stmt_results: Dict[str, Dict[str, List[StatementStepOutput]]],
all_triplet_results: Dict[str, Dict[str, TripletStepOutput]],
emotion_results: Dict[str, EmotionStepOutput],
embedding_output: Optional[EmbeddingStepOutput],
) -> None:
"""Assign extraction results back to dialog_data_list in-place.
Replaces chunk.statements with new Statement objects built from step
outputs, because the new orchestrator generates its own statement IDs
that don't match the original chunk statement IDs.
"""
from app.core.memory.models.message_models import (
Statement,
TemporalValidityRange,
)
from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
Entity as TripletEntity,
Triplet as TripletRelation,
)
from app.core.memory.utils.data.ontology import (
RelevenceInfo,
StatementType,
TemporalInfo,
)
# Map string values to enums
_STMT_TYPE_MAP = {
"FACT": StatementType.FACT,
"OPINION": StatementType.OPINION,
"PREDICTION": StatementType.PREDICTION,
"SUGGESTION": StatementType.SUGGESTION,
}
_TEMPORAL_MAP = {
"STATIC": TemporalInfo.STATIC,
"DYNAMIC": TemporalInfo.DYNAMIC,
"ATEMPORAL": TemporalInfo.ATEMPORAL,
}
total_stmts = 0
assigned_triplets = 0
assigned_emotions = 0
assigned_stmt_emb = 0
assigned_chunk_emb = 0
assigned_dialog_emb = 0
for dialog in dialog_data_list:
dialog_stmts = all_stmt_results.get(dialog.id, {})
dialog_triplets = all_triplet_results.get(dialog.id, {})
# Assign dialog embedding
if embedding_output and embedding_output.dialog_embeddings:
idx = dialog_data_list.index(dialog)
if idx < len(embedding_output.dialog_embeddings):
dialog.dialog_embedding = embedding_output.dialog_embeddings[idx]
assigned_dialog_emb += 1
for chunk in dialog.chunks:
# Assign chunk embedding
if embedding_output and chunk.id in embedding_output.chunk_embeddings:
chunk.chunk_embedding = embedding_output.chunk_embeddings[chunk.id]
assigned_chunk_emb += 1
# Build new Statement objects from step outputs
chunk_stmt_outputs = dialog_stmts.get(chunk.id, [])
new_statements = []
for stmt_out in chunk_stmt_outputs:
total_stmts += 1
# Temporal validity
valid_at = stmt_out.valid_at if stmt_out.valid_at != "NULL" else None
invalid_at = stmt_out.invalid_at if stmt_out.invalid_at != "NULL" else None
# Triplet info
triplet_info = None
triplet_out = dialog_triplets.get(stmt_out.statement_id)
if triplet_out and (triplet_out.entities or triplet_out.triplets):
entities = [
TripletEntity(
entity_idx=e.entity_idx,
name=e.name,
type=e.type,
description=e.description,
is_explicit_memory=e.is_explicit_memory,
)
for e in triplet_out.entities
]
triplets = [
TripletRelation(
subject_name=t.subject_name,
subject_id=t.subject_id,
predicate=t.predicate,
object_name=t.object_name,
object_id=t.object_id,
)
for t in triplet_out.triplets
]
triplet_info = TripletExtractionResponse(
entities=entities, triplets=triplets,
)
assigned_triplets += 1
# Emotion info
emo = emotion_results.get(stmt_out.statement_id)
emotion_kwargs = {}
if emo:
emotion_kwargs = {
"emotion_type": emo.emotion_type,
"emotion_intensity": emo.emotion_intensity,
"emotion_keywords": emo.emotion_keywords,
}
assigned_emotions += 1
# Statement embedding
stmt_embedding = None
if (
embedding_output
and stmt_out.statement_id in embedding_output.statement_embeddings
):
stmt_embedding = embedding_output.statement_embeddings[stmt_out.statement_id]
assigned_stmt_emb += 1
# Build the Statement object that _create_nodes_and_edges expects
stmt = Statement(
id=stmt_out.statement_id,
chunk_id=chunk.id,
end_user_id=dialog.end_user_id,
statement=stmt_out.statement_text,
speaker=stmt_out.speaker,
stmt_type=_STMT_TYPE_MAP.get(stmt_out.statement_type, StatementType.FACT),
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
triplet_extraction_info=triplet_info,
statement_embedding=stmt_embedding,
**emotion_kwargs,
)
new_statements.append(stmt)
# Replace chunk.statements with newly built objects
chunk.statements = new_statements
logger.info(
"Data assignment complete — statements: %d, triplets: %d, "
"emotions: %d, stmt_emb: %d, chunk_emb: %d, dialog_emb: %d",
total_stmts,
assigned_triplets,
assigned_emotions,
assigned_stmt_emb,
assigned_chunk_emb,
assigned_dialog_emb,
)
class _EmotionBatchWrapper(ExtractionStep):
"""情绪批量提取包装器。再考虑一下用法,这是子类?
将单条情绪旁路 Step 包装为批量并发执行,适配 ``_run_with_sidecars`` 接口。
编排器传入一个 sentinel input``run()`` 忽略它,转而对预收集的 statement 列表
逐条并发调用内部 Step返回 ``{statement_id: EmotionStepOutput}`` 字典。
"""
# ── 初始化 ──
def __init__(
self,
inner_step: ExtractionStep,
statements: List[StatementStepOutput],
) -> None:
# 不调用 super().__init__() — 本类是薄包装,不需要 StepContext
self._inner = inner_step
self._statements = statements
# ── Step 身份属性(满足 ExtractionStep 抽象接口) ──
@property
def name(self) -> str:
return self._inner.name
@property
def is_critical(self) -> bool:
return False
def get_default_output(self) -> Dict[str, EmotionStepOutput]:
return {}
# ── 未使用的生命周期方法(批量包装器不走 render→call→parse 流程) ──
async def render_prompt(self, input_data: Any) -> Any:
raise NotImplementedError
async def call_llm(self, prompt: Any) -> Any:
raise NotImplementedError
async def parse_response(self, raw_response: Any, input_data: Any) -> Any:
raise NotImplementedError
# ── 批量执行入口 ──
async def run(self, input_data: Any) -> Dict[str, EmotionStepOutput]:
"""对所有预收集的 statement 并发执行情绪提取,单条失败返回默认值。"""
if not self._statements:
return {}
async def _extract_one(stmt: StatementStepOutput) -> Tuple[str, EmotionStepOutput]:
inp = EmotionStepInput(
statement_id=stmt.statement_id,
statement_text=stmt.statement_text,
speaker=stmt.speaker,
)
try:
result = await self._inner.run(inp)
return stmt.statement_id, result
except Exception:
return stmt.statement_id, self._inner.get_default_output()
pairs = await asyncio.gather(
*[_extract_one(s) for s in self._statements]
)
return dict(pairs)

View File

@@ -0,0 +1,366 @@
"""
GraphBuildStep — 从 DialogData 构建 Neo4j 图节点和边。
职责:
- 遍历 DialogData 列表,构建 DialogueNode、ChunkNode、StatementNode、
ExtractedEntityNode、PerceptualNode 及各类 Edge
- 不涉及 LLM 调用、去重、Neo4j 写入
依赖:
- embedder_client可选为 PerceptualNode 生成 summary embedding
- progress_callback可选流式输出关系创建进度
从 ExtractionOrchestrator._create_nodes_and_edges() 提取而来,
旧编排器保留原方法不变,新旧流水线完全隔离。
"""
from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from app.core.memory.models.graph_models import (
ChunkNode,
DialogueNode,
EntityEntityEdge,
ExtractedEntityNode,
PerceptualEdge,
PerceptualNode,
StatementChunkEdge,
StatementEntityEdge,
StatementNode,
)
from app.core.memory.models.message_models import DialogData, TemporalInfo
logger = logging.getLogger(__name__)
class GraphBuildResult:
"""图构建步骤的输出。"""
__slots__ = (
"dialogue_nodes",
"chunk_nodes",
"statement_nodes",
"entity_nodes",
"perceptual_nodes",
"stmt_chunk_edges",
"stmt_entity_edges",
"entity_entity_edges",
"perceptual_edges",
)
def __init__(
self,
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
perceptual_nodes: List[PerceptualNode],
stmt_chunk_edges: List[StatementChunkEdge],
stmt_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
perceptual_edges: List[PerceptualEdge],
):
self.dialogue_nodes = dialogue_nodes
self.chunk_nodes = chunk_nodes
self.statement_nodes = statement_nodes
self.entity_nodes = entity_nodes
self.perceptual_nodes = perceptual_nodes
self.stmt_chunk_edges = stmt_chunk_edges
self.stmt_entity_edges = stmt_entity_edges
self.entity_entity_edges = entity_entity_edges
self.perceptual_edges = perceptual_edges
async def build_graph_nodes_and_edges(
dialog_data_list: List[DialogData],
embedder_client: Any = None,
progress_callback: Optional[
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
] = None,
) -> GraphBuildResult:
"""
从 DialogData 列表构建完整的图节点和边。
Args:
dialog_data_list: 经过萃取和数据赋值后的 DialogData 列表
embedder_client: 可选的嵌入客户端,用于 PerceptualNode summary embedding
progress_callback: 可选的进度回调
Returns:
GraphBuildResult 包含所有节点和边
"""
logger.info("开始创建节点和边")
dialogue_nodes: List[DialogueNode] = []
chunk_nodes: List[ChunkNode] = []
statement_nodes: List[StatementNode] = []
entity_nodes: List[ExtractedEntityNode] = []
perceptual_nodes: List[PerceptualNode] = []
stmt_chunk_edges: List[StatementChunkEdge] = []
stmt_entity_edges: List[StatementEntityEdge] = []
entity_entity_edges: List[EntityEntityEdge] = []
perceptual_edges: List[PerceptualEdge] = []
entity_id_set: set = set()
total_dialogs = len(dialog_data_list)
processed_dialogs = 0
for dialog_data in dialog_data_list:
processed_dialogs += 1
# region TODO 乐力齐 重构流水线切换生产环境稳定后修改
# ── 对话节点 ──
dialogue_node = DialogueNode(
id=dialog_data.id,
name=f"Dialog_{dialog_data.id}",
ref_id=dialog_data.ref_id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None,
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
metadata=dialog_data.metadata,
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
)
dialogue_nodes.append(dialogue_node)
# ── 分块节点 ──
for chunk_idx, chunk in enumerate(dialog_data.chunks):
chunk_node = ChunkNode(
id=chunk.id,
name=f"Chunk_{chunk.id}",
dialog_id=dialog_data.id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
content=chunk.content,
speaker=getattr(chunk, "speaker", None),
chunk_embedding=chunk.chunk_embedding,
sequence_number=chunk_idx,
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
metadata=chunk.metadata,
)
chunk_nodes.append(chunk_node)
# ── 感知节点 ──
for p, file_type in chunk.files:
meta = p.meta_data or {}
content_meta = meta.get("content", {})
summary_embedding = None
if embedder_client and p.summary:
try:
summary_embedding = (await embedder_client.response([p.summary]))[0]
except Exception as emb_err:
logger.warning(f"Failed to embed perceptual summary: {emb_err}")
perceptual = PerceptualNode(
name=f"Perceptual_{p.id}",
id=str(p.id),
end_user_id=str(p.end_user_id),
perceptual_type=p.perceptual_type,
file_path=p.file_path or "",
file_name=p.file_name or "",
file_ext=p.file_ext or "",
summary=p.summary or "",
keywords=content_meta.get("keywords", []),
topic=content_meta.get("topic", ""),
domain=content_meta.get("domain", ""),
created_at=p.created_time.isoformat() if p.created_time else None,
file_type=file_type,
summary_embedding=summary_embedding,
)
perceptual_nodes.append(perceptual)
perceptual_edges.append(
PerceptualEdge(
source=perceptual.id,
target=chunk.id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
created_at=dialog_data.created_at,
)
)
# ── 陈述句节点 + 边 ──
for statement in chunk.statements:
statement_node = StatementNode(
id=statement.id,
name=f"Statement_{statement.id}",
chunk_id=chunk.id,
stmt_type=getattr(statement, "stmt_type", "general"),
temporal_info=getattr(statement, "temporal_info", TemporalInfo.ATEMPORAL),
connect_strength=(
statement.connect_strength
if statement.connect_strength is not None
else "Strong"
),
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
statement=statement.statement,
speaker=getattr(statement, "speaker", None),
statement_embedding=statement.statement_embedding,
valid_at=(
statement.temporal_validity.valid_at
if hasattr(statement, "temporal_validity") and statement.temporal_validity
else None
),
invalid_at=(
statement.temporal_validity.invalid_at
if hasattr(statement, "temporal_validity") and statement.temporal_validity
else None
),
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
emotion_type=getattr(statement, "emotion_type", None),
emotion_intensity=getattr(statement, "emotion_intensity", None),
emotion_keywords=getattr(statement, "emotion_keywords", None),
emotion_subject=getattr(statement, "emotion_subject", None),
emotion_target=getattr(statement, "emotion_target", None),
)
statement_nodes.append(statement_node)
stmt_chunk_edges.append(
StatementChunkEdge(
source=statement.id,
target=chunk.id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
created_at=dialog_data.created_at,
)
)
# ── 三元组 → 实体节点 + 边 ──
if not statement.triplet_extraction_info:
continue
triplet_info = statement.triplet_extraction_info
entity_idx_to_id: Dict[int, str] = {}
for entity_idx, entity in enumerate(triplet_info.entities):
entity_idx_to_id[entity.entity_idx] = entity.id
entity_idx_to_id[entity_idx] = entity.id
if entity.id not in entity_id_set:
entity_connect_strength = getattr(entity, "connect_strength", "Strong")
entity_node = ExtractedEntityNode(
id=entity.id,
name=getattr(entity, "name", f"Entity_{entity.id}"),
entity_idx=entity.entity_idx,
statement_id=statement.id,
entity_type=getattr(entity, "type", "unknown"),
description=getattr(entity, "description", ""),
example=getattr(entity, "example", ""),
connect_strength=(
entity_connect_strength
if entity_connect_strength is not None
else "Strong"
),
aliases=getattr(entity, "aliases", []) or [],
name_embedding=getattr(entity, "name_embedding", None),
is_explicit_memory=getattr(entity, "is_explicit_memory", False),
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
)
entity_nodes.append(entity_node)
entity_id_set.add(entity.id)
entity_connect_strength = getattr(entity, "connect_strength", "Strong")
stmt_entity_edges.append(
StatementEntityEdge(
source=statement.id,
target=entity.id,
connect_strength=(
entity_connect_strength
if entity_connect_strength is not None
else "Strong"
),
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
created_at=dialog_data.created_at,
)
)
# endregion
for triplet in triplet_info.triplets:
subject_entity_id = entity_idx_to_id.get(triplet.subject_id)
object_entity_id = entity_idx_to_id.get(triplet.object_id)
if subject_entity_id and object_entity_id:
entity_entity_edges.append(
EntityEntityEdge(
source=subject_entity_id,
target=object_entity_id,
relation_type=triplet.predicate,
statement=statement.statement,
source_statement_id=statement.id,
end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id,
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
)
)
if progress_callback and len(entity_entity_edges) <= 10:
relationship_result = {
"result_type": "relationship_creation",
"relationship_index": len(entity_entity_edges),
"source_entity": triplet.subject_name,
"relation_type": triplet.predicate,
"target_entity": triplet.object_name,
"relationship_text": f"{triplet.subject_name} -[{triplet.predicate}]-> {triplet.object_name}",
"dialog_progress": f"{processed_dialogs}/{total_dialogs}",
}
await progress_callback(
"creating_nodes_edges_result",
f"关系创建中 ({processed_dialogs}/{total_dialogs})",
relationship_result,
)
else:
missing_subject = "subject" if not subject_entity_id else ""
missing_object = "object" if not object_entity_id else ""
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
logger.debug(
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
f"object_id={triplet.object_id} ({triplet.object_name}), "
f"predicate={triplet.predicate}, "
f"statement_id={statement.id}, "
f"available_indices={sorted(entity_idx_to_id.keys())}"
)
logger.info(
f"节点和边创建完成 - 对话节点: {len(dialogue_nodes)}, "
f"分块节点: {len(chunk_nodes)}, 陈述句节点: {len(statement_nodes)}, "
f"实体节点: {len(entity_nodes)}, 陈述句-分块边: {len(stmt_chunk_edges)}, "
f"陈述句-实体边: {len(stmt_entity_edges)}, "
f"实体-实体边: {len(entity_entity_edges)}"
)
if progress_callback:
nodes_edges_stats = {
"dialogue_nodes_count": len(dialogue_nodes),
"chunk_nodes_count": len(chunk_nodes),
"statement_nodes_count": len(statement_nodes),
"entity_nodes_count": len(entity_nodes),
"statement_chunk_edges_count": len(stmt_chunk_edges),
"statement_entity_edges_count": len(stmt_entity_edges),
"entity_entity_edges_count": len(entity_entity_edges),
}
await progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
return GraphBuildResult(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
perceptual_nodes=perceptual_nodes,
stmt_chunk_edges=stmt_chunk_edges,
stmt_entity_edges=stmt_entity_edges,
entity_entity_edges=entity_entity_edges,
perceptual_edges=perceptual_edges,
)

View File

@@ -0,0 +1,42 @@
"""Schema package for ExtractionStep inputs and outputs.
Re-exports all models for convenient access:
from .schema import StatementStepInput, EmotionStepOutput, ...
"""
from .extraction_step_schema import (
EmbeddingStepInput,
EmbeddingStepOutput,
EntityItem,
MessageItem,
StatementStepInput,
StatementStepOutput,
SupportingContext,
TripletItem,
TripletStepInput,
TripletStepOutput,
)
from .sidecar_step_schema import (
EmotionStepInput,
EmotionStepOutput,
)
__all__ = [
# Shared
"MessageItem",
"SupportingContext",
# Statement
"StatementStepInput",
"StatementStepOutput",
# Triplet
"TripletStepInput",
"TripletStepOutput",
"EntityItem",
"TripletItem",
# Embedding
"EmbeddingStepInput",
"EmbeddingStepOutput",
# Sidecar — Emotion
"EmotionStepInput",
"EmotionStepOutput",
]

View File

@@ -0,0 +1,115 @@
"""Pydantic models for base extraction pipeline inputs and outputs.
Covers the core (critical) stages: Statement extraction, Triplet extraction,
Embedding generation, and shared types used across stages.
Malformed LLM JSON will raise ``ValidationError`` and trigger stage-level retry.
"""
from typing import Dict, List
from pydantic import BaseModel, Field
# ── Shared types ──
class MessageItem(BaseModel):
"""Single conversation message."""
role: str # "User" / "Assistant"
msg: str
class SupportingContext(BaseModel):
"""Dialogue context window (used for pronoun resolution, etc.)."""
msgs: List[MessageItem] = Field(default_factory=list)
# ── Statement extraction ──
class StatementStepInput(BaseModel):
"""Input for StatementExtractionStep."""
chunk_id: str
end_user_id: str
target_content: str
target_message_date: str
supporting_context: SupportingContext
class StatementStepOutput(BaseModel):
"""Single extracted statement (including temporal info)."""
statement_id: str
statement_text: str
statement_type: str # FACT / OPINION / PREDICTION / SUGGESTION
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
relevance: str # RELEVANT / IRRELEVANT
speaker: str # "user" / "assistant"
valid_at: str # ISO 8601 or "NULL"
invalid_at: str # ISO 8601 or "NULL"
# ── Triplet extraction ──
class TripletStepInput(BaseModel):
"""Input for TripletExtractionStep."""
statement_id: str
statement_text: str
statement_type: str
temporal_type: str
supporting_context: SupportingContext
speaker: str
valid_at: str
invalid_at: str
class EntityItem(BaseModel):
"""Single entity extracted during triplet extraction."""
entity_idx: int
name: str
type: str
description: str
is_explicit_memory: bool = False
class TripletItem(BaseModel):
"""Single triplet (subject-predicate-object) relationship."""
subject_name: str
subject_id: int
predicate: str
object_name: str
object_id: int
class TripletStepOutput(BaseModel):
"""Output of TripletExtractionStep."""
entities: List[EntityItem] = Field(default_factory=list)
triplets: List[TripletItem] = Field(default_factory=list)
# ── Embedding generation ──
class EmbeddingStepInput(BaseModel):
"""Input for EmbeddingStep.
Each dict maps an ID to the text that should be embedded.
Fields can be left empty for partial embedding runs.
"""
statement_texts: Dict[str, str] = Field(default_factory=dict)
chunk_texts: Dict[str, str] = Field(default_factory=dict)
dialog_texts: List[str] = Field(default_factory=list)
entity_names: Dict[str, str] = Field(default_factory=dict)
entity_descriptions: Dict[str, str] = Field(default_factory=dict)
class EmbeddingStepOutput(BaseModel):
"""Output of EmbeddingStep."""
statement_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
chunk_embeddings: Dict[str, List[float]] = Field(default_factory=dict)
dialog_embeddings: List[List[float]] = Field(default_factory=list)
entity_embeddings: Dict[str, List[float]] = Field(default_factory=dict)

View File

@@ -0,0 +1,26 @@
"""Pydantic models for hot-pluggable sidecar step inputs and outputs.
Sidecar steps are non-critical (is_critical=False) modules registered via
``@SidecarStepFactory.register`` that run concurrently alongside the main
extraction pipeline. Failures degrade gracefully to default outputs.
"""
from typing import List
from pydantic import BaseModel, Field
# ── Emotion extraction (sidecar) ──
class EmotionStepInput(BaseModel):
"""Input for EmotionExtractionStep."""
statement_id: str
statement_text: str
speaker: str
class EmotionStepOutput(BaseModel):
"""Output of EmotionExtractionStep."""
emotion_type: str = "neutral"
emotion_intensity: float = 0.0
emotion_keywords: List[str] = Field(default_factory=list)

View File

@@ -0,0 +1,97 @@
"""SidecarStepFactory — decorator-based registry for sidecar (non-critical) steps.
New sidecar modules self-register via ``@SidecarStepFactory.register`` and are
automatically discovered and instantiated by the orchestrator without any
changes to orchestrator code.
"""
import logging
from enum import Enum
from typing import Any, Dict, List, Tuple, Type
from .base import ExtractionStep, StepContext
logger = logging.getLogger(__name__)
class SidecarTiming(str, Enum):
"""Declares when a sidecar step runs relative to the main pipeline."""
AFTER_STATEMENT = "after_statement"
AFTER_TRIPLET = "after_triplet"
class SidecarStepFactory:
"""Factory that manages sidecar step registration and creation.
Registry maps ``config_key`` → ``(step_class, timing)``.
Adding a new sidecar only requires the ``@register`` decorator on the
step class — no orchestrator modifications needed.
"""
_registry: Dict[str, Tuple[Type[ExtractionStep], SidecarTiming]] = {}
@classmethod
def register(cls, config_key: str, timing: SidecarTiming):
"""Class decorator that registers a sidecar step.
Args:
config_key: Configuration flag name (e.g. ``"emotion_enabled"``).
The step is instantiated only when this flag is ``True``.
timing: When the sidecar runs relative to the main pipeline.
Returns:
The original class, unmodified.
"""
def decorator(step_class: Type[ExtractionStep]):
cls._registry[config_key] = (step_class, timing)
logger.debug(
"Registered sidecar '%s' (config_key=%s, timing=%s)",
step_class.__name__,
config_key,
timing.value,
)
return step_class
return decorator
@classmethod
def create_sidecars(
cls, config: Any, context: StepContext
) -> Dict[SidecarTiming, List[ExtractionStep]]:
"""Instantiate enabled sidecar steps, grouped by timing.
Args:
config: Pipeline configuration object. Each registered
``config_key`` is looked up via ``getattr(config, key, False)``.
context: Shared :class:`StepContext` injected into every step.
Returns:
A dict keyed by :class:`SidecarTiming`, each value a list of
instantiated sidecar steps whose config flag is ``True``.
"""
result: Dict[SidecarTiming, List[ExtractionStep]] = {
timing: [] for timing in SidecarTiming
}
for config_key, (step_class, timing) in cls._registry.items():
if getattr(config, config_key, False):
step = step_class(context)
result[timing].append(step)
logger.debug(
"Created sidecar '%s' (timing=%s)",
step_class.__name__,
timing.value,
)
else:
logger.debug(
"Skipped sidecar '%s' (config_key=%s is disabled)",
step_class.__name__,
config_key,
)
return result
@classmethod
def clear_registry(cls) -> None:
"""Remove all registered sidecars. Useful for testing."""
cls._registry.clear()

View File

@@ -0,0 +1,141 @@
"""StatementExtractionStep — critical step for extracting statements from chunks.
Replaces the legacy ``StatementExtractor`` with the unified ExtractionStep paradigm.
Temporal extraction logic (valid_at / invalid_at) is merged into this step,
eliminating the need for a separate ``TemporalExtractor`` call.
"""
import logging
import uuid
from typing import Any, List
from pydantic import BaseModel, Field, field_validator
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
from .base import ExtractionStep, StepContext
from .schema import StatementStepInput, StatementStepOutput
logger = logging.getLogger(__name__)
# ── LLM response schemas (internal) ──
class _ExtractedStatement(BaseModel):
"""Raw statement returned by the LLM (before enrichment)."""
statement: str = Field(..., description="The extracted statement text")
statement_type: str = Field(..., description="FACT / OPINION / SUGGESTION / PREDICTION")
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
class _StatementExtractionResponse(BaseModel):
"""Structured LLM response containing a list of extracted statements."""
statements: List[_ExtractedStatement] = Field(default_factory=list)
@field_validator("statements", mode="before")
@classmethod
def filter_empty(cls, v: Any) -> Any:
"""Drop empty / malformed dicts that the LLM occasionally produces."""
if isinstance(v, list):
return [s for s in v if isinstance(s, dict) and s.get("statement")]
return v
class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementStepOutput]]):
"""Extract atomic statements (with temporal info) from a dialogue chunk.
This is a **critical** step — failure aborts the pipeline after retries.
Config params bound at init (from ``StepContext.config.statement_extraction``):
* ``definitions`` — label definitions for statement classification
* ``json_schema`` — JSON schema for the expected LLM output
* ``granularity`` — extraction granularity level (1-3)
* ``include_dialogue_context`` — whether to include full dialogue context
"""
def __init__(self, context: StepContext) -> None:
super().__init__(context)
stmt_cfg = getattr(self.config, "statement_extraction", None)
self.definitions = LABEL_DEFINITIONS
self.json_schema = _ExtractedStatement.model_json_schema()
self.granularity = getattr(stmt_cfg, "statement_granularity", None)
self.include_dialogue_context = getattr(stmt_cfg, "include_dialogue_context", True)
self.max_dialogue_context_chars = getattr(stmt_cfg, "max_dialogue_context_chars", 2000)
# ── Identity ──
@property
def name(self) -> str:
return "statement_extraction"
@property
def is_critical(self) -> bool:
return True
# ── Lifecycle ──
async def render_prompt(self, input_data: StatementStepInput) -> str:
# Build optional dialogue context from supporting_context messages
dialogue_content = None
if self.include_dialogue_context and input_data.supporting_context.msgs:
dialogue_content = "\n".join(
f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs
)
return await render_statement_extraction_prompt(
chunk_content=input_data.target_content,
definitions=self.definitions,
json_schema=self.json_schema,
granularity=self.granularity,
include_dialogue_context=self.include_dialogue_context,
dialogue_content=dialogue_content,
max_dialogue_chars=self.max_dialogue_context_chars,
language=self.language,
)
async def call_llm(self, prompt: Any) -> Any:
messages = [
{
"role": "system",
"content": (
"You are an expert at extracting and labeling atomic statements "
"from conversational text. Return valid JSON conforming to the schema."
),
},
{"role": "user", "content": prompt},
]
return await self.llm_client.response_structured(
messages, _StatementExtractionResponse
)
async def parse_response(
self, raw_response: Any, input_data: StatementStepInput
) -> List[StatementStepOutput]:
if not hasattr(raw_response, "statements") or raw_response.statements is None:
return []
results: List[StatementStepOutput] = []
for stmt in raw_response.statements:
results.append(
StatementStepOutput(
statement_id=uuid.uuid4().hex,
statement_text=stmt.statement,
statement_type=stmt.statement_type.strip().upper(),
temporal_type=stmt.temporal_type.strip().upper(),
relevance=stmt.relevance.strip().upper(),
speaker="user", # default; orchestrator overrides from chunk metadata
valid_at=stmt.valid_at or "NULL",
invalid_at=stmt.invalid_at or "NULL",
)
)
return results
def get_default_output(self) -> List[StatementStepOutput]:
return []

View File

@@ -0,0 +1,118 @@
"""TripletExtractionStep — critical step for extracting entities and triplets.
Replaces the legacy ``TripletExtractor`` with the unified ExtractionStep paradigm.
Predicate filtering against the ontology whitelist is performed in ``parse_response``.
"""
import logging
from typing import Any
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
from .base import ExtractionStep, StepContext
from .schema import EntityItem, TripletItem, TripletStepInput, TripletStepOutput
logger = logging.getLogger(__name__)
class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput]):
"""Extract knowledge triplets and entities from a single statement.
This is a **critical** step — failure aborts the pipeline after retries.
Config params bound at init (from ``StepContext.config``):
* ``ontology_types`` — predefined ontology types for entity classification
* ``predicate_instructions`` — predicate definition guidance for the LLM
* ``json_schema`` — JSON schema for the expected LLM output
"""
def __init__(
self,
context: StepContext,
ontology_types: Any = None,
) -> None:
super().__init__(context)
self.ontology_types = ontology_types
self.predicate_instructions = PREDICATE_DEFINITIONS
self.json_schema = TripletExtractionResponse.model_json_schema()
self._allowed_predicates = {p.value for p in Predicate}
# ── Identity ──
@property
def name(self) -> str:
return "triplet_extraction"
@property
def is_critical(self) -> bool:
return True
# ── Lifecycle ──
async def render_prompt(self, input_data: TripletStepInput) -> str:
# Build chunk_content from supporting_context for pronoun resolution
chunk_content = "\n".join(
f"{m.role}: {m.msg}" for m in input_data.supporting_context.msgs
) if input_data.supporting_context.msgs else ""
return await render_triplet_extraction_prompt(
statement=input_data.statement_text,
chunk_content=chunk_content,
json_schema=self.json_schema,
predicate_instructions=self.predicate_instructions,
language=self.language,
ontology_types=self.ontology_types,
speaker=input_data.speaker,
)
async def call_llm(self, prompt: Any) -> Any:
messages = [
{
"role": "system",
"content": (
"You are an expert at extracting knowledge triplets and entities "
"from text. Follow the provided instructions carefully and return valid JSON."
),
},
{"role": "user", "content": prompt},
]
return await self.llm_client.response_structured(
messages, TripletExtractionResponse
)
async def parse_response(
self, raw_response: Any, input_data: TripletStepInput
) -> TripletStepOutput:
if not hasattr(raw_response, "triplets"):
return self.get_default_output()
# Filter triplets to allowed predicates from ontology whitelist
filtered_triplets = [
TripletItem(
subject_name=t.subject_name,
subject_id=t.subject_id,
predicate=t.predicate,
object_name=t.object_name,
object_id=t.object_id,
)
for t in raw_response.triplets
if getattr(t, "predicate", "") in self._allowed_predicates
]
entities = [
EntityItem(
entity_idx=e.entity_idx,
name=e.name,
type=e.type,
description=e.description,
is_explicit_memory=getattr(e, "is_explicit_memory", False),
)
for e in (raw_response.entities or [])
]
return TripletStepOutput(entities=entities, triplets=filtered_triplets)
def get_default_output(self) -> TripletStepOutput:
return TripletStepOutput(entities=[], triplets=[])

View File

@@ -421,6 +421,9 @@ class MemoryConfig:
pruning_scene: Optional[str] = "education"
pruning_threshold: float = 0.5
# Pipeline config: Emotion extraction
emotion_enabled: bool = False
# Ontology scene association
scene_id: Optional[UUID] = None
ontology_class_infos: list[dict] = field(default_factory=list)

View File

@@ -360,40 +360,64 @@ class MemoryAgentService:
await write_rag(end_user_id, message_text, user_rag_memory_id)
return "success"
else:
await write_neo4j(
end_user_id=end_user_id,
messages=messages,
memory_config=memory_config,
ref_id='',
language=language
)
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
# TODO 乐力齐 重构流水线切换至生产环境后,更改如下代码
import os
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
try:
from app.core.memory.memory_service import MemoryService
import copy
use_new_pipeline = os.getenv("NEW_PIPELINE_ENABLED", "false").lower() == "true"
shadow_messages = copy.deepcopy(messages)
shadow_service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
shadow_result = await shadow_service.write(
messages=shadow_messages,
language=language,
ref_id='',
is_pilot_run=True, # 试运行模式:只萃取不写入,避免重复写入 Neo4j
)
logger.info(
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
f"extraction={shadow_result.extraction}"
)
except Exception as shadow_err:
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
# ── 影子运行结束 ──
if use_new_pipeline:
# ── 新流水线WritePipeline + NewExtractionOrchestrator ──
from app.core.memory.memory_service import MemoryService
service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
result = await service.write(
messages=messages,
language=language,
ref_id='',
is_pilot_run=False,
)
logger.info(
f"[NewPipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, "
f"extraction={result.extraction}"
)
else:
# ── 旧流水线write_tools.write() + ExtractionOrchestrator ──
await write_neo4j(
end_user_id=end_user_id,
messages=messages,
memory_config=memory_config,
ref_id='',
language=language
)
# ── 影子运行:新流水线静默执行,只记录日志不影响主流程 ──
if os.getenv("SHADOW_PIPELINE_ENABLED", "false").lower() == "true":
try:
from app.core.memory.memory_service import MemoryService
import copy
shadow_messages = copy.deepcopy(messages)
shadow_service = MemoryService(
memory_config=memory_config,
end_user_id=end_user_id,
)
shadow_result = await shadow_service.write(
messages=shadow_messages,
language=language,
ref_id='',
is_pilot_run=True,
)
logger.info(
f"[Shadow] 新流水线影子运行完成: status={shadow_result.status}, "
f"elapsed={shadow_result.elapsed_seconds:.2f}s, "
f"extraction={shadow_result.extraction}"
)
except Exception as shadow_err:
logger.warning(f"[Shadow] 新流水线影子运行失败(不影响主流程): {shadow_err}")
# ── 影子运行结束 ──
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id, lang

View File

@@ -418,6 +418,9 @@ class MemoryConfigService:
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Pipeline config: Emotion extraction
emotion_enabled=bool(
memory_config.emotion_enabled) if memory_config.emotion_enabled is not None else False,
# Ontology scene association
scene_id=memory_config.scene_id,
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
@@ -573,6 +576,7 @@ class MemoryConfigService:
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
emotion_enabled=getattr(memory_config, "emotion_enabled", False),
)
@staticmethod