refactor(memory): remove expired_at field and add dialog_at timestamp
Remove the deprecated expired_at field from all graph models, Neo4j Cypher queries, repositories, and pipeline code. Replace with dialog_at on StatementNode to track the original dialog timestamp. - Strip expired_at from DialogueNode, ChunkNode, StatementNode, ExtractedEntityNode, edges, and all Cypher queries - Add dialog_at to MessageItem schema and propagate through extraction and graph build steps - Extract emotion/metadata async submission from WritePipeline into a generic _submit_celery_task helper - Add post_store_dedup_and_alias_merge Celery task for async alias merging and second-layer dedup after Neo4j write - Switch pytest async backend from anyio to asyncio_mode=auto
This commit is contained in:
@@ -252,7 +252,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||
|
||||
@@ -16,7 +16,7 @@ logger = get_agent_logger(__name__)
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
_EXPAND_FIELDS_TO_REMOVE = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ async def get_chunked_dialogs(
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
messages: Structured message list [{"role": "user", "content": "...", "dialog_at": "..."}]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
||||
@@ -47,7 +47,12 @@ async def get_chunked_dialogs(
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||
conversation_messages.append(ConversationMessage(
|
||||
role=role,
|
||||
msg=content.strip(),
|
||||
dialog_at=msg.get("dialog_at"),
|
||||
files=files,
|
||||
))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
@@ -57,7 +62,7 @@ async def get_chunked_dialogs(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
# step2: 语义剪枝步骤(在分块之前)
|
||||
|
||||
@@ -242,6 +242,7 @@ class ChunkerClient:
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {sub_chunk_text}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
dialog_at=getattr(msg, "dialog_at", None),
|
||||
metadata={
|
||||
"message_index": msg_idx,
|
||||
"message_role": msg.role,
|
||||
@@ -257,6 +258,7 @@ class ChunkerClient:
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {msg_content}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
dialog_at=getattr(msg, "dialog_at", None),
|
||||
metadata={
|
||||
"message_index": msg_idx,
|
||||
"message_role": msg.role,
|
||||
|
||||
@@ -62,7 +62,7 @@ class MemoryService:
|
||||
"""写入记忆:对话 → 萃取 → 存储 → 聚类 → 摘要
|
||||
|
||||
Args:
|
||||
messages: 结构化消息 [{"role": "user"/"assistant", "content": "..."}]
|
||||
messages: 结构化消息 [{"role": "user"/"assistant", "content": "...", "dialog_at": "..."}]
|
||||
language: 语言 ("zh" | "en")
|
||||
ref_id: 引用 ID,为空则自动生成
|
||||
is_pilot_run: 试运行模式(只萃取不写入)
|
||||
|
||||
@@ -106,7 +106,6 @@ class Edge(BaseModel):
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this edge
|
||||
created_at: Timestamp when the edge was created (system perspective)
|
||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||
source: str = Field(..., description="The ID of the source node.")
|
||||
@@ -114,7 +113,6 @@ class Edge(BaseModel):
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||
|
||||
|
||||
class ChunkEdge(Edge):
|
||||
@@ -191,14 +189,12 @@ class Node(BaseModel):
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this node
|
||||
created_at: Timestamp when the node was created (system perspective)
|
||||
expired_at: Optional timestamp when the node expires (system perspective)
|
||||
"""
|
||||
id: str = Field(..., description="The unique identifier for the node.")
|
||||
name: str = Field(..., description="The name of the node.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||
|
||||
|
||||
class DialogueNode(Node):
|
||||
@@ -284,6 +280,7 @@ class StatementNode(Node):
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
dialog_at: Optional[datetime] = Field(None, description="Absolute timestamp of the conversation this statement belongs to")
|
||||
|
||||
# Embedding and other fields
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
@@ -319,7 +316,7 @@ class StatementNode(Node):
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@field_validator('valid_at', 'invalid_at', 'dialog_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
|
||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of this message (ISO 8601).")
|
||||
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||
|
||||
|
||||
@@ -100,6 +101,7 @@ class Statement(BaseModel):
|
||||
False,
|
||||
description="Whether the statement reflects user's emotional state",
|
||||
)
|
||||
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
@@ -139,6 +141,7 @@ class Chunk(BaseModel):
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
@@ -155,6 +158,7 @@ class Chunk(BaseModel):
|
||||
return cls(
|
||||
content=f"{message.role}: {message.msg}",
|
||||
speaker=message.role,
|
||||
dialog_at=message.dialog_at,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
@@ -169,7 +173,6 @@ class DialogData(BaseModel):
|
||||
ref_id: Reference ID linking to external dialog system
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
created_at: Timestamp when the dialog was created
|
||||
expired_at: Timestamp when the dialog expires (default: far future)
|
||||
metadata: Additional metadata as key-value pairs
|
||||
chunks: List of chunks from the conversation
|
||||
config_id: Configuration ID used to process this dialog
|
||||
@@ -184,7 +187,6 @@ class DialogData(BaseModel):
|
||||
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
|
||||
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")
|
||||
|
||||
@@ -198,11 +198,13 @@ class WritePipeline:
|
||||
chunked_dialogs = await self._preprocess(messages, ref_id)
|
||||
s.metadata(chunks=sum(len(d.chunks) for d in chunked_dialogs))
|
||||
|
||||
# Step 2: 萃取 - 知识提取
|
||||
# Step 2: 萃取 - 知识提取 + 第一层去重 + 别名归并(内存侧)
|
||||
async with bear.step(2, 5, "萃取", "知识提取") as s:
|
||||
extraction_result = await self._extract(
|
||||
chunked_dialogs, is_pilot_run
|
||||
)
|
||||
# 别名归并(内存侧):在写入前完成,确保写入的数据已归并
|
||||
self._merge_alias_in_memory(extraction_result)
|
||||
stats = extraction_result.stats
|
||||
s.metadata(
|
||||
entities=stats["entity_count"],
|
||||
@@ -222,15 +224,8 @@ class WritePipeline:
|
||||
async with bear.step(3, 5, "存储", "写入 Neo4j"):
|
||||
await self._store(extraction_result)
|
||||
|
||||
# Step 3.2: 别名归并
|
||||
async with bear.step(3, 5, "别名归并", "处理别名属于关系"):
|
||||
await self._merge_alias_belongs_to(extraction_result)
|
||||
|
||||
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
||||
await self._extract_emotion(getattr(self, "_emotion_statements", []))
|
||||
|
||||
# Step 3.6: 异步元数据提取(fire-and-forget,需在 _store 之后确保 Entity 节点已存在)
|
||||
await self._extract_metadata(extraction_result)
|
||||
# Step 3.5: 异步后处理(别名归并 Neo4j 侧 + 第二层去重 + 情绪 + 元数据)
|
||||
await self._post_store_async_tasks(extraction_result)
|
||||
|
||||
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
|
||||
async with bear.step(4, 5, "聚类", "增量更新社区") as s:
|
||||
@@ -359,16 +354,17 @@ class WritePipeline:
|
||||
# Snapshot: 图节点和边(去重前)
|
||||
recorder.record_graph_before_dedup(graph)
|
||||
|
||||
# step3: 两阶段去重消歧
|
||||
# step3: 第一层去重消歧(同一轮对话内的实体碎片合并)
|
||||
# 第二层(Neo4j 联合去重)后移到 _store 之后异步执行
|
||||
dedup_result = await run_dedup(
|
||||
entity_nodes=graph.entity_nodes,
|
||||
statement_entity_edges=graph.stmt_entity_edges,
|
||||
entity_entity_edges=graph.entity_entity_edges,
|
||||
dialog_data_list=dialog_data_list,
|
||||
pipeline_config=pipeline_config,
|
||||
connector=self._neo4j_connector,
|
||||
connector=None,
|
||||
llm_client=self._llm_client,
|
||||
is_pilot_run=is_pilot_run,
|
||||
is_pilot_run=True,
|
||||
progress_callback=self.progress_callback,
|
||||
)
|
||||
|
||||
@@ -455,29 +451,21 @@ class WritePipeline:
|
||||
raise
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 3.2: 别名归并
|
||||
# Step 3.2: 别名归并(内存侧)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
async def _merge_alias_belongs_to(self, result: ExtractionResult) -> None:
|
||||
def _merge_alias_in_memory(self, result: ExtractionResult) -> None:
|
||||
"""别名归并(内存侧):处理 predicate="别名属于" 的边。
|
||||
|
||||
在写入 Neo4j 之前执行,确保写入的数据已经完成别名归并:
|
||||
- 将别名实体的 name 追加到目标实体的 aliases
|
||||
- 将别名实体的 description 拼接到目标实体的 description
|
||||
- 重定向指向别名节点的边到目标节点
|
||||
|
||||
纯内存操作,不涉及 Neo4j。
|
||||
"""
|
||||
所有去重合并都可以使用这个这种统一的处理方式(未实现)
|
||||
别名归并:处理 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边。
|
||||
|
||||
对每条 source -[EXTRACTED_RELATIONSHIP {predicate:"别名属于"}]-> target 边:
|
||||
- 将 source.name 追加到 target.aliases(去重)
|
||||
- 将 source.description append 进 target.description_list(list 形式)
|
||||
|
||||
同时在内存中同步更新 ExtractionResult.entity_nodes,保持内存与 Neo4j 一致。
|
||||
失败不中断主流程。
|
||||
"""
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
MERGE_ALIAS_BELONGS_TO,
|
||||
REDIRECT_ALIAS_EDGES,
|
||||
)
|
||||
|
||||
ALIAS_PREDICATE = "别名属于"
|
||||
|
||||
# 筛选出所有 predicate="别名属于" 的边
|
||||
alias_edges = [
|
||||
e
|
||||
for e in result.entity_entity_edges
|
||||
@@ -490,10 +478,7 @@ class WritePipeline:
|
||||
return
|
||||
|
||||
try:
|
||||
# ── 1. 在内存中同步更新 entity_nodes ──
|
||||
entity_map = {e.id: e for e in result.entity_nodes}
|
||||
|
||||
# 构建 alias_id → target_id 映射(别名节点 → 用户节点)
|
||||
alias_to_target: dict[str, str] = {}
|
||||
|
||||
for edge in alias_edges:
|
||||
@@ -513,7 +498,7 @@ class WritePipeline:
|
||||
source_name
|
||||
]
|
||||
|
||||
# 将 source.description append 进 target.description(追加,分号分隔)
|
||||
# 将 source.description 拼接到 target.description(分号分隔,去重)
|
||||
src_desc = (source_node.description or "").strip()
|
||||
if src_desc:
|
||||
tgt_desc = (target_node.description or "").strip()
|
||||
@@ -522,12 +507,11 @@ class WritePipeline:
|
||||
f"{tgt_desc};{src_desc}" if tgt_desc else src_desc
|
||||
)
|
||||
|
||||
# ── 1.1 内存中重定向指向别名节点的边到用户节点 ──
|
||||
# 重定向指向别名节点的边到目标节点
|
||||
alias_ids = set(alias_to_target.keys())
|
||||
redirected_ee_count = 0
|
||||
redirected_se_count = 0
|
||||
|
||||
# 重定向 entity_entity_edges(排除"别名属于"边本身)
|
||||
for edge in result.entity_entity_edges:
|
||||
rel_type = getattr(edge, "relation_type", "")
|
||||
if rel_type == ALIAS_PREDICATE:
|
||||
@@ -539,39 +523,101 @@ class WritePipeline:
|
||||
edge.target = alias_to_target[edge.target]
|
||||
redirected_ee_count += 1
|
||||
|
||||
# 重定向 stmt_entity_edges(陈述句 → 实体边)
|
||||
for edge in result.stmt_entity_edges:
|
||||
if edge.target in alias_ids:
|
||||
edge.target = alias_to_target[edge.target]
|
||||
redirected_se_count += 1
|
||||
|
||||
logger.info(
|
||||
f"[AliasMerge] 内存同步完成,处理 {len(alias_edges)} 条 '别名属于' 边,"
|
||||
f"[AliasMerge] 内存归并完成,处理 {len(alias_edges)} 条 '别名属于' 边,"
|
||||
f"重定向 entity_entity 边 {redirected_ee_count} 次,"
|
||||
f"重定向 stmt_entity 边 {redirected_se_count} 次"
|
||||
)
|
||||
|
||||
# ── 2. 写入 Neo4j:别名属性归并 ──
|
||||
records = await self._neo4j_connector.execute_query(
|
||||
MERGE_ALIAS_BELONGS_TO,
|
||||
end_user_id=self.end_user_id,
|
||||
)
|
||||
merged_count = len(records) if records else 0
|
||||
logger.info(f"[AliasMerge] Neo4j 别名归并完成,影响 {merged_count} 条记录")
|
||||
|
||||
# ── 3. 写入 Neo4j:重定向指向别名节点的边到用户节点 ──
|
||||
redirect_records = await self._neo4j_connector.execute_query(
|
||||
REDIRECT_ALIAS_EDGES,
|
||||
end_user_id=self.end_user_id,
|
||||
)
|
||||
redirect_count = len(redirect_records) if redirect_records else 0
|
||||
logger.info(
|
||||
f"[AliasMerge] Neo4j 边重定向完成,影响 {redirect_count} 条记录"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AliasMerge] 别名归并失败(不影响主流程): {e}", exc_info=True
|
||||
f"[AliasMerge] 内存归并失败(不影响主流程): {e}", exc_info=True
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 3.5: 异步后处理(Neo4j 别名归并 + 第二层去重)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
async def _post_store_async_tasks(self, result: ExtractionResult) -> None:
|
||||
"""提交写入后的异步 Celery 任务(全部 fire-and-forget,失败不影响主流程):
|
||||
|
||||
1. Neo4j 别名归并 + 第二层去重
|
||||
2. 异步情绪提取
|
||||
3. 异步元数据提取
|
||||
"""
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
|
||||
collect_user_entities_for_metadata,
|
||||
)
|
||||
|
||||
llm_model_id = (
|
||||
str(self.memory_config.llm_model_id)
|
||||
if self.memory_config.llm_model_id
|
||||
else None
|
||||
)
|
||||
recorder = getattr(self, "_recorder", None)
|
||||
snapshot_dir = (
|
||||
recorder.snapshot_dir
|
||||
if recorder is not None and recorder.enabled
|
||||
else None
|
||||
)
|
||||
|
||||
# ── 1. Neo4j 别名归并 + 第二层去重 ──
|
||||
self._submit_celery_task(
|
||||
"PostStore",
|
||||
"app.tasks.post_store_dedup_and_alias_merge",
|
||||
{
|
||||
"end_user_id": self.end_user_id,
|
||||
"entity_ids": [e.id for e in result.entity_nodes],
|
||||
"llm_model_id": llm_model_id,
|
||||
},
|
||||
)
|
||||
|
||||
# ── 2. 异步情绪提取 ──
|
||||
emotion_statements = getattr(self, "_emotion_statements", [])
|
||||
if emotion_statements and llm_model_id:
|
||||
self._submit_celery_task(
|
||||
"Emotion",
|
||||
"app.tasks.extract_emotion_batch",
|
||||
{
|
||||
"statements": emotion_statements,
|
||||
"llm_model_id": llm_model_id,
|
||||
"language": self.language,
|
||||
"snapshot_dir": snapshot_dir,
|
||||
},
|
||||
)
|
||||
|
||||
# ── 3. 异步元数据提取 ──
|
||||
user_entities = collect_user_entities_for_metadata(result.entity_nodes)
|
||||
if user_entities and llm_model_id:
|
||||
self._submit_celery_task(
|
||||
"Metadata",
|
||||
"app.tasks.extract_metadata_batch",
|
||||
{
|
||||
"user_entities": user_entities,
|
||||
"llm_model_id": llm_model_id,
|
||||
"language": self.language,
|
||||
"snapshot_dir": snapshot_dir,
|
||||
},
|
||||
)
|
||||
|
||||
def _submit_celery_task(
|
||||
self, label: str, task_name: str, kwargs: dict
|
||||
) -> None:
|
||||
"""提交 Celery 异步任务的通用方法。失败只记日志,不抛异常。"""
|
||||
try:
|
||||
from app.celery_app import celery_app
|
||||
|
||||
task_result = celery_app.send_task(task_name, kwargs=kwargs)
|
||||
logger.info(f"[{label}] 异步任务已提交 - task_id={task_result.id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[{label}] 提交异步任务失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
@@ -625,127 +671,6 @@ class WritePipeline:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 4.5: 异步情绪提取
|
||||
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
async def _extract_emotion(self, emotion_statements: list) -> None:
|
||||
"""提交异步情绪提取 Celery 任务。
|
||||
|
||||
从编排器收集的 user statement 列表中提取情绪,
|
||||
异步回写到 Neo4j Statement 节点。失败不影响主流程。
|
||||
|
||||
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
|
||||
通过 snapshot_dir 透传给 Celery 任务;worker 端在完成 LLM 抽取后,
|
||||
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json,避免主进程重复调用 LLM。
|
||||
"""
|
||||
if not emotion_statements:
|
||||
return
|
||||
|
||||
llm_model_id = (
|
||||
str(self.memory_config.llm_model_id)
|
||||
if self.memory_config.llm_model_id
|
||||
else None
|
||||
)
|
||||
if not llm_model_id:
|
||||
logger.warning("[Emotion] 无法提交情绪提取任务:llm_model_id 为空")
|
||||
return
|
||||
|
||||
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
|
||||
recorder = getattr(self, "_recorder", None)
|
||||
snapshot_dir = (
|
||||
recorder.snapshot_dir
|
||||
if recorder is not None and recorder.enabled
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
from app.celery_app import celery_app
|
||||
|
||||
result = celery_app.send_task(
|
||||
"app.tasks.extract_emotion_batch",
|
||||
kwargs={
|
||||
"statements": emotion_statements,
|
||||
"llm_model_id": llm_model_id,
|
||||
"language": self.language,
|
||||
"snapshot_dir": snapshot_dir,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"[Emotion] 异步情绪提取任务已提交 - "
|
||||
f"task_id = {result.id}, "
|
||||
f"statement_count = {len(emotion_statements)}, "
|
||||
f"snapshot_dir = {snapshot_dir}, "
|
||||
f"source=async"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 3.6: 异步元数据提取
|
||||
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
async def _extract_metadata(self, result: ExtractionResult) -> None:
|
||||
"""提交异步元数据提取 Celery 任务。
|
||||
|
||||
从去重后的用户实体 description 中提取结构化元数据,
|
||||
异步回写到 Neo4j ExtractedEntity 节点。失败不影响主流程。
|
||||
"""
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
|
||||
collect_user_entities_for_metadata,
|
||||
)
|
||||
|
||||
user_entities = collect_user_entities_for_metadata(result.entity_nodes)
|
||||
|
||||
if not user_entities:
|
||||
return
|
||||
|
||||
llm_model_id = (
|
||||
str(self.memory_config.llm_model_id)
|
||||
if self.memory_config.llm_model_id
|
||||
else None
|
||||
)
|
||||
if not llm_model_id:
|
||||
logger.warning("[Metadata] 无法提交元数据提取任务:llm_model_id 为空")
|
||||
return
|
||||
|
||||
# 快照目录
|
||||
recorder = getattr(self, "_recorder", None)
|
||||
snapshot_dir = (
|
||||
recorder.snapshot_dir
|
||||
if recorder is not None and recorder.enabled
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
from app.celery_app import celery_app
|
||||
|
||||
task_result = celery_app.send_task(
|
||||
"app.tasks.extract_metadata_batch",
|
||||
kwargs={
|
||||
"user_entities": user_entities,
|
||||
"llm_model_id": llm_model_id,
|
||||
"language": self.language,
|
||||
"snapshot_dir": snapshot_dir,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"[Metadata] 异步元数据提取任务已提交 - "
|
||||
f"task_id = {task_result.id}, "
|
||||
f"entity_count = {len(user_entities)}, "
|
||||
f"snapshot_dir = {snapshot_dir}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Metadata] 提交元数据提取任务失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Step 5: 摘要
|
||||
# (+ entity_description)+ meta_data部分在此提取
|
||||
|
||||
@@ -183,14 +183,8 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
|
||||
# 时间范围合并
|
||||
try:
|
||||
# 统一使用 created_at / expired_at
|
||||
if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at:
|
||||
canonical.created_at = ent.created_at
|
||||
if getattr(ent, "expired_at", None) and getattr(canonical, "expired_at", None):
|
||||
if canonical.expired_at is None:
|
||||
canonical.expired_at = ent.expired_at
|
||||
elif ent.expired_at and ent.expired_at > canonical.expired_at:
|
||||
canonical.expired_at = ent.expired_at
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -65,7 +65,6 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
user_id=row.get("user_id") or "",
|
||||
apply_id=row.get("apply_id") or "",
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
expired_at=_parse_dt(row.get("expired_at")) if row.get("expired_at") else None,
|
||||
entity_idx=int(row.get("entity_idx") or 0),
|
||||
statement_id=row.get("statement_id") or "",
|
||||
entity_type=row.get("entity_type") or "",
|
||||
|
||||
@@ -1089,7 +1089,6 @@ class ExtractionOrchestrator:
|
||||
content=dialog_data.context.content if dialog_data.context else "",
|
||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
metadata=dialog_data.metadata,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||
)
|
||||
@@ -1109,7 +1108,6 @@ class ExtractionOrchestrator:
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
chunk_nodes.append(chunk_node)
|
||||
@@ -1175,7 +1173,6 @@ class ExtractionOrchestrator:
|
||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
|
||||
'temporal_validity') and statement.temporal_validity else None,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||
# Emotion fields
|
||||
emotion_type=getattr(statement, 'emotion_type', None),
|
||||
@@ -1232,7 +1229,6 @@ class ExtractionOrchestrator:
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||
)
|
||||
entity_nodes.append(entity_node)
|
||||
@@ -1269,7 +1265,6 @@ class ExtractionOrchestrator:
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
valid_at=_tv.valid_at if _tv else None,
|
||||
invalid_at=_tv.invalid_at if _tv else None,
|
||||
)
|
||||
|
||||
@@ -912,6 +912,7 @@ class NewExtractionOrchestrator:
|
||||
has_emotional_state=stmt_out.has_emotional_state,
|
||||
triplet_extraction_info=triplet_info,
|
||||
statement_embedding=stmt_embedding,
|
||||
dialog_at=getattr(chunk, "dialog_at", None),
|
||||
**emotion_kwargs,
|
||||
)
|
||||
new_statements.append(stmt)
|
||||
|
||||
@@ -215,7 +215,6 @@ async def _process_chunk_summary(
|
||||
apply_id=dialog.end_user_id,
|
||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||
created_at=datetime.now(),
|
||||
expired_at=datetime(9999, 12, 31),
|
||||
dialog_id=dialog.id,
|
||||
chunk_ids=[chunk.id],
|
||||
content=summary_text,
|
||||
|
||||
@@ -181,6 +181,7 @@ class StatementExtractor:
|
||||
chunk_id=chunk.id,
|
||||
end_user_id=end_user_id,
|
||||
speaker=chunk_speaker,
|
||||
dialog_at=getattr(chunk, "dialog_at", None),
|
||||
has_unsolved_reference=getattr(extracted_stmt, "has_unsolved_reference", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -135,7 +135,6 @@ async def build_graph_nodes_and_edges(
|
||||
content=dialog_data.context.content if dialog_data.context else "",
|
||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
metadata=dialog_data.metadata,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||
)
|
||||
@@ -154,7 +153,6 @@ async def build_graph_nodes_and_edges(
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
sequence_number=chunk_idx,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
metadata=chunk.metadata,
|
||||
)
|
||||
chunk_nodes.append(chunk_node)
|
||||
@@ -227,7 +225,7 @@ async def build_graph_nodes_and_edges(
|
||||
else None
|
||||
),
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
dialog_at=getattr(statement, "dialog_at", None),
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||
emotion_type=getattr(statement, "emotion_type", None),
|
||||
emotion_intensity=getattr(statement, "emotion_intensity", None),
|
||||
@@ -280,7 +278,6 @@ async def build_graph_nodes_and_edges(
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||
)
|
||||
entity_nodes.append(entity_node)
|
||||
@@ -320,7 +317,6 @@ async def build_graph_nodes_and_edges(
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
valid_at=_tv.valid_at if _tv else None,
|
||||
invalid_at=_tv.invalid_at if _tv else None,
|
||||
)
|
||||
@@ -382,7 +378,6 @@ async def build_graph_nodes_and_edges(
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
pair_id=pair_id,
|
||||
dialog_id=dialog_data.id,
|
||||
text=record["original_text"],
|
||||
@@ -408,7 +403,6 @@ async def build_graph_nodes_and_edges(
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id,
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
pair_id=pair_id,
|
||||
dialog_id=dialog_data.id,
|
||||
text=record["pruned_text"],
|
||||
|
||||
@@ -483,7 +483,7 @@ class ReflectionEngine:
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
REMOVE_KEYS = {"created_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
# Clean conflict_data, and memory_verify and quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
|
||||
@@ -26,7 +26,6 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
||||
"end_user_id",
|
||||
"chunk_id",
|
||||
"created_at",
|
||||
"expired_at",
|
||||
"valid_at",
|
||||
"invalid_at",
|
||||
]
|
||||
@@ -93,7 +92,6 @@ async def get_data(result):
|
||||
rel_filtered['run_id'] = value.get('run_id')
|
||||
rel_filtered['statement'] = value.get('statement')
|
||||
rel_filtered['statement_id'] = value.get('statement_id')
|
||||
rel_filtered['expired_at'] = value.get('expired_at')
|
||||
rel_filtered['created_at'] = value.get('created_at')
|
||||
filtered_item[key] = value
|
||||
elif key == 'entity2' and value is not None:
|
||||
|
||||
Reference in New Issue
Block a user