refactor(memory): add PilotWritePipeline and enrich extraction schema

- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write)
- Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders
- Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction
- Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor
- Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing
- Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
lanceyq
2026-04-27 18:15:46 +08:00
parent b0ddd12cc6
commit 2355536b44
23 changed files with 806 additions and 1070 deletions

View File

@@ -272,6 +272,12 @@ class Settings:
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
# Pilot run pipeline switch:
# true -> use refactored PilotWritePipeline
# false -> use legacy ExtractionOrchestrator pipeline
PILOT_RUN_USE_REFACTORED_PIPELINE: bool = (
os.getenv("PILOT_RUN_USE_REFACTORED_PIPELINE", "true").lower() == "true"
)
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")

View File

@@ -9,7 +9,8 @@ async def get_chunked_dialogs(
end_user_id: str = "group_1",
messages: list = None,
ref_id: str = "",
config_id: str = None
config_id: str = None,
snapshot=None,
) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy.
@@ -19,6 +20,7 @@ async def get_chunked_dialogs(
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing (used to load pruning config)
snapshot: Optional PipelineSnapshot instance for saving pruning output
Returns:
List of DialogData objects with generated chunks
@@ -93,7 +95,7 @@ async def get_chunked_dialogs(
llm_client = factory.get_llm_client_from_config(memory_config)
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot)
original_msg_count = len(dialog_data.context.msgs)
# 使用 prune_dataset 而不是 prune_dialog

View File

@@ -184,7 +184,8 @@ async def write(
"entities": [
{
"entity_idx": e.entity_idx, "name": e.name,
"type": e.type, "description": e.description,
"type": e.type, "type_description": getattr(e, "type_description", ""),
"description": e.description,
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
}
for e in s.triplet_extraction_info.entities
@@ -193,6 +194,7 @@ async def write(
{
"subject_name": t.subject_name, "subject_id": t.subject_id,
"predicate": t.predicate,
"predicate_description": getattr(t, "predicate_description", ""),
"object_name": t.object_name, "object_id": t.object_id,
}
for t in s.triplet_extraction_info.triplets
@@ -206,13 +208,13 @@ async def write(
"chunk_nodes_count": len(all_chunk_nodes),
"statement_nodes_count": len(all_statement_nodes),
"entity_nodes": [
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "type_description": e.type_description, "description": e.description}
for e in all_entity_nodes
],
"entity_entity_edges": [
{
"source": e.source, "target": e.target,
"relation_type": e.relation_type, "statement": e.statement,
"relation_type": e.relation_type, "relation_type_description": e.relation_type_description, "statement": e.statement,
}
for e in all_entity_entity_edges
],

View File

@@ -162,6 +162,7 @@ class EntityEntityEdge(Edge):
invalid_at: Optional end date of temporal validity
"""
relation_type: str = Field(..., description="Relation type as defined in ontology")
relation_type_description: str = Field(default="", description="Chinese definition of the relation type from ontology")
relation_value: Optional[str] = Field(None, description="Value of the relation")
statement: str = Field(..., description='The statement of the edge.')
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
@@ -413,6 +414,7 @@ class ExtractedEntityNode(Node):
entity_idx: int = Field(..., description="Unique identifier for the entity")
statement_id: str = Field(..., description="Statement this entity was extracted from")
entity_type: str = Field(..., description="Type of the entity")
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
description: str = Field(..., description="Entity description")
example: str = Field(
default="",

View File

@@ -96,6 +96,10 @@ class Statement(BaseModel):
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
# Reference resolution
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
has_emotional_state: bool = Field(
False,
description="Whether the statement reflects user's emotional state",
)
class ConversationContext(BaseModel):

View File

@@ -37,6 +37,7 @@ class Entity(BaseModel):
name: str = Field(..., description="Name of the entity")
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
type: str = Field(..., description="Type/category of the entity")
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
description: str = Field(..., description="Description of the entity")
example: str = Field(
default="",
@@ -79,6 +80,7 @@ class Triplet(BaseModel):
subject_name: str = Field(..., description="Name of the subject entity")
subject_id: int = Field(..., description="ID of the subject entity")
predicate: str = Field(..., description="Relationship/predicate between subject and object")
predicate_description: str = Field(default="", description="Chinese definition of the predicate from ontology")
object_name: str = Field(..., description="Name of the object entity")
object_id: int = Field(..., description="ID of the object entity")
value: Optional[str] = Field(None, description="Additional value or context")

View File

@@ -14,13 +14,31 @@ def __getattr__(name):
WritePipeline,
WriteResult,
)
_exports = {
"WritePipeline": WritePipeline,
"ExtractionResult": ExtractionResult,
"WriteResult": WriteResult,
}
return _exports[name]
if name in ("PilotWritePipeline", "PilotWriteResult"):
from app.core.memory.pipelines.pilot_write_pipeline import (
PilotWritePipeline,
PilotWriteResult,
)
_exports = {
"PilotWritePipeline": PilotWritePipeline,
"PilotWriteResult": PilotWriteResult,
}
return _exports[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = ["WritePipeline", "ExtractionResult", "WriteResult"]
__all__ = [
"WritePipeline",
"ExtractionResult",
"WriteResult",
"PilotWritePipeline",
"PilotWriteResult",
]

View File

@@ -0,0 +1,108 @@
"""PilotWritePipeline — 试运行专用萃取流水线。
职责边界:
- 只执行“萃取相关”链路statement -> triplet -> graph_build -> 第一层去重消歧
- 不负责 Neo4j 写入、聚类、摘要、缓存更新
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.dedup_step import (
DedupResult,
run_dedup,
)
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
NewExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
GraphBuildResult,
build_graph_nodes_and_edges,
)
@dataclass
class PilotWriteResult:
"""试运行流水线输出。"""
dialog_data_list: List[DialogData]
graph: GraphBuildResult
dedup: DedupResult
@property
def stats(self) -> Dict[str, int]:
return {
"chunk_count": len(self.graph.chunk_nodes),
"statement_count": len(self.graph.statement_nodes),
"entity_count_before_dedup": len(self.graph.entity_nodes),
"entity_count_after_dedup": len(self.dedup.entity_nodes),
"relation_count_before_dedup": len(self.graph.entity_entity_edges),
"relation_count_after_dedup": len(self.dedup.entity_entity_edges),
}
class PilotWritePipeline:
"""重构后试运行专用流水线。"""
def __init__(
self,
llm_client: Any,
embedder_client: Any,
pipeline_config: ExtractionPipelineConfig,
embedding_id: Optional[str],
language: str = "zh",
ontology_types: Any = None,
progress_callback: Optional[
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
] = None,
) -> None:
self.llm_client = llm_client
self.embedder_client = embedder_client
self.pipeline_config = pipeline_config
self.embedding_id = embedding_id
self.language = language
self.ontology_types = ontology_types
self.progress_callback = progress_callback
async def run(self, dialog_data_list: List[DialogData]) -> PilotWriteResult:
"""执行试运行萃取链路。"""
orchestrator = NewExtractionOrchestrator(
llm_client=self.llm_client,
embedder_client=self.embedder_client,
config=self.pipeline_config,
embedding_id=self.embedding_id,
ontology_types=self.ontology_types,
language=self.language,
is_pilot_run=True,
progress_callback=self.progress_callback,
)
extracted_dialogs = await orchestrator.run(dialog_data_list)
graph = await build_graph_nodes_and_edges(
dialog_data_list=extracted_dialogs,
embedder_client=self.embedder_client,
progress_callback=self.progress_callback,
)
dedup = await run_dedup(
entity_nodes=graph.entity_nodes,
statement_entity_edges=graph.stmt_entity_edges,
entity_entity_edges=graph.entity_entity_edges,
dialog_data_list=extracted_dialogs,
pipeline_config=self.pipeline_config,
connector=None, # pilot: no layer-2 db dedup
llm_client=self.llm_client,
is_pilot_run=True,
progress_callback=self.progress_callback,
)
return PilotWriteResult(
dialog_data_list=extracted_dialogs,
graph=graph,
dedup=dedup,
)

View File

@@ -180,7 +180,11 @@ class WritePipeline:
self._init_clients()
self._init_neo4j_connector()
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝暂无实现
# 初始化 Snapshot提前创建供预处理阶段的剪枝使用
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
self._snapshot = PipelineSnapshot("new")
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝
step_start = time.time()
chunked_dialogs = await self._preprocess(messages, ref_id)
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
@@ -220,7 +224,7 @@ class WritePipeline:
)
# Step 3.5: 异步情绪提取fire-and-forget需在 _store 之后确保 Statement 节点已存在)
self._extract_emotion(getattr(self, "_emotion_statements", []))
await self._extract_emotion(getattr(self, "_emotion_statements", []))
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
step_start = time.time()
@@ -266,7 +270,7 @@ class WritePipeline:
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
"""
预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。
预处理:消息校验 → AI消息语义剪枝 → 对话分块。
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
get_dialogs.py 内部已包含:
@@ -276,12 +280,15 @@ class WritePipeline:
"""
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
snapshot = getattr(self, "_snapshot", None)
return await get_chunked_dialogs(
chunker_strategy=self.memory_config.chunker_strategy,
end_user_id=self.end_user_id,
messages=messages,
ref_id=ref_id,
config_id=str(self.memory_config.config_id),
snapshot=snapshot,
)
# ──────────────────────────────────────────────
@@ -321,7 +328,9 @@ class WritePipeline:
pipeline_config = get_pipeline_config(self.memory_config)
ontology_types = self._load_ontology_types()
snapshot = PipelineSnapshot("new")
# 复用 run() 中已创建的 snapshot剪枝阶段已使用同一实例
snapshot = getattr(self, "_snapshot", None) or PipelineSnapshot("new")
self._snapshot = snapshot
# ── 新编排器LLM 萃取 + 数据赋值 ──
new_orchestrator = NewExtractionOrchestrator(
@@ -589,11 +598,15 @@ class WritePipeline:
# fire-and-forget 提交 Celery 任务,不阻塞主流程
# ──────────────────────────────────────────────
def _extract_emotion(self, emotion_statements: list) -> None:
async def _extract_emotion(self, emotion_statements: list) -> None:
"""提交异步情绪提取 Celery 任务。
从编排器收集的 user statement 列表中提取情绪,
异步回写到 Neo4j Statement 节点。失败不影响主流程。
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
通过 snapshot_dir 透传给 Celery 任务worker 端在完成 LLM 抽取后,
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json避免主进程重复调用 LLM。
"""
if not emotion_statements:
return
@@ -607,6 +620,14 @@ class WritePipeline:
logger.warning("[Emotion] 无法提交情绪提取任务llm_model_id 为空")
return
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
snapshot = getattr(self, "_snapshot", None)
snapshot_dir = (
snapshot.directory
if snapshot is not None and getattr(snapshot, "enabled", False)
else None
)
try:
from app.celery_app import celery_app
@@ -616,12 +637,14 @@ class WritePipeline:
"statements": emotion_statements,
"llm_model_id": llm_model_id,
"language": self.language,
"snapshot_dir": snapshot_dir,
},
)
logger.info(
f"[Emotion] 异步情绪提取任务已提交 - "
f"task_id={result.id}, "
f"statement_count={len(emotion_statements)}, "
f"snapshot_dir={snapshot_dir}, "
f"source=async"
)
except Exception as e:
@@ -629,6 +652,7 @@ class WritePipeline:
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
exc_info=True,
)
# ──────────────────────────────────────────────
# Step 5: 摘要
# + entity_description+ meta_data部分在此提取

View File

@@ -1264,6 +1264,7 @@ class ExtractionOrchestrator:
entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx
statement_id=statement.id, # 添加必需的 statement_id 字段
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
type_description=getattr(entity, 'type_description', ''),
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
@@ -1306,6 +1307,7 @@ class ExtractionOrchestrator:
source=subject_entity_id,
target=object_entity_id,
relation_type=triplet.predicate,
relation_type_description=getattr(triplet, 'predicate_description', ''),
statement=statement.statement,
source_statement_id=statement.id,
end_user_id=dialog_data.end_user_id,

View File

@@ -12,16 +12,21 @@ from app.core.memory.utils.data.ontology import (
TemporalInfo,
)
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
from pydantic import BaseModel, Field, field_validator
from pydantic import AliasChoices, BaseModel, Field, field_validator
logger = logging.getLogger(__name__)
class ExtractedStatement(BaseModel):
"""Schema for extracted statement from LLM"""
statement: str = Field(..., description="The extracted statement text")
statement: str = Field(
...,
validation_alias=AliasChoices("statement", "statement_text"),
description="The extracted statement text",
)
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
# New prompt no longer outputs relevence; keep backward-compatible default.
relevence: str = Field("RELEVANT", description="RELEVANT or IRRELEVANT")
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
class StatementExtractionResponse(BaseModel):
@@ -41,7 +46,7 @@ class StatementExtractionResponse(BaseModel):
valid_statements = []
filtered_count = 0
for i, stmt in enumerate(v):
if isinstance(stmt, dict) and stmt.get('statement'):
if isinstance(stmt, dict) and (stmt.get("statement") or stmt.get("statement_text")):
valid_statements.append(stmt)
elif isinstance(stmt, dict):
# Log which statement was filtered
@@ -96,6 +101,11 @@ class StatementExtractor:
"""
chunk_content = chunk.content
chunk_speaker = self._get_speaker_from_chunk(chunk)
logger.info(
"[LegacyStatementExtractor] chunk_id=%s content_len=%d",
getattr(chunk, "id", ""),
len(chunk_content or ""),
)
if not chunk_content or len(chunk_content.strip()) < 5:
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
@@ -108,7 +118,18 @@ class StatementExtractor:
granularity=self.config.statement_granularity,
include_dialogue_context=self.config.include_dialogue_context,
dialogue_content=dialogue_content,
max_dialogue_chars=self.config.max_dialogue_context_chars
max_dialogue_chars=self.config.max_dialogue_context_chars,
input_json={
"chunk_id": getattr(chunk, "id", ""),
"end_user_id": end_user_id or "",
"target_content": chunk_content,
"target_message_date": datetime.now().isoformat(),
"supporting_context": {
"msgs": [
{"role": "context", "msg": dialogue_content}
] if dialogue_content else []
},
},
)
# Simple system message

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
@@ -73,15 +73,9 @@ class TripletExtractor:
try:
# Get structured response from LLM
response = await self.llm_client.response_structured(messages, TripletExtractionResponse)
# Filter triplets to only allowed predicates from ontology
# 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系
allowed_predicates = {p.value for p in Predicate}
filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates]
# 仅保留predicate ∈ Predicate 的三元组,其余全部剔除
# Create new triplets with statement_id set during creation
updated_triplets = []
for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组
for triplet in response.triplets:
updated_triplet = triplet.model_copy(update={"statement_id": statement.id})
updated_triplets.append(updated_triplet)

View File

@@ -300,6 +300,33 @@ class NewExtractionOrchestrator:
"embedding_output": None,
}
if self.progress_callback:
statements_count = sum(
len(stmts)
for chunk_stmts in all_stmt_results.values()
for stmts in chunk_stmts.values()
)
entities_count = sum(
len(t_out.entities)
for stmt_triplets in all_triplet_results.values()
for t_out in stmt_triplets.values()
)
triplets_count = sum(
len(t_out.triplets)
for stmt_triplets in all_triplet_results.values()
for t_out in stmt_triplets.values()
)
await self.progress_callback(
"knowledge_extraction_complete",
"知识抽取完成",
{
"entities_count": entities_count,
"statements_count": statements_count,
"temporal_ranges_count": 0,
"triplets_count": triplets_count,
},
)
logger.info("Pilot extraction complete")
return dialog_data_list
@@ -467,6 +494,11 @@ class NewExtractionOrchestrator:
else None
)
for chunk in dialog.chunks:
# 仅对 speaker="user" 的 chunk 进行陈述句抽取assistant 内容交给
# 上游预处理/剪枝阶段处理,避免浪费 LLM 调用。
chunk_speaker = getattr(chunk, "speaker", "user")
if chunk_speaker != "user":
continue
inp = StatementStepInput(
chunk_id=chunk.id,
end_user_id=dialog.end_user_id,
@@ -478,7 +510,7 @@ class NewExtractionOrchestrator:
)
tasks.append(self.statement_step.run(inp))
task_meta.append(
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
(dialog.id, chunk.id, chunk_speaker, ctx)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -499,6 +531,15 @@ class NewExtractionOrchestrator:
for s in stmts:
s.speaker = speaker
stmt_map[dialog_id][chunk_id] = stmts
if self.progress_callback:
# Frontend consumes knowledge_extraction_result with data.statement.
# Emit one event per statement to keep payload contract simple.
for s in stmts:
await self.progress_callback(
"knowledge_extraction_result",
"知识抽取中",
{"statement": s.statement_text},
)
return stmt_map
@@ -520,6 +561,11 @@ class NewExtractionOrchestrator:
chunk_stmts = all_stmt_results.get(dialog.id, {})
for _chunk_id, stmts in chunk_stmts.items():
for stmt in stmts:
# 防御性过滤:三元组抽取仅针对 user statement。
# 上游 _extract_all_statements 已过滤 chunk.speaker此处再做
# 一次 statement.speaker 的二次校验,防止外部注入或 legacy 数据脱漏。
if getattr(stmt, "speaker", "user") != "user":
continue
inp = self._convert_to_triplet_input(stmt, ctx)
tasks.append(self.triplet_step.run(inp))
task_meta.append((dialog.id, stmt.statement_id))
@@ -541,6 +587,24 @@ class NewExtractionOrchestrator:
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
else:
triplet_map[dialog_id][stmt_id] = result
if self.progress_callback:
await self.progress_callback(
"extract_triplet_result",
f"statement {stmt_id} 提取完成",
{
"statement_id": stmt_id,
"triplet_count": len(result.triplets),
"entity_count": len(result.entities),
"triplets": [
{
"subject_name": t.subject_name,
"predicate": t.predicate,
"object_name": t.object_name,
}
for t in result.triplets[:5]
],
},
)
return triplet_map
@@ -842,6 +906,8 @@ class NewExtractionOrchestrator:
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
# relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
has_unsolved_reference=stmt_out.has_unsolved_reference,
has_emotional_state=stmt_out.has_emotional_state,
triplet_extraction_info=triplet_info,
statement_embedding=stmt_embedding,
**emotion_kwargs,

View File

@@ -250,6 +250,7 @@ async def build_graph_nodes_and_edges(
entity_idx=entity.entity_idx,
statement_id=statement.id,
entity_type=getattr(entity, "type", "unknown"),
type_description=getattr(entity, "type_description", ""),
description=getattr(entity, "description", ""),
example=getattr(entity, "example", ""),
connect_strength=(
@@ -296,6 +297,7 @@ async def build_graph_nodes_and_edges(
source=subject_entity_id,
target=object_entity_id,
relation_type=triplet.predicate,
relation_type_description=getattr(triplet, "predicate_description", ""),
statement=statement.statement,
source_statement_id=statement.id,
end_user_id=dialog_data.end_user_id,

View File

@@ -46,6 +46,7 @@ class StatementStepOutput(BaseModel):
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
# relevance: str # RELEVANT / IRRELEVANT
speaker: str # "user" / "assistant"
has_emotional_state: bool = False # Whether statement reflects user's emotional state
valid_at: str # ISO 8601 or "NULL"
invalid_at: str # ISO 8601 or "NULL"
has_unsolved_reference: bool = False # Whether the statement has unresolved references
@@ -72,6 +73,7 @@ class EntityItem(BaseModel):
entity_idx: int
name: str
type: str
type_description: str = ""
description: str
is_explicit_memory: bool = False
@@ -82,6 +84,7 @@ class TripletItem(BaseModel):
subject_name: str
subject_id: int
predicate: str
predicate_description: str = ""
object_name: str
object_id: int

View File

@@ -34,6 +34,10 @@ class _ExtractedStatement(BaseModel):
statement_type: str = Field(..., description="FACT / OPINION / OTHER")
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
# relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
has_emotional_state: bool = Field(
False,
description="Whether the statement reflects user's emotional state",
)
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
@@ -155,6 +159,7 @@ class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementS
temporal_type=stmt.temporal_type.strip().upper(),
# relevance=stmt.relevance.strip().upper(),
speaker="user", # default; orchestrator overrides from chunk metadata
has_emotional_state=getattr(stmt, "has_emotional_state", False),
valid_at=stmt.valid_at or "NULL",
invalid_at=stmt.invalid_at or "NULL",
has_unsolved_reference=getattr(stmt, "has_unsolved_reference", False),

View File

@@ -112,6 +112,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
subject_name=t.subject_name,
subject_id=t.subject_id,
predicate=t.predicate,
predicate_description=getattr(t, "predicate_description", ""),
object_name=t.object_name,
object_id=t.object_id,
)
@@ -123,6 +124,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
entity_idx=e.entity_idx,
name=e.name,
type=e.type,
type_description=getattr(e, "type_description", ""),
description=e.description,
is_explicit_memory=getattr(e, "is_explicit_memory", False),
)

View File

@@ -92,6 +92,7 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
THEN entity.expired_at ELSE e.expired_at END,
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END,
e.description = CASE
WHEN entity.description IS NOT NULL AND entity.description <> ''
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
@@ -147,6 +148,7 @@ MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
// Avoid duplicate edges across runs for the same endpoints
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
SET r.predicate = rel.predicate,
r.predicate_description = rel.predicate_description,
r.statement_id = rel.statement_id,
r.value = rel.value,
r.statement = rel.statement,

View File

@@ -44,6 +44,7 @@ async def save_entities_and_relationships(
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'predicate_description': edge.relation_type_description,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
@@ -297,6 +298,7 @@ async def save_dialog_and_statements_to_neo4j(
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'predicate_description': edge.relation_type_description,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,

View File

@@ -441,21 +441,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
# 步骤 6: 计算本体覆盖率并合并到结果中
# 步骤 6: 组装结果(试运行不做额外覆盖率后处理)
result_data = {
"config_id": cid,
"time_log": os.path.join(project_root, "logs", "time.log"),
"extracted_result": extracted_result,
}
try:
ontology_coverage = await self._compute_ontology_coverage(
extracted_result=extracted_result,
memory_config=memory_config,
)
if ontology_coverage:
result_data["ontology_coverage"] = ontology_coverage
except Exception as cov_err:
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
yield format_sse_message("result", result_data)
@@ -479,100 +470,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000)
})
async def _compute_ontology_coverage(
self,
extracted_result: Dict[str, Any],
memory_config,
) -> Optional[Dict[str, Any]]:
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
分类规则(互斥):场景类型优先 > 通用类型 > 未匹配
确保: 场景实体数 + 通用实体数 + 未匹配数 = 总实体数
Returns:
包含三部分统计的字典,或 None无实体数据时
"""
core_entities = extracted_result.get("core_entities", [])
if not core_entities:
return None
# 1. 加载场景本体类型集合
scene_ontology_types: set = set()
try:
from app.repositories.ontology_class_repository import OntologyClassRepository
if memory_config.scene_id:
class_repo = OntologyClassRepository(self.db)
ontology_classes = class_repo.get_classes_by_scene(memory_config.scene_id)
scene_ontology_types = {oc.class_name for oc in ontology_classes}
except Exception as e:
logger.warning(f"Failed to load scene ontology types: {e}")
# 2. 加载通用本体类型集合
general_ontology_types: set = set()
try:
from app.core.memory.ontology_services.ontology_type_loader import (
get_general_ontology_registry,
is_general_ontology_enabled,
)
if is_general_ontology_enabled():
registry = get_general_ontology_registry()
if registry:
general_ontology_types = set(registry.types.keys())
except Exception as e:
logger.warning(f"Failed to load general ontology types: {e}")
# 3. 互斥分类:场景优先 > 通用 > 未匹配
scene_distribution: list = []
general_distribution: list = []
unmatched_distribution: list = []
scene_total = 0
general_total = 0
unmatched_total = 0
for item in core_entities:
entity_type = item.get("type", "")
count = item.get("count", 0)
if entity_type in scene_ontology_types:
scene_distribution.append({"type": entity_type, "count": count})
scene_total += count
elif entity_type in general_ontology_types:
general_distribution.append({"type": entity_type, "count": count})
general_total += count
else:
unmatched_distribution.append({"type": entity_type, "count": count})
unmatched_total += count
# 按数量降序排列
scene_distribution.sort(key=lambda x: x["count"], reverse=True)
general_distribution.sort(key=lambda x: x["count"], reverse=True)
unmatched_distribution.sort(key=lambda x: x["count"], reverse=True)
total_entities = scene_total + general_total + unmatched_total
return {
"scene_type_distribution": {
"type_count": len(scene_distribution),
"entity_total": scene_total,
"types": scene_distribution,
},
"general_type_distribution": {
"type_count": len(general_distribution),
"entity_total": general_total,
"types": general_distribution,
},
"unmatched": {
"type_count": len(unmatched_distribution),
"entity_total": unmatched_total,
"types": unmatched_distribution,
},
"total_entities": total_entities,
"time": int(time.time() * 1000),
}
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD)

View File

@@ -10,7 +10,9 @@ import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.config import settings
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
@@ -20,9 +22,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
@@ -31,6 +35,42 @@ from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
"""Write triplet/entity text report compatible with pipeline_help parsers."""
all_triplets = []
all_entities = []
for dialog in dialog_data_list:
for chunk in getattr(dialog, "chunks", []) or []:
for statement in getattr(chunk, "statements", []) or []:
triplet_info = getattr(statement, "triplet_extraction_info", None)
if not triplet_info:
continue
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
all_entities.extend(getattr(triplet_info, "entities", []) or [])
with open(output_path, "w", encoding="utf-8") as f:
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
for i, triplet in enumerate(all_triplets, 1):
f.write(f"Triplet {i}:\n")
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
f.write(f" Predicate: {triplet.predicate}\n")
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
value = getattr(triplet, "value", None)
if value:
f.write(f" Value: {value}\n")
f.write("\n")
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
for i, entity in enumerate(all_entities, 1):
f.write(f"Entity {i}:\n")
f.write(f" ID: {entity.entity_idx}\n")
f.write(f" Name: {entity.name}\n")
f.write(f" Type: {entity.type}\n")
f.write(f" Description: {entity.description}\n")
f.write("\n")
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
@@ -58,7 +98,6 @@ async def run_pilot_extraction(
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
@@ -69,8 +108,6 @@ async def run_pilot_extraction(
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
@@ -242,15 +279,17 @@ async def run_pilot_extraction(
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
use_refactored,
)
logger.info(
"Initializing %s pilot pipeline...",
"refactored" if use_refactored else "legacy",
)
step_start = time.time()
# 加载本体类型(如果配置了 scene_id支持通用类型回退
ontology_types = None
@@ -266,100 +305,105 @@ async def run_pilot_extraction(
except Exception as e:
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
if use_refactored:
pilot_pipeline = PilotWritePipeline(
llm_client=llm_client,
embedder_client=embedder_client,
pipeline_config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4a: 执行重构后试运行短链路
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
logger.info("Running refactored pilot extraction short pipeline...")
step_start = time.time()
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
pilot_result = await pilot_pipeline.run(chunked_dialogs)
dialog_data_list = pilot_result.dialog_data_list
graph = pilot_result.graph
chunk_nodes = graph.chunk_nodes
export_entity_nodes = graph.entity_nodes
export_stmt_entity_edges = graph.stmt_entity_edges
export_entity_edges = graph.entity_entity_edges
else:
# 步骤 4b: 执行旧试运行流水线
logger.info("Running legacy pilot extraction pipeline...")
step_start = time.time()
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
_,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
_,
_
) = extraction_result
neo4j_connector = Neo4jConnector()
try:
legacy_orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=get_pipeline_config(memory_config),
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
extraction_result = await legacy_orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
(
_dialogue_nodes,
chunk_nodes,
_statement_nodes,
entity_nodes,
_perceptual_nodes,
_statement_chunk_edges,
statement_entity_edges,
entity_edges,
_perceptual_edges,
_last_created_at,
) = extraction_result
dialog_data_list = chunked_dialogs
export_entity_nodes = entity_nodes
export_stmt_entity_edges = statement_entity_edges
export_entity_edges = entity_edges
finally:
try:
await neo4j_connector.close()
except Exception:
pass
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致
try:
logger.info("Generating memory summaries...")
step_start = time.time()
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约
settings.ensure_memory_output_dir()
export_test_input_doc(
entity_nodes=export_entity_nodes,
statement_entity_edges=export_stmt_entity_edges,
entity_entity_edges=export_entity_edges,
)
_save_triplets_from_dialogs(
dialog_data_list=dialog_data_list,
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
)
_write_extracted_result_summary(
chunk_nodes=chunk_nodes,
pipeline_output_dir=settings.get_memory_output_path(),
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
language=language,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
# 将提取统计写入 Redis按 workspace_id 存储
try:
from app.cache.memory.activity_stats_cache import ActivityStatsCache
stats_to_cache = {
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
"statements_count": len(statement_nodes) if statement_nodes else 0,
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
"temporal_count": 0, # temporal 数据在日志中此处暂置0
}
await ActivityStatsCache.set_activity_stats(
workspace_id=str(memory_config.workspace_id),
stats=stats_to_cache,
)
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
except Exception as cache_err:
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)

View File

@@ -1382,6 +1382,7 @@ def extract_emotion_batch_task(
llm_model_id: str,
language: str = "zh",
emotion_config: Optional[Dict[str, Any]] = None,
snapshot_dir: Optional[str] = None,
) -> Dict[str, Any]:
"""Celery task: batch emotion extraction + Neo4j backfill.
@@ -1395,6 +1396,10 @@ def extract_emotion_batch_task(
language: Language code ("zh" / "en").
emotion_config: Optional dict with emotion step config overrides
(emotion_extract_keywords, emotion_enable_subject).
snapshot_dir: Optional absolute path of the current run's snapshot directory.
When provided (only in debug mode), emotion outputs will be
dumped to <snapshot_dir>/4_emotion_outputs.json for offline
comparison between the legacy / new pipelines.
"""
task_id = self.request.id
total = len(statements)
@@ -1445,6 +1450,8 @@ def extract_emotion_batch_task(
extracted = 0
failed = 0
update_items = []
# 快照用:收集每条 statement 的 EmotionStepOutput仅当 snapshot_dir 非空时使用)
snapshot_outputs: Dict[str, Any] = {} if snapshot_dir else None # type: ignore[assignment]
async def _extract_one(stmt_dict: Dict[str, str]):
nonlocal extracted, failed
@@ -1461,6 +1468,8 @@ def extract_emotion_batch_task(
"emotion_intensity": result.emotion_intensity,
"emotion_keywords": result.emotion_keywords,
})
if snapshot_outputs is not None:
snapshot_outputs[stmt_dict["statement_id"]] = result.model_dump()
extracted += 1
logger.debug(
f"[Emotion] 单条提取完成: stmt={stmt_dict['statement_id']}, "
@@ -1468,12 +1477,33 @@ def extract_emotion_batch_task(
)
except Exception as e:
failed += 1
if snapshot_outputs is not None:
snapshot_outputs[stmt_dict["statement_id"]] = {"error": str(e)}
logger.warning(
f"[Emotion] 单条提取失败 stmt={stmt_dict['statement_id']}: {e}"
)
await asyncio.gather(*[_extract_one(s) for s in statements])
# 快照落盘worker 端):不影响 Neo4j 写入流程,失败只打日志
if snapshot_outputs is not None:
try:
from pathlib import Path as _Path
import json as _json
_dir = _Path(snapshot_dir)
_dir.mkdir(parents=True, exist_ok=True)
_path = _dir / "4_emotion_outputs.json"
with open(_path, "w", encoding="utf-8") as _f:
_json.dump(snapshot_outputs, _f, ensure_ascii=False, indent=2, default=str)
logger.info(
f"[Emotion][Snapshot] 已落盘 {len(snapshot_outputs)} 条情绪结果 → {_path}"
)
except Exception as _e:
logger.warning(
f"[Emotion][Snapshot] 快照落盘失败(不影响主流程): {_e}"
)
# Batch update Neo4j via write transaction
if update_items:
connector = Neo4jConnector()