refactor(memory): add PilotWritePipeline and enrich extraction schema
- Add dedicated PilotWritePipeline (statement → triplet → graph_build → layer-1 dedup, no Neo4j write) - Add type_description/predicate_description fields across entity and triplet models, Cypher queries, and graph builders - Refactor data_pruning with LRU cache and snapshot support; skip assistant chunks in extraction - Remove strict Predicate enum whitelist; support statement_text alias in legacy extractor - Wire PipelineSnapshot through preprocessing and emotion extraction for debug tracing - Add PILOT_RUN_USE_REFACTORED_PIPELINE env toggle for pipeline selection
This commit is contained in:
@@ -272,6 +272,12 @@ class Settings:
|
|||||||
|
|
||||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||||
|
# Pilot run pipeline switch:
|
||||||
|
# true -> use refactored PilotWritePipeline
|
||||||
|
# false -> use legacy ExtractionOrchestrator pipeline
|
||||||
|
PILOT_RUN_USE_REFACTORED_PIPELINE: bool = (
|
||||||
|
os.getenv("PILOT_RUN_USE_REFACTORED_PIPELINE", "true").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
# Tool Management Configuration
|
# Tool Management Configuration
|
||||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ async def get_chunked_dialogs(
|
|||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
messages: list = None,
|
messages: list = None,
|
||||||
ref_id: str = "",
|
ref_id: str = "",
|
||||||
config_id: str = None
|
config_id: str = None,
|
||||||
|
snapshot=None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ async def get_chunked_dialogs(
|
|||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference identifier
|
ref_id: Reference identifier
|
||||||
config_id: Configuration ID for processing (used to load pruning config)
|
config_id: Configuration ID for processing (used to load pruning config)
|
||||||
|
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of DialogData objects with generated chunks
|
List of DialogData objects with generated chunks
|
||||||
@@ -93,7 +95,7 @@ async def get_chunked_dialogs(
|
|||||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||||
|
|
||||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client, snapshot=snapshot)
|
||||||
original_msg_count = len(dialog_data.context.msgs)
|
original_msg_count = len(dialog_data.context.msgs)
|
||||||
|
|
||||||
# 使用 prune_dataset 而不是 prune_dialog
|
# 使用 prune_dataset 而不是 prune_dialog
|
||||||
|
|||||||
@@ -184,7 +184,8 @@ async def write(
|
|||||||
"entities": [
|
"entities": [
|
||||||
{
|
{
|
||||||
"entity_idx": e.entity_idx, "name": e.name,
|
"entity_idx": e.entity_idx, "name": e.name,
|
||||||
"type": e.type, "description": e.description,
|
"type": e.type, "type_description": getattr(e, "type_description", ""),
|
||||||
|
"description": e.description,
|
||||||
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
|
"is_explicit_memory": getattr(e, "is_explicit_memory", False),
|
||||||
}
|
}
|
||||||
for e in s.triplet_extraction_info.entities
|
for e in s.triplet_extraction_info.entities
|
||||||
@@ -193,6 +194,7 @@ async def write(
|
|||||||
{
|
{
|
||||||
"subject_name": t.subject_name, "subject_id": t.subject_id,
|
"subject_name": t.subject_name, "subject_id": t.subject_id,
|
||||||
"predicate": t.predicate,
|
"predicate": t.predicate,
|
||||||
|
"predicate_description": getattr(t, "predicate_description", ""),
|
||||||
"object_name": t.object_name, "object_id": t.object_id,
|
"object_name": t.object_name, "object_id": t.object_id,
|
||||||
}
|
}
|
||||||
for t in s.triplet_extraction_info.triplets
|
for t in s.triplet_extraction_info.triplets
|
||||||
@@ -206,13 +208,13 @@ async def write(
|
|||||||
"chunk_nodes_count": len(all_chunk_nodes),
|
"chunk_nodes_count": len(all_chunk_nodes),
|
||||||
"statement_nodes_count": len(all_statement_nodes),
|
"statement_nodes_count": len(all_statement_nodes),
|
||||||
"entity_nodes": [
|
"entity_nodes": [
|
||||||
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "description": e.description}
|
{"id": e.id, "name": e.name, "entity_type": e.entity_type, "type_description": e.type_description, "description": e.description}
|
||||||
for e in all_entity_nodes
|
for e in all_entity_nodes
|
||||||
],
|
],
|
||||||
"entity_entity_edges": [
|
"entity_entity_edges": [
|
||||||
{
|
{
|
||||||
"source": e.source, "target": e.target,
|
"source": e.source, "target": e.target,
|
||||||
"relation_type": e.relation_type, "statement": e.statement,
|
"relation_type": e.relation_type, "relation_type_description": e.relation_type_description, "statement": e.statement,
|
||||||
}
|
}
|
||||||
for e in all_entity_entity_edges
|
for e in all_entity_entity_edges
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ class EntityEntityEdge(Edge):
|
|||||||
invalid_at: Optional end date of temporal validity
|
invalid_at: Optional end date of temporal validity
|
||||||
"""
|
"""
|
||||||
relation_type: str = Field(..., description="Relation type as defined in ontology")
|
relation_type: str = Field(..., description="Relation type as defined in ontology")
|
||||||
|
relation_type_description: str = Field(default="", description="Chinese definition of the relation type from ontology")
|
||||||
relation_value: Optional[str] = Field(None, description="Value of the relation")
|
relation_value: Optional[str] = Field(None, description="Value of the relation")
|
||||||
statement: str = Field(..., description='The statement of the edge.')
|
statement: str = Field(..., description='The statement of the edge.')
|
||||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||||
@@ -413,6 +414,7 @@ class ExtractedEntityNode(Node):
|
|||||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||||
entity_type: str = Field(..., description="Type of the entity")
|
entity_type: str = Field(..., description="Type of the entity")
|
||||||
|
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||||
description: str = Field(..., description="Entity description")
|
description: str = Field(..., description="Entity description")
|
||||||
example: str = Field(
|
example: str = Field(
|
||||||
default="",
|
default="",
|
||||||
|
|||||||
@@ -96,6 +96,10 @@ class Statement(BaseModel):
|
|||||||
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
||||||
# Reference resolution
|
# Reference resolution
|
||||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||||
|
has_emotional_state: bool = Field(
|
||||||
|
False,
|
||||||
|
description="Whether the statement reflects user's emotional state",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationContext(BaseModel):
|
class ConversationContext(BaseModel):
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class Entity(BaseModel):
|
|||||||
name: str = Field(..., description="Name of the entity")
|
name: str = Field(..., description="Name of the entity")
|
||||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||||
type: str = Field(..., description="Type/category of the entity")
|
type: str = Field(..., description="Type/category of the entity")
|
||||||
|
type_description: str = Field(default="", description="Chinese definition of the entity type from ontology")
|
||||||
description: str = Field(..., description="Description of the entity")
|
description: str = Field(..., description="Description of the entity")
|
||||||
example: str = Field(
|
example: str = Field(
|
||||||
default="",
|
default="",
|
||||||
@@ -79,6 +80,7 @@ class Triplet(BaseModel):
|
|||||||
subject_name: str = Field(..., description="Name of the subject entity")
|
subject_name: str = Field(..., description="Name of the subject entity")
|
||||||
subject_id: int = Field(..., description="ID of the subject entity")
|
subject_id: int = Field(..., description="ID of the subject entity")
|
||||||
predicate: str = Field(..., description="Relationship/predicate between subject and object")
|
predicate: str = Field(..., description="Relationship/predicate between subject and object")
|
||||||
|
predicate_description: str = Field(default="", description="Chinese definition of the predicate from ontology")
|
||||||
object_name: str = Field(..., description="Name of the object entity")
|
object_name: str = Field(..., description="Name of the object entity")
|
||||||
object_id: int = Field(..., description="ID of the object entity")
|
object_id: int = Field(..., description="ID of the object entity")
|
||||||
value: Optional[str] = Field(None, description="Additional value or context")
|
value: Optional[str] = Field(None, description="Additional value or context")
|
||||||
|
|||||||
@@ -14,13 +14,31 @@ def __getattr__(name):
|
|||||||
WritePipeline,
|
WritePipeline,
|
||||||
WriteResult,
|
WriteResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
_exports = {
|
_exports = {
|
||||||
"WritePipeline": WritePipeline,
|
"WritePipeline": WritePipeline,
|
||||||
"ExtractionResult": ExtractionResult,
|
"ExtractionResult": ExtractionResult,
|
||||||
"WriteResult": WriteResult,
|
"WriteResult": WriteResult,
|
||||||
}
|
}
|
||||||
return _exports[name]
|
return _exports[name]
|
||||||
|
if name in ("PilotWritePipeline", "PilotWriteResult"):
|
||||||
|
from app.core.memory.pipelines.pilot_write_pipeline import (
|
||||||
|
PilotWritePipeline,
|
||||||
|
PilotWriteResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
_exports = {
|
||||||
|
"PilotWritePipeline": PilotWritePipeline,
|
||||||
|
"PilotWriteResult": PilotWriteResult,
|
||||||
|
}
|
||||||
|
return _exports[name]
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["WritePipeline", "ExtractionResult", "WriteResult"]
|
__all__ = [
|
||||||
|
"WritePipeline",
|
||||||
|
"ExtractionResult",
|
||||||
|
"WriteResult",
|
||||||
|
"PilotWritePipeline",
|
||||||
|
"PilotWriteResult",
|
||||||
|
]
|
||||||
|
|||||||
108
api/app/core/memory/pipelines/pilot_write_pipeline.py
Normal file
108
api/app/core/memory/pipelines/pilot_write_pipeline.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""PilotWritePipeline — 试运行专用萃取流水线。
|
||||||
|
|
||||||
|
职责边界:
|
||||||
|
- 只执行“萃取相关”链路:statement -> triplet -> graph_build -> 第一层去重消歧
|
||||||
|
- 不负责 Neo4j 写入、聚类、摘要、缓存更新
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from app.core.memory.models.message_models import DialogData
|
||||||
|
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||||
|
from app.core.memory.storage_services.extraction_engine.dedup_step import (
|
||||||
|
DedupResult,
|
||||||
|
run_dedup,
|
||||||
|
)
|
||||||
|
from app.core.memory.storage_services.extraction_engine.steps.extraction_pipeline_orchestrator import (
|
||||||
|
NewExtractionOrchestrator,
|
||||||
|
)
|
||||||
|
from app.core.memory.storage_services.extraction_engine.steps.graph_build_step import (
|
||||||
|
GraphBuildResult,
|
||||||
|
build_graph_nodes_and_edges,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PilotWriteResult:
|
||||||
|
"""试运行流水线输出。"""
|
||||||
|
|
||||||
|
dialog_data_list: List[DialogData]
|
||||||
|
graph: GraphBuildResult
|
||||||
|
dedup: DedupResult
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> Dict[str, int]:
|
||||||
|
return {
|
||||||
|
"chunk_count": len(self.graph.chunk_nodes),
|
||||||
|
"statement_count": len(self.graph.statement_nodes),
|
||||||
|
"entity_count_before_dedup": len(self.graph.entity_nodes),
|
||||||
|
"entity_count_after_dedup": len(self.dedup.entity_nodes),
|
||||||
|
"relation_count_before_dedup": len(self.graph.entity_entity_edges),
|
||||||
|
"relation_count_after_dedup": len(self.dedup.entity_entity_edges),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PilotWritePipeline:
|
||||||
|
"""重构后试运行专用流水线。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_client: Any,
|
||||||
|
embedder_client: Any,
|
||||||
|
pipeline_config: ExtractionPipelineConfig,
|
||||||
|
embedding_id: Optional[str],
|
||||||
|
language: str = "zh",
|
||||||
|
ontology_types: Any = None,
|
||||||
|
progress_callback: Optional[
|
||||||
|
Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
self.llm_client = llm_client
|
||||||
|
self.embedder_client = embedder_client
|
||||||
|
self.pipeline_config = pipeline_config
|
||||||
|
self.embedding_id = embedding_id
|
||||||
|
self.language = language
|
||||||
|
self.ontology_types = ontology_types
|
||||||
|
self.progress_callback = progress_callback
|
||||||
|
|
||||||
|
async def run(self, dialog_data_list: List[DialogData]) -> PilotWriteResult:
|
||||||
|
"""执行试运行萃取链路。"""
|
||||||
|
orchestrator = NewExtractionOrchestrator(
|
||||||
|
llm_client=self.llm_client,
|
||||||
|
embedder_client=self.embedder_client,
|
||||||
|
config=self.pipeline_config,
|
||||||
|
embedding_id=self.embedding_id,
|
||||||
|
ontology_types=self.ontology_types,
|
||||||
|
language=self.language,
|
||||||
|
is_pilot_run=True,
|
||||||
|
progress_callback=self.progress_callback,
|
||||||
|
)
|
||||||
|
extracted_dialogs = await orchestrator.run(dialog_data_list)
|
||||||
|
|
||||||
|
graph = await build_graph_nodes_and_edges(
|
||||||
|
dialog_data_list=extracted_dialogs,
|
||||||
|
embedder_client=self.embedder_client,
|
||||||
|
progress_callback=self.progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
dedup = await run_dedup(
|
||||||
|
entity_nodes=graph.entity_nodes,
|
||||||
|
statement_entity_edges=graph.stmt_entity_edges,
|
||||||
|
entity_entity_edges=graph.entity_entity_edges,
|
||||||
|
dialog_data_list=extracted_dialogs,
|
||||||
|
pipeline_config=self.pipeline_config,
|
||||||
|
connector=None, # pilot: no layer-2 db dedup
|
||||||
|
llm_client=self.llm_client,
|
||||||
|
is_pilot_run=True,
|
||||||
|
progress_callback=self.progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
return PilotWriteResult(
|
||||||
|
dialog_data_list=extracted_dialogs,
|
||||||
|
graph=graph,
|
||||||
|
dedup=dedup,
|
||||||
|
)
|
||||||
|
|
||||||
@@ -180,7 +180,11 @@ class WritePipeline:
|
|||||||
self._init_clients()
|
self._init_clients()
|
||||||
self._init_neo4j_connector()
|
self._init_neo4j_connector()
|
||||||
|
|
||||||
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝(暂无实现)
|
# 初始化 Snapshot(提前创建,供预处理阶段的剪枝使用)
|
||||||
|
from app.core.memory.utils.debug.pipeline_snapshot import PipelineSnapshot
|
||||||
|
self._snapshot = PipelineSnapshot("new")
|
||||||
|
|
||||||
|
# Step 1: 预处理 - 消息分块 + AI消息语义剪枝
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
chunked_dialogs = await self._preprocess(messages, ref_id)
|
chunked_dialogs = await self._preprocess(messages, ref_id)
|
||||||
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
|
chunks_count = sum(len(d.chunks) for d in chunked_dialogs)
|
||||||
@@ -220,7 +224,7 @@ class WritePipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
||||||
self._extract_emotion(getattr(self, "_emotion_statements", []))
|
await self._extract_emotion(getattr(self, "_emotion_statements", []))
|
||||||
|
|
||||||
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
|
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
@@ -266,7 +270,7 @@ class WritePipeline:
|
|||||||
|
|
||||||
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
|
async def _preprocess(self, messages: List[dict], ref_id: str) -> List[DialogData]:
|
||||||
"""
|
"""
|
||||||
预处理:消息校验 → AI消息语义剪枝(暂未实现) → 对话分块。
|
预处理:消息校验 → AI消息语义剪枝 → 对话分块。
|
||||||
|
|
||||||
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
|
委托给 get_chunked_dialogs(),保持现有预处理逻辑不变。
|
||||||
get_dialogs.py 内部已包含:
|
get_dialogs.py 内部已包含:
|
||||||
@@ -276,12 +280,15 @@ class WritePipeline:
|
|||||||
"""
|
"""
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
|
|
||||||
|
snapshot = getattr(self, "_snapshot", None)
|
||||||
|
|
||||||
return await get_chunked_dialogs(
|
return await get_chunked_dialogs(
|
||||||
chunker_strategy=self.memory_config.chunker_strategy,
|
chunker_strategy=self.memory_config.chunker_strategy,
|
||||||
end_user_id=self.end_user_id,
|
end_user_id=self.end_user_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
config_id=str(self.memory_config.config_id),
|
config_id=str(self.memory_config.config_id),
|
||||||
|
snapshot=snapshot,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
@@ -321,7 +328,9 @@ class WritePipeline:
|
|||||||
pipeline_config = get_pipeline_config(self.memory_config)
|
pipeline_config = get_pipeline_config(self.memory_config)
|
||||||
ontology_types = self._load_ontology_types()
|
ontology_types = self._load_ontology_types()
|
||||||
|
|
||||||
snapshot = PipelineSnapshot("new")
|
# 复用 run() 中已创建的 snapshot(剪枝阶段已使用同一实例)
|
||||||
|
snapshot = getattr(self, "_snapshot", None) or PipelineSnapshot("new")
|
||||||
|
self._snapshot = snapshot
|
||||||
|
|
||||||
# ── 新编排器:LLM 萃取 + 数据赋值 ──
|
# ── 新编排器:LLM 萃取 + 数据赋值 ──
|
||||||
new_orchestrator = NewExtractionOrchestrator(
|
new_orchestrator = NewExtractionOrchestrator(
|
||||||
@@ -589,11 +598,15 @@ class WritePipeline:
|
|||||||
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
def _extract_emotion(self, emotion_statements: list) -> None:
|
async def _extract_emotion(self, emotion_statements: list) -> None:
|
||||||
"""提交异步情绪提取 Celery 任务。
|
"""提交异步情绪提取 Celery 任务。
|
||||||
|
|
||||||
从编排器收集的 user statement 列表中提取情绪,
|
从编排器收集的 user statement 列表中提取情绪,
|
||||||
异步回写到 Neo4j Statement 节点。失败不影响主流程。
|
异步回写到 Neo4j Statement 节点。失败不影响主流程。
|
||||||
|
|
||||||
|
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
|
||||||
|
通过 snapshot_dir 透传给 Celery 任务;worker 端在完成 LLM 抽取后,
|
||||||
|
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json,避免主进程重复调用 LLM。
|
||||||
"""
|
"""
|
||||||
if not emotion_statements:
|
if not emotion_statements:
|
||||||
return
|
return
|
||||||
@@ -607,6 +620,14 @@ class WritePipeline:
|
|||||||
logger.warning("[Emotion] 无法提交情绪提取任务:llm_model_id 为空")
|
logger.warning("[Emotion] 无法提交情绪提取任务:llm_model_id 为空")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
|
||||||
|
snapshot = getattr(self, "_snapshot", None)
|
||||||
|
snapshot_dir = (
|
||||||
|
snapshot.directory
|
||||||
|
if snapshot is not None and getattr(snapshot, "enabled", False)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
@@ -616,12 +637,14 @@ class WritePipeline:
|
|||||||
"statements": emotion_statements,
|
"statements": emotion_statements,
|
||||||
"llm_model_id": llm_model_id,
|
"llm_model_id": llm_model_id,
|
||||||
"language": self.language,
|
"language": self.language,
|
||||||
|
"snapshot_dir": snapshot_dir,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[Emotion] 异步情绪提取任务已提交 - "
|
f"[Emotion] 异步情绪提取任务已提交 - "
|
||||||
f"task_id={result.id}, "
|
f"task_id={result.id}, "
|
||||||
f"statement_count={len(emotion_statements)}, "
|
f"statement_count={len(emotion_statements)}, "
|
||||||
|
f"snapshot_dir={snapshot_dir}, "
|
||||||
f"source=async"
|
f"source=async"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -629,6 +652,7 @@ class WritePipeline:
|
|||||||
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
|
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
# Step 5: 摘要
|
# Step 5: 摘要
|
||||||
# (+ entity_description)+ meta_data部分在此提取
|
# (+ entity_description)+ meta_data部分在此提取
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1264,6 +1264,7 @@ class ExtractionOrchestrator:
|
|||||||
entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx
|
entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx
|
||||||
statement_id=statement.id, # 添加必需的 statement_id 字段
|
statement_id=statement.id, # 添加必需的 statement_id 字段
|
||||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||||
|
type_description=getattr(entity, 'type_description', ''),
|
||||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
@@ -1306,6 +1307,7 @@ class ExtractionOrchestrator:
|
|||||||
source=subject_entity_id,
|
source=subject_entity_id,
|
||||||
target=object_entity_id,
|
target=object_entity_id,
|
||||||
relation_type=triplet.predicate,
|
relation_type=triplet.predicate,
|
||||||
|
relation_type_description=getattr(triplet, 'predicate_description', ''),
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
source_statement_id=statement.id,
|
source_statement_id=statement.id,
|
||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
|||||||
@@ -12,16 +12,21 @@ from app.core.memory.utils.data.ontology import (
|
|||||||
TemporalInfo,
|
TemporalInfo,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ExtractedStatement(BaseModel):
|
class ExtractedStatement(BaseModel):
|
||||||
"""Schema for extracted statement from LLM"""
|
"""Schema for extracted statement from LLM"""
|
||||||
statement: str = Field(..., description="The extracted statement text")
|
statement: str = Field(
|
||||||
|
...,
|
||||||
|
validation_alias=AliasChoices("statement", "statement_text"),
|
||||||
|
description="The extracted statement text",
|
||||||
|
)
|
||||||
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
|
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
|
||||||
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
||||||
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
# New prompt no longer outputs relevence; keep backward-compatible default.
|
||||||
|
relevence: str = Field("RELEVANT", description="RELEVANT or IRRELEVANT")
|
||||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||||
|
|
||||||
class StatementExtractionResponse(BaseModel):
|
class StatementExtractionResponse(BaseModel):
|
||||||
@@ -41,7 +46,7 @@ class StatementExtractionResponse(BaseModel):
|
|||||||
valid_statements = []
|
valid_statements = []
|
||||||
filtered_count = 0
|
filtered_count = 0
|
||||||
for i, stmt in enumerate(v):
|
for i, stmt in enumerate(v):
|
||||||
if isinstance(stmt, dict) and stmt.get('statement'):
|
if isinstance(stmt, dict) and (stmt.get("statement") or stmt.get("statement_text")):
|
||||||
valid_statements.append(stmt)
|
valid_statements.append(stmt)
|
||||||
elif isinstance(stmt, dict):
|
elif isinstance(stmt, dict):
|
||||||
# Log which statement was filtered
|
# Log which statement was filtered
|
||||||
@@ -96,6 +101,11 @@ class StatementExtractor:
|
|||||||
"""
|
"""
|
||||||
chunk_content = chunk.content
|
chunk_content = chunk.content
|
||||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||||
|
logger.info(
|
||||||
|
"[LegacyStatementExtractor] chunk_id=%s content_len=%d",
|
||||||
|
getattr(chunk, "id", ""),
|
||||||
|
len(chunk_content or ""),
|
||||||
|
)
|
||||||
|
|
||||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||||
@@ -108,7 +118,18 @@ class StatementExtractor:
|
|||||||
granularity=self.config.statement_granularity,
|
granularity=self.config.statement_granularity,
|
||||||
include_dialogue_context=self.config.include_dialogue_context,
|
include_dialogue_context=self.config.include_dialogue_context,
|
||||||
dialogue_content=dialogue_content,
|
dialogue_content=dialogue_content,
|
||||||
max_dialogue_chars=self.config.max_dialogue_context_chars
|
max_dialogue_chars=self.config.max_dialogue_context_chars,
|
||||||
|
input_json={
|
||||||
|
"chunk_id": getattr(chunk, "id", ""),
|
||||||
|
"end_user_id": end_user_id or "",
|
||||||
|
"target_content": chunk_content,
|
||||||
|
"target_message_date": datetime.now().isoformat(),
|
||||||
|
"supporting_context": {
|
||||||
|
"msgs": [
|
||||||
|
{"role": "context", "msg": dialogue_content}
|
||||||
|
] if dialogue_content else []
|
||||||
|
},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Simple system message
|
# Simple system message
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import List, Dict, Optional
|
|||||||
from app.core.logging_config import get_memory_logger
|
from app.core.logging_config import get_memory_logger
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS
|
||||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||||
from app.core.memory.models.message_models import DialogData, Statement
|
from app.core.memory.models.message_models import DialogData, Statement
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||||
@@ -73,15 +73,9 @@ class TripletExtractor:
|
|||||||
try:
|
try:
|
||||||
# Get structured response from LLM
|
# Get structured response from LLM
|
||||||
response = await self.llm_client.response_structured(messages, TripletExtractionResponse)
|
response = await self.llm_client.response_structured(messages, TripletExtractionResponse)
|
||||||
# Filter triplets to only allowed predicates from ontology
|
|
||||||
# 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系
|
|
||||||
allowed_predicates = {p.value for p in Predicate}
|
|
||||||
filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates]
|
|
||||||
# 仅保留predicate ∈ Predicate 的三元组,其余全部剔除
|
|
||||||
|
|
||||||
# Create new triplets with statement_id set during creation
|
# Create new triplets with statement_id set during creation
|
||||||
updated_triplets = []
|
updated_triplets = []
|
||||||
for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组
|
for triplet in response.triplets:
|
||||||
updated_triplet = triplet.model_copy(update={"statement_id": statement.id})
|
updated_triplet = triplet.model_copy(update={"statement_id": statement.id})
|
||||||
updated_triplets.append(updated_triplet)
|
updated_triplets.append(updated_triplet)
|
||||||
|
|
||||||
|
|||||||
@@ -300,6 +300,33 @@ class NewExtractionOrchestrator:
|
|||||||
"embedding_output": None,
|
"embedding_output": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.progress_callback:
|
||||||
|
statements_count = sum(
|
||||||
|
len(stmts)
|
||||||
|
for chunk_stmts in all_stmt_results.values()
|
||||||
|
for stmts in chunk_stmts.values()
|
||||||
|
)
|
||||||
|
entities_count = sum(
|
||||||
|
len(t_out.entities)
|
||||||
|
for stmt_triplets in all_triplet_results.values()
|
||||||
|
for t_out in stmt_triplets.values()
|
||||||
|
)
|
||||||
|
triplets_count = sum(
|
||||||
|
len(t_out.triplets)
|
||||||
|
for stmt_triplets in all_triplet_results.values()
|
||||||
|
for t_out in stmt_triplets.values()
|
||||||
|
)
|
||||||
|
await self.progress_callback(
|
||||||
|
"knowledge_extraction_complete",
|
||||||
|
"知识抽取完成",
|
||||||
|
{
|
||||||
|
"entities_count": entities_count,
|
||||||
|
"statements_count": statements_count,
|
||||||
|
"temporal_ranges_count": 0,
|
||||||
|
"triplets_count": triplets_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Pilot extraction complete")
|
logger.info("Pilot extraction complete")
|
||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
@@ -467,6 +494,11 @@ class NewExtractionOrchestrator:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
for chunk in dialog.chunks:
|
for chunk in dialog.chunks:
|
||||||
|
# 仅对 speaker="user" 的 chunk 进行陈述句抽取;assistant 内容交给
|
||||||
|
# 上游预处理/剪枝阶段处理,避免浪费 LLM 调用。
|
||||||
|
chunk_speaker = getattr(chunk, "speaker", "user")
|
||||||
|
if chunk_speaker != "user":
|
||||||
|
continue
|
||||||
inp = StatementStepInput(
|
inp = StatementStepInput(
|
||||||
chunk_id=chunk.id,
|
chunk_id=chunk.id,
|
||||||
end_user_id=dialog.end_user_id,
|
end_user_id=dialog.end_user_id,
|
||||||
@@ -478,7 +510,7 @@ class NewExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
tasks.append(self.statement_step.run(inp))
|
tasks.append(self.statement_step.run(inp))
|
||||||
task_meta.append(
|
task_meta.append(
|
||||||
(dialog.id, chunk.id, getattr(chunk, "speaker", "user"), ctx)
|
(dialog.id, chunk.id, chunk_speaker, ctx)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
@@ -499,6 +531,15 @@ class NewExtractionOrchestrator:
|
|||||||
for s in stmts:
|
for s in stmts:
|
||||||
s.speaker = speaker
|
s.speaker = speaker
|
||||||
stmt_map[dialog_id][chunk_id] = stmts
|
stmt_map[dialog_id][chunk_id] = stmts
|
||||||
|
if self.progress_callback:
|
||||||
|
# Frontend consumes knowledge_extraction_result with data.statement.
|
||||||
|
# Emit one event per statement to keep payload contract simple.
|
||||||
|
for s in stmts:
|
||||||
|
await self.progress_callback(
|
||||||
|
"knowledge_extraction_result",
|
||||||
|
"知识抽取中",
|
||||||
|
{"statement": s.statement_text},
|
||||||
|
)
|
||||||
|
|
||||||
return stmt_map
|
return stmt_map
|
||||||
|
|
||||||
@@ -520,6 +561,11 @@ class NewExtractionOrchestrator:
|
|||||||
chunk_stmts = all_stmt_results.get(dialog.id, {})
|
chunk_stmts = all_stmt_results.get(dialog.id, {})
|
||||||
for _chunk_id, stmts in chunk_stmts.items():
|
for _chunk_id, stmts in chunk_stmts.items():
|
||||||
for stmt in stmts:
|
for stmt in stmts:
|
||||||
|
# 防御性过滤:三元组抽取仅针对 user statement。
|
||||||
|
# 上游 _extract_all_statements 已过滤 chunk.speaker,此处再做
|
||||||
|
# 一次 statement.speaker 的二次校验,防止外部注入或 legacy 数据脱漏。
|
||||||
|
if getattr(stmt, "speaker", "user") != "user":
|
||||||
|
continue
|
||||||
inp = self._convert_to_triplet_input(stmt, ctx)
|
inp = self._convert_to_triplet_input(stmt, ctx)
|
||||||
tasks.append(self.triplet_step.run(inp))
|
tasks.append(self.triplet_step.run(inp))
|
||||||
task_meta.append((dialog.id, stmt.statement_id))
|
task_meta.append((dialog.id, stmt.statement_id))
|
||||||
@@ -541,6 +587,24 @@ class NewExtractionOrchestrator:
|
|||||||
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
|
triplet_map[dialog_id][stmt_id] = self.triplet_step.get_default_output()
|
||||||
else:
|
else:
|
||||||
triplet_map[dialog_id][stmt_id] = result
|
triplet_map[dialog_id][stmt_id] = result
|
||||||
|
if self.progress_callback:
|
||||||
|
await self.progress_callback(
|
||||||
|
"extract_triplet_result",
|
||||||
|
f"statement {stmt_id} 提取完成",
|
||||||
|
{
|
||||||
|
"statement_id": stmt_id,
|
||||||
|
"triplet_count": len(result.triplets),
|
||||||
|
"entity_count": len(result.entities),
|
||||||
|
"triplets": [
|
||||||
|
{
|
||||||
|
"subject_name": t.subject_name,
|
||||||
|
"predicate": t.predicate,
|
||||||
|
"object_name": t.object_name,
|
||||||
|
}
|
||||||
|
for t in result.triplets[:5]
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return triplet_map
|
return triplet_map
|
||||||
|
|
||||||
@@ -842,6 +906,8 @@ class NewExtractionOrchestrator:
|
|||||||
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
|
temporal_info=_TEMPORAL_MAP.get(stmt_out.temporal_type, TemporalInfo.ATEMPORAL),
|
||||||
# relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
|
# relevence_info=RelevenceInfo.RELEVANT if stmt_out.relevance == "RELEVANT" else RelevenceInfo.IRRELEVANT,
|
||||||
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
|
temporal_validity=TemporalValidityRange(valid_at=valid_at, invalid_at=invalid_at),
|
||||||
|
has_unsolved_reference=stmt_out.has_unsolved_reference,
|
||||||
|
has_emotional_state=stmt_out.has_emotional_state,
|
||||||
triplet_extraction_info=triplet_info,
|
triplet_extraction_info=triplet_info,
|
||||||
statement_embedding=stmt_embedding,
|
statement_embedding=stmt_embedding,
|
||||||
**emotion_kwargs,
|
**emotion_kwargs,
|
||||||
|
|||||||
@@ -250,6 +250,7 @@ async def build_graph_nodes_and_edges(
|
|||||||
entity_idx=entity.entity_idx,
|
entity_idx=entity.entity_idx,
|
||||||
statement_id=statement.id,
|
statement_id=statement.id,
|
||||||
entity_type=getattr(entity, "type", "unknown"),
|
entity_type=getattr(entity, "type", "unknown"),
|
||||||
|
type_description=getattr(entity, "type_description", ""),
|
||||||
description=getattr(entity, "description", ""),
|
description=getattr(entity, "description", ""),
|
||||||
example=getattr(entity, "example", ""),
|
example=getattr(entity, "example", ""),
|
||||||
connect_strength=(
|
connect_strength=(
|
||||||
@@ -296,6 +297,7 @@ async def build_graph_nodes_and_edges(
|
|||||||
source=subject_entity_id,
|
source=subject_entity_id,
|
||||||
target=object_entity_id,
|
target=object_entity_id,
|
||||||
relation_type=triplet.predicate,
|
relation_type=triplet.predicate,
|
||||||
|
relation_type_description=getattr(triplet, "predicate_description", ""),
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
source_statement_id=statement.id,
|
source_statement_id=statement.id,
|
||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class StatementStepOutput(BaseModel):
|
|||||||
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
|
temporal_type: str # STATIC / DYNAMIC / ATEMPORAL
|
||||||
# relevance: str # RELEVANT / IRRELEVANT
|
# relevance: str # RELEVANT / IRRELEVANT
|
||||||
speaker: str # "user" / "assistant"
|
speaker: str # "user" / "assistant"
|
||||||
|
has_emotional_state: bool = False # Whether statement reflects user's emotional state
|
||||||
valid_at: str # ISO 8601 or "NULL"
|
valid_at: str # ISO 8601 or "NULL"
|
||||||
invalid_at: str # ISO 8601 or "NULL"
|
invalid_at: str # ISO 8601 or "NULL"
|
||||||
has_unsolved_reference: bool = False # Whether the statement has unresolved references
|
has_unsolved_reference: bool = False # Whether the statement has unresolved references
|
||||||
@@ -72,6 +73,7 @@ class EntityItem(BaseModel):
|
|||||||
entity_idx: int
|
entity_idx: int
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str
|
||||||
|
type_description: str = ""
|
||||||
description: str
|
description: str
|
||||||
is_explicit_memory: bool = False
|
is_explicit_memory: bool = False
|
||||||
|
|
||||||
@@ -82,6 +84,7 @@ class TripletItem(BaseModel):
|
|||||||
subject_name: str
|
subject_name: str
|
||||||
subject_id: int
|
subject_id: int
|
||||||
predicate: str
|
predicate: str
|
||||||
|
predicate_description: str = ""
|
||||||
object_name: str
|
object_name: str
|
||||||
object_id: int
|
object_id: int
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,10 @@ class _ExtractedStatement(BaseModel):
|
|||||||
statement_type: str = Field(..., description="FACT / OPINION / OTHER")
|
statement_type: str = Field(..., description="FACT / OPINION / OTHER")
|
||||||
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
|
temporal_type: str = Field(..., description="STATIC / DYNAMIC / ATEMPORAL")
|
||||||
# relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
|
# relevance: str = Field("RELEVANT", description="RELEVANT / IRRELEVANT")
|
||||||
|
has_emotional_state: bool = Field(
|
||||||
|
False,
|
||||||
|
description="Whether the statement reflects user's emotional state",
|
||||||
|
)
|
||||||
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
valid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||||
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
invalid_at: str = Field("NULL", description="ISO 8601 or NULL")
|
||||||
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
has_unsolved_reference: bool = Field(False, description="Whether the statement has unresolved references")
|
||||||
@@ -155,6 +159,7 @@ class StatementExtractionStep(ExtractionStep[StatementStepInput, List[StatementS
|
|||||||
temporal_type=stmt.temporal_type.strip().upper(),
|
temporal_type=stmt.temporal_type.strip().upper(),
|
||||||
# relevance=stmt.relevance.strip().upper(),
|
# relevance=stmt.relevance.strip().upper(),
|
||||||
speaker="user", # default; orchestrator overrides from chunk metadata
|
speaker="user", # default; orchestrator overrides from chunk metadata
|
||||||
|
has_emotional_state=getattr(stmt, "has_emotional_state", False),
|
||||||
valid_at=stmt.valid_at or "NULL",
|
valid_at=stmt.valid_at or "NULL",
|
||||||
invalid_at=stmt.invalid_at or "NULL",
|
invalid_at=stmt.invalid_at or "NULL",
|
||||||
has_unsolved_reference=getattr(stmt, "has_unsolved_reference", False),
|
has_unsolved_reference=getattr(stmt, "has_unsolved_reference", False),
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
|
|||||||
subject_name=t.subject_name,
|
subject_name=t.subject_name,
|
||||||
subject_id=t.subject_id,
|
subject_id=t.subject_id,
|
||||||
predicate=t.predicate,
|
predicate=t.predicate,
|
||||||
|
predicate_description=getattr(t, "predicate_description", ""),
|
||||||
object_name=t.object_name,
|
object_name=t.object_name,
|
||||||
object_id=t.object_id,
|
object_id=t.object_id,
|
||||||
)
|
)
|
||||||
@@ -123,6 +124,7 @@ class TripletExtractionStep(ExtractionStep[TripletStepInput, TripletStepOutput])
|
|||||||
entity_idx=e.entity_idx,
|
entity_idx=e.entity_idx,
|
||||||
name=e.name,
|
name=e.name,
|
||||||
type=e.type,
|
type=e.type,
|
||||||
|
type_description=getattr(e, "type_description", ""),
|
||||||
description=e.description,
|
description=e.description,
|
||||||
is_explicit_memory=getattr(e, "is_explicit_memory", False),
|
is_explicit_memory=getattr(e, "is_explicit_memory", False),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
|||||||
THEN entity.expired_at ELSE e.expired_at END,
|
THEN entity.expired_at ELSE e.expired_at END,
|
||||||
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
|
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
|
||||||
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
|
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
|
||||||
|
e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END,
|
||||||
e.description = CASE
|
e.description = CASE
|
||||||
WHEN entity.description IS NOT NULL AND entity.description <> ''
|
WHEN entity.description IS NOT NULL AND entity.description <> ''
|
||||||
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
|
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
|
||||||
@@ -147,6 +148,7 @@ MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id})
|
|||||||
// Avoid duplicate edges across runs for the same endpoints
|
// Avoid duplicate edges across runs for the same endpoints
|
||||||
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object)
|
||||||
SET r.predicate = rel.predicate,
|
SET r.predicate = rel.predicate,
|
||||||
|
r.predicate_description = rel.predicate_description,
|
||||||
r.statement_id = rel.statement_id,
|
r.statement_id = rel.statement_id,
|
||||||
r.value = rel.value,
|
r.value = rel.value,
|
||||||
r.statement = rel.statement,
|
r.statement = rel.statement,
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ async def save_entities_and_relationships(
|
|||||||
'source_id': edge.source,
|
'source_id': edge.source,
|
||||||
'target_id': edge.target,
|
'target_id': edge.target,
|
||||||
'predicate': edge.relation_type,
|
'predicate': edge.relation_type,
|
||||||
|
'predicate_description': edge.relation_type_description,
|
||||||
'statement_id': edge.source_statement_id,
|
'statement_id': edge.source_statement_id,
|
||||||
'value': edge.relation_value,
|
'value': edge.relation_value,
|
||||||
'statement': edge.statement,
|
'statement': edge.statement,
|
||||||
@@ -297,6 +298,7 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
'source_id': edge.source,
|
'source_id': edge.source,
|
||||||
'target_id': edge.target,
|
'target_id': edge.target,
|
||||||
'predicate': edge.relation_type,
|
'predicate': edge.relation_type,
|
||||||
|
'predicate_description': edge.relation_type_description,
|
||||||
'statement_id': edge.source_statement_id,
|
'statement_id': edge.source_statement_id,
|
||||||
'value': edge.relation_value,
|
'value': edge.relation_value,
|
||||||
'statement': edge.statement,
|
'statement': edge.statement,
|
||||||
|
|||||||
@@ -441,21 +441,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
with open(result_path, "r", encoding="utf-8") as rf:
|
with open(result_path, "r", encoding="utf-8") as rf:
|
||||||
extracted_result = json.load(rf)
|
extracted_result = json.load(rf)
|
||||||
|
|
||||||
# 步骤 6: 计算本体覆盖率并合并到结果中
|
# 步骤 6: 组装结果(试运行不做额外覆盖率后处理)
|
||||||
result_data = {
|
result_data = {
|
||||||
"config_id": cid,
|
"config_id": cid,
|
||||||
"time_log": os.path.join(project_root, "logs", "time.log"),
|
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||||||
"extracted_result": extracted_result,
|
"extracted_result": extracted_result,
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
ontology_coverage = await self._compute_ontology_coverage(
|
|
||||||
extracted_result=extracted_result,
|
|
||||||
memory_config=memory_config,
|
|
||||||
)
|
|
||||||
if ontology_coverage:
|
|
||||||
result_data["ontology_coverage"] = ontology_coverage
|
|
||||||
except Exception as cov_err:
|
|
||||||
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
|
|
||||||
|
|
||||||
yield format_sse_message("result", result_data)
|
yield format_sse_message("result", result_data)
|
||||||
|
|
||||||
@@ -479,100 +470,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"time": int(time.time() * 1000)
|
"time": int(time.time() * 1000)
|
||||||
})
|
})
|
||||||
|
|
||||||
async def _compute_ontology_coverage(
|
|
||||||
self,
|
|
||||||
extracted_result: Dict[str, Any],
|
|
||||||
memory_config,
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
|
|
||||||
|
|
||||||
分类规则(互斥):场景类型优先 > 通用类型 > 未匹配
|
|
||||||
确保: 场景实体数 + 通用实体数 + 未匹配数 = 总实体数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含三部分统计的字典,或 None(无实体数据时)
|
|
||||||
"""
|
|
||||||
core_entities = extracted_result.get("core_entities", [])
|
|
||||||
if not core_entities:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 1. 加载场景本体类型集合
|
|
||||||
scene_ontology_types: set = set()
|
|
||||||
try:
|
|
||||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
|
||||||
|
|
||||||
if memory_config.scene_id:
|
|
||||||
class_repo = OntologyClassRepository(self.db)
|
|
||||||
ontology_classes = class_repo.get_classes_by_scene(memory_config.scene_id)
|
|
||||||
scene_ontology_types = {oc.class_name for oc in ontology_classes}
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load scene ontology types: {e}")
|
|
||||||
|
|
||||||
# 2. 加载通用本体类型集合
|
|
||||||
general_ontology_types: set = set()
|
|
||||||
try:
|
|
||||||
from app.core.memory.ontology_services.ontology_type_loader import (
|
|
||||||
get_general_ontology_registry,
|
|
||||||
is_general_ontology_enabled,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_general_ontology_enabled():
|
|
||||||
registry = get_general_ontology_registry()
|
|
||||||
if registry:
|
|
||||||
general_ontology_types = set(registry.types.keys())
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load general ontology types: {e}")
|
|
||||||
|
|
||||||
# 3. 互斥分类:场景优先 > 通用 > 未匹配
|
|
||||||
scene_distribution: list = []
|
|
||||||
general_distribution: list = []
|
|
||||||
unmatched_distribution: list = []
|
|
||||||
scene_total = 0
|
|
||||||
general_total = 0
|
|
||||||
unmatched_total = 0
|
|
||||||
|
|
||||||
for item in core_entities:
|
|
||||||
entity_type = item.get("type", "")
|
|
||||||
count = item.get("count", 0)
|
|
||||||
|
|
||||||
if entity_type in scene_ontology_types:
|
|
||||||
scene_distribution.append({"type": entity_type, "count": count})
|
|
||||||
scene_total += count
|
|
||||||
elif entity_type in general_ontology_types:
|
|
||||||
general_distribution.append({"type": entity_type, "count": count})
|
|
||||||
general_total += count
|
|
||||||
else:
|
|
||||||
unmatched_distribution.append({"type": entity_type, "count": count})
|
|
||||||
unmatched_total += count
|
|
||||||
|
|
||||||
# 按数量降序排列
|
|
||||||
scene_distribution.sort(key=lambda x: x["count"], reverse=True)
|
|
||||||
general_distribution.sort(key=lambda x: x["count"], reverse=True)
|
|
||||||
unmatched_distribution.sort(key=lambda x: x["count"], reverse=True)
|
|
||||||
|
|
||||||
total_entities = scene_total + general_total + unmatched_total
|
|
||||||
|
|
||||||
return {
|
|
||||||
"scene_type_distribution": {
|
|
||||||
"type_count": len(scene_distribution),
|
|
||||||
"entity_total": scene_total,
|
|
||||||
"types": scene_distribution,
|
|
||||||
},
|
|
||||||
"general_type_distribution": {
|
|
||||||
"type_count": len(general_distribution),
|
|
||||||
"entity_total": general_total,
|
|
||||||
"types": general_distribution,
|
|
||||||
},
|
|
||||||
"unmatched": {
|
|
||||||
"type_count": len(unmatched_distribution),
|
|
||||||
"entity_total": unmatched_total,
|
|
||||||
"types": unmatched_distribution,
|
|
||||||
},
|
|
||||||
"total_entities": total_entities,
|
|
||||||
"time": int(time.time() * 1000),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Awaitable, Callable, Optional
|
from typing import Awaitable, Callable, Optional
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_memory_logger, log_time
|
from app.core.logging_config import get_memory_logger, log_time
|
||||||
|
from app.core.memory.pipelines.pilot_write_pipeline import PilotWritePipeline
|
||||||
from app.core.memory.models.message_models import (
|
from app.core.memory.models.message_models import (
|
||||||
ConversationContext,
|
ConversationContext,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
@@ -20,9 +22,11 @@ from app.core.memory.storage_services.extraction_engine.extraction_orchestrator
|
|||||||
ExtractionOrchestrator,
|
ExtractionOrchestrator,
|
||||||
get_chunked_dialogs_from_preprocessed,
|
get_chunked_dialogs_from_preprocessed,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.config.config_utils import (
|
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
||||||
get_pipeline_config,
|
_write_extracted_result_summary,
|
||||||
|
export_test_input_doc,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
@@ -31,6 +35,42 @@ from sqlalchemy.orm import Session
|
|||||||
logger = get_memory_logger(__name__)
|
logger = get_memory_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_triplets_from_dialogs(dialog_data_list: list[DialogData], output_path: str) -> None:
|
||||||
|
"""Write triplet/entity text report compatible with pipeline_help parsers."""
|
||||||
|
all_triplets = []
|
||||||
|
all_entities = []
|
||||||
|
|
||||||
|
for dialog in dialog_data_list:
|
||||||
|
for chunk in getattr(dialog, "chunks", []) or []:
|
||||||
|
for statement in getattr(chunk, "statements", []) or []:
|
||||||
|
triplet_info = getattr(statement, "triplet_extraction_info", None)
|
||||||
|
if not triplet_info:
|
||||||
|
continue
|
||||||
|
all_triplets.extend(getattr(triplet_info, "triplets", []) or [])
|
||||||
|
all_entities.extend(getattr(triplet_info, "entities", []) or [])
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n")
|
||||||
|
for i, triplet in enumerate(all_triplets, 1):
|
||||||
|
f.write(f"Triplet {i}:\n")
|
||||||
|
f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n")
|
||||||
|
f.write(f" Predicate: {triplet.predicate}\n")
|
||||||
|
f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n")
|
||||||
|
value = getattr(triplet, "value", None)
|
||||||
|
if value:
|
||||||
|
f.write(f" Value: {value}\n")
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n")
|
||||||
|
for i, entity in enumerate(all_entities, 1):
|
||||||
|
f.write(f"Entity {i}:\n")
|
||||||
|
f.write(f" ID: {entity.entity_idx}\n")
|
||||||
|
f.write(f" Name: {entity.name}\n")
|
||||||
|
f.write(f" Type: {entity.type}\n")
|
||||||
|
f.write(f" Description: {entity.description}\n")
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
async def run_pilot_extraction(
|
async def run_pilot_extraction(
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
dialogue_text: str,
|
dialogue_text: str,
|
||||||
@@ -58,7 +98,6 @@ async def run_pilot_extraction(
|
|||||||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||||||
|
|
||||||
pipeline_start = time.time()
|
pipeline_start = time.time()
|
||||||
neo4j_connector = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 步骤 1: 初始化客户端
|
# 步骤 1: 初始化客户端
|
||||||
@@ -69,8 +108,6 @@ async def run_pilot_extraction(
|
|||||||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||||||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||||||
|
|
||||||
neo4j_connector = Neo4jConnector()
|
|
||||||
|
|
||||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||||
|
|
||||||
# 步骤 2: 解析对话文本
|
# 步骤 2: 解析对话文本
|
||||||
@@ -242,15 +279,17 @@ async def run_pilot_extraction(
|
|||||||
|
|
||||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||||
|
|
||||||
# 步骤 3: 初始化流水线编排器
|
# 步骤 3: 初始化并选择试运行流水线(环境变量可切换)
|
||||||
logger.info("Initializing extraction orchestrator...")
|
use_refactored = bool(settings.PILOT_RUN_USE_REFACTORED_PIPELINE)
|
||||||
step_start = time.time()
|
|
||||||
|
|
||||||
config = get_pipeline_config(memory_config)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
"Selecting pilot pipeline by env: PILOT_RUN_USE_REFACTORED_PIPELINE=%s",
|
||||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
use_refactored,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"Initializing %s pilot pipeline...",
|
||||||
|
"refactored" if use_refactored else "legacy",
|
||||||
|
)
|
||||||
|
step_start = time.time()
|
||||||
|
|
||||||
# 加载本体类型(如果配置了 scene_id),支持通用类型回退
|
# 加载本体类型(如果配置了 scene_id),支持通用类型回退
|
||||||
ontology_types = None
|
ontology_types = None
|
||||||
@@ -266,100 +305,105 @@ async def run_pilot_extraction(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
|
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
|
||||||
|
|
||||||
orchestrator = ExtractionOrchestrator(
|
if use_refactored:
|
||||||
llm_client=llm_client,
|
pilot_pipeline = PilotWritePipeline(
|
||||||
embedder_client=embedder_client,
|
llm_client=llm_client,
|
||||||
connector=neo4j_connector,
|
embedder_client=embedder_client,
|
||||||
config=config,
|
pipeline_config=get_pipeline_config(memory_config),
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback,
|
||||||
embedding_id=str(memory_config.embedding_model_id),
|
embedding_id=str(memory_config.embedding_model_id),
|
||||||
language=language,
|
language=language,
|
||||||
ontology_types=ontology_types,
|
ontology_types=ontology_types,
|
||||||
)
|
)
|
||||||
|
log_time("Pilot Pipeline Initialization", time.time() - step_start, log_file)
|
||||||
|
|
||||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
# 步骤 4a: 执行重构后试运行短链路
|
||||||
|
# statement -> triplet -> graph_build -> 第一层去重消歧(结束)
|
||||||
|
logger.info("Running refactored pilot extraction short pipeline...")
|
||||||
|
step_start = time.time()
|
||||||
|
|
||||||
# 步骤 4: 执行知识提取流水线
|
if progress_callback:
|
||||||
logger.info("Running extraction pipeline...")
|
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||||
step_start = time.time()
|
|
||||||
|
|
||||||
if progress_callback:
|
pilot_result = await pilot_pipeline.run(chunked_dialogs)
|
||||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
dialog_data_list = pilot_result.dialog_data_list
|
||||||
|
graph = pilot_result.graph
|
||||||
|
chunk_nodes = graph.chunk_nodes
|
||||||
|
export_entity_nodes = graph.entity_nodes
|
||||||
|
export_stmt_entity_edges = graph.stmt_entity_edges
|
||||||
|
export_entity_edges = graph.entity_entity_edges
|
||||||
|
else:
|
||||||
|
# 步骤 4b: 执行旧试运行流水线
|
||||||
|
logger.info("Running legacy pilot extraction pipeline...")
|
||||||
|
step_start = time.time()
|
||||||
|
|
||||||
extraction_result = await orchestrator.run(
|
if progress_callback:
|
||||||
dialog_data_list=chunked_dialogs,
|
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||||
is_pilot_run=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
neo4j_connector = Neo4jConnector()
|
||||||
(
|
try:
|
||||||
dialogue_nodes,
|
legacy_orchestrator = ExtractionOrchestrator(
|
||||||
chunk_nodes,
|
llm_client=llm_client,
|
||||||
statement_nodes,
|
embedder_client=embedder_client,
|
||||||
entity_nodes,
|
connector=neo4j_connector,
|
||||||
_,
|
config=get_pipeline_config(memory_config),
|
||||||
statement_chunk_edges,
|
progress_callback=progress_callback,
|
||||||
statement_entity_edges,
|
embedding_id=str(memory_config.embedding_model_id),
|
||||||
entity_edges,
|
language=language,
|
||||||
_,
|
ontology_types=ontology_types,
|
||||||
_
|
)
|
||||||
) = extraction_result
|
extraction_result = await legacy_orchestrator.run(
|
||||||
|
dialog_data_list=chunked_dialogs,
|
||||||
|
is_pilot_run=True,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
_dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
_statement_nodes,
|
||||||
|
entity_nodes,
|
||||||
|
_perceptual_nodes,
|
||||||
|
_statement_chunk_edges,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_edges,
|
||||||
|
_perceptual_edges,
|
||||||
|
_last_created_at,
|
||||||
|
) = extraction_result
|
||||||
|
dialog_data_list = chunked_dialogs
|
||||||
|
export_entity_nodes = entity_nodes
|
||||||
|
export_stmt_entity_edges = statement_entity_edges
|
||||||
|
export_entity_edges = entity_edges
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await neo4j_connector.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback("generating_results", "正在生成结果...")
|
await progress_callback("generating_results", "正在生成结果...")
|
||||||
|
|
||||||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
# 步骤 5: 输出试运行结果文件(保持 /pilot_run 返回契约)
|
||||||
try:
|
settings.ensure_memory_output_dir()
|
||||||
logger.info("Generating memory summaries...")
|
export_test_input_doc(
|
||||||
step_start = time.time()
|
entity_nodes=export_entity_nodes,
|
||||||
|
statement_entity_edges=export_stmt_entity_edges,
|
||||||
|
entity_entity_edges=export_entity_edges,
|
||||||
|
)
|
||||||
|
_save_triplets_from_dialogs(
|
||||||
|
dialog_data_list=dialog_data_list,
|
||||||
|
output_path=settings.get_memory_output_path("extracted_triplets.txt"),
|
||||||
|
)
|
||||||
|
_write_extracted_result_summary(
|
||||||
|
chunk_nodes=chunk_nodes,
|
||||||
|
pipeline_output_dir=settings.get_memory_output_path(),
|
||||||
|
)
|
||||||
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
logger.info("Pilot run completed: stop after layer-1 dedup (no layer-2 / no Neo4j write)")
|
||||||
memory_summary_generation,
|
|
||||||
)
|
|
||||||
|
|
||||||
summaries = await memory_summary_generation(
|
|
||||||
chunked_dialogs,
|
|
||||||
llm_client=llm_client,
|
|
||||||
embedder_client=embedder_client,
|
|
||||||
language=language,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
|
||||||
|
|
||||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
|
||||||
|
|
||||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
|
||||||
try:
|
|
||||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
|
||||||
|
|
||||||
stats_to_cache = {
|
|
||||||
"chunk_count": len(chunk_nodes) if chunk_nodes else 0,
|
|
||||||
"statements_count": len(statement_nodes) if statement_nodes else 0,
|
|
||||||
"triplet_entities_count": len(entity_nodes) if entity_nodes else 0,
|
|
||||||
"triplet_relations_count": len(entity_edges) if entity_edges else 0,
|
|
||||||
"temporal_count": 0, # temporal 数据在日志中,此处暂置0
|
|
||||||
}
|
|
||||||
await ActivityStatsCache.set_activity_stats(
|
|
||||||
workspace_id=str(memory_config.workspace_id),
|
|
||||||
stats=stats_to_cache,
|
|
||||||
)
|
|
||||||
logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
|
||||||
except Exception as cache_err:
|
|
||||||
logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
if neo4j_connector:
|
|
||||||
try:
|
|
||||||
await neo4j_connector.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
total_time = time.time() - pipeline_start
|
total_time = time.time() - pipeline_start
|
||||||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||||||
|
|||||||
@@ -1382,6 +1382,7 @@ def extract_emotion_batch_task(
|
|||||||
llm_model_id: str,
|
llm_model_id: str,
|
||||||
language: str = "zh",
|
language: str = "zh",
|
||||||
emotion_config: Optional[Dict[str, Any]] = None,
|
emotion_config: Optional[Dict[str, Any]] = None,
|
||||||
|
snapshot_dir: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Celery task: batch emotion extraction + Neo4j backfill.
|
"""Celery task: batch emotion extraction + Neo4j backfill.
|
||||||
|
|
||||||
@@ -1395,6 +1396,10 @@ def extract_emotion_batch_task(
|
|||||||
language: Language code ("zh" / "en").
|
language: Language code ("zh" / "en").
|
||||||
emotion_config: Optional dict with emotion step config overrides
|
emotion_config: Optional dict with emotion step config overrides
|
||||||
(emotion_extract_keywords, emotion_enable_subject).
|
(emotion_extract_keywords, emotion_enable_subject).
|
||||||
|
snapshot_dir: Optional absolute path of the current run's snapshot directory.
|
||||||
|
When provided (only in debug mode), emotion outputs will be
|
||||||
|
dumped to <snapshot_dir>/4_emotion_outputs.json for offline
|
||||||
|
comparison between the legacy / new pipelines.
|
||||||
"""
|
"""
|
||||||
task_id = self.request.id
|
task_id = self.request.id
|
||||||
total = len(statements)
|
total = len(statements)
|
||||||
@@ -1445,6 +1450,8 @@ def extract_emotion_batch_task(
|
|||||||
extracted = 0
|
extracted = 0
|
||||||
failed = 0
|
failed = 0
|
||||||
update_items = []
|
update_items = []
|
||||||
|
# 快照用:收集每条 statement 的 EmotionStepOutput(仅当 snapshot_dir 非空时使用)
|
||||||
|
snapshot_outputs: Dict[str, Any] = {} if snapshot_dir else None # type: ignore[assignment]
|
||||||
|
|
||||||
async def _extract_one(stmt_dict: Dict[str, str]):
|
async def _extract_one(stmt_dict: Dict[str, str]):
|
||||||
nonlocal extracted, failed
|
nonlocal extracted, failed
|
||||||
@@ -1461,6 +1468,8 @@ def extract_emotion_batch_task(
|
|||||||
"emotion_intensity": result.emotion_intensity,
|
"emotion_intensity": result.emotion_intensity,
|
||||||
"emotion_keywords": result.emotion_keywords,
|
"emotion_keywords": result.emotion_keywords,
|
||||||
})
|
})
|
||||||
|
if snapshot_outputs is not None:
|
||||||
|
snapshot_outputs[stmt_dict["statement_id"]] = result.model_dump()
|
||||||
extracted += 1
|
extracted += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[Emotion] 单条提取完成: stmt={stmt_dict['statement_id']}, "
|
f"[Emotion] 单条提取完成: stmt={stmt_dict['statement_id']}, "
|
||||||
@@ -1468,12 +1477,33 @@ def extract_emotion_batch_task(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed += 1
|
failed += 1
|
||||||
|
if snapshot_outputs is not None:
|
||||||
|
snapshot_outputs[stmt_dict["statement_id"]] = {"error": str(e)}
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[Emotion] 单条提取失败 stmt={stmt_dict['statement_id']}: {e}"
|
f"[Emotion] 单条提取失败 stmt={stmt_dict['statement_id']}: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.gather(*[_extract_one(s) for s in statements])
|
await asyncio.gather(*[_extract_one(s) for s in statements])
|
||||||
|
|
||||||
|
# 快照落盘(worker 端):不影响 Neo4j 写入流程,失败只打日志
|
||||||
|
if snapshot_outputs is not None:
|
||||||
|
try:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
_dir = _Path(snapshot_dir)
|
||||||
|
_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
_path = _dir / "4_emotion_outputs.json"
|
||||||
|
with open(_path, "w", encoding="utf-8") as _f:
|
||||||
|
_json.dump(snapshot_outputs, _f, ensure_ascii=False, indent=2, default=str)
|
||||||
|
logger.info(
|
||||||
|
f"[Emotion][Snapshot] 已落盘 {len(snapshot_outputs)} 条情绪结果 → {_path}"
|
||||||
|
)
|
||||||
|
except Exception as _e:
|
||||||
|
logger.warning(
|
||||||
|
f"[Emotion][Snapshot] 快照落盘失败(不影响主流程): {_e}"
|
||||||
|
)
|
||||||
|
|
||||||
# Batch update Neo4j via write transaction
|
# Batch update Neo4j via write transaction
|
||||||
if update_items:
|
if update_items:
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
|
|||||||
Reference in New Issue
Block a user