refactor(memory): remove expired_at field and add dialog_at timestamp

Remove the deprecated expired_at field from all graph models, Neo4j
Cypher queries, repositories, and pipeline code. Replace with dialog_at
on StatementNode to track the original dialog timestamp.

- Strip expired_at from DialogueNode, ChunkNode, StatementNode,
  ExtractedEntityNode, edges, and all Cypher queries
- Add dialog_at to MessageItem schema and propagate through extraction
  and graph build steps
- Extract emotion/metadata async submission from WritePipeline into
  a generic _submit_celery_task helper
- Add post_store_dedup_and_alias_merge Celery task for async alias
  merging and second-layer dedup after Neo4j write
- Switch pytest async backend from anyio to asyncio_mode=auto
This commit is contained in:
lanceyq
2026-04-30 12:20:47 +08:00
parent d66d601e41
commit cf389bb978
30 changed files with 531 additions and 278 deletions

View File

@@ -117,6 +117,12 @@ celery_app.conf.update(
# Async emotion extraction → memory_tasks queue (IO-bound LLM calls) # Async emotion extraction → memory_tasks queue (IO-bound LLM calls)
'app.tasks.extract_emotion_batch': {'queue': 'memory_tasks'}, 'app.tasks.extract_emotion_batch': {'queue': 'memory_tasks'},
# Post-store dedup + alias merge → memory_tasks queue
'app.tasks.post_store_dedup_and_alias_merge': {'queue': 'memory_tasks'},
# Async metadata extraction → memory_tasks queue
'app.tasks.extract_metadata_batch': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker) # Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},

View File

@@ -252,7 +252,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development # TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = { fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary" 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
} }
# 注意:'id' 字段保留community 展开时需要用 community id 查询成员 statements # 注意:'id' 字段保留community 展开时需要用 community id 查询成员 statements

View File

@@ -16,7 +16,7 @@ logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化) # 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
_EXPAND_FIELDS_TO_REMOVE = { _EXPAND_FIELDS_TO_REMOVE = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary' 'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
} }

View File

@@ -18,7 +18,7 @@ async def get_chunked_dialogs(
Args: Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker) chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
end_user_id: Group identifier end_user_id: Group identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "...", "dialog_at": "..."}]
ref_id: Reference identifier ref_id: Reference identifier
config_id: Configuration ID for processing (used to load pruning config) config_id: Configuration ID for processing (used to load pruning config)
snapshot: Optional PipelineSnapshot instance for saving pruning output snapshot: Optional PipelineSnapshot instance for saving pruning output
@@ -47,7 +47,12 @@ async def get_chunked_dialogs(
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip(): if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files)) conversation_messages.append(ConversationMessage(
role=role,
msg=content.strip(),
dialog_at=msg.get("dialog_at"),
files=files,
))
if not conversation_messages: if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering") raise ValueError("Message list cannot be empty after filtering")
@@ -57,7 +62,7 @@ async def get_chunked_dialogs(
context=conversation_context, context=conversation_context,
ref_id=ref_id, ref_id=ref_id,
end_user_id=end_user_id, end_user_id=end_user_id,
config_id=config_id config_id=config_id,
) )
# step2: 语义剪枝步骤(在分块之前) # step2: 语义剪枝步骤(在分块之前)

View File

@@ -242,6 +242,7 @@ class ChunkerClient:
chunk = Chunk( chunk = Chunk(
content=f"{msg.role}: {sub_chunk_text}", content=f"{msg.role}: {sub_chunk_text}",
speaker=msg.role, # 直接继承角色 speaker=msg.role, # 直接继承角色
dialog_at=getattr(msg, "dialog_at", None),
metadata={ metadata={
"message_index": msg_idx, "message_index": msg_idx,
"message_role": msg.role, "message_role": msg.role,
@@ -257,6 +258,7 @@ class ChunkerClient:
chunk = Chunk( chunk = Chunk(
content=f"{msg.role}: {msg_content}", content=f"{msg.role}: {msg_content}",
speaker=msg.role, # 直接继承角色 speaker=msg.role, # 直接继承角色
dialog_at=getattr(msg, "dialog_at", None),
metadata={ metadata={
"message_index": msg_idx, "message_index": msg_idx,
"message_role": msg.role, "message_role": msg.role,

View File

@@ -62,7 +62,7 @@ class MemoryService:
"""写入记忆:对话 → 萃取 → 存储 → 聚类 → 摘要 """写入记忆:对话 → 萃取 → 存储 → 聚类 → 摘要
Args: Args:
messages: 结构化消息 [{"role": "user"/"assistant", "content": "..."}] messages: 结构化消息 [{"role": "user"/"assistant", "content": "...", "dialog_at": "..."}]
language: 语言 ("zh" | "en") language: 语言 ("zh" | "en")
ref_id: 引用 ID为空则自动生成 ref_id: 引用 ID为空则自动生成
is_pilot_run: 试运行模式(只萃取不写入) is_pilot_run: 试运行模式(只萃取不写入)

View File

@@ -106,7 +106,6 @@ class Edge(BaseModel):
end_user_id: End user ID for multi-tenancy end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this edge run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective) created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective)
""" """
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.") source: str = Field(..., description="The ID of the source node.")
@@ -114,7 +113,6 @@ class Edge(BaseModel):
end_user_id: str = Field(..., description="The end user ID of the edge.") end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
class ChunkEdge(Edge): class ChunkEdge(Edge):
@@ -191,14 +189,12 @@ class Node(BaseModel):
end_user_id: End user ID for multi-tenancy end_user_id: End user ID for multi-tenancy
run_id: Unique identifier for the pipeline run that created this node run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective) created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective)
""" """
id: str = Field(..., description="The unique identifier for the node.") id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.") name: str = Field(..., description="The name of the node.")
end_user_id: str = Field(..., description="The end user ID of the node.") end_user_id: str = Field(..., description="The end user ID of the node.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.") created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
class DialogueNode(Node): class DialogueNode(Node):
@@ -284,6 +280,7 @@ class StatementNode(Node):
temporal_info: TemporalInfo = Field(..., description="Temporal information") temporal_info: TemporalInfo = Field(..., description="Temporal information")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start") valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
dialog_at: Optional[datetime] = Field(None, description="Absolute timestamp of the conversation this statement belongs to")
# Embedding and other fields # Embedding and other fields
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
@@ -319,7 +316,7 @@ class StatementNode(Node):
description="Total number of times this node has been accessed" description="Total number of times this node has been accessed"
) )
@field_validator('valid_at', 'invalid_at', mode='before') @field_validator('valid_at', 'invalid_at', 'dialog_at', mode='before')
@classmethod @classmethod
def validate_datetime(cls, v): def validate_datetime(cls, v):
"""使用通用的历史日期解析函数""" """使用通用的历史日期解析函数"""

View File

@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
""" """
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.") msg: str = Field(..., description="The text content of the message.")
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of this message (ISO 8601).")
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True) files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
@@ -100,6 +101,7 @@ class Statement(BaseModel):
False, False,
description="Whether the statement reflects user's emotional state", description="Whether the statement reflects user's emotional state",
) )
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
class ConversationContext(BaseModel): class ConversationContext(BaseModel):
@@ -139,6 +141,7 @@ class Chunk(BaseModel):
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.") files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.") chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod @classmethod
@@ -155,6 +158,7 @@ class Chunk(BaseModel):
return cls( return cls(
content=f"{message.role}: {message.msg}", content=f"{message.role}: {message.msg}",
speaker=message.role, speaker=message.role,
dialog_at=message.dialog_at,
metadata=metadata or {} metadata=metadata or {}
) )
@@ -169,7 +173,6 @@ class DialogData(BaseModel):
ref_id: Reference ID linking to external dialog system ref_id: Reference ID linking to external dialog system
end_user_id: End user ID for multi-tenancy end_user_id: End user ID for multi-tenancy
created_at: Timestamp when the dialog was created created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs metadata: Additional metadata as key-value pairs
chunks: List of chunks from the conversation chunks: List of chunks from the conversation
config_id: Configuration ID used to process this dialog config_id: Configuration ID used to process this dialog
@@ -184,7 +187,6 @@ class DialogData(BaseModel):
end_user_id: str = Field(default=..., description="End user ID of dialogue data") end_user_id: str = Field(default=..., description="End user ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.") created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.") chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")

View File

@@ -198,11 +198,13 @@ class WritePipeline:
chunked_dialogs = await self._preprocess(messages, ref_id) chunked_dialogs = await self._preprocess(messages, ref_id)
s.metadata(chunks=sum(len(d.chunks) for d in chunked_dialogs)) s.metadata(chunks=sum(len(d.chunks) for d in chunked_dialogs))
# Step 2: 萃取 - 知识提取 # Step 2: 萃取 - 知识提取 + 第一层去重 + 别名归并(内存侧)
async with bear.step(2, 5, "萃取", "知识提取") as s: async with bear.step(2, 5, "萃取", "知识提取") as s:
extraction_result = await self._extract( extraction_result = await self._extract(
chunked_dialogs, is_pilot_run chunked_dialogs, is_pilot_run
) )
# 别名归并(内存侧):在写入前完成,确保写入的数据已归并
self._merge_alias_in_memory(extraction_result)
stats = extraction_result.stats stats = extraction_result.stats
s.metadata( s.metadata(
entities=stats["entity_count"], entities=stats["entity_count"],
@@ -222,15 +224,8 @@ class WritePipeline:
async with bear.step(3, 5, "存储", "写入 Neo4j"): async with bear.step(3, 5, "存储", "写入 Neo4j"):
await self._store(extraction_result) await self._store(extraction_result)
# Step 3.2: 别名归并 # Step 3.5: 异步后处理(别名归并 Neo4j 侧 + 第二层去重 + 情绪 + 元数据)
async with bear.step(3, 5, "别名归并", "处理别名属于关系"): await self._post_store_async_tasks(extraction_result)
await self._merge_alias_belongs_to(extraction_result)
# Step 3.5: 异步情绪提取fire-and-forget需在 _store 之后确保 Statement 节点已存在)
await self._extract_emotion(getattr(self, "_emotion_statements", []))
# Step 3.6: 异步元数据提取fire-and-forget需在 _store 之后确保 Entity 节点已存在)
await self._extract_metadata(extraction_result)
# Step 4: 聚类 - 增量更新社区(异步,不阻塞) # Step 4: 聚类 - 增量更新社区(异步,不阻塞)
async with bear.step(4, 5, "聚类", "增量更新社区") as s: async with bear.step(4, 5, "聚类", "增量更新社区") as s:
@@ -359,16 +354,17 @@ class WritePipeline:
# Snapshot: 图节点和边(去重前) # Snapshot: 图节点和边(去重前)
recorder.record_graph_before_dedup(graph) recorder.record_graph_before_dedup(graph)
# step3: 两阶段去重消歧 # step3: 第一层去重消歧(同一轮对话内的实体碎片合并)
# 第二层Neo4j 联合去重)后移到 _store 之后异步执行
dedup_result = await run_dedup( dedup_result = await run_dedup(
entity_nodes=graph.entity_nodes, entity_nodes=graph.entity_nodes,
statement_entity_edges=graph.stmt_entity_edges, statement_entity_edges=graph.stmt_entity_edges,
entity_entity_edges=graph.entity_entity_edges, entity_entity_edges=graph.entity_entity_edges,
dialog_data_list=dialog_data_list, dialog_data_list=dialog_data_list,
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
connector=self._neo4j_connector, connector=None,
llm_client=self._llm_client, llm_client=self._llm_client,
is_pilot_run=is_pilot_run, is_pilot_run=True,
progress_callback=self.progress_callback, progress_callback=self.progress_callback,
) )
@@ -455,29 +451,21 @@ class WritePipeline:
raise raise
# ────────────────────────────────────────────── # ──────────────────────────────────────────────
# Step 3.2: 别名归并 # Step 3.2: 别名归并(内存侧)
# ────────────────────────────────────────────── # ──────────────────────────────────────────────
async def _merge_alias_belongs_to(self, result: ExtractionResult) -> None: def _merge_alias_in_memory(self, result: ExtractionResult) -> None:
"""别名归并(内存侧):处理 predicate="别名属于" 的边。
在写入 Neo4j 之前执行,确保写入的数据已经完成别名归并:
- 将别名实体的 name 追加到目标实体的 aliases
- 将别名实体的 description 拼接到目标实体的 description
- 重定向指向别名节点的边到目标节点
纯内存操作,不涉及 Neo4j。
""" """
所有去重合并都可以使用这个这种统一的处理方式(未实现)
别名归并:处理 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边。
对每条 source -[EXTRACTED_RELATIONSHIP {predicate:"别名属于"}]-> target 边:
- 将 source.name 追加到 target.aliases去重
- 将 source.description append 进 target.description_listlist 形式)
同时在内存中同步更新 ExtractionResult.entity_nodes保持内存与 Neo4j 一致。
失败不中断主流程。
"""
from app.repositories.neo4j.cypher_queries import (
MERGE_ALIAS_BELONGS_TO,
REDIRECT_ALIAS_EDGES,
)
ALIAS_PREDICATE = "别名属于" ALIAS_PREDICATE = "别名属于"
# 筛选出所有 predicate="别名属于" 的边
alias_edges = [ alias_edges = [
e e
for e in result.entity_entity_edges for e in result.entity_entity_edges
@@ -490,10 +478,7 @@ class WritePipeline:
return return
try: try:
# ── 1. 在内存中同步更新 entity_nodes ──
entity_map = {e.id: e for e in result.entity_nodes} entity_map = {e.id: e for e in result.entity_nodes}
# 构建 alias_id → target_id 映射(别名节点 → 用户节点)
alias_to_target: dict[str, str] = {} alias_to_target: dict[str, str] = {}
for edge in alias_edges: for edge in alias_edges:
@@ -513,7 +498,7 @@ class WritePipeline:
source_name source_name
] ]
# 将 source.description append 进 target.description追加,分号分隔) # 将 source.description 拼接到 target.description分号分隔,去重
src_desc = (source_node.description or "").strip() src_desc = (source_node.description or "").strip()
if src_desc: if src_desc:
tgt_desc = (target_node.description or "").strip() tgt_desc = (target_node.description or "").strip()
@@ -522,12 +507,11 @@ class WritePipeline:
f"{tgt_desc}{src_desc}" if tgt_desc else src_desc f"{tgt_desc}{src_desc}" if tgt_desc else src_desc
) )
# ── 1.1 内存中重定向指向别名节点的边到用户节点 ── # 重定向指向别名节点的边到目标节点
alias_ids = set(alias_to_target.keys()) alias_ids = set(alias_to_target.keys())
redirected_ee_count = 0 redirected_ee_count = 0
redirected_se_count = 0 redirected_se_count = 0
# 重定向 entity_entity_edges排除"别名属于"边本身)
for edge in result.entity_entity_edges: for edge in result.entity_entity_edges:
rel_type = getattr(edge, "relation_type", "") rel_type = getattr(edge, "relation_type", "")
if rel_type == ALIAS_PREDICATE: if rel_type == ALIAS_PREDICATE:
@@ -539,39 +523,101 @@ class WritePipeline:
edge.target = alias_to_target[edge.target] edge.target = alias_to_target[edge.target]
redirected_ee_count += 1 redirected_ee_count += 1
# 重定向 stmt_entity_edges陈述句 → 实体边)
for edge in result.stmt_entity_edges: for edge in result.stmt_entity_edges:
if edge.target in alias_ids: if edge.target in alias_ids:
edge.target = alias_to_target[edge.target] edge.target = alias_to_target[edge.target]
redirected_se_count += 1 redirected_se_count += 1
logger.info( logger.info(
f"[AliasMerge] 内存同步完成,处理 {len(alias_edges)}'别名属于' 边," f"[AliasMerge] 内存归并完成,处理 {len(alias_edges)}'别名属于' 边,"
f"重定向 entity_entity 边 {redirected_ee_count} 次," f"重定向 entity_entity 边 {redirected_ee_count} 次,"
f"重定向 stmt_entity 边 {redirected_se_count}" f"重定向 stmt_entity 边 {redirected_se_count}"
) )
# ── 2. 写入 Neo4j别名属性归并 ──
records = await self._neo4j_connector.execute_query(
MERGE_ALIAS_BELONGS_TO,
end_user_id=self.end_user_id,
)
merged_count = len(records) if records else 0
logger.info(f"[AliasMerge] Neo4j 别名归并完成,影响 {merged_count} 条记录")
# ── 3. 写入 Neo4j重定向指向别名节点的边到用户节点 ──
redirect_records = await self._neo4j_connector.execute_query(
REDIRECT_ALIAS_EDGES,
end_user_id=self.end_user_id,
)
redirect_count = len(redirect_records) if redirect_records else 0
logger.info(
f"[AliasMerge] Neo4j 边重定向完成,影响 {redirect_count} 条记录"
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"[AliasMerge] 别名归并失败(不影响主流程): {e}", exc_info=True f"[AliasMerge] 内存归并失败(不影响主流程): {e}", exc_info=True
)
# ──────────────────────────────────────────────
# Step 3.5: 异步后处理Neo4j 别名归并 + 第二层去重)
# ──────────────────────────────────────────────
async def _post_store_async_tasks(self, result: ExtractionResult) -> None:
"""提交写入后的异步 Celery 任务(全部 fire-and-forget失败不影响主流程
1. Neo4j 别名归并 + 第二层去重
2. 异步情绪提取
3. 异步元数据提取
"""
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
collect_user_entities_for_metadata,
)
llm_model_id = (
str(self.memory_config.llm_model_id)
if self.memory_config.llm_model_id
else None
)
recorder = getattr(self, "_recorder", None)
snapshot_dir = (
recorder.snapshot_dir
if recorder is not None and recorder.enabled
else None
)
# ── 1. Neo4j 别名归并 + 第二层去重 ──
self._submit_celery_task(
"PostStore",
"app.tasks.post_store_dedup_and_alias_merge",
{
"end_user_id": self.end_user_id,
"entity_ids": [e.id for e in result.entity_nodes],
"llm_model_id": llm_model_id,
},
)
# ── 2. 异步情绪提取 ──
emotion_statements = getattr(self, "_emotion_statements", [])
if emotion_statements and llm_model_id:
self._submit_celery_task(
"Emotion",
"app.tasks.extract_emotion_batch",
{
"statements": emotion_statements,
"llm_model_id": llm_model_id,
"language": self.language,
"snapshot_dir": snapshot_dir,
},
)
# ── 3. 异步元数据提取 ──
user_entities = collect_user_entities_for_metadata(result.entity_nodes)
if user_entities and llm_model_id:
self._submit_celery_task(
"Metadata",
"app.tasks.extract_metadata_batch",
{
"user_entities": user_entities,
"llm_model_id": llm_model_id,
"language": self.language,
"snapshot_dir": snapshot_dir,
},
)
def _submit_celery_task(
self, label: str, task_name: str, kwargs: dict
) -> None:
"""提交 Celery 异步任务的通用方法。失败只记日志,不抛异常。"""
try:
from app.celery_app import celery_app
task_result = celery_app.send_task(task_name, kwargs=kwargs)
logger.info(f"[{label}] 异步任务已提交 - task_id={task_result.id}")
except Exception as e:
logger.error(
f"[{label}] 提交异步任务失败(不影响主流程): {e}",
exc_info=True,
) )
# ────────────────────────────────────────────── # ──────────────────────────────────────────────
@@ -625,127 +671,6 @@ class WritePipeline:
exc_info=True, exc_info=True,
) )
# ──────────────────────────────────────────────
# Step 4.5: 异步情绪提取
# fire-and-forget 提交 Celery 任务,不阻塞主流程
# ──────────────────────────────────────────────
async def _extract_emotion(self, emotion_statements: list) -> None:
"""提交异步情绪提取 Celery 任务。
从编排器收集的 user statement 列表中提取情绪,
异步回写到 Neo4j Statement 节点。失败不影响主流程。
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
通过 snapshot_dir 透传给 Celery 任务worker 端在完成 LLM 抽取后,
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json避免主进程重复调用 LLM。
"""
if not emotion_statements:
return
llm_model_id = (
str(self.memory_config.llm_model_id)
if self.memory_config.llm_model_id
else None
)
if not llm_model_id:
logger.warning("[Emotion] 无法提交情绪提取任务llm_model_id 为空")
return
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
recorder = getattr(self, "_recorder", None)
snapshot_dir = (
recorder.snapshot_dir
if recorder is not None and recorder.enabled
else None
)
try:
from app.celery_app import celery_app
result = celery_app.send_task(
"app.tasks.extract_emotion_batch",
kwargs={
"statements": emotion_statements,
"llm_model_id": llm_model_id,
"language": self.language,
"snapshot_dir": snapshot_dir,
},
)
logger.info(
f"[Emotion] 异步情绪提取任务已提交 - "
f"task_id = {result.id}, "
f"statement_count = {len(emotion_statements)}, "
f"snapshot_dir = {snapshot_dir}, "
f"source=async"
)
except Exception as e:
logger.error(
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
exc_info=True,
)
# ──────────────────────────────────────────────
# Step 3.6: 异步元数据提取
# fire-and-forget 提交 Celery 任务,不阻塞主流程
# ──────────────────────────────────────────────
async def _extract_metadata(self, result: ExtractionResult) -> None:
"""提交异步元数据提取 Celery 任务。
从去重后的用户实体 description 中提取结构化元数据,
异步回写到 Neo4j ExtractedEntity 节点。失败不影响主流程。
"""
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
collect_user_entities_for_metadata,
)
user_entities = collect_user_entities_for_metadata(result.entity_nodes)
if not user_entities:
return
llm_model_id = (
str(self.memory_config.llm_model_id)
if self.memory_config.llm_model_id
else None
)
if not llm_model_id:
logger.warning("[Metadata] 无法提交元数据提取任务llm_model_id 为空")
return
# 快照目录
recorder = getattr(self, "_recorder", None)
snapshot_dir = (
recorder.snapshot_dir
if recorder is not None and recorder.enabled
else None
)
try:
from app.celery_app import celery_app
task_result = celery_app.send_task(
"app.tasks.extract_metadata_batch",
kwargs={
"user_entities": user_entities,
"llm_model_id": llm_model_id,
"language": self.language,
"snapshot_dir": snapshot_dir,
},
)
logger.info(
f"[Metadata] 异步元数据提取任务已提交 - "
f"task_id = {task_result.id}, "
f"entity_count = {len(user_entities)}, "
f"snapshot_dir = {snapshot_dir}"
)
except Exception as e:
logger.error(
f"[Metadata] 提交元数据提取任务失败(不影响主流程): {e}",
exc_info=True,
)
# ────────────────────────────────────────────── # ──────────────────────────────────────────────
# Step 5: 摘要 # Step 5: 摘要
# + entity_description+ meta_data部分在此提取 # + entity_description+ meta_data部分在此提取

View File

@@ -183,14 +183,8 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
# 时间范围合并 # 时间范围合并
try: try:
# 统一使用 created_at / expired_at
if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at: if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at:
canonical.created_at = ent.created_at canonical.created_at = ent.created_at
if getattr(ent, "expired_at", None) and getattr(canonical, "expired_at", None):
if canonical.expired_at is None:
canonical.expired_at = ent.expired_at
elif ent.expired_at and ent.expired_at > canonical.expired_at:
canonical.expired_at = ent.expired_at
except Exception: except Exception:
pass pass

View File

@@ -65,7 +65,6 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
user_id=row.get("user_id") or "", user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "", apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")), created_at=_parse_dt(row.get("created_at")),
expired_at=_parse_dt(row.get("expired_at")) if row.get("expired_at") else None,
entity_idx=int(row.get("entity_idx") or 0), entity_idx=int(row.get("entity_idx") or 0),
statement_id=row.get("statement_id") or "", statement_id=row.get("statement_id") or "",
entity_type=row.get("entity_type") or "", entity_type=row.get("entity_type") or "",

View File

@@ -1089,7 +1089,6 @@ class ExtractionOrchestrator:
content=dialog_data.context.content if dialog_data.context else "", content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None, dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
metadata=dialog_data.metadata, metadata=dialog_data.metadata,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
) )
@@ -1109,7 +1108,6 @@ class ExtractionOrchestrator:
chunk_embedding=chunk.chunk_embedding, chunk_embedding=chunk.chunk_embedding,
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段 sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
metadata=chunk.metadata, metadata=chunk.metadata,
) )
chunk_nodes.append(chunk_node) chunk_nodes.append(chunk_node)
@@ -1175,7 +1173,6 @@ class ExtractionOrchestrator:
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
'temporal_validity') and statement.temporal_validity else None, 'temporal_validity') and statement.temporal_validity else None,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
# Emotion fields # Emotion fields
emotion_type=getattr(statement, 'emotion_type', None), emotion_type=getattr(statement, 'emotion_type', None),
@@ -1232,7 +1229,6 @@ class ExtractionOrchestrator:
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
) )
entity_nodes.append(entity_node) entity_nodes.append(entity_node)
@@ -1269,7 +1265,6 @@ class ExtractionOrchestrator:
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
valid_at=_tv.valid_at if _tv else None, valid_at=_tv.valid_at if _tv else None,
invalid_at=_tv.invalid_at if _tv else None, invalid_at=_tv.invalid_at if _tv else None,
) )

View File

@@ -912,6 +912,7 @@ class NewExtractionOrchestrator:
has_emotional_state=stmt_out.has_emotional_state, has_emotional_state=stmt_out.has_emotional_state,
triplet_extraction_info=triplet_info, triplet_extraction_info=triplet_info,
statement_embedding=stmt_embedding, statement_embedding=stmt_embedding,
dialog_at=getattr(chunk, "dialog_at", None),
**emotion_kwargs, **emotion_kwargs,
) )
new_statements.append(stmt) new_statements.append(stmt)

View File

@@ -215,7 +215,6 @@ async def _process_chunk_summary(
apply_id=dialog.end_user_id, apply_id=dialog.end_user_id,
run_id=dialog.run_id, # 使用 dialog 的 run_id run_id=dialog.run_id, # 使用 dialog 的 run_id
created_at=datetime.now(), created_at=datetime.now(),
expired_at=datetime(9999, 12, 31),
dialog_id=dialog.id, dialog_id=dialog.id,
chunk_ids=[chunk.id], chunk_ids=[chunk.id],
content=summary_text, content=summary_text,

View File

@@ -181,6 +181,7 @@ class StatementExtractor:
chunk_id=chunk.id, chunk_id=chunk.id,
end_user_id=end_user_id, end_user_id=end_user_id,
speaker=chunk_speaker, speaker=chunk_speaker,
dialog_at=getattr(chunk, "dialog_at", None),
has_unsolved_reference=getattr(extracted_stmt, "has_unsolved_reference", False), has_unsolved_reference=getattr(extracted_stmt, "has_unsolved_reference", False),
) )

View File

@@ -135,7 +135,6 @@ async def build_graph_nodes_and_edges(
content=dialog_data.context.content if dialog_data.context else "", content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None, dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
metadata=dialog_data.metadata, metadata=dialog_data.metadata,
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None, config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
) )
@@ -154,7 +153,6 @@ async def build_graph_nodes_and_edges(
chunk_embedding=chunk.chunk_embedding, chunk_embedding=chunk.chunk_embedding,
sequence_number=chunk_idx, sequence_number=chunk_idx,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
metadata=chunk.metadata, metadata=chunk.metadata,
) )
chunk_nodes.append(chunk_node) chunk_nodes.append(chunk_node)
@@ -227,7 +225,7 @@ async def build_graph_nodes_and_edges(
else None else None
), ),
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, dialog_at=getattr(statement, "dialog_at", None),
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None, config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
emotion_type=getattr(statement, "emotion_type", None), emotion_type=getattr(statement, "emotion_type", None),
emotion_intensity=getattr(statement, "emotion_intensity", None), emotion_intensity=getattr(statement, "emotion_intensity", None),
@@ -280,7 +278,6 @@ async def build_graph_nodes_and_edges(
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, run_id=dialog_data.run_id,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None, config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
) )
entity_nodes.append(entity_node) entity_nodes.append(entity_node)
@@ -320,7 +317,6 @@ async def build_graph_nodes_and_edges(
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, run_id=dialog_data.run_id,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
valid_at=_tv.valid_at if _tv else None, valid_at=_tv.valid_at if _tv else None,
invalid_at=_tv.invalid_at if _tv else None, invalid_at=_tv.invalid_at if _tv else None,
) )
@@ -382,7 +378,6 @@ async def build_graph_nodes_and_edges(
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, run_id=dialog_data.run_id,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
pair_id=pair_id, pair_id=pair_id,
dialog_id=dialog_data.id, dialog_id=dialog_data.id,
text=record["original_text"], text=record["original_text"],
@@ -408,7 +403,6 @@ async def build_graph_nodes_and_edges(
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, run_id=dialog_data.run_id,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
pair_id=pair_id, pair_id=pair_id,
dialog_id=dialog_data.id, dialog_id=dialog_data.id,
text=record["pruned_text"], text=record["pruned_text"],

View File

@@ -483,7 +483,7 @@ class ReflectionEngine:
result_data['memory_verifies'] = memory_verifies result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments result_data['quality_assessments'] = quality_assessments
conflicts_found = 0 # Initialize as integer 0 instead of empty string conflicts_found = 0 # Initialize as integer 0 instead of empty string
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"} REMOVE_KEYS = {"created_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
# Clean conflict_data, and memory_verify and quality_assessment # Clean conflict_data, and memory_verify and quality_assessment
cleaned_conflict_data = [] cleaned_conflict_data = []
for item in conflict_data: for item in conflict_data:

View File

@@ -26,7 +26,6 @@ async def _load_(data: List[Any]) -> List[Dict]:
"end_user_id", "end_user_id",
"chunk_id", "chunk_id",
"created_at", "created_at",
"expired_at",
"valid_at", "valid_at",
"invalid_at", "invalid_at",
] ]
@@ -93,7 +92,6 @@ async def get_data(result):
rel_filtered['run_id'] = value.get('run_id') rel_filtered['run_id'] = value.get('run_id')
rel_filtered['statement'] = value.get('statement') rel_filtered['statement'] = value.get('statement')
rel_filtered['statement_id'] = value.get('statement_id') rel_filtered['statement_id'] = value.get('statement_id')
rel_filtered['expired_at'] = value.get('expired_at')
rel_filtered['created_at'] = value.get('created_at') rel_filtered['created_at'] = value.get('created_at')
filtered_item[key] = value filtered_item[key] = value
elif key == 'entity2' and value is not None: elif key == 'entity2' and value is not None:

View File

@@ -37,7 +37,6 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
"apply_id": getattr(stmt, 'apply_id', None), "apply_id": getattr(stmt, 'apply_id', None),
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None), "run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
"created_at": getattr(stmt, 'created_at', None), "created_at": getattr(stmt, 'created_at', None),
"expired_at": getattr(stmt, 'expired_at', None),
# "created_at": getattr(statement, 'created_at', None), # "created_at": getattr(statement, 'created_at', None),
# "expired_at": None # Set to None or appropriate default # "expired_at": None # Set to None or appropriate default
} }
@@ -87,7 +86,6 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
"end_user_id": s.end_user_id, "end_user_id": s.end_user_id,
"run_id": s.run_id, "run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None, "created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
}) })
if not edges: if not edges:

View File

@@ -42,7 +42,6 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
"ref_id": dialogue.ref_id, "ref_id": dialogue.ref_id,
"name": dialogue.name, "name": dialogue.name,
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None, "created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
"content": dialogue.content, "content": dialogue.content,
"dialog_embedding": dialogue.dialog_embedding "dialog_embedding": dialogue.dialog_embedding
}) })
@@ -87,7 +86,6 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
"chunk_id": statement.chunk_id, "chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(), # "created_at": statement.created_at.isoformat(),
"created_at": statement.created_at.isoformat() if statement.created_at else None, "created_at": statement.created_at.isoformat() if statement.created_at else None,
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
"stmt_type": statement.stmt_type, "stmt_type": statement.stmt_type,
"temporal_info": statement.temporal_info.value, "temporal_info": statement.temporal_info.value,
"statement": statement.statement, "statement": statement.statement,
@@ -115,7 +113,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
"activation_value": statement.activation_value, "activation_value": statement.activation_value,
"access_history": statement.access_history if statement.access_history else [], "access_history": statement.access_history if statement.access_history else [],
"last_access_time": statement.last_access_time, "last_access_time": statement.last_access_time,
"access_count": statement.access_count "access_count": statement.access_count,
"dialog_at": statement.dialog_at.isoformat() if statement.dialog_at else None,
} }
flattened_statements.append(flattened_statement) flattened_statements.append(flattened_statement)
@@ -159,7 +158,6 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
"end_user_id": chunk.end_user_id, "end_user_id": chunk.end_user_id,
"run_id": chunk.run_id, "run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None, "created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
"dialog_id": chunk.dialog_id, "dialog_id": chunk.dialog_id,
"content": chunk.content, "content": chunk.content,
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
@@ -211,7 +209,6 @@ async def add_memory_summary_nodes(
"end_user_id": s.end_user_id, "end_user_id": s.end_user_id,
"run_id": s.run_id, "run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None, "created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
"dialog_id": s.dialog_id, "dialog_id": s.dialog_id,
"chunk_ids": s.chunk_ids, "chunk_ids": s.chunk_ids,
"content": s.content, "content": s.content,

View File

@@ -8,7 +8,6 @@ DIALOGUE_NODE_SAVE = """
n.run_id = dialogue.run_id, n.run_id = dialogue.run_id,
n.ref_id = dialogue.ref_id, n.ref_id = dialogue.ref_id,
n.created_at = dialogue.created_at, n.created_at = dialogue.created_at,
n.expired_at = dialogue.expired_at,
n.content = dialogue.content, n.content = dialogue.content,
n.dialog_embedding = dialogue.dialog_embedding n.dialog_embedding = dialogue.dialog_embedding
RETURN n.id AS uuid RETURN n.id AS uuid
@@ -32,7 +31,6 @@ SET s += {
emotion_keywords: statement.emotion_keywords, emotion_keywords: statement.emotion_keywords,
temporal_info: statement.temporal_info, temporal_info: statement.temporal_info,
created_at: statement.created_at, created_at: statement.created_at,
expired_at: statement.expired_at,
valid_at: coalesce(statement.valid_at, ""), valid_at: coalesce(statement.valid_at, ""),
invalid_at: coalesce(statement.invalid_at, ""), invalid_at: coalesce(statement.invalid_at, ""),
statement_embedding: statement.statement_embedding, statement_embedding: statement.statement_embedding,
@@ -41,7 +39,8 @@ SET s += {
activation_value: statement.activation_value, activation_value: statement.activation_value,
access_history: statement.access_history, access_history: statement.access_history,
last_access_time: statement.last_access_time, last_access_time: statement.last_access_time,
access_count: statement.access_count access_count: statement.access_count,
dialog_at: statement.dialog_at
} }
RETURN s.id AS uuid RETURN s.id AS uuid
""" """
@@ -64,7 +63,6 @@ SET c += {
end_user_id: chunk.end_user_id, end_user_id: chunk.end_user_id,
run_id: chunk.run_id, run_id: chunk.run_id,
created_at: chunk.created_at, created_at: chunk.created_at,
expired_at: chunk.expired_at,
dialog_id: chunk.dialog_id, dialog_id: chunk.dialog_id,
content: chunk.content, content: chunk.content,
speaker: chunk.speaker, speaker: chunk.speaker,
@@ -87,9 +85,6 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
e.created_at = CASE e.created_at = CASE
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at) WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
THEN entity.created_at ELSE e.created_at END, THEN entity.created_at ELSE e.created_at END,
e.expired_at = CASE
WHEN entity.expired_at IS NOT NULL AND (e.expired_at IS NULL OR entity.expired_at > e.expired_at)
THEN entity.expired_at ELSE e.expired_at END,
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END, e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END, e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END, e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END,
@@ -214,12 +209,61 @@ SET r.predicate = rel.predicate,
r.valid_at = coalesce(rel.valid_at, ""), r.valid_at = coalesce(rel.valid_at, ""),
r.invalid_at = coalesce(rel.invalid_at, ""), r.invalid_at = coalesce(rel.invalid_at, ""),
r.created_at = rel.created_at, r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.run_id = rel.run_id, r.run_id = rel.run_id,
r.end_user_id = rel.end_user_id r.end_user_id = rel.end_user_id
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
# 保存弱关系实体,设置 e.is_weak = true不维护 e.relations 聚合字段
WEAK_ENTITY_NODE_SAVE = """
UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += {
name: entity.name,
end_user_id: entity.end_user_id,
run_id: entity.run_id,
description: entity.description,
chunk_id: entity.chunk_id,
dialog_id: entity.dialog_id
}
// Independent weak flag仅标记弱关系不再维护 relations 聚合字段
SET e.is_weak = true
RETURN e.id AS id
"""
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true不维护 e.relations 字段
SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET o.is_strong = true
"""
DIALOGUE_STATEMENT_EDGE_SAVE = """
UNWIND $dialogue_statement_edges AS edge
// 支持按 uuid 或 ref_id 连接到 Dialogue避免因来源 ID 不一致而断链
MATCH (dialogue:Dialogue)
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
MATCH (statement:Statement {id: edge.target})
// 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id,
e.end_user_id = edge.end_user_id,
e.created_at = edge.created_at
RETURN e.uuid AS uuid
"""
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
CHUNK_STATEMENT_EDGE_SAVE = """ CHUNK_STATEMENT_EDGE_SAVE = """
UNWIND $chunk_statement_edges AS edge UNWIND $chunk_statement_edges AS edge
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id}) MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
@@ -227,8 +271,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement) MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
SET e.end_user_id = edge.end_user_id, SET e.end_user_id = edge.end_user_id,
e.run_id = edge.run_id, e.run_id = edge.run_id,
e.created_at = edge.created_at, e.created_at = edge.created_at
e.expired_at = edge.expired_at
RETURN e.id AS uuid RETURN e.id AS uuid
""" """
@@ -243,11 +286,89 @@ MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
SET r.end_user_id = rel.end_user_id, SET r.end_user_id = rel.end_user_id,
r.run_id = rel.run_id, r.run_id = rel.run_id,
r.created_at = rel.created_at, r.created_at = rel.created_at,
r.expired_at = rel.expired_at,
r.connect_strength = rel.connect_strength r.connect_strength = rel.connect_strength
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Statement.statement_embedding
STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
RETURN s.id AS id,
s.statement AS statement,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS chunk_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# 查询实体名称包含指定字符串的实体 # 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """ SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
@@ -259,7 +380,6 @@ RETURN e.id AS id,
e.end_user_id AS end_user_id, e.end_user_id AS end_user_id,
e.entity_type AS entity_type, e.entity_type AS entity_type,
e.created_at AS created_at, e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx, e.entity_idx AS entity_idx,
e.statement_id AS statement_id, e.statement_id AS statement_id,
e.description AS description, e.description AS description,
@@ -279,6 +399,72 @@ ORDER BY score DESC
LIMIT $limit LIMIT $limit
""" """
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
WITH e, score
With collect({entity: e, score: score}) AS fulltextResults
OPTIONAL MATCH (ae:ExtractedEntity)
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
AND ae.aliases IS NOT NULL
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
WITH fulltextResults, collect(ae) AS aliasEntities
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
CASE
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
ELSE 0.8
END
}]) AS row
WITH row.entity AS e, row.score AS score
WITH DISTINCT e, MAX(score) AS score
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
e.created_at AS created_at,
e.entity_idx AS entity_idx,
e.statement_id AS statement_id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT e.id) AS entity_ids,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用 # 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索 # # 同组group_id下按“精确名字或别名+可选类型一致”来检索
@@ -332,8 +518,7 @@ WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
RETURN d.id AS dialog_id, RETURN d.id AS dialog_id,
d.end_user_id AS end_user_id, d.end_user_id AS end_user_id,
d.content AS content, d.content AS content,
d.created_at AS created_at, d.created_at AS created_at
d.expired_at AS expired_at
ORDER BY d.created_at DESC ORDER BY d.created_at DESC
LIMIT $limit LIMIT $limit
""" """
@@ -347,7 +532,6 @@ RETURN c.id AS chunk_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
c.created_at AS created_at, c.created_at AS created_at,
c.expired_at AS expired_at,
c.sequence_number AS sequence_number c.sequence_number AS sequence_number
ORDER BY c.created_at DESC ORDER BY c.created_at DESC
LIMIT $limit LIMIT $limit
@@ -560,7 +744,6 @@ SET m += {
end_user_id: summary.end_user_id, end_user_id: summary.end_user_id,
run_id: summary.run_id, run_id: summary.run_id,
created_at: summary.created_at, created_at: summary.created_at,
expired_at: summary.expired_at,
dialog_id: summary.dialog_id, dialog_id: summary.dialog_id,
chunk_ids: summary.chunk_ids, chunk_ids: summary.chunk_ids,
content: summary.content, content: summary.content,
@@ -584,8 +767,7 @@ MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s) MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
SET r.end_user_id = e.end_user_id, SET r.end_user_id = e.end_user_id,
r.run_id = e.run_id, r.run_id = e.run_id,
r.created_at = e.created_at, r.created_at = e.created_at
r.expired_at = e.expired_at
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
@@ -614,8 +796,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
user_id: rel.user_id, user_id: rel.user_id,
apply_id: rel.apply_id, apply_id: rel.apply_id,
run_id: rel.run_id, run_id: rel.run_id,
created_at: rel.created_at, created_at: rel.created_at
expired_at: rel.expired_at
}]->(target) }]->(target)
) )
@@ -636,8 +817,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
user_id: rel.user_id, user_id: rel.user_id,
apply_id: rel.apply_id, apply_id: rel.apply_id,
run_id: rel.run_id, run_id: rel.run_id,
created_at: rel.created_at, created_at: rel.created_at
expired_at: rel.expired_at
}]->(canonical) }]->(canonical)
) )
@@ -678,7 +858,6 @@ neo4j_query_part = """
m.description as description, m.description as description,
m.statement_id as statement_id, m.statement_id as statement_id,
m.created_at as created_at, m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
elementId(rel) as rel_id, elementId(rel) as rel_id,
rel.predicate as predicate, rel.predicate as predicate,
@@ -698,7 +877,6 @@ neo4j_query_all = """
m.description as description, m.description as description,
m.statement_id as statement_id, m.statement_id as statement_id,
m.created_at as created_at, m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
elementId(rel) as rel_id, elementId(rel) as rel_id,
rel.predicate as predicate, rel.predicate as predicate,
@@ -1513,8 +1691,7 @@ SET o += {
dialog_id: orig.dialog_id, dialog_id: orig.dialog_id,
pair_id: orig.pair_id, pair_id: orig.pair_id,
text: orig.text, text: orig.text,
created_at: orig.created_at, created_at: orig.created_at
expired_at: orig.expired_at
} }
RETURN o.id AS uuid RETURN o.id AS uuid
""" """
@@ -1530,8 +1707,7 @@ SET pr += {
text: p.text, text: p.text,
memory_type: p.memory_type, memory_type: p.memory_type,
text_embedding: p.text_embedding, text_embedding: p.text_embedding,
created_at: p.created_at, created_at: p.created_at
expired_at: p.expired_at
} }
RETURN pr.id AS uuid RETURN pr.id AS uuid
""" """

View File

@@ -49,8 +49,6 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
# 处理datetime字段 # 处理datetime字段
if isinstance(n.get('created_at'), str): if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at']) n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
return DialogueNode(**n) return DialogueNode(**n)

View File

@@ -48,8 +48,6 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
# 处理datetime字段 # 处理datetime字段
if isinstance(n.get('created_at'), str): if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at']) n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n.get('expired_at'), str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
# 确保aliases字段存在且为列表 # 确保aliases字段存在且为列表
if 'aliases' not in n or n['aliases'] is None: if 'aliases' not in n or n['aliases'] is None:

View File

@@ -55,7 +55,6 @@ async def save_entities_and_relationships(
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None, 'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id, 'run_id': edge.run_id,
'end_user_id': edge.end_user_id, 'end_user_id': edge.end_user_id,
} }
@@ -115,7 +114,6 @@ async def save_statement_chunk_edges(
"end_user_id": edge.end_user_id, "end_user_id": edge.end_user_id,
"run_id": edge.run_id, "run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
}) })
try: try:
@@ -145,7 +143,6 @@ async def save_statement_entity_edges(
"run_id": edge.run_id, "run_id": edge.run_id,
"connect_strength": edge.connect_strength, "connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
} }
all_se_edges.append(edge_data) all_se_edges.append(edge_data)
@@ -313,7 +310,6 @@ async def save_dialog_and_statements_to_neo4j(
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None, 'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id, 'run_id': edge.run_id,
'end_user_id': edge.end_user_id, 'end_user_id': edge.end_user_id,
}) })
@@ -332,7 +328,6 @@ async def save_dialog_and_statements_to_neo4j(
"source": edge.source, "source": edge.source,
"target": edge.target, "target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id, "run_id": edge.run_id,
"end_user_id": edge.end_user_id, "end_user_id": edge.end_user_id,
}) })
@@ -350,7 +345,6 @@ async def save_dialog_and_statements_to_neo4j(
"source": edge.source, "source": edge.source,
"target": edge.target, "target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None, "created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id, "run_id": edge.run_id,
"end_user_id": edge.end_user_id, "end_user_id": edge.end_user_id,
"connect_strength": getattr(edge, "connect_strength", "strong"), "connect_strength": getattr(edge, "connect_strength", "strong"),

View File

@@ -232,8 +232,6 @@ async def neo4j_data(solved_data):
updata_entity = {} updata_entity = {}
ori_edge = {} ori_edge = {}
updata_edge = {} updata_edge = {}
ori_expired_at={}
updat_expired_at={}
for i in solved_data: for i in solved_data:
databasets = i['data'] databasets = i['data']
for key, values in databasets.items(): for key, values in databasets.items():
@@ -247,12 +245,9 @@ async def neo4j_data(solved_data):
key = 'name' key = 'name'
ori_entity[key] = values[0] ori_entity[key] = values[0]
updata_entity[key] = values[1] updata_entity[key] = values[1]
ori_expired_at[key] = values[0]
if key == 'statement': if key == 'statement':
ori_edge[key] = values[0] ori_edge[key] = values[0]
updata_edge[key] = values[1] updata_edge[key] = values[1]
if key=='expired_at':
updat_expired_at[key] = values[1]
elif key == 'id': elif key == 'id':
ori_edge[key] = values ori_edge[key] = values
@@ -260,8 +255,6 @@ async def neo4j_data(solved_data):
ori_entity[key] = values ori_entity[key] = values
updata_entity[key] = values updata_entity[key] = values
ori_expired_at[key] = values
elif key == 'rel_id': elif key == 'rel_id':
key='id' key='id'
ori_edge[key] = values ori_edge[key] = values
@@ -270,18 +263,12 @@ async def neo4j_data(solved_data):
ori_entity[key] = values ori_entity[key] = values
updata_entity[key] = values updata_entity[key] = values
ori_expired_at[key] = values
print(ori_entity) print(ori_entity)
print(updata_entity) print(updata_entity)
print(100*'-') print(100*'-')
print(ori_edge) print(ori_edge)
print(updata_edge) print(updata_edge)
expired_at_ = updat_expired_at.get('expired_at', None)
if expired_at_ is not None:
await update_neo4j_data(ori_expired_at, updat_expired_at)
success_count += 1
if ori_entity != updata_entity: if ori_entity != updata_entity:
await update_neo4j_data(ori_entity, updata_entity) await update_neo4j_data(ori_entity, updata_entity)
success_count += 1 success_count += 1

View File

@@ -50,12 +50,12 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
# 处理datetime字段 # 处理datetime字段
if isinstance(n.get('created_at'), str): if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at']) n['created_at'] = datetime.fromisoformat(n['created_at'])
if n.get('expired_at') and isinstance(n['expired_at'], str):
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
if n.get('valid_at') and isinstance(n['valid_at'], str): if n.get('valid_at') and isinstance(n['valid_at'], str):
n['valid_at'] = datetime.fromisoformat(n['valid_at']) n['valid_at'] = datetime.fromisoformat(n['valid_at'])
if n.get('invalid_at') and isinstance(n['invalid_at'], str): if n.get('invalid_at') and isinstance(n['invalid_at'], str):
n['invalid_at'] = datetime.fromisoformat(n['invalid_at']) n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
if n.get('dialog_at') and isinstance(n['dialog_at'], str):
n['dialog_at'] = datetime.fromisoformat(n['dialog_at'])
# 处理temporal_info字段 # 处理temporal_info字段
if isinstance(n.get('temporal_info'), str): if isinstance(n.get('temporal_info'), str):

View File

@@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
from pydantic import Field
from pydantic import BaseModel from pydantic import BaseModel
@@ -21,6 +22,10 @@ class MessageItem(BaseModel):
"""单条消息结构""" """单条消息结构"""
role: str role: str
content: str content: str
dialog_at: Optional[str] = Field(
None,
description="该条消息发生的绝对时间ISO 8601 格式),不传则使用服务端当前时间",
)
files: Optional[list[dict]] = None files: Optional[list[dict]] = None
file_content: Optional[list[Any]] = None file_content: Optional[list[Any]] = None

View File

@@ -457,7 +457,9 @@ class MemoryAgentService:
if use_new_pipeline: if use_new_pipeline:
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id) service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
result = await service.write(messages=messages_dict, language=language, ref_id='') result = await service.write(
messages=messages_dict, language=language, ref_id='',
)
logger.info( logger.info(
f"[NewPipeline] 完成: status={result.status}, " f"[NewPipeline] 完成: status={result.status}, "
f"elapsed={result.elapsed_seconds:.2f}s, " f"elapsed={result.elapsed_seconds:.2f}s, "

View File

@@ -1564,6 +1564,186 @@ def extract_emotion_batch_task(
_shutdown_loop_gracefully(loop) _shutdown_loop_gracefully(loop)
@celery_app.task(
bind=True,
name="app.tasks.post_store_dedup_and_alias_merge",
max_retries=1,
default_retry_delay=30,
)
def post_store_dedup_and_alias_merge_task(
self,
end_user_id: str,
entity_ids: List[str],
llm_model_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Celery task: 写入后异步执行 Neo4j 别名归并 + 第二层去重。
在主写入流水线将第一层去重结果写入 Neo4j 之后执行:
1. Neo4j 别名归并:将 "别名属于" 边的 source.name 合并到 target.aliases
2. Neo4j 边重定向:将指向别名节点的边重定向到目标节点
3. 第二层去重:与 Neo4j 中已有的同组实体做联合去重
Args:
end_user_id: 终端用户 ID
entity_ids: 本轮写入的实体 ID 列表(用于第二层去重的候选检索)
llm_model_id: LLM 模型 UUID用于第二层去重的 LLM 兜底判定)
"""
task_id = self.request.id
logger.info(
f"[PostStore] 开始异步别名归并+第二层去重: "
f"end_user_id={end_user_id}, entity_count={len(entity_ids)}, "
f"task_id={task_id}"
)
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.cypher_queries import (
MERGE_ALIAS_BELONGS_TO,
REDIRECT_ALIAS_EDGES,
)
connector = Neo4jConnector()
result_info: Dict[str, Any] = {}
try:
# ── 1. Neo4j 别名归并 ──
try:
records = await connector.execute_query(
MERGE_ALIAS_BELONGS_TO,
end_user_id=end_user_id,
)
merged_count = len(records) if records else 0
result_info["alias_merged"] = merged_count
logger.info(f"[PostStore] Neo4j 别名归并完成,影响 {merged_count} 条记录")
except Exception as e:
logger.warning(f"[PostStore] Neo4j 别名归并失败: {e}")
result_info["alias_merge_error"] = str(e)
# ── 2. Neo4j 边重定向 ──
try:
redirect_records = await connector.execute_query(
REDIRECT_ALIAS_EDGES,
end_user_id=end_user_id,
)
redirect_count = len(redirect_records) if redirect_records else 0
result_info["edges_redirected"] = redirect_count
logger.info(f"[PostStore] Neo4j 边重定向完成,影响 {redirect_count} 条记录")
except Exception as e:
logger.warning(f"[PostStore] Neo4j 边重定向失败: {e}")
result_info["redirect_error"] = str(e)
# ── 3. 第二层去重(与 Neo4j 已有实体联合去重) ──
try:
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j,
)
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
clean_cross_role_aliases,
)
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
# 从 Neo4j 加载本轮写入的实体(第一层去重后的结果)
load_query = """
UNWIND $entity_ids AS eid
MATCH (e:ExtractedEntity {id: eid})
RETURN e {.*} AS entity
"""
entity_records = await connector.execute_query(
load_query, entity_ids=entity_ids
)
if entity_records:
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
_row_to_entity,
)
current_entities = []
for rec in entity_records:
try:
entity_data = rec.get("entity") or rec
current_entities.append(_row_to_entity(entity_data))
except Exception:
pass
if current_entities:
# 构建 LLM client如果有 llm_model_id
llm_client = None
if llm_model_id:
try:
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_model_id)
except Exception as e:
logger.warning(f"[PostStore] 构建 LLM client 失败,跳过 LLM 兜底: {e}")
fused_entities, _, _ = await second_layer_dedup_and_merge_with_neo4j(
connector=connector,
end_user_id=end_user_id,
entity_nodes=current_entities,
statement_entity_edges=[],
entity_entity_edges=[],
llm_client=llm_client,
)
# 清洗跨角色别名污染
clean_cross_role_aliases(fused_entities)
# 将融合后的实体回写 Neo4j
if fused_entities:
entity_data = [e.model_dump() for e in fused_entities]
await connector.execute_query(
EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data
)
result_info["layer2_input"] = len(current_entities)
result_info["layer2_output"] = len(fused_entities)
logger.info(
f"[PostStore] 第二层去重完成: "
f"{len(current_entities)}{len(fused_entities)} 个实体"
)
else:
result_info["layer2_skipped"] = "no entities loaded"
else:
result_info["layer2_skipped"] = "no entity records found"
except Exception as e:
logger.warning(f"[PostStore] 第二层去重失败(不影响主流程): {e}", exc_info=True)
result_info["layer2_error"] = str(e)
finally:
await connector.close()
return result_info
loop = None
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed = time.time() - start_time
logger.info(
f"[PostStore] 任务完成: {result}, 耗时={elapsed:.2f}s, task_id={task_id}"
)
return {
"status": "SUCCESS",
**result,
"elapsed_time": elapsed,
"task_id": task_id,
}
except Exception as e:
elapsed = time.time() - start_time
logger.error(
f"[PostStore] 任务失败: {e}, 耗时={elapsed:.2f}s",
exc_info=True,
)
raise self.retry(exc=e)
finally:
if loop:
_shutdown_loop_gracefully(loop)
@celery_app.task( @celery_app.task(
bind=True, bind=True,
name="app.tasks.extract_metadata_batch", name="app.tasks.extract_metadata_batch",

View File

@@ -156,5 +156,5 @@ testpaths = ["tests"]
python_files = ["test_*.py"] python_files = ["test_*.py"]
python_classes = ["Test*"] python_classes = ["Test*"]
python_functions = ["test_*"] python_functions = ["test_*"]
# 使用 anyio 作为异步测试后端 # 使用 asyncio 作为异步测试后端
anyio_backends = ["asyncio"] asyncio_mode = "auto"