refactor(memory): add PilotWritePipeline and enrich extraction schema

- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write)
- Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders
- Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction
- Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor
- Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing
- Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
lanceyq
2026-04-27 18:15:46 +08:00
parent b0ddd12cc6
commit 2355536b44
23 changed files with 806 additions and 1070 deletions

View File

@@ -9,7 +9,8 @@ async def get_chunked_dialogs(
end_user_id: str = "group_1",
messages: list = None,
ref_id: str = "",
config_id: str = None
config_id: str = None,
snapshot=None,
) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy.
@@ -19,6 +20,7 @@ async def get_chunked_dialogs(
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing (used to load pruning config)
snapshot: Optional PipelineSnapshot instance for saving pruning output
Returns:
List of DialogData objects with generated chunks
@@ -93,7 +95,7 @@ async def get_chunked_dialogs(
llm_client = factory.get_llm_client_from_config(memory_config)
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot)
original_msg_count = len(dialog_data.context.msgs)
# 使用 prune_dataset 而不是 prune_dialog

View File

@@ -184,7 +184,8 @@ async def write(
"entities": [
{
"entity_idx": e.entity_idx, "name": e.name,
"type": e.type, "description": e.description,
"type": e.type, "type_description": getattr(e, "type_description", ""),
"description": e.description,
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
}
for e in s.triplet_extraction_info.entities
@@ -193,6 +194,7 @@ async def write(
{
"subject_name": t.subject_name, "subject_id": t.subject_id,
"predicate": t.predicate,
"predicate_description": getattr(t, "predicate_description", ""),
"object_name": t.object_name, "object_id": t.object_id,
}
for t in s.triplet_extraction_info.triplets
@@ -206,13 +208,13 @@ async def write(
"chunk_nodes_count": len(all_chunk_nodes),
"statement_nodes_count": len(all_statement_nodes),
"entity_nodes": [
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "type_description": e.type_description, "description": e.description}
for e in all_entity_nodes
],
"entity_entity_edges": [
{
"source": e.source, "target": e.target,
"relation_type": e.relation_type, "statement": e.statement,
"relation_type": e.relation_type, "relation_type_description": e.relation_type_description, "statement": e.statement,
}
for e in all_entity_entity_edges
],

View File

@@ -162,6 +162,7 @@ class EntityEntityEdge(Edge):
invalid_at: Optional end date of temporal validity
"""
relation_type: str = Field(..., description="Relation type as defined in ontology")
relation_type_description: str = Field(default="", description="Chinese definition of the relation type from ontology")
relation_value: Optional[str] = Field(None, description="Value of the relation")
statement: str = Field(..., description='The statement of the edge.')
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
@@ -413,6 +414,7 @@ class ExtractedEntityNode(Node):
entity_idx: int = Field(..., description="Unique identifier for the entity")
statement_id: str = Field(..., description="Statement this entity was extracted from")
entity_type: str = Field(..., description="Type of the entity")
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
description: str = Field(..., description="Entity description")
example: str = Field(
default="",

View File

@@ -96,6 +96,10 @@ class Statement(BaseModel):
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
# Reference resolution
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
has_emotional_state: bool = Field(
False,
description="Whether the statement reflects user's emotional state",
)
class ConversationContext(BaseModel):

View File

@@ -37,6 +37,7 @@ class Entity(BaseModel):
name: str = Field(..., description="Name of the entity")
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
type: str = Field(..., description="Type/category of the entity")
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
description: str = Field(..., description="Description of the entity")
example: str = Field(
default="",
@@ -79,6 +80,7 @@ class Triplet(BaseModel):
subject_name: str = Field(..., description="Name of the subject entity")
subject_id: int = Field(..., description="ID of the subject entity")
predicate: str = Field(..., description="Relationship/predicate between subject and object")
predicate_description: str = Field(default="", description="Chinese definition of the predicate from ontology")
object_name: str = Field(..., description="Name of the object entity")
object_id: int = Field(..., description="ID of the object entity")
value: Optional[str] = Field(None, description="Additional value or context")

View File

@@ -14,13 +14,31 @@ def __getattr__(name):
WritePipeline,
WriteResult,
)
_exports = {
"WritePipeline": WritePipeline,
"ExtractionResult": ExtractionResult,
"WriteResult": WriteResult,
}
return _exports[name]
if name in ("PilotWritePipeline", "PilotWriteResult"):
from app.core.memory.pipelines.pilot_write_pipeline import (
PilotWritePipeline,
PilotWriteResult,
)
_exports = {
"PilotWritePipeline": PilotWritePipeline,
"PilotWriteResult": PilotWriteResult,
}
return _exports[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = ["WritePipeline", "ExtractionResult", "WriteResult"]
__all__ = [
"WritePipeline",
"ExtractionResult",
"WriteResult",
"PilotWritePipeline",
"PilotWriteResult",
]

View File

@@ -0,0 +1,108 @@
"""PilotWritePipeline — 试运行专用萃取流水线。
职责边界:
- 只执行“萃取相关”链路statement -> triplet -> graph_build -> 第一层去重消歧
- 不负责 Neo4j 写入、聚类、摘要、缓存更新
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.dedup_step import (
DedupResult,
run_dedup,
)
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
NewExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
GraphBuildResult,
build_graph_nodes_and_edges,
)
@dataclass
class PilotWriteResult:
"""试运行流水线输出。"""
dialog_data_list: List[DialogData]
graph: GraphBuildResult
dedup: DedupResult
@property
def stats(self) -> Dict[str, int]:
return {
"chunk_count": len(self.graph.chunk_nodes),
"statement_count": len(self.graph.statement_nodes),
"entity_count_before_dedup": len(self.graph.entity_nodes),
"entity_count_after_dedup": len(self.dedup.entity_nodes),
"relation_count_before_dedup": len(self.graph.entity_entity_edges),
"relation_count_after_dedup": len(self.dedup.entity_entity_edges),
}
class PilotWritePipeline:
"""重构后试运行专用流水线。"""
def __init__(
self,
llm_client: Any,
embedder_client: Any,
pipeline_config: ExtractionPipelineConfig,
embedding_id: Optional[str],
language: str = "zh",
ontology_types: Any = None,
progress_callback: Optional[
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
] = None,
) -> None:
self.llm_client = llm_client
self.embedder_client = embedder_client
self.pipeline_config = pipeline_config
self.embedding_id = embedding_id
self.language = language
self.ontology_types = ontology_types
self.progress_callback = progress_callback
async def run(self, dialog_data_list: List[DialogData]) -> PilotWriteResult:
"""执行试运行萃取链路。"""
orchestrator = NewExtractionOrchestrator(
llm_client=self.llm_client,
embedder_client=self.embedder_client,
config=self.pipeline_config,
embedding_id=self.embedding_id,
ontology_types=self.ontology_types,
language=self.language,
is_pilot_run=True,
progress_callback=self.progress_callback,
)
extracted_dialogs = await orchestrator.run(dialog_data_list)
graph = await build_graph_nodes_and_edges(
dialog_data_list=extracted_dialogs,
embedder_client=self.embedder_client,
progress_callback=self.progress_callback,
)
dedup = await run_dedup(
entity_nodes=graph.entity_nodes,
statement_entity_edges=graph.stmt_entity_edges,
entity_entity_edges=graph.entity_entity_edges,
dialog_data_list=extracted_dialogs,
pipeline_config=self.pipeline_config,
connector=None, # pilot: no layer-2 db dedup
llm_client=self.llm_client,
is_pilot_run=True,
progress_callback=self.progress_callback,
)
return PilotWriteResult(
dialog_data_list=extracted_dialogs,
graph=graph,
dedup=dedup,
)

View File

@@ -180,7 +180,11 @@ class WritePipeline:
self._init_clients()
self._init_neo4j_connector()
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝暂无实现
# 初始化 Snapshot提前创建供预处理阶段的剪枝使用
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
self._snapshot = PipelineSnapshot("new")
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝
step_start = time.time()
chunked_dialogs = await self._preprocess(messages, ref_id)
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
@@ -220,7 +224,7 @@ class WritePipeline:
)
# Step 3.5: 异步情绪提取fire-and-forget需在 _store 之后确保 Statement 节点已存在)
self._extract_emotion(getattr(self, "_emotion_statements", []))
await self._extract_emotion(getattr(self, "_emotion_statements", []))
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
step_start = time.time()
@@ -266,7 +270,7 @@ class WritePipeline:
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
"""
预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。
预处理:消息校验 → AI消息语义剪枝 → 对话分块。
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
get_dialogs.py 内部已包含:
@@ -276,12 +280,15 @@ class WritePipeline:
"""
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
snapshot = getattr(self, "_snapshot", None)
return await get_chunked_dialogs(
chunker_strategy=self.memory_config.chunker_strategy,
end_user_id=self.end_user_id,
messages=messages,
ref_id=ref_id,
config_id=str(self.memory_config.config_id),
snapshot=snapshot,
)
# ──────────────────────────────────────────────
@@ -321,7 +328,9 @@ class WritePipeline:
pipeline_config = get_pipeline_config(self.memory_config)
ontology_types = self._load_ontology_types()
snapshot = PipelineSnapshot("new")
# 复用 run() 中已创建的 snapshot剪枝阶段已使用同一实例
snapshot = getattr(self, "_snapshot", None) or PipelineSnapshot("new")
self._snapshot = snapshot
# ── 新编排器LLM 萃取 + 数据赋值 ──
new_orchestrator = NewExtractionOrchestrator(
@@ -589,11 +598,15 @@ class WritePipeline:
# fire-and-forget 提交 Celery 任务,不阻塞主流程
# ──────────────────────────────────────────────
def _extract_emotion(self, emotion_statements: list) -> None:
async def _extract_emotion(self, emotion_statements: list) -> None:
"""提交异步情绪提取 Celery 任务。
从编排器收集的 user statement 列表中提取情绪,
异步回写到 Neo4j Statement 节点。失败不影响主流程。
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
通过 snapshot_dir 透传给 Celery 任务worker 端在完成 LLM 抽取后,
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json避免主进程重复调用 LLM。
"""
if not emotion_statements:
return
@@ -607,6 +620,14 @@ class WritePipeline:
logger.warning("[Emotion] 无法提交情绪提取任务llm_model_id 为空")
return
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
snapshot = getattr(self, "_snapshot", None)
snapshot_dir = (
snapshot.directory
if snapshot is not None and getattr(snapshot, "enabled", False)
else None
)
try:
from app.celery_app import celery_app
@@ -616,12 +637,14 @@ class WritePipeline:
"statements": emotion_statements,
"llm_model_id": llm_model_id,
"language": self.language,
"snapshot_dir": snapshot_dir,
},
)
logger.info(
f"[Emotion] 异步情绪提取任务已提交 - "
f"task_id={result.id}, "
f"statement_count={len(emotion_statements)}, "
f"snapshot_dir={snapshot_dir}, "
f"source=async"
)
except Exception as e:
@@ -629,6 +652,7 @@ class WritePipeline:
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
exc_info=True,
)
# ──────────────────────────────────────────────
# Step 5: 摘要
# + entity_description+ meta_data部分在此提取

View File

@@ -1264,6 +1264,7 @@ class ExtractionOrchestrator:
entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx
statement_id=statement.id, # 添加必需的 statement_id 字段
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
type_description=getattr(entity, 'type_description', ''),
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
@@ -1306,6 +1307,7 @@ class ExtractionOrchestrator:
source=subject_entity_id,
target=object_entity_id,
relation_type=triplet.predicate,
relation_type_description=getattr(triplet, 'predicate_description', ''),
statement=statement.statement,
source_statement_id=statement.id,
end_user_id=dialog_data.end_user_id,

View File

@@ -12,16 +12,21 @@ from app.core.memory.utils.data.ontology import (
TemporalInfo,
)
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
from pydantic import BaseModel, Field, field_validator
from pydantic import AliasChoices, BaseModel, Field, field_validator
logger = logging.getLogger(__name__)
class ExtractedStatement(BaseModel):
"""Schema for extracted statement from LLM"""
statement: str = Field(..., description="The extracted statement text")
statement: str = Field(
...,
validation_alias=AliasChoices("statement", "statement_text"),
description="The extracted statement text",
)
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
# New prompt no longer outputs relevence; keep backward-compatible default.
relevence: str = Field("RELEVANT", description="RELEVANT or IRRELEVANT")
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
class StatementExtractionResponse(BaseModel):
@@ -41,7 +46,7 @@ class StatementExtractionResponse(BaseModel):
valid_statements = []
filtered_count = 0
for i, stmt in enumerate(v):
if isinstance(stmt, dict) and stmt.get('statement'):
if isinstance(stmt, dict) and (stmt.get("statement") or stmt.get("statement_text")):
valid_statements.append(stmt)
elif isinstance(stmt, dict):
# Log which statement was filtered
@@ -96,6 +101,11 @@ class StatementExtractor:
"""
chunk_content = chunk.content
chunk_speaker = self._get_speaker_from_chunk(chunk)
logger.info(
"[LegacyStatementExtractor] chunk_id=%s content_len=%d",
getattr(chunk, "id", ""),
len(chunk_content or ""),
)
if not chunk_content or len(chunk_content.strip()) < 5:
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
@@ -108,7 +118,18 @@ class StatementExtractor:
granularity=self.config.statement_granularity,
include_dialogue_context=self.config.include_dialogue_context,
dialogue_content=dialogue_content,
max_dialogue_chars=self.config.max_dialogue_context_chars
max_dialogue_chars=self.config.max_dialogue_context_chars,
input_json={
"chunk_id": getattr(chunk, "id", ""),
"end_user_id": end_user_id or "",
"target_content": chunk_content,
"target_message_date": datetime.now().isoformat(),
"supporting_context": {
"msgs": [
{"role": "context", "msg": dialogue_content}
] if dialogue_content else []
},
},
)
# Simple system message

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
@@ -73,15 +73,9 @@ class TripletExtractor:
try:
# Get structured response from LLM
response = await self.llm_client.response_structured(messages, TripletExtractionResponse)
# Filter triplets to only allowed predicates from ontology
# 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系
allowed_predicates = {p.value for p in Predicate}
filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates]
# 仅保留predicate ∈ Predicate 的三元组,其余全部剔除
# Create new triplets with statement_id set during creation
updated_triplets = []
for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组
for triplet in response.triplets:
updated_triplet = triplet.model_copy(update={"statement_id": statement.id})
updated_triplets.append(updated_triplet)

View File

@@ -300,6 +300,33 @@ class NewExtractionOrchestrator:
"embedding_output": None,
}
if self.progress_callback:
statements_count = sum(
len(stmts)
for chunk_stmts in all_stmt_results.values()
for stmts in chunk_stmts.values()
)
entities_count = sum(
len(t_out.entities)
for stmt_triplets in all_triplet_results.values()
for t_out in stmt_triplets.values()
)
triplets_count = sum(
len(t_out.triplets)
for stmt_triplets in all_triplet_results.values()
for t_out in stmt_triplets.values()
)
await self.progress_callback(
"knowledge_extraction_complete",
"知识抽取完成",
{
"entities_count": entities_count,
"statements_count": statements_count,
"temporal_ranges_count": 0,
"triplets_count": triplets_count,
},
)
logger.info("Pilot extraction complete")
return dialog_data_list
@@ -467,6 +494,11 @@ class NewExtractionOrchestrator:
else None
)
for chunk in dialog.chunks:
# 仅对 speaker="user" 的 chunk 进行陈述句抽取assistant 内容交给
# 上游预处理/剪枝阶段处理,避免浪费 LLM 调用。
chunk_speaker = getattr(chunk, "speaker", "user")
if chunk_speaker != "user":
continue
inp = StatementStepInput(
chunk_id=chunk.id,
end_user_id=dialog.end_user_id,
@@ -478,7 +510,7 @@ class NewExtractionOrchestrator:
)
tasks.append(self.statement_step.run(inp))
task_meta.append(
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
(dialog.id, chunk.id, chunk_speaker, ctx)
)
results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -499,6 +531,15 @@ class NewExtractionOrchestrator:
for s in stmts:
s.speaker = speaker
stmt_map[dialog_id][chunk_id] = stmts
if self.progress_callback:
# Frontend consumes knowledge_extraction_result with data.statement.
# Emit one event per statement to keep payload contract simple.
for s in stmts:
await self.progress_callback(
"knowledge_extraction_result",
"知识抽取中",
{"statement": s.statement_text},
)
return stmt_map
@@ -520,6 +561,11 @@ class NewExtractionOrchestrator:
chunk_stmts = all_stmt_results.get(dialog.id, {})
for _chunk_id, stmts in chunk_stmts.items():
for stmt in stmts:
# 防御性过滤:三元组抽取仅针对 user statement。
# 上游 _extract_all_statements 已过滤 chunk.speaker此处再做
# 一次 statement.speaker 的二次校验,防止外部注入或 legacy 数据脱漏。
if getattr(stmt, "speaker", "user") != "user":
continue
inp = self._convert_to_triplet_input(stmt, ctx)
tasks.append(self.triplet_step.run(inp))
task_meta.append((dialog.id, stmt.statement_id))
@@ -541,6 +587,24 @@ class NewExtractionOrchestrator:
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
else:
triplet_map[dialog_id][stmt_id] = result
if self.progress_callback:
await self.progress_callback(
"extract_triplet_result",
f"statement {stmt_id} 提取完成",
{
"statement_id": stmt_id,
"triplet_count": len(result.triplets),
"entity_count": len(result.entities),
"triplets": [
{
"subject_name": t.subject_name,
"predicate": t.predicate,
"object_name": t.object_name,
}
for t in result.triplets[:5]
],
},
)
return triplet_map
@@ -842,6 +906,8 @@ class NewExtractionOrchestrator:
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
# relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
has_unsolved_reference=stmt_out.has_unsolved_reference,
has_emotional_state=stmt_out.has_emotional_state,
triplet_extraction_info=triplet_info,
statement_embedding=stmt_embedding,
**emotion_kwargs,

View File

@@ -250,6 +250,7 @@ async def build_graph_nodes_and_edges(
entity_idx=entity.entity_idx,
statement_id=statement.id,
entity_type=getattr(entity, "type", "unknown"),
type_description=getattr(entity, "type_description", ""),
description=getattr(entity, "description", ""),
example=getattr(entity, "example", ""),
connect_strength=(
@@ -296,6 +297,7 @@ async def build_graph_nodes_and_edges(
source=subject_entity_id,
target=object_entity_id,
relation_type=triplet.predicate,
relation_type_description=getattr(triplet, "predicate_description", ""),
statement=statement.statement,
source_statement_id=statement.id,
end_user_id=dialog_data.end_user_id,

View File

@@ -46,6 +46,7 @@ class StatementStepOutput(BaseModel):
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
# relevance: str # RELEVANT / IRRELEVANT
speaker: str # "user" / "assistant"
has_emotional_state: bool = False # Whether statement reflects user's emotional state
valid_at: str # ISO 8601 or "NULL"
invalid_at: str # ISO 8601 or "NULL"
has_unsolved_reference: bool = False # Whether the statement has unresolved references
@@ -72,6 +73,7 @@ class EntityItem(BaseModel):
entity_idx: int
name: str
type: str
type_description: str = ""
description: str
is_explicit_memory: bool = False
@@ -82,6 +84,7 @@ class TripletItem(BaseModel):
subject_name: str
subject_id: int
predicate: str
predicate_description: str = ""
object_name: str
object_id: int

View File

@@ -34,6 +34,10 @@ class _ExtractedStatement(BaseModel):
statement_type: str = Field(..., description="FACT / OPINION / OTHER")
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
# relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
has_emotional_state: bool = Field(
False,
description="Whether the statement reflects user's emotional state",
)
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
@@ -155,6 +159,7 @@ class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementS
temporal_type=stmt.temporal_type.strip().upper(),
# relevance=stmt.relevance.strip().upper(),
speaker="user", # default; orchestrator overrides from chunk metadata
has_emotional_state=getattr(stmt, "has_emotional_state", False),
valid_at=stmt.valid_at or "NULL",
invalid_at=stmt.invalid_at or "NULL",
has_unsolved_reference=getattr(stmt, "has_unsolved_reference", False),

View File

@@ -112,6 +112,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
subject_name=t.subject_name,
subject_id=t.subject_id,
predicate=t.predicate,
predicate_description=getattr(t, "predicate_description", ""),
object_name=t.object_name,
object_id=t.object_id,
)
@@ -123,6 +124,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
entity_idx=e.entity_idx,
name=e.name,
type=e.type,
type_description=getattr(e, "type_description", ""),
description=e.description,
is_explicit_memory=getattr(e, "is_explicit_memory", False),
)