refactor(memory): add PilotWritePipeline and enrich extraction schema
- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write) - Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders - Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction - Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor - Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing - Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
@@ -272,6 +272,12 @@ class Settings:
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
# Pilot run pipeline switch:
|
||||
# true -> use refactored PilotWritePipeline
|
||||
# false -> use legacy ExtractionOrchestrator pipeline
|
||||
PILOT_RUN_USE_REFACTORED_PIPELINE: bool = (
|
||||
os.getenv("PILOT_RUN_USE_REFACTORED_PIPELINE", "true").lower() == "true"
|
||||
)
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
|
||||
@@ -9,7 +9,8 @@ async def get_chunked_dialogs(
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
config_id: str = None,
|
||||
snapshot=None,
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
|
||||
@@ -19,6 +20,7 @@ async def get_chunked_dialogs(
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks
|
||||
@@ -93,7 +95,7 @@ async def get_chunked_dialogs(
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot)
|
||||
original_msg_count = len(dialog_data.context.msgs)
|
||||
|
||||
# 使用 prune_dataset 而不是 prune_dialog
|
||||
|
||||
@@ -184,7 +184,8 @@ async def write(
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": e.entity_idx, "name": e.name,
|
||||
"type": e.type, "description": e.description,
|
||||
"type": e.type, "type_description": getattr(e, "type_description", ""),
|
||||
"description": e.description,
|
||||
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
|
||||
}
|
||||
for e in s.triplet_extraction_info.entities
|
||||
@@ -193,6 +194,7 @@ async def write(
|
||||
{
|
||||
"subject_name": t.subject_name, "subject_id": t.subject_id,
|
||||
"predicate": t.predicate,
|
||||
"predicate_description": getattr(t, "predicate_description", ""),
|
||||
"object_name": t.object_name, "object_id": t.object_id,
|
||||
}
|
||||
for t in s.triplet_extraction_info.triplets
|
||||
@@ -206,13 +208,13 @@ async def write(
|
||||
"chunk_nodes_count": len(all_chunk_nodes),
|
||||
"statement_nodes_count": len(all_statement_nodes),
|
||||
"entity_nodes": [
|
||||
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
|
||||
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "type_description": e.type_description, "description": e.description}
|
||||
for e in all_entity_nodes
|
||||
],
|
||||
"entity_entity_edges": [
|
||||
{
|
||||
"source": e.source, "target": e.target,
|
||||
"relation_type": e.relation_type, "statement": e.statement,
|
||||
"relation_type": e.relation_type, "relation_type_description": e.relation_type_description, "statement": e.statement,
|
||||
}
|
||||
for e in all_entity_entity_edges
|
||||
],
|
||||
|
||||
@@ -162,6 +162,7 @@ class EntityEntityEdge(Edge):
|
||||
invalid_at: Optional end date of temporal validity
|
||||
"""
|
||||
relation_type: str = Field(..., description="Relation type as defined in ontology")
|
||||
relation_type_description: str = Field(default="", description="Chinese definition of the relation type from ontology")
|
||||
relation_value: Optional[str] = Field(None, description="Value of the relation")
|
||||
statement: str = Field(..., description='The statement of the edge.')
|
||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||
@@ -413,6 +414,7 @@ class ExtractedEntityNode(Node):
|
||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
|
||||
@@ -96,6 +96,10 @@ class Statement(BaseModel):
|
||||
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
||||
# Reference resolution
|
||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||
has_emotional_state: bool = Field(
|
||||
False,
|
||||
description="Whether the statement reflects user's emotional state",
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
|
||||
@@ -37,6 +37,7 @@ class Entity(BaseModel):
|
||||
name: str = Field(..., description="Name of the entity")
|
||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||
type: str = Field(..., description="Type/category of the entity")
|
||||
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||
description: str = Field(..., description="Description of the entity")
|
||||
example: str = Field(
|
||||
default="",
|
||||
@@ -79,6 +80,7 @@ class Triplet(BaseModel):
|
||||
subject_name: str = Field(..., description="Name of the subject entity")
|
||||
subject_id: int = Field(..., description="ID of the subject entity")
|
||||
predicate: str = Field(..., description="Relationship/predicate between subject and object")
|
||||
predicate_description: str = Field(default="", description="Chinese definition of the predicate from ontology")
|
||||
object_name: str = Field(..., description="Name of the object entity")
|
||||
object_id: int = Field(..., description="ID of the object entity")
|
||||
value: Optional[str] = Field(None, description="Additional value or context")
|
||||
|
||||
@@ -14,13 +14,31 @@ def __getattr__(name):
|
||||
WritePipeline,
|
||||
WriteResult,
|
||||
)
|
||||
|
||||
_exports = {
|
||||
"WritePipeline": WritePipeline,
|
||||
"ExtractionResult": ExtractionResult,
|
||||
"WriteResult": WriteResult,
|
||||
}
|
||||
return _exports[name]
|
||||
if name in ("PilotWritePipeline", "PilotWriteResult"):
|
||||
from app.core.memory.pipelines.pilot_write_pipeline import (
|
||||
PilotWritePipeline,
|
||||
PilotWriteResult,
|
||||
)
|
||||
|
||||
_exports = {
|
||||
"PilotWritePipeline": PilotWritePipeline,
|
||||
"PilotWriteResult": PilotWriteResult,
|
||||
}
|
||||
return _exports[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = ["WritePipeline", "ExtractionResult", "WriteResult"]
|
||||
__all__ = [
|
||||
"WritePipeline",
|
||||
"ExtractionResult",
|
||||
"WriteResult",
|
||||
"PilotWritePipeline",
|
||||
"PilotWriteResult",
|
||||
]
|
||||
|
||||
108
api/app/core/memory/pipelines/pilot_write_pipeline.py
Normal file
108
api/app/core/memory/pipelines/pilot_write_pipeline.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""PilotWritePipeline — 试运行专用萃取流水线。
|
||||
|
||||
职责边界:
|
||||
- 只执行“萃取相关”链路:statement -> triplet -> graph_build -> 第一层去重消歧
|
||||
- 不负责 Neo4j 写入、聚类、摘要、缓存更新
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.dedup_step import (
|
||||
DedupResult,
|
||||
run_dedup,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
|
||||
NewExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
||||
GraphBuildResult,
|
||||
build_graph_nodes_and_edges,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PilotWriteResult:
|
||||
"""试运行流水线输出。"""
|
||||
|
||||
dialog_data_list: List[DialogData]
|
||||
graph: GraphBuildResult
|
||||
dedup: DedupResult
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, int]:
|
||||
return {
|
||||
"chunk_count": len(self.graph.chunk_nodes),
|
||||
"statement_count": len(self.graph.statement_nodes),
|
||||
"entity_count_before_dedup": len(self.graph.entity_nodes),
|
||||
"entity_count_after_dedup": len(self.dedup.entity_nodes),
|
||||
"relation_count_before_dedup": len(self.graph.entity_entity_edges),
|
||||
"relation_count_after_dedup": len(self.dedup.entity_entity_edges),
|
||||
}
|
||||
|
||||
|
||||
class PilotWritePipeline:
|
||||
"""重构后试运行专用流水线。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Any,
|
||||
embedder_client: Any,
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
embedding_id: Optional[str],
|
||||
language: str = "zh",
|
||||
ontology_types: Any = None,
|
||||
progress_callback: Optional[
|
||||
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||
] = None,
|
||||
) -> None:
|
||||
self.llm_client = llm_client
|
||||
self.embedder_client = embedder_client
|
||||
self.pipeline_config = pipeline_config
|
||||
self.embedding_id = embedding_id
|
||||
self.language = language
|
||||
self.ontology_types = ontology_types
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
async def run(self, dialog_data_list: List[DialogData]) -> PilotWriteResult:
|
||||
"""执行试运行萃取链路。"""
|
||||
orchestrator = NewExtractionOrchestrator(
|
||||
llm_client=self.llm_client,
|
||||
embedder_client=self.embedder_client,
|
||||
config=self.pipeline_config,
|
||||
embedding_id=self.embedding_id,
|
||||
ontology_types=self.ontology_types,
|
||||
language=self.language,
|
||||
is_pilot_run=True,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
extracted_dialogs = await orchestrator.run(dialog_data_list)
|
||||
|
||||
graph = await build_graph_nodes_and_edges(
|
||||
dialog_data_list=extracted_dialogs,
|
||||
embedder_client=self.embedder_client,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
|
||||
dedup = await run_dedup(
|
||||
entity_nodes=graph.entity_nodes,
|
||||
statement_entity_edges=graph.stmt_entity_edges,
|
||||
entity_entity_edges=graph.entity_entity_edges,
|
||||
dialog_data_list=extracted_dialogs,
|
||||
pipeline_config=self.pipeline_config,
|
||||
connector=None, # pilot: no layer-2 db dedup
|
||||
llm_client=self.llm_client,
|
||||
is_pilot_run=True,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
|
||||
return PilotWriteResult(
|
||||
dialog_data_list=extracted_dialogs,
|
||||
graph=graph,
|
||||
dedup=dedup,
|
||||
)
|
||||
|
||||
@@ -180,7 +180,11 @@ class WritePipeline:
|
||||
self._init_clients()
|
||||
self._init_neo4j_connector()
|
||||
|
||||
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝(暂无实现)
|
||||
# 初始化 Snapshot(提前创建,供预处理阶段的剪枝使用)
|
||||
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
|
||||
self._snapshot = PipelineSnapshot("new")
|
||||
|
||||
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await self._preprocess(messages, ref_id)
|
||||
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
|
||||
@@ -220,7 +224,7 @@ class WritePipeline:
|
||||
)
|
||||
|
||||
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
||||
self._extract_emotion(getattr(self, "_emotion_statements", []))
|
||||
await self._extract_emotion(getattr(self, "_emotion_statements", []))
|
||||
|
||||
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
|
||||
step_start = time.time()
|
||||
@@ -266,7 +270,7 @@ class WritePipeline:
|
||||
|
||||
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
|
||||
"""
|
||||
预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。
|
||||
预处理:消息校验 → AI消息语义剪枝 → 对话分块。
|
||||
|
||||
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
|
||||
get_dialogs.py 内部已包含:
|
||||
@@ -276,12 +280,15 @@ class WritePipeline:
|
||||
"""
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
|
||||
snapshot = getattr(self, "_snapshot", None)
|
||||
|
||||
return await get_chunked_dialogs(
|
||||
chunker_strategy=self.memory_config.chunker_strategy,
|
||||
end_user_id=self.end_user_id,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=str(self.memory_config.config_id),
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
@@ -321,7 +328,9 @@ class WritePipeline:
|
||||
pipeline_config = get_pipeline_config(self.memory_config)
|
||||
ontology_types = self._load_ontology_types()
|
||||
|
||||
snapshot = PipelineSnapshot("new")
|
||||
# 复用 run() 中已创建的 snapshot(剪枝阶段已使用同一实例)
|
||||
snapshot = getattr(self, "_snapshot", None) or PipelineSnapshot("new")
|
||||
self._snapshot = snapshot
|
||||
|
||||
# ── 新编排器:LLM 萃取 + 数据赋值 ──
|
||||
new_orchestrator = NewExtractionOrchestrator(
|
||||
@@ -589,11 +598,15 @@ class WritePipeline:
|
||||
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def _extract_emotion(self, emotion_statements: list) -> None:
|
||||
async def _extract_emotion(self, emotion_statements: list) -> None:
|
||||
"""提交异步情绪提取 Celery 任务。
|
||||
|
||||
从编排器收集的 user statement 列表中提取情绪,
|
||||
异步回写到 Neo4j Statement 节点。失败不影响主流程。
|
||||
|
||||
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
|
||||
通过 snapshot_dir 透传给 Celery 任务;worker 端在完成 LLM 抽取后,
|
||||
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json,避免主进程重复调用 LLM。
|
||||
"""
|
||||
if not emotion_statements:
|
||||
return
|
||||
@@ -607,6 +620,14 @@ class WritePipeline:
|
||||
logger.warning("[Emotion] 无法提交情绪提取任务:llm_model_id 为空")
|
||||
return
|
||||
|
||||
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
|
||||
snapshot = getattr(self, "_snapshot", None)
|
||||
snapshot_dir = (
|
||||
snapshot.directory
|
||||
if snapshot is not None and getattr(snapshot, "enabled", False)
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
from app.celery_app import celery_app
|
||||
|
||||
@@ -616,12 +637,14 @@ class WritePipeline:
|
||||
"statements": emotion_statements,
|
||||
"llm_model_id": llm_model_id,
|
||||
"language": self.language,
|
||||
"snapshot_dir": snapshot_dir,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"[Emotion] 异步情绪提取任务已提交 - "
|
||||
f"task_id={result.id}, "
|
||||
f"statement_count={len(emotion_statements)}, "
|
||||
f"snapshot_dir={snapshot_dir}, "
|
||||
f"source=async"
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -629,6 +652,7 @@ class WritePipeline:
|
||||
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 5: 摘要
|
||||
# (+ entity_description)+ meta_data部分在此提取
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1264,6 +1264,7 @@ class ExtractionOrchestrator:
|
||||
entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx
|
||||
statement_id=statement.id, # 添加必需的 statement_id 字段
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
type_description=getattr(entity, 'type_description', ''),
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
@@ -1306,6 +1307,7 @@ class ExtractionOrchestrator:
|
||||
source=subject_entity_id,
|
||||
target=object_entity_id,
|
||||
relation_type=triplet.predicate,
|
||||
relation_type_description=getattr(triplet, 'predicate_description', ''),
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
|
||||
@@ -12,16 +12,21 @@ from app.core.memory.utils.data.ontology import (
|
||||
TemporalInfo,
|
||||
)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ExtractedStatement(BaseModel):
|
||||
"""Schema for extracted statement from LLM"""
|
||||
statement: str = Field(..., description="The extracted statement text")
|
||||
statement: str = Field(
|
||||
...,
|
||||
validation_alias=AliasChoices("statement", "statement_text"),
|
||||
description="The extracted statement text",
|
||||
)
|
||||
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
|
||||
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
||||
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
||||
# New prompt no longer outputs relevence; keep backward-compatible default.
|
||||
relevence: str = Field("RELEVANT", description="RELEVANT or IRRELEVANT")
|
||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||
|
||||
class StatementExtractionResponse(BaseModel):
|
||||
@@ -41,7 +46,7 @@ class StatementExtractionResponse(BaseModel):
|
||||
valid_statements = []
|
||||
filtered_count = 0
|
||||
for i, stmt in enumerate(v):
|
||||
if isinstance(stmt, dict) and stmt.get('statement'):
|
||||
if isinstance(stmt, dict) and (stmt.get("statement") or stmt.get("statement_text")):
|
||||
valid_statements.append(stmt)
|
||||
elif isinstance(stmt, dict):
|
||||
# Log which statement was filtered
|
||||
@@ -96,6 +101,11 @@ class StatementExtractor:
|
||||
"""
|
||||
chunk_content = chunk.content
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
logger.info(
|
||||
"[LegacyStatementExtractor] chunk_id=%s content_len=%d",
|
||||
getattr(chunk, "id", ""),
|
||||
len(chunk_content or ""),
|
||||
)
|
||||
|
||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||
@@ -108,7 +118,18 @@ class StatementExtractor:
|
||||
granularity=self.config.statement_granularity,
|
||||
include_dialogue_context=self.config.include_dialogue_context,
|
||||
dialogue_content=dialogue_content,
|
||||
max_dialogue_chars=self.config.max_dialogue_context_chars
|
||||
max_dialogue_chars=self.config.max_dialogue_context_chars,
|
||||
input_json={
|
||||
"chunk_id": getattr(chunk, "id", ""),
|
||||
"end_user_id": end_user_id or "",
|
||||
"target_content": chunk_content,
|
||||
"target_message_date": datetime.now().isoformat(),
|
||||
"supporting_context": {
|
||||
"msgs": [
|
||||
{"role": "context", "msg": dialogue_content}
|
||||
] if dialogue_content else []
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Simple system message
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Dict, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -73,15 +73,9 @@ class TripletExtractor:
|
||||
try:
|
||||
# Get structured response from LLM
|
||||
response = await self.llm_client.response_structured(messages, TripletExtractionResponse)
|
||||
# Filter triplets to only allowed predicates from ontology
|
||||
# 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系
|
||||
allowed_predicates = {p.value for p in Predicate}
|
||||
filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates]
|
||||
# 仅保留predicate ∈ Predicate 的三元组,其余全部剔除
|
||||
|
||||
# Create new triplets with statement_id set during creation
|
||||
updated_triplets = []
|
||||
for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组
|
||||
for triplet in response.triplets:
|
||||
updated_triplet = triplet.model_copy(update={"statement_id": statement.id})
|
||||
updated_triplets.append(updated_triplet)
|
||||
|
||||
|
||||
@@ -300,6 +300,33 @@ class NewExtractionOrchestrator:
|
||||
"embedding_output": None,
|
||||
}
|
||||
|
||||
if self.progress_callback:
|
||||
statements_count = sum(
|
||||
len(stmts)
|
||||
for chunk_stmts in all_stmt_results.values()
|
||||
for stmts in chunk_stmts.values()
|
||||
)
|
||||
entities_count = sum(
|
||||
len(t_out.entities)
|
||||
for stmt_triplets in all_triplet_results.values()
|
||||
for t_out in stmt_triplets.values()
|
||||
)
|
||||
triplets_count = sum(
|
||||
len(t_out.triplets)
|
||||
for stmt_triplets in all_triplet_results.values()
|
||||
for t_out in stmt_triplets.values()
|
||||
)
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_complete",
|
||||
"知识抽取完成",
|
||||
{
|
||||
"entities_count": entities_count,
|
||||
"statements_count": statements_count,
|
||||
"temporal_ranges_count": 0,
|
||||
"triplets_count": triplets_count,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Pilot extraction complete")
|
||||
return dialog_data_list
|
||||
|
||||
@@ -467,6 +494,11 @@ class NewExtractionOrchestrator:
|
||||
else None
|
||||
)
|
||||
for chunk in dialog.chunks:
|
||||
# 仅对 speaker="user" 的 chunk 进行陈述句抽取;assistant 内容交给
|
||||
# 上游预处理/剪枝阶段处理,避免浪费 LLM 调用。
|
||||
chunk_speaker = getattr(chunk, "speaker", "user")
|
||||
if chunk_speaker != "user":
|
||||
continue
|
||||
inp = StatementStepInput(
|
||||
chunk_id=chunk.id,
|
||||
end_user_id=dialog.end_user_id,
|
||||
@@ -478,7 +510,7 @@ class NewExtractionOrchestrator:
|
||||
)
|
||||
tasks.append(self.statement_step.run(inp))
|
||||
task_meta.append(
|
||||
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
|
||||
(dialog.id, chunk.id, chunk_speaker, ctx)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
@@ -499,6 +531,15 @@ class NewExtractionOrchestrator:
|
||||
for s in stmts:
|
||||
s.speaker = speaker
|
||||
stmt_map[dialog_id][chunk_id] = stmts
|
||||
if self.progress_callback:
|
||||
# Frontend consumes knowledge_extraction_result with data.statement.
|
||||
# Emit one event per statement to keep payload contract simple.
|
||||
for s in stmts:
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_result",
|
||||
"知识抽取中",
|
||||
{"statement": s.statement_text},
|
||||
)
|
||||
|
||||
return stmt_map
|
||||
|
||||
@@ -520,6 +561,11 @@ class NewExtractionOrchestrator:
|
||||
chunk_stmts = all_stmt_results.get(dialog.id, {})
|
||||
for _chunk_id, stmts in chunk_stmts.items():
|
||||
for stmt in stmts:
|
||||
# 防御性过滤:三元组抽取仅针对 user statement。
|
||||
# 上游 _extract_all_statements 已过滤 chunk.speaker,此处再做
|
||||
# 一次 statement.speaker 的二次校验,防止外部注入或 legacy 数据脱漏。
|
||||
if getattr(stmt, "speaker", "user") != "user":
|
||||
continue
|
||||
inp = self._convert_to_triplet_input(stmt, ctx)
|
||||
tasks.append(self.triplet_step.run(inp))
|
||||
task_meta.append((dialog.id, stmt.statement_id))
|
||||
@@ -541,6 +587,24 @@ class NewExtractionOrchestrator:
|
||||
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
|
||||
else:
|
||||
triplet_map[dialog_id][stmt_id] = result
|
||||
if self.progress_callback:
|
||||
await self.progress_callback(
|
||||
"extract_triplet_result",
|
||||
f"statement {stmt_id} 提取完成",
|
||||
{
|
||||
"statement_id": stmt_id,
|
||||
"triplet_count": len(result.triplets),
|
||||
"entity_count": len(result.entities),
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": t.subject_name,
|
||||
"predicate": t.predicate,
|
||||
"object_name": t.object_name,
|
||||
}
|
||||
for t in result.triplets[:5]
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
return triplet_map
|
||||
|
||||
@@ -842,6 +906,8 @@ class NewExtractionOrchestrator:
|
||||
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
|
||||
# relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
|
||||
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
|
||||
has_unsolved_reference=stmt_out.has_unsolved_reference,
|
||||
has_emotional_state=stmt_out.has_emotional_state,
|
||||
triplet_extraction_info=triplet_info,
|
||||
statement_embedding=stmt_embedding,
|
||||
**emotion_kwargs,
|
||||
|
||||
@@ -250,6 +250,7 @@ async def build_graph_nodes_and_edges(
|
||||
entity_idx=entity.entity_idx,
|
||||
statement_id=statement.id,
|
||||
entity_type=getattr(entity, "type", "unknown"),
|
||||
type_description=getattr(entity, "type_description", ""),
|
||||
description=getattr(entity, "description", ""),
|
||||
example=getattr(entity, "example", ""),
|
||||
connect_strength=(
|
||||
@@ -296,6 +297,7 @@ async def build_graph_nodes_and_edges(
|
||||
source=subject_entity_id,
|
||||
target=object_entity_id,
|
||||
relation_type=triplet.predicate,
|
||||
relation_type_description=getattr(triplet, "predicate_description", ""),
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
|
||||
@@ -46,6 +46,7 @@ class StatementStepOutput(BaseModel):
|
||||
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
|
||||
# relevance: str # RELEVANT / IRRELEVANT
|
||||
speaker: str # "user" / "assistant"
|
||||
has_emotional_state: bool = False # Whether statement reflects user's emotional state
|
||||
valid_at: str # ISO 8601 or "NULL"
|
||||
invalid_at: str # ISO 8601 or "NULL"
|
||||
has_unsolved_reference: bool = False # Whether the statement has unresolved references
|
||||
@@ -72,6 +73,7 @@ class EntityItem(BaseModel):
|
||||
entity_idx: int
|
||||
name: str
|
||||
type: str
|
||||
type_description: str = ""
|
||||
description: str
|
||||
is_explicit_memory: bool = False
|
||||
|
||||
@@ -82,6 +84,7 @@ class TripletItem(BaseModel):
|
||||
subject_name: str
|
||||
subject_id: int
|
||||
predicate: str
|
||||
predicate_description: str = ""
|
||||
object_name: str
|
||||
object_id: int
|
||||
|
||||
|
||||
@@ -34,6 +34,10 @@ class _ExtractedStatement(BaseModel):
|
||||
statement_type: str = Field(..., description="FACT / OPINION / OTHER")
|
||||
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
|
||||
# relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
|
||||
has_emotional_state: bool = Field(
|
||||
False,
|
||||
description="Whether the statement reflects user's emotional state",
|
||||
)
|
||||
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||
@@ -155,6 +159,7 @@ class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementS
|
||||
temporal_type=stmt.temporal_type.strip().upper(),
|
||||
# relevance=stmt.relevance.strip().upper(),
|
||||
speaker="user", # default; orchestrator overrides from chunk metadata
|
||||
has_emotional_state=getattr(stmt, "has_emotional_state", False),
|
||||
valid_at=stmt.valid_at or "NULL",
|
||||
invalid_at=stmt.invalid_at or "NULL",
|
||||
has_unsolved_reference=getattr(stmt, "has_unsolved_reference", False),
|
||||
|
||||
@@ -112,6 +112,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
|
||||
subject_name=t.subject_name,
|
||||
subject_id=t.subject_id,
|
||||
predicate=t.predicate,
|
||||
predicate_description=getattr(t, "predicate_description", ""),
|
||||
object_name=t.object_name,
|
||||
object_id=t.object_id,
|
||||
)
|
||||
@@ -123,6 +124,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
|
||||
entity_idx=e.entity_idx,
|
||||
name=e.name,
|
||||
type=e.type,
|
||||
type_description=getattr(e, "type_description", ""),
|
||||
description=e.description,
|
||||
is_explicit_memory=getattr(e, "is_explicit_memory", False),
|
||||
)
|
||||
|
||||
@@ -92,6 +92,7 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
THEN entity.expired_at ELSE e.expired_at END,
|
||||
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
|
||||
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
|
||||
e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END,
|
||||
e.description = CASE
|
||||
WHEN entity.description IS NOT NULL AND entity.description <> ''
|
||||
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
|
||||
@@ -147,6 +148,7 @@ MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
|
||||
// Avoid duplicate edges across runs for the same endpoints
|
||||
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
||||
SET r.predicate = rel.predicate,
|
||||
r.predicate_description = rel.predicate_description,
|
||||
r.statement_id = rel.statement_id,
|
||||
r.value = rel.value,
|
||||
r.statement = rel.statement,
|
||||
|
||||
@@ -44,6 +44,7 @@ async def save_entities_and_relationships(
|
||||
'source_id': edge.source,
|
||||
'target_id': edge.target,
|
||||
'predicate': edge.relation_type,
|
||||
'predicate_description': edge.relation_type_description,
|
||||
'statement_id': edge.source_statement_id,
|
||||
'value': edge.relation_value,
|
||||
'statement': edge.statement,
|
||||
@@ -297,6 +298,7 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
'source_id': edge.source,
|
||||
'target_id': edge.target,
|
||||
'predicate': edge.relation_type,
|
||||
'predicate_description': edge.relation_type_description,
|
||||
'statement_id': edge.source_statement_id,
|
||||
'value': edge.relation_value,
|
||||
'statement': edge.statement,
|
||||
|
||||
@@ -441,21 +441,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
# 步骤 6: 计算本体覆盖率并合并到结果中
|
||||
# 步骤 6: 组装结果(试运行不做额外覆盖率后处理)
|
||||
result_data = {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
try:
|
||||
ontology_coverage = await self._compute_ontology_coverage(
|
||||
extracted_result=extracted_result,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
if ontology_coverage:
|
||||
result_data["ontology_coverage"] = ontology_coverage
|
||||
except Exception as cov_err:
|
||||
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
|
||||
|
||||
yield format_sse_message("result", result_data)
|
||||
|
||||
@@ -479,100 +470,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
async def _compute_ontology_coverage(
|
||||
self,
|
||||
extracted_result: Dict[str, Any],
|
||||
memory_config,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
|
||||
|
||||
分类规则(互斥):场景类型优先 > 通用类型 > 未匹配
|
||||
确保: 场景实体数 + 通用实体数 + 未匹配数 = 总实体数
|
||||
|
||||
Returns:
|
||||
包含三部分统计的字典,或 None(无实体数据时)
|
||||
"""
|
||||
core_entities = extracted_result.get("core_entities", [])
|
||||
if not core_entities:
|
||||
return None
|
||||
|
||||
# 1. 加载场景本体类型集合
|
||||
scene_ontology_types: set = set()
|
||||
try:
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
if memory_config.scene_id:
|
||||
class_repo = OntologyClassRepository(self.db)
|
||||
ontology_classes = class_repo.get_classes_by_scene(memory_config.scene_id)
|
||||
scene_ontology_types = {oc.class_name for oc in ontology_classes}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load scene ontology types: {e}")
|
||||
|
||||
# 2. 加载通用本体类型集合
|
||||
general_ontology_types: set = set()
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import (
|
||||
get_general_ontology_registry,
|
||||
is_general_ontology_enabled,
|
||||
)
|
||||
|
||||
if is_general_ontology_enabled():
|
||||
registry = get_general_ontology_registry()
|
||||
if registry:
|
||||
general_ontology_types = set(registry.types.keys())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load general ontology types: {e}")
|
||||
|
||||
# 3. 互斥分类:场景优先 > 通用 > 未匹配
|
||||
scene_distribution: list = []
|
||||
general_distribution: list = []
|
||||
unmatched_distribution: list = []
|
||||
scene_total = 0
|
||||
general_total = 0
|
||||
unmatched_total = 0
|
||||
|
||||
for item in core_entities:
|
||||
entity_type = item.get("type", "")
|
||||
count = item.get("count", 0)
|
||||
|
||||
if entity_type in scene_ontology_types:
|
||||
scene_distribution.append({"type": entity_type, "count": count})
|
||||
scene_total += count
|
||||
elif entity_type in general_ontology_types:
|
||||
general_distribution.append({"type": entity_type, "count": count})
|
||||
general_total += count
|
||||
else:
|
||||
unmatched_distribution.append({"type": entity_type, "count": count})
|
||||
unmatched_total += count
|
||||
|
||||
# 按数量降序排列
|
||||
scene_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||||
general_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||||
unmatched_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||||
|
||||
total_entities = scene_total + general_total + unmatched_total
|
||||
|
||||
return {
|
||||
"scene_type_distribution": {
|
||||
"type_count": len(scene_distribution),
|
||||
"entity_total": scene_total,
|
||||
"types": scene_distribution,
|
||||
},
|
||||
"general_type_distribution": {
|
||||
"type_count": len(general_distribution),
|
||||
"entity_total": general_total,
|
||||
"types": general_distribution,
|
||||
},
|
||||
"unmatched": {
|
||||
"type_count": len(unmatched_distribution),
|
||||
"entity_total": unmatched_total,
|
||||
"types": unmatched_distribution,
|
||||
},
|
||||
"total_entities": total_entities,
|
||||
"time": int(time.time() * 1000),
|
||||
}
|
||||
|
||||
|
||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||||
|
||||
|
||||
@@ -10,7 +10,9 @@ import time
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
@@ -20,9 +22,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
|
||||
ExtractionOrchestrator,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_pipeline_config,
|
||||
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
||||
_write_extracted_result_summary,
|
||||
export_test_input_doc,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -31,6 +35,42 @@ from sqlalchemy.orm import Session
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
|
||||
"""Write triplet/entity text report compatible with pipeline_help parsers."""
|
||||
all_triplets = []
|
||||
all_entities = []
|
||||
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in getattr(dialog, "chunks", []) or []:
|
||||
for statement in getattr(chunk, "statements", []) or []:
|
||||
triplet_info = getattr(statement, "triplet_extraction_info", None)
|
||||
if not triplet_info:
|
||||
continue
|
||||
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
|
||||
all_entities.extend(getattr(triplet_info, "entities", []) or [])
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
|
||||
for i, triplet in enumerate(all_triplets, 1):
|
||||
f.write(f"Triplet {i}:\n")
|
||||
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
|
||||
f.write(f" Predicate: {triplet.predicate}\n")
|
||||
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
|
||||
value = getattr(triplet, "value", None)
|
||||
if value:
|
||||
f.write(f" Value: {value}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
|
||||
for i, entity in enumerate(all_entities, 1):
|
||||
f.write(f"Entity {i}:\n")
|
||||
f.write(f" ID: {entity.entity_idx}\n")
|
||||
f.write(f" Name: {entity.name}\n")
|
||||
f.write(f" Type: {entity.type}\n")
|
||||
f.write(f" Description: {entity.description}\n")
|
||||
f.write("\n")
|
||||
|
||||
|
||||
async def run_pilot_extraction(
|
||||
memory_config: MemoryConfig,
|
||||
dialogue_text: str,
|
||||
@@ -58,7 +98,6 @@ async def run_pilot_extraction(
|
||||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
neo4j_connector = None
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
@@ -69,8 +108,6 @@ async def run_pilot_extraction(
|
||||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 解析对话文本
|
||||
@@ -242,15 +279,17 @@ async def run_pilot_extraction(
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
config = get_pipeline_config(memory_config)
|
||||
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
|
||||
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
|
||||
logger.info(
|
||||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||||
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
|
||||
use_refactored,
|
||||
)
|
||||
logger.info(
|
||||
"Initializing %s pilot pipeline...",
|
||||
"refactored" if use_refactored else "legacy",
|
||||
)
|
||||
step_start = time.time()
|
||||
|
||||
# 加载本体类型(如果配置了 scene_id),支持通用类型回退
|
||||
ontology_types = None
|
||||
@@ -266,100 +305,105 @@ async def run_pilot_extraction(
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
if use_refactored:
|
||||
pilot_pipeline = PilotWritePipeline(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
pipeline_config=get_pipeline_config(memory_config),
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
# 步骤 4a: 执行重构后试运行短链路
|
||||
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
|
||||
logger.info("Running refactored pilot extraction short pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
pilot_result = await pilot_pipeline.run(chunked_dialogs)
|
||||
dialog_data_list = pilot_result.dialog_data_list
|
||||
graph = pilot_result.graph
|
||||
chunk_nodes = graph.chunk_nodes
|
||||
export_entity_nodes = graph.entity_nodes
|
||||
export_stmt_entity_edges = graph.stmt_entity_edges
|
||||
export_entity_edges = graph.entity_entity_edges
|
||||
else:
|
||||
# 步骤 4b: 执行旧试运行流水线
|
||||
logger.info("Running legacy pilot extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
_,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
_,
|
||||
_
|
||||
) = extraction_result
|
||||
neo4j_connector = Neo4jConnector()
|
||||
try:
|
||||
legacy_orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=get_pipeline_config(memory_config),
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
extraction_result = await legacy_orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
(
|
||||
_dialogue_nodes,
|
||||
chunk_nodes,
|
||||
_statement_nodes,
|
||||
entity_nodes,
|
||||
_perceptual_nodes,
|
||||
_statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
_perceptual_edges,
|
||||
_last_created_at,
|
||||
) = extraction_result
|
||||
dialog_data_list = chunked_dialogs
|
||||
export_entity_nodes = entity_nodes
|
||||
export_stmt_entity_edges = statement_entity_edges
|
||||
export_entity_edges = entity_edges
|
||||
finally:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约)
|
||||
settings.ensure_memory_output_dir()
|
||||
export_test_input_doc(
|
||||
entity_nodes=export_entity_nodes,
|
||||
statement_entity_edges=export_stmt_entity_edges,
|
||||
entity_entity_edges=export_entity_edges,
|
||||
)
|
||||
_save_triplets_from_dialogs(
|
||||
dialog_data_list=dialog_data_list,
|
||||
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
|
||||
)
|
||||
_write_extracted_result_summary(
|
||||
chunk_nodes=chunk_nodes,
|
||||
pipeline_output_dir=settings.get_memory_output_path(),
|
||||
)
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
language=language,
|
||||
)
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||
|
||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
|
||||
stats_to_cache = {
|
||||
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
|
||||
"statements_count": len(statement_nodes) if statement_nodes else 0,
|
||||
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
|
||||
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
|
||||
"temporal_count": 0, # temporal 数据在日志中,此处暂置0
|
||||
}
|
||||
await ActivityStatsCache.set_activity_stats(
|
||||
workspace_id=str(memory_config.workspace_id),
|
||||
stats=stats_to_cache,
|
||||
)
|
||||
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if neo4j_connector:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||||
|
||||
@@ -1382,6 +1382,7 @@ def extract_emotion_batch_task(
|
||||
llm_model_id: str,
|
||||
language: str = "zh",
|
||||
emotion_config: Optional[Dict[str, Any]] = None,
|
||||
snapshot_dir: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task: batch emotion extraction + Neo4j backfill.
|
||||
|
||||
@@ -1395,6 +1396,10 @@ def extract_emotion_batch_task(
|
||||
language: Language code ("zh" / "en").
|
||||
emotion_config: Optional dict with emotion step config overrides
|
||||
(emotion_extract_keywords, emotion_enable_subject).
|
||||
snapshot_dir: Optional absolute path of the current run's snapshot directory.
|
||||
When provided (only in debug mode), emotion outputs will be
|
||||
dumped to <snapshot_dir>/4_emotion_outputs.json for offline
|
||||
comparison between the legacy / new pipelines.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
total = len(statements)
|
||||
@@ -1445,6 +1450,8 @@ def extract_emotion_batch_task(
|
||||
extracted = 0
|
||||
failed = 0
|
||||
update_items = []
|
||||
# 快照用:收集每条 statement 的 EmotionStepOutput(仅当 snapshot_dir 非空时使用)
|
||||
snapshot_outputs: Dict[str, Any] = {} if snapshot_dir else None # type: ignore[assignment]
|
||||
|
||||
async def _extract_one(stmt_dict: Dict[str, str]):
|
||||
nonlocal extracted, failed
|
||||
@@ -1461,6 +1468,8 @@ def extract_emotion_batch_task(
|
||||
"emotion_intensity": result.emotion_intensity,
|
||||
"emotion_keywords": result.emotion_keywords,
|
||||
})
|
||||
if snapshot_outputs is not None:
|
||||
snapshot_outputs[stmt_dict["statement_id"]] = result.model_dump()
|
||||
extracted += 1
|
||||
logger.debug(
|
||||
f"[Emotion] 单条提取完成: stmt={stmt_dict['statement_id']}, "
|
||||
@@ -1468,12 +1477,33 @@ def extract_emotion_batch_task(
|
||||
)
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
if snapshot_outputs is not None:
|
||||
snapshot_outputs[stmt_dict["statement_id"]] = {"error": str(e)}
|
||||
logger.warning(
|
||||
f"[Emotion] 单条提取失败 stmt={stmt_dict['statement_id']}: {e}"
|
||||
)
|
||||
|
||||
await asyncio.gather(*[_extract_one(s) for s in statements])
|
||||
|
||||
# 快照落盘(worker 端):不影响 Neo4j 写入流程,失败只打日志
|
||||
if snapshot_outputs is not None:
|
||||
try:
|
||||
from pathlib import Path as _Path
|
||||
import json as _json
|
||||
|
||||
_dir = _Path(snapshot_dir)
|
||||
_dir.mkdir(parents=True, exist_ok=True)
|
||||
_path = _dir / "4_emotion_outputs.json"
|
||||
with open(_path, "w", encoding="utf-8") as _f:
|
||||
_json.dump(snapshot_outputs, _f, ensure_ascii=False, indent=2, default=str)
|
||||
logger.info(
|
||||
f"[Emotion][Snapshot] 已落盘 {len(snapshot_outputs)} 条情绪结果 → {_path}"
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.warning(
|
||||
f"[Emotion][Snapshot] 快照落盘失败(不影响主流程): {_e}"
|
||||
)
|
||||
|
||||
# Batch update Neo4j via write transaction
|
||||
if update_items:
|
||||
connector = Neo4jConnector()
|
||||
|
||||
Reference in New Issue
Block a user