refactor(memory): add PilotWritePipeline and enrich extraction schema
- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write) - Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders - Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction - Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor - Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing - Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
@@ -272,6 +272,12 @@ class Settings:
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
# Pilot run pipeline switch:
|
||||
# true -> use refactored PilotWritePipeline
|
||||
# false -> use legacy ExtractionOrchestrator pipeline
|
||||
PILOT_RUN_USE_REFACTORED_PIPELINE: bool = (
|
||||
os.getenv("PILOT_RUN_USE_REFACTORED_PIPELINE", "true").lower() == "true"
|
||||
)
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
|
||||
@@ -9,7 +9,8 @@ async def get_chunked_dialogs(
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
config_id: str = None,
|
||||
snapshot=None,
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
|
||||
@@ -19,6 +20,7 @@ async def get_chunked_dialogs(
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks
|
||||
@@ -93,7 +95,7 @@ async def get_chunked_dialogs(
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot)
|
||||
original_msg_count = len(dialog_data.context.msgs)
|
||||
|
||||
# 使用 prune_dataset 而不是 prune_dialog
|
||||
|
||||
@@ -184,7 +184,8 @@ async def write(
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": e.entity_idx, "name": e.name,
|
||||
"type": e.type, "description": e.description,
|
||||
"type": e.type, "type_description": getattr(e, "type_description", ""),
|
||||
"description": e.description,
|
||||
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
|
||||
}
|
||||
for e in s.triplet_extraction_info.entities
|
||||
@@ -193,6 +194,7 @@ async def write(
|
||||
{
|
||||
"subject_name": t.subject_name, "subject_id": t.subject_id,
|
||||
"predicate": t.predicate,
|
||||
"predicate_description": getattr(t, "predicate_description", ""),
|
||||
"object_name": t.object_name, "object_id": t.object_id,
|
||||
}
|
||||
for t in s.triplet_extraction_info.triplets
|
||||
@@ -206,13 +208,13 @@ async def write(
|
||||
"chunk_nodes_count": len(all_chunk_nodes),
|
||||
"statement_nodes_count": len(all_statement_nodes),
|
||||
"entity_nodes": [
|
||||
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
|
||||
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "type_description": e.type_description, "description": e.description}
|
||||
for e in all_entity_nodes
|
||||
],
|
||||
"entity_entity_edges": [
|
||||
{
|
||||
"source": e.source, "target": e.target,
|
||||
"relation_type": e.relation_type, "statement": e.statement,
|
||||
"relation_type": e.relation_type, "relation_type_description": e.relation_type_description, "statement": e.statement,
|
||||
}
|
||||
for e in all_entity_entity_edges
|
||||
],
|
||||
|
||||
@@ -162,6 +162,7 @@ class EntityEntityEdge(Edge):
|
||||
invalid_at: Optional end date of temporal validity
|
||||
"""
|
||||
relation_type: str = Field(..., description="Relation type as defined in ontology")
|
||||
relation_type_description: str = Field(default="", description="Chinese definition of the relation type from ontology")
|
||||
relation_value: Optional[str] = Field(None, description="Value of the relation")
|
||||
statement: str = Field(..., description='The statement of the edge.')
|
||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||
@@ -413,6 +414,7 @@ class ExtractedEntityNode(Node):
|
||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
|
||||
@@ -96,6 +96,10 @@ class Statement(BaseModel):
|
||||
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
||||
# Reference resolution
|
||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||
has_emotional_state: bool = Field(
|
||||
False,
|
||||
description="Whether the statement reflects user's emotional state",
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
|
||||
@@ -37,6 +37,7 @@ class Entity(BaseModel):
|
||||
name: str = Field(..., description="Name of the entity")
|
||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||
type: str = Field(..., description="Type/category of the entity")
|
||||
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||
description: str = Field(..., description="Description of the entity")
|
||||
example: str = Field(
|
||||
default="",
|
||||
@@ -79,6 +80,7 @@ class Triplet(BaseModel):
|
||||
subject_name: str = Field(..., description="Name of the subject entity")
|
||||
subject_id: int = Field(..., description="ID of the subject entity")
|
||||
predicate: str = Field(..., description="Relationship/predicate between subject and object")
|
||||
predicate_description: str = Field(default="", description="Chinese definition of the predicate from ontology")
|
||||
object_name: str = Field(..., description="Name of the object entity")
|
||||
object_id: int = Field(..., description="ID of the object entity")
|
||||
value: Optional[str] = Field(None, description="Additional value or context")
|
||||
|
||||
@@ -14,13 +14,31 @@ def __getattr__(name):
|
||||
WritePipeline,
|
||||
WriteResult,
|
||||
)
|
||||
|
||||
_exports = {
|
||||
"WritePipeline": WritePipeline,
|
||||
"ExtractionResult": ExtractionResult,
|
||||
"WriteResult": WriteResult,
|
||||
}
|
||||
return _exports[name]
|
||||
if name in ("PilotWritePipeline", "PilotWriteResult"):
|
||||
from app.core.memory.pipelines.pilot_write_pipeline import (
|
||||
PilotWritePipeline,
|
||||
PilotWriteResult,
|
||||
)
|
||||
|
||||
_exports = {
|
||||
"PilotWritePipeline": PilotWritePipeline,
|
||||
"PilotWriteResult": PilotWriteResult,
|
||||
}
|
||||
return _exports[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = ["WritePipeline", "ExtractionResult", "WriteResult"]
|
||||
__all__ = [
|
||||
"WritePipeline",
|
||||
"ExtractionResult",
|
||||
"WriteResult",
|
||||
"PilotWritePipeline",
|
||||
"PilotWriteResult",
|
||||
]
|
||||
|
||||
108
api/app/core/memory/pipelines/pilot_write_pipeline.py
Normal file
108
api/app/core/memory/pipelines/pilot_write_pipeline.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""PilotWritePipeline — 试运行专用萃取流水线。
|
||||
|
||||
职责边界:
|
||||
- 只执行“萃取相关”链路:statement -> triplet -> graph_build -> 第一层去重消歧
|
||||
- 不负责 Neo4j 写入、聚类、摘要、缓存更新
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.dedup_step import (
|
||||
DedupResult,
|
||||
run_dedup,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
|
||||
NewExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
||||
GraphBuildResult,
|
||||
build_graph_nodes_and_edges,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PilotWriteResult:
|
||||
"""试运行流水线输出。"""
|
||||
|
||||
dialog_data_list: List[DialogData]
|
||||
graph: GraphBuildResult
|
||||
dedup: DedupResult
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, int]:
|
||||
return {
|
||||
"chunk_count": len(self.graph.chunk_nodes),
|
||||
"statement_count": len(self.graph.statement_nodes),
|
||||
"entity_count_before_dedup": len(self.graph.entity_nodes),
|
||||
"entity_count_after_dedup": len(self.dedup.entity_nodes),
|
||||
"relation_count_before_dedup": len(self.graph.entity_entity_edges),
|
||||
"relation_count_after_dedup": len(self.dedup.entity_entity_edges),
|
||||
}
|
||||
|
||||
|
||||
class PilotWritePipeline:
|
||||
"""重构后试运行专用流水线。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Any,
|
||||
embedder_client: Any,
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
embedding_id: Optional[str],
|
||||
language: str = "zh",
|
||||
ontology_types: Any = None,
|
||||
progress_callback: Optional[
|
||||
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||
] = None,
|
||||
) -> None:
|
||||
self.llm_client = llm_client
|
||||
self.embedder_client = embedder_client
|
||||
self.pipeline_config = pipeline_config
|
||||
self.embedding_id = embedding_id
|
||||
self.language = language
|
||||
self.ontology_types = ontology_types
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
async def run(self, dialog_data_list: List[DialogData]) -> PilotWriteResult:
|
||||
"""执行试运行萃取链路。"""
|
||||
orchestrator = NewExtractionOrchestrator(
|
||||
llm_client=self.llm_client,
|
||||
embedder_client=self.embedder_client,
|
||||
config=self.pipeline_config,
|
||||
embedding_id=self.embedding_id,
|
||||
ontology_types=self.ontology_types,
|
||||
language=self.language,
|
||||
is_pilot_run=True,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
extracted_dialogs = await orchestrator.run(dialog_data_list)
|
||||
|
||||
graph = await build_graph_nodes_and_edges(
|
||||
dialog_data_list=extracted_dialogs,
|
||||
embedder_client=self.embedder_client,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
|
||||
dedup = await run_dedup(
|
||||
entity_nodes=graph.entity_nodes,
|
||||
statement_entity_edges=graph.stmt_entity_edges,
|
||||
entity_entity_edges=graph.entity_entity_edges,
|
||||
dialog_data_list=extracted_dialogs,
|
||||
pipeline_config=self.pipeline_config,
|
||||
connector=None, # pilot: no layer-2 db dedup
|
||||
llm_client=self.llm_client,
|
||||
is_pilot_run=True,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
|
||||
return PilotWriteResult(
|
||||
dialog_data_list=extracted_dialogs,
|
||||
graph=graph,
|
||||
dedup=dedup,
|
||||
)
|
||||
|
||||
@@ -180,7 +180,11 @@ class WritePipeline:
|
||||
self._init_clients()
|
||||
self._init_neo4j_connector()
|
||||
|
||||
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝(暂无实现)
|
||||
# 初始化 Snapshot(提前创建,供预处理阶段的剪枝使用)
|
||||
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
|
||||
self._snapshot = PipelineSnapshot("new")
|
||||
|
||||
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await self._preprocess(messages, ref_id)
|
||||
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
|
||||
@@ -220,7 +224,7 @@ class WritePipeline:
|
||||
)
|
||||
|
||||
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
||||
self._extract_emotion(getattr(self, "_emotion_statements", []))
|
||||
await self._extract_emotion(getattr(self, "_emotion_statements", []))
|
||||
|
||||
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
|
||||
step_start = time.time()
|
||||
@@ -266,7 +270,7 @@ class WritePipeline:
|
||||
|
||||
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
|
||||
"""
|
||||
预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。
|
||||
预处理:消息校验 → AI消息语义剪枝 → 对话分块。
|
||||
|
||||
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
|
||||
get_dialogs.py 内部已包含:
|
||||
@@ -276,12 +280,15 @@ class WritePipeline:
|
||||
"""
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
|
||||
snapshot = getattr(self, "_snapshot", None)
|
||||
|
||||
return await get_chunked_dialogs(
|
||||
chunker_strategy=self.memory_config.chunker_strategy,
|
||||
end_user_id=self.end_user_id,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=str(self.memory_config.config_id),
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
@@ -321,7 +328,9 @@ class WritePipeline:
|
||||
pipeline_config = get_pipeline_config(self.memory_config)
|
||||
ontology_types = self._load_ontology_types()
|
||||
|
||||
snapshot = PipelineSnapshot("new")
|
||||
# 复用 run() 中已创建的 snapshot(剪枝阶段已使用同一实例)
|
||||
snapshot = getattr(self, "_snapshot", None) or PipelineSnapshot("new")
|
||||
self._snapshot = snapshot
|
||||
|
||||
# ── 新编排器:LLM 萃取 + 数据赋值 ──
|
||||
new_orchestrator = NewExtractionOrchestrator(
|
||||
@@ -589,11 +598,15 @@ class WritePipeline:
|
||||
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def _extract_emotion(self, emotion_statements: list) -> None:
|
||||
async def _extract_emotion(self, emotion_statements: list) -> None:
|
||||
"""提交异步情绪提取 Celery 任务。
|
||||
|
||||
从编排器收集的 user statement 列表中提取情绪,
|
||||
异步回写到 Neo4j Statement 节点。失败不影响主流程。
|
||||
|
||||
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
|
||||
通过 snapshot_dir 透传给 Celery 任务;worker 端在完成 LLM 抽取后,
|
||||
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json,避免主进程重复调用 LLM。
|
||||
"""
|
||||
if not emotion_statements:
|
||||
return
|
||||
@@ -607,6 +620,14 @@ class WritePipeline:
|
||||
logger.warning("[Emotion] 无法提交情绪提取任务:llm_model_id 为空")
|
||||
return
|
||||
|
||||
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
|
||||
snapshot = getattr(self, "_snapshot", None)
|
||||
snapshot_dir = (
|
||||
snapshot.directory
|
||||
if snapshot is not None and getattr(snapshot, "enabled", False)
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
from app.celery_app import celery_app
|
||||
|
||||
@@ -616,12 +637,14 @@ class WritePipeline:
|
||||
"statements": emotion_statements,
|
||||
"llm_model_id": llm_model_id,
|
||||
"language": self.language,
|
||||
"snapshot_dir": snapshot_dir,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"[Emotion] 异步情绪提取任务已提交 - "
|
||||
f"task_id={result.id}, "
|
||||
f"statement_count={len(emotion_statements)}, "
|
||||
f"snapshot_dir={snapshot_dir}, "
|
||||
f"source=async"
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -629,6 +652,7 @@ class WritePipeline:
|
||||
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 5: 摘要
|
||||
# (+ entity_description)+ meta_data部分在此提取
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1264,6 +1264,7 @@ class ExtractionOrchestrator:
|
||||
entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx
|
||||
statement_id=statement.id, # 添加必需的 statement_id 字段
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
type_description=getattr(entity, 'type_description', ''),
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
@@ -1306,6 +1307,7 @@ class ExtractionOrchestrator:
|
||||
source=subject_entity_id,
|
||||
target=object_entity_id,
|
||||
relation_type=triplet.predicate,
|
||||
relation_type_description=getattr(triplet, 'predicate_description', ''),
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
|
||||
@@ -12,16 +12,21 @@ from app.core.memory.utils.data.ontology import (
|
||||
TemporalInfo,
|
||||
)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExtractedStatement(BaseModel):
|
||||
"""Schema for extracted statement from LLM"""
|
||||
statement: str = Field(..., description="The extracted statement text")
|
||||
statement: str = Field(
|
||||
...,
|
||||
validation_alias=AliasChoices("statement", "statement_text"),
|
||||
description="The extracted statement text",
|
||||
)
|
||||
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
|
||||
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
||||
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
||||
# New prompt no longer outputs relevence; keep backward-compatible default.
|
||||
relevence: str = Field("RELEVANT", description="RELEVANT or IRRELEVANT")
|
||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||
|
||||
class StatementExtractionResponse(BaseModel):
|
||||
@@ -41,7 +46,7 @@ class StatementExtractionResponse(BaseModel):
|
||||
valid_statements = []
|
||||
filtered_count = 0
|
||||
for i, stmt in enumerate(v):
|
||||
if isinstance(stmt, dict) and stmt.get('statement'):
|
||||
if isinstance(stmt, dict) and (stmt.get("statement") or stmt.get("statement_text")):
|
||||
valid_statements.append(stmt)
|
||||
elif isinstance(stmt, dict):
|
||||
# Log which statement was filtered
|
||||
@@ -96,6 +101,11 @@ class StatementExtractor:
|
||||
"""
|
||||
chunk_content = chunk.content
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
logger.info(
|
||||
"[LegacyStatementExtractor] chunk_id=%s content_len=%d",
|
||||
getattr(chunk, "id", ""),
|
||||
len(chunk_content or ""),
|
||||
)
|
||||
|
||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||
@@ -108,7 +118,18 @@ class StatementExtractor:
|
||||
granularity=self.config.statement_granularity,
|
||||
include_dialogue_context=self.config.include_dialogue_context,
|
||||
dialogue_content=dialogue_content,
|
||||
max_dialogue_chars=self.config.max_dialogue_context_chars
|
||||
max_dialogue_chars=self.config.max_dialogue_context_chars,
|
||||
input_json={
|
||||
"chunk_id": getattr(chunk, "id", ""),
|
||||
"end_user_id": end_user_id or "",
|
||||
"target_content": chunk_content,
|
||||
"target_message_date": datetime.now().isoformat(),
|
||||
"supporting_context": {
|
||||
"msgs": [
|
||||
{"role": "context", "msg": dialogue_content}
|
||||
] if dialogue_content else []
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Simple system message
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Dict, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -73,15 +73,9 @@ class TripletExtractor:
|
||||
try:
|
||||
# Get structured response from LLM
|
||||
response = await self.llm_client.response_structured(messages, TripletExtractionResponse)
|
||||
# Filter triplets to only allowed predicates from ontology
|
||||
# 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系
|
||||
allowed_predicates = {p.value for p in Predicate}
|
||||
filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates]
|
||||
# 仅保留predicate ∈ Predicate 的三元组,其余全部剔除
|
||||
|
||||
# Create new triplets with statement_id set during creation
|
||||
updated_triplets = []
|
||||
for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组
|
||||
for triplet in response.triplets:
|
||||
updated_triplet = triplet.model_copy(update={"statement_id": statement.id})
|
||||
updated_triplets.append(updated_triplet)
|
||||
|
||||
|
||||
@@ -300,6 +300,33 @@ class NewExtractionOrchestrator:
|
||||
"embedding_output": None,
|
||||
}
|
||||
|
||||
if self.progress_callback:
|
||||
statements_count = sum(
|
||||
len(stmts)
|
||||
for chunk_stmts in all_stmt_results.values()
|
||||
for stmts in chunk_stmts.values()
|
||||
)
|
||||
entities_count = sum(
|
||||
len(t_out.entities)
|
||||
for stmt_triplets in all_triplet_results.values()
|
||||
for t_out in stmt_triplets.values()
|
||||
)
|
||||
triplets_count = sum(
|
||||
len(t_out.triplets)
|
||||
for stmt_triplets in all_triplet_results.values()
|
||||
for t_out in stmt_triplets.values()
|
||||
)
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_complete",
|
||||
"知识抽取完成",
|
||||
{
|
||||
"entities_count": entities_count,
|
||||
"statements_count": statements_count,
|
||||
"temporal_ranges_count": 0,
|
||||
"triplets_count": triplets_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Pilot extraction complete")
|
||||
return dialog_data_list
|
||||
|
||||
@@ -467,6 +494,11 @@ class NewExtractionOrchestrator:
|
||||
else None
|
||||
)
|
||||
for chunk in dialog.chunks:
|
||||
# 仅对 speaker="user" 的 chunk 进行陈述句抽取;assistant 内容交给
|
||||
# 上游预处理/剪枝阶段处理,避免浪费 LLM 调用。
|
||||
chunk_speaker = getattr(chunk, "speaker", "user")
|
||||
if chunk_speaker != "user":
|
||||
continue
|
||||
inp = StatementStepInput(
|
||||
chunk_id=chunk.id,
|
||||
end_user_id=dialog.end_user_id,
|
||||
@@ -478,7 +510,7 @@ class NewExtractionOrchestrator:
|
||||
)
|
||||
tasks.append(self.statement_step.run(inp))
|
||||
task_meta.append(
|
||||
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
|
||||
(dialog.id, chunk.id, chunk_speaker, ctx)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
@@ -499,6 +531,15 @@ class NewExtractionOrchestrator:
|
||||
for s in stmts:
|
||||
s.speaker = speaker
|
||||
stmt_map[dialog_id][chunk_id] = stmts
|
||||
if self.progress_callback:
|
||||
# Frontend consumes knowledge_extraction_result with data.statement.
|
||||
# Emit one event per statement to keep payload contract simple.
|
||||
for s in stmts:
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_result",
|
||||
"知识抽取中",
|
||||
{"statement": s.statement_text},
|
||||
)
|
||||
|
||||
return stmt_map
|
||||
|
||||
@@ -520,6 +561,11 @@ class NewExtractionOrchestrator:
|
||||
chunk_stmts = all_stmt_results.get(dialog.id, {})
|
||||
for _chunk_id, stmts in chunk_stmts.items():
|
||||
for stmt in stmts:
|
||||
# 防御性过滤:三元组抽取仅针对 user statement。
|
||||
# 上游 _extract_all_statements 已过滤 chunk.speaker,此处再做
|
||||
# 一次 statement.speaker 的二次校验,防止外部注入或 legacy 数据脱漏。
|
||||
if getattr(stmt, "speaker", "user") != "user":
|
||||
continue
|
||||
inp = self._convert_to_triplet_input(stmt, ctx)
|
||||
tasks.append(self.triplet_step.run(inp))
|
||||
task_meta.append((dialog.id, stmt.statement_id))
|
||||
@@ -541,6 +587,24 @@ class NewExtractionOrchestrator:
|
||||
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
|
||||
else:
|
||||
triplet_map[dialog_id][stmt_id] = result
|
||||
if self.progress_callback:
|
||||
await self.progress_callback(
|
||||
"extract_triplet_result",
|
||||
f"statement {stmt_id} 提取完成",
|
||||
{
|
||||
"statement_id": stmt_id,
|
||||
"triplet_count": len(result.triplets),
|
||||
"entity_count": len(result.entities),
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": t.subject_name,
|
||||
"predicate": t.predicate,
|
||||
"object_name": t.object_name,
|
||||
}
|
||||
for t in result.triplets[:5]
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
return triplet_map
|
||||
|
||||
@@ -842,6 +906,8 @@ class NewExtractionOrchestrator:
|
||||
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
|
||||
# relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
|
||||
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
|
||||
has_unsolved_reference=stmt_out.has_unsolved_reference,
|
||||
has_emotional_state=stmt_out.has_emotional_state,
|
||||
triplet_extraction_info=triplet_info,
|
||||
statement_embedding=stmt_embedding,
|
||||
**emotion_kwargs,
|
||||
|
||||
@@ -250,6 +250,7 @@ async def build_graph_nodes_and_edges(
|
||||
entity_idx=entity.entity_idx,
|
||||
statement_id=statement.id,
|
||||
entity_type=getattr(entity, "type", "unknown"),
|
||||
type_description=getattr(entity, "type_description", ""),
|
||||
description=getattr(entity, "description", ""),
|
||||
example=getattr(entity, "example", ""),
|
||||
connect_strength=(
|
||||
@@ -296,6 +297,7 @@ async def build_graph_nodes_and_edges(
|
||||
source=subject_entity_id,
|
||||
target=object_entity_id,
|
||||
relation_type=triplet.predicate,
|
||||
relation_type_description=getattr(triplet, "predicate_description", ""),
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
|
||||
@@ -46,6 +46,7 @@ class StatementStepOutput(BaseModel):
|
||||
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
|
||||
# relevance: str # RELEVANT / IRRELEVANT
|
||||
speaker: str # "user" / "assistant"
|
||||
has_emotional_state: bool = False # Whether statement reflects user's emotional state
|
||||
valid_at: str # ISO 8601 or "NULL"
|
||||
invalid_at: str # ISO 8601 or "NULL"
|
||||
has_unsolved_reference: bool = False # Whether the statement has unresolved references
|
||||
@@ -72,6 +73,7 @@ class EntityItem(BaseModel):
|
||||
entity_idx: int
|
||||
name: str
|
||||
type: str
|
||||
type_description: str = ""
|
||||
description: str
|
||||
is_explicit_memory: bool = False
|
||||
|
||||
@@ -82,6 +84,7 @@ class TripletItem(BaseModel):
|
||||
subject_name: str
|
||||
subject_id: int
|
||||
predicate: str
|
||||
predicate_description: str = ""
|
||||
object_name: str
|
||||
object_id: int
|
||||
|
||||
|
||||
@@ -34,6 +34,10 @@ class _ExtractedStatement(BaseModel):
|
||||
statement_type: str = Field(..., description="FACT / OPINION / OTHER")
|
||||
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
|
||||
# relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
|
||||
has_emotional_state: bool = Field(
|
||||
False,
|
||||
description="Whether the statement reflects user's emotional state",
|
||||
)
|
||||
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||
@@ -155,6 +159,7 @@ class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementS
|
||||
temporal_type=stmt.temporal_type.strip().upper(),
|
||||
# relevance=stmt.relevance.strip().upper(),
|
||||
speaker="user", # default; orchestrator overrides from chunk metadata
|
||||
has_emotional_state=getattr(stmt, "has_emotional_state", False),
|
||||
valid_at=stmt.valid_at or "NULL",
|
||||
invalid_at=stmt.invalid_at or "NULL",
|
||||
has_unsolved_reference=getattr(stmt, "has_unsolved_reference", False),
|
||||
|
||||
@@ -112,6 +112,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
|
||||
subject_name=t.subject_name,
|
||||
subject_id=t.subject_id,
|
||||
predicate=t.predicate,
|
||||
predicate_description=getattr(t, "predicate_description", ""),
|
||||
object_name=t.object_name,
|
||||
object_id=t.object_id,
|
||||
)
|
||||
@@ -123,6 +124,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
|
||||
entity_idx=e.entity_idx,
|
||||
name=e.name,
|
||||
type=e.type,
|
||||
type_description=getattr(e, "type_description", ""),
|
||||
description=e.description,
|
||||
is_explicit_memory=getattr(e, "is_explicit_memory", False),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user