refactor(memory): remove expired_at field and add dialog_at timestamp
Remove the deprecated expired_at field from all graph models, Neo4j Cypher queries, repositories, and pipeline code. Replace with dialog_at on StatementNode to track the original dialog timestamp. - Strip expired_at from DialogueNode, ChunkNode, StatementNode, ExtractedEntityNode, edges, and all Cypher queries - Add dialog_at to MessageItem schema and propagate through extraction and graph build steps - Extract emotion/metadata async submission from WritePipeline into a generic _submit_celery_task helper - Add post_store_dedup_and_alias_merge Celery task for async alias merging and second-layer dedup after Neo4j write - Switch pytest async backend from anyio to asyncio_mode=auto
This commit is contained in:
@@ -117,6 +117,12 @@ celery_app.conf.update(
|
|||||||
# Async emotion extraction → memory_tasks queue (IO-bound LLM calls)
|
# Async emotion extraction → memory_tasks queue (IO-bound LLM calls)
|
||||||
'app.tasks.extract_emotion_batch': {'queue': 'memory_tasks'},
|
'app.tasks.extract_emotion_batch': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Post-store dedup + alias merge → memory_tasks queue
|
||||||
|
'app.tasks.post_store_dedup_and_alias_merge': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Async metadata extraction → memory_tasks queue
|
||||||
|
'app.tasks.extract_metadata_batch': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
'created_at', 'chunk_id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||||
}
|
}
|
||||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ logger = get_agent_logger(__name__)
|
|||||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||||
_EXPAND_FIELDS_TO_REMOVE = {
|
_EXPAND_FIELDS_TO_REMOVE = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
'created_at', 'chunk_id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ async def get_chunked_dialogs(
|
|||||||
Args:
|
Args:
|
||||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "...", "dialog_at": "..."}]
|
||||||
ref_id: Reference identifier
|
ref_id: Reference identifier
|
||||||
config_id: Configuration ID for processing (used to load pruning config)
|
config_id: Configuration ID for processing (used to load pruning config)
|
||||||
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
snapshot: Optional PipelineSnapshot instance for saving pruning output
|
||||||
@@ -47,7 +47,12 @@ async def get_chunked_dialogs(
|
|||||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||||
|
|
||||||
if content.strip():
|
if content.strip():
|
||||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
conversation_messages.append(ConversationMessage(
|
||||||
|
role=role,
|
||||||
|
msg=content.strip(),
|
||||||
|
dialog_at=msg.get("dialog_at"),
|
||||||
|
files=files,
|
||||||
|
))
|
||||||
|
|
||||||
if not conversation_messages:
|
if not conversation_messages:
|
||||||
raise ValueError("Message list cannot be empty after filtering")
|
raise ValueError("Message list cannot be empty after filtering")
|
||||||
@@ -57,7 +62,7 @@ async def get_chunked_dialogs(
|
|||||||
context=conversation_context,
|
context=conversation_context,
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
config_id=config_id
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# step2: 语义剪枝步骤(在分块之前)
|
# step2: 语义剪枝步骤(在分块之前)
|
||||||
|
|||||||
@@ -242,6 +242,7 @@ class ChunkerClient:
|
|||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
content=f"{msg.role}: {sub_chunk_text}",
|
content=f"{msg.role}: {sub_chunk_text}",
|
||||||
speaker=msg.role, # 直接继承角色
|
speaker=msg.role, # 直接继承角色
|
||||||
|
dialog_at=getattr(msg, "dialog_at", None),
|
||||||
metadata={
|
metadata={
|
||||||
"message_index": msg_idx,
|
"message_index": msg_idx,
|
||||||
"message_role": msg.role,
|
"message_role": msg.role,
|
||||||
@@ -257,6 +258,7 @@ class ChunkerClient:
|
|||||||
chunk = Chunk(
|
chunk = Chunk(
|
||||||
content=f"{msg.role}: {msg_content}",
|
content=f"{msg.role}: {msg_content}",
|
||||||
speaker=msg.role, # 直接继承角色
|
speaker=msg.role, # 直接继承角色
|
||||||
|
dialog_at=getattr(msg, "dialog_at", None),
|
||||||
metadata={
|
metadata={
|
||||||
"message_index": msg_idx,
|
"message_index": msg_idx,
|
||||||
"message_role": msg.role,
|
"message_role": msg.role,
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class MemoryService:
|
|||||||
"""写入记忆:对话 → 萃取 → 存储 → 聚类 → 摘要
|
"""写入记忆:对话 → 萃取 → 存储 → 聚类 → 摘要
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 结构化消息 [{"role": "user"/"assistant", "content": "..."}]
|
messages: 结构化消息 [{"role": "user"/"assistant", "content": "...", "dialog_at": "..."}]
|
||||||
language: 语言 ("zh" | "en")
|
language: 语言 ("zh" | "en")
|
||||||
ref_id: 引用 ID,为空则自动生成
|
ref_id: 引用 ID,为空则自动生成
|
||||||
is_pilot_run: 试运行模式(只萃取不写入)
|
is_pilot_run: 试运行模式(只萃取不写入)
|
||||||
|
|||||||
@@ -106,7 +106,6 @@ class Edge(BaseModel):
|
|||||||
end_user_id: End user ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
run_id: Unique identifier for the pipeline run that created this edge
|
run_id: Unique identifier for the pipeline run that created this edge
|
||||||
created_at: Timestamp when the edge was created (system perspective)
|
created_at: Timestamp when the edge was created (system perspective)
|
||||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
|
||||||
"""
|
"""
|
||||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||||
source: str = Field(..., description="The ID of the source node.")
|
source: str = Field(..., description="The ID of the source node.")
|
||||||
@@ -114,7 +113,6 @@ class Edge(BaseModel):
|
|||||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkEdge(Edge):
|
class ChunkEdge(Edge):
|
||||||
@@ -191,14 +189,12 @@ class Node(BaseModel):
|
|||||||
end_user_id: End user ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
run_id: Unique identifier for the pipeline run that created this node
|
run_id: Unique identifier for the pipeline run that created this node
|
||||||
created_at: Timestamp when the node was created (system perspective)
|
created_at: Timestamp when the node was created (system perspective)
|
||||||
expired_at: Optional timestamp when the node expires (system perspective)
|
|
||||||
"""
|
"""
|
||||||
id: str = Field(..., description="The unique identifier for the node.")
|
id: str = Field(..., description="The unique identifier for the node.")
|
||||||
name: str = Field(..., description="The name of the node.")
|
name: str = Field(..., description="The name of the node.")
|
||||||
end_user_id: str = Field(..., description="The end user ID of the node.")
|
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
|
||||||
|
|
||||||
|
|
||||||
class DialogueNode(Node):
|
class DialogueNode(Node):
|
||||||
@@ -284,6 +280,7 @@ class StatementNode(Node):
|
|||||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||||
|
dialog_at: Optional[datetime] = Field(None, description="Absolute timestamp of the conversation this statement belongs to")
|
||||||
|
|
||||||
# Embedding and other fields
|
# Embedding and other fields
|
||||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||||
@@ -319,7 +316,7 @@ class StatementNode(Node):
|
|||||||
description="Total number of times this node has been accessed"
|
description="Total number of times this node has been accessed"
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
@field_validator('valid_at', 'invalid_at', 'dialog_at', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_datetime(cls, v):
|
def validate_datetime(cls, v):
|
||||||
"""使用通用的历史日期解析函数"""
|
"""使用通用的历史日期解析函数"""
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||||
msg: str = Field(..., description="The text content of the message.")
|
msg: str = Field(..., description="The text content of the message.")
|
||||||
|
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of this message (ISO 8601).")
|
||||||
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -100,6 +101,7 @@ class Statement(BaseModel):
|
|||||||
False,
|
False,
|
||||||
description="Whether the statement reflects user's emotional state",
|
description="Whether the statement reflects user's emotional state",
|
||||||
)
|
)
|
||||||
|
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
|
||||||
|
|
||||||
|
|
||||||
class ConversationContext(BaseModel):
|
class ConversationContext(BaseModel):
|
||||||
@@ -139,6 +141,7 @@ class Chunk(BaseModel):
|
|||||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||||
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||||
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||||
|
dialog_at: Optional[str] = Field(None, description="Absolute timestamp of the source message (ISO 8601).")
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -155,6 +158,7 @@ class Chunk(BaseModel):
|
|||||||
return cls(
|
return cls(
|
||||||
content=f"{message.role}: {message.msg}",
|
content=f"{message.role}: {message.msg}",
|
||||||
speaker=message.role,
|
speaker=message.role,
|
||||||
|
dialog_at=message.dialog_at,
|
||||||
metadata=metadata or {}
|
metadata=metadata or {}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -169,7 +173,6 @@ class DialogData(BaseModel):
|
|||||||
ref_id: Reference ID linking to external dialog system
|
ref_id: Reference ID linking to external dialog system
|
||||||
end_user_id: End user ID for multi-tenancy
|
end_user_id: End user ID for multi-tenancy
|
||||||
created_at: Timestamp when the dialog was created
|
created_at: Timestamp when the dialog was created
|
||||||
expired_at: Timestamp when the dialog expires (default: far future)
|
|
||||||
metadata: Additional metadata as key-value pairs
|
metadata: Additional metadata as key-value pairs
|
||||||
chunks: List of chunks from the conversation
|
chunks: List of chunks from the conversation
|
||||||
config_id: Configuration ID used to process this dialog
|
config_id: Configuration ID used to process this dialog
|
||||||
@@ -184,7 +187,6 @@ class DialogData(BaseModel):
|
|||||||
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.")
|
||||||
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
|
chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")
|
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)")
|
||||||
|
|||||||
@@ -198,11 +198,13 @@ class WritePipeline:
|
|||||||
chunked_dialogs = await self._preprocess(messages, ref_id)
|
chunked_dialogs = await self._preprocess(messages, ref_id)
|
||||||
s.metadata(chunks=sum(len(d.chunks) for d in chunked_dialogs))
|
s.metadata(chunks=sum(len(d.chunks) for d in chunked_dialogs))
|
||||||
|
|
||||||
# Step 2: 萃取 - 知识提取
|
# Step 2: 萃取 - 知识提取 + 第一层去重 + 别名归并(内存侧)
|
||||||
async with bear.step(2, 5, "萃取", "知识提取") as s:
|
async with bear.step(2, 5, "萃取", "知识提取") as s:
|
||||||
extraction_result = await self._extract(
|
extraction_result = await self._extract(
|
||||||
chunked_dialogs, is_pilot_run
|
chunked_dialogs, is_pilot_run
|
||||||
)
|
)
|
||||||
|
# 别名归并(内存侧):在写入前完成,确保写入的数据已归并
|
||||||
|
self._merge_alias_in_memory(extraction_result)
|
||||||
stats = extraction_result.stats
|
stats = extraction_result.stats
|
||||||
s.metadata(
|
s.metadata(
|
||||||
entities=stats["entity_count"],
|
entities=stats["entity_count"],
|
||||||
@@ -222,15 +224,8 @@ class WritePipeline:
|
|||||||
async with bear.step(3, 5, "存储", "写入 Neo4j"):
|
async with bear.step(3, 5, "存储", "写入 Neo4j"):
|
||||||
await self._store(extraction_result)
|
await self._store(extraction_result)
|
||||||
|
|
||||||
# Step 3.2: 别名归并
|
# Step 3.5: 异步后处理(别名归并 Neo4j 侧 + 第二层去重 + 情绪 + 元数据)
|
||||||
async with bear.step(3, 5, "别名归并", "处理别名属于关系"):
|
await self._post_store_async_tasks(extraction_result)
|
||||||
await self._merge_alias_belongs_to(extraction_result)
|
|
||||||
|
|
||||||
# Step 3.5: 异步情绪提取(fire-and-forget,需在 _store 之后确保 Statement 节点已存在)
|
|
||||||
await self._extract_emotion(getattr(self, "_emotion_statements", []))
|
|
||||||
|
|
||||||
# Step 3.6: 异步元数据提取(fire-and-forget,需在 _store 之后确保 Entity 节点已存在)
|
|
||||||
await self._extract_metadata(extraction_result)
|
|
||||||
|
|
||||||
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
|
# Step 4: 聚类 - 增量更新社区(异步,不阻塞)
|
||||||
async with bear.step(4, 5, "聚类", "增量更新社区") as s:
|
async with bear.step(4, 5, "聚类", "增量更新社区") as s:
|
||||||
@@ -359,16 +354,17 @@ class WritePipeline:
|
|||||||
# Snapshot: 图节点和边(去重前)
|
# Snapshot: 图节点和边(去重前)
|
||||||
recorder.record_graph_before_dedup(graph)
|
recorder.record_graph_before_dedup(graph)
|
||||||
|
|
||||||
# step3: 两阶段去重消歧
|
# step3: 第一层去重消歧(同一轮对话内的实体碎片合并)
|
||||||
|
# 第二层(Neo4j 联合去重)后移到 _store 之后异步执行
|
||||||
dedup_result = await run_dedup(
|
dedup_result = await run_dedup(
|
||||||
entity_nodes=graph.entity_nodes,
|
entity_nodes=graph.entity_nodes,
|
||||||
statement_entity_edges=graph.stmt_entity_edges,
|
statement_entity_edges=graph.stmt_entity_edges,
|
||||||
entity_entity_edges=graph.entity_entity_edges,
|
entity_entity_edges=graph.entity_entity_edges,
|
||||||
dialog_data_list=dialog_data_list,
|
dialog_data_list=dialog_data_list,
|
||||||
pipeline_config=pipeline_config,
|
pipeline_config=pipeline_config,
|
||||||
connector=self._neo4j_connector,
|
connector=None,
|
||||||
llm_client=self._llm_client,
|
llm_client=self._llm_client,
|
||||||
is_pilot_run=is_pilot_run,
|
is_pilot_run=True,
|
||||||
progress_callback=self.progress_callback,
|
progress_callback=self.progress_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -455,29 +451,21 @@ class WritePipeline:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
# Step 3.2: 别名归并
|
# Step 3.2: 别名归并(内存侧)
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
async def _merge_alias_belongs_to(self, result: ExtractionResult) -> None:
|
def _merge_alias_in_memory(self, result: ExtractionResult) -> None:
|
||||||
|
"""别名归并(内存侧):处理 predicate="别名属于" 的边。
|
||||||
|
|
||||||
|
在写入 Neo4j 之前执行,确保写入的数据已经完成别名归并:
|
||||||
|
- 将别名实体的 name 追加到目标实体的 aliases
|
||||||
|
- 将别名实体的 description 拼接到目标实体的 description
|
||||||
|
- 重定向指向别名节点的边到目标节点
|
||||||
|
|
||||||
|
纯内存操作,不涉及 Neo4j。
|
||||||
"""
|
"""
|
||||||
所有去重合并都可以使用这个这种统一的处理方式(未实现)
|
|
||||||
别名归并:处理 predicate="别名属于" 的 EXTRACTED_RELATIONSHIP 边。
|
|
||||||
|
|
||||||
对每条 source -[EXTRACTED_RELATIONSHIP {predicate:"别名属于"}]-> target 边:
|
|
||||||
- 将 source.name 追加到 target.aliases(去重)
|
|
||||||
- 将 source.description append 进 target.description_list(list 形式)
|
|
||||||
|
|
||||||
同时在内存中同步更新 ExtractionResult.entity_nodes,保持内存与 Neo4j 一致。
|
|
||||||
失败不中断主流程。
|
|
||||||
"""
|
|
||||||
from app.repositories.neo4j.cypher_queries import (
|
|
||||||
MERGE_ALIAS_BELONGS_TO,
|
|
||||||
REDIRECT_ALIAS_EDGES,
|
|
||||||
)
|
|
||||||
|
|
||||||
ALIAS_PREDICATE = "别名属于"
|
ALIAS_PREDICATE = "别名属于"
|
||||||
|
|
||||||
# 筛选出所有 predicate="别名属于" 的边
|
|
||||||
alias_edges = [
|
alias_edges = [
|
||||||
e
|
e
|
||||||
for e in result.entity_entity_edges
|
for e in result.entity_entity_edges
|
||||||
@@ -490,10 +478,7 @@ class WritePipeline:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ── 1. 在内存中同步更新 entity_nodes ──
|
|
||||||
entity_map = {e.id: e for e in result.entity_nodes}
|
entity_map = {e.id: e for e in result.entity_nodes}
|
||||||
|
|
||||||
# 构建 alias_id → target_id 映射(别名节点 → 用户节点)
|
|
||||||
alias_to_target: dict[str, str] = {}
|
alias_to_target: dict[str, str] = {}
|
||||||
|
|
||||||
for edge in alias_edges:
|
for edge in alias_edges:
|
||||||
@@ -513,7 +498,7 @@ class WritePipeline:
|
|||||||
source_name
|
source_name
|
||||||
]
|
]
|
||||||
|
|
||||||
# 将 source.description append 进 target.description(追加,分号分隔)
|
# 将 source.description 拼接到 target.description(分号分隔,去重)
|
||||||
src_desc = (source_node.description or "").strip()
|
src_desc = (source_node.description or "").strip()
|
||||||
if src_desc:
|
if src_desc:
|
||||||
tgt_desc = (target_node.description or "").strip()
|
tgt_desc = (target_node.description or "").strip()
|
||||||
@@ -522,12 +507,11 @@ class WritePipeline:
|
|||||||
f"{tgt_desc};{src_desc}" if tgt_desc else src_desc
|
f"{tgt_desc};{src_desc}" if tgt_desc else src_desc
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── 1.1 内存中重定向指向别名节点的边到用户节点 ──
|
# 重定向指向别名节点的边到目标节点
|
||||||
alias_ids = set(alias_to_target.keys())
|
alias_ids = set(alias_to_target.keys())
|
||||||
redirected_ee_count = 0
|
redirected_ee_count = 0
|
||||||
redirected_se_count = 0
|
redirected_se_count = 0
|
||||||
|
|
||||||
# 重定向 entity_entity_edges(排除"别名属于"边本身)
|
|
||||||
for edge in result.entity_entity_edges:
|
for edge in result.entity_entity_edges:
|
||||||
rel_type = getattr(edge, "relation_type", "")
|
rel_type = getattr(edge, "relation_type", "")
|
||||||
if rel_type == ALIAS_PREDICATE:
|
if rel_type == ALIAS_PREDICATE:
|
||||||
@@ -539,39 +523,101 @@ class WritePipeline:
|
|||||||
edge.target = alias_to_target[edge.target]
|
edge.target = alias_to_target[edge.target]
|
||||||
redirected_ee_count += 1
|
redirected_ee_count += 1
|
||||||
|
|
||||||
# 重定向 stmt_entity_edges(陈述句 → 实体边)
|
|
||||||
for edge in result.stmt_entity_edges:
|
for edge in result.stmt_entity_edges:
|
||||||
if edge.target in alias_ids:
|
if edge.target in alias_ids:
|
||||||
edge.target = alias_to_target[edge.target]
|
edge.target = alias_to_target[edge.target]
|
||||||
redirected_se_count += 1
|
redirected_se_count += 1
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AliasMerge] 内存同步完成,处理 {len(alias_edges)} 条 '别名属于' 边,"
|
f"[AliasMerge] 内存归并完成,处理 {len(alias_edges)} 条 '别名属于' 边,"
|
||||||
f"重定向 entity_entity 边 {redirected_ee_count} 次,"
|
f"重定向 entity_entity 边 {redirected_ee_count} 次,"
|
||||||
f"重定向 stmt_entity 边 {redirected_se_count} 次"
|
f"重定向 stmt_entity 边 {redirected_se_count} 次"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── 2. 写入 Neo4j:别名属性归并 ──
|
|
||||||
records = await self._neo4j_connector.execute_query(
|
|
||||||
MERGE_ALIAS_BELONGS_TO,
|
|
||||||
end_user_id=self.end_user_id,
|
|
||||||
)
|
|
||||||
merged_count = len(records) if records else 0
|
|
||||||
logger.info(f"[AliasMerge] Neo4j 别名归并完成,影响 {merged_count} 条记录")
|
|
||||||
|
|
||||||
# ── 3. 写入 Neo4j:重定向指向别名节点的边到用户节点 ──
|
|
||||||
redirect_records = await self._neo4j_connector.execute_query(
|
|
||||||
REDIRECT_ALIAS_EDGES,
|
|
||||||
end_user_id=self.end_user_id,
|
|
||||||
)
|
|
||||||
redirect_count = len(redirect_records) if redirect_records else 0
|
|
||||||
logger.info(
|
|
||||||
f"[AliasMerge] Neo4j 边重定向完成,影响 {redirect_count} 条记录"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"[AliasMerge] 别名归并失败(不影响主流程): {e}", exc_info=True
|
f"[AliasMerge] 内存归并失败(不影响主流程): {e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
# Step 3.5: 异步后处理(Neo4j 别名归并 + 第二层去重)
|
||||||
|
# ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _post_store_async_tasks(self, result: ExtractionResult) -> None:
|
||||||
|
"""提交写入后的异步 Celery 任务(全部 fire-and-forget,失败不影响主流程):
|
||||||
|
|
||||||
|
1. Neo4j 别名归并 + 第二层去重
|
||||||
|
2. 异步情绪提取
|
||||||
|
3. 异步元数据提取
|
||||||
|
"""
|
||||||
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
|
||||||
|
collect_user_entities_for_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_model_id = (
|
||||||
|
str(self.memory_config.llm_model_id)
|
||||||
|
if self.memory_config.llm_model_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
recorder = getattr(self, "_recorder", None)
|
||||||
|
snapshot_dir = (
|
||||||
|
recorder.snapshot_dir
|
||||||
|
if recorder is not None and recorder.enabled
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 1. Neo4j 别名归并 + 第二层去重 ──
|
||||||
|
self._submit_celery_task(
|
||||||
|
"PostStore",
|
||||||
|
"app.tasks.post_store_dedup_and_alias_merge",
|
||||||
|
{
|
||||||
|
"end_user_id": self.end_user_id,
|
||||||
|
"entity_ids": [e.id for e in result.entity_nodes],
|
||||||
|
"llm_model_id": llm_model_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 2. 异步情绪提取 ──
|
||||||
|
emotion_statements = getattr(self, "_emotion_statements", [])
|
||||||
|
if emotion_statements and llm_model_id:
|
||||||
|
self._submit_celery_task(
|
||||||
|
"Emotion",
|
||||||
|
"app.tasks.extract_emotion_batch",
|
||||||
|
{
|
||||||
|
"statements": emotion_statements,
|
||||||
|
"llm_model_id": llm_model_id,
|
||||||
|
"language": self.language,
|
||||||
|
"snapshot_dir": snapshot_dir,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 3. 异步元数据提取 ──
|
||||||
|
user_entities = collect_user_entities_for_metadata(result.entity_nodes)
|
||||||
|
if user_entities and llm_model_id:
|
||||||
|
self._submit_celery_task(
|
||||||
|
"Metadata",
|
||||||
|
"app.tasks.extract_metadata_batch",
|
||||||
|
{
|
||||||
|
"user_entities": user_entities,
|
||||||
|
"llm_model_id": llm_model_id,
|
||||||
|
"language": self.language,
|
||||||
|
"snapshot_dir": snapshot_dir,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _submit_celery_task(
|
||||||
|
self, label: str, task_name: str, kwargs: dict
|
||||||
|
) -> None:
|
||||||
|
"""提交 Celery 异步任务的通用方法。失败只记日志,不抛异常。"""
|
||||||
|
try:
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
task_result = celery_app.send_task(task_name, kwargs=kwargs)
|
||||||
|
logger.info(f"[{label}] 异步任务已提交 - task_id={task_result.id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{label}] 提交异步任务失败(不影响主流程): {e}",
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
@@ -625,127 +671,6 @@ class WritePipeline:
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
|
||||||
# Step 4.5: 异步情绪提取
|
|
||||||
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
|
||||||
# ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _extract_emotion(self, emotion_statements: list) -> None:
|
|
||||||
"""提交异步情绪提取 Celery 任务。
|
|
||||||
|
|
||||||
从编排器收集的 user statement 列表中提取情绪,
|
|
||||||
异步回写到 Neo4j Statement 节点。失败不影响主流程。
|
|
||||||
|
|
||||||
在 PIPELINE_SNAPSHOT_ENABLED=true 时,会把当前运行的快照目录路径
|
|
||||||
通过 snapshot_dir 透传给 Celery 任务;worker 端在完成 LLM 抽取后,
|
|
||||||
将结果落盘到 <snapshot_dir>/4_emotion_outputs.json,避免主进程重复调用 LLM。
|
|
||||||
"""
|
|
||||||
if not emotion_statements:
|
|
||||||
return
|
|
||||||
|
|
||||||
llm_model_id = (
|
|
||||||
str(self.memory_config.llm_model_id)
|
|
||||||
if self.memory_config.llm_model_id
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if not llm_model_id:
|
|
||||||
logger.warning("[Emotion] 无法提交情绪提取任务:llm_model_id 为空")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 快照目录:仅在 PIPELINE_SNAPSHOT_ENABLED=true 时非空,供 worker 端落盘
|
|
||||||
recorder = getattr(self, "_recorder", None)
|
|
||||||
snapshot_dir = (
|
|
||||||
recorder.snapshot_dir
|
|
||||||
if recorder is not None and recorder.enabled
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
|
|
||||||
result = celery_app.send_task(
|
|
||||||
"app.tasks.extract_emotion_batch",
|
|
||||||
kwargs={
|
|
||||||
"statements": emotion_statements,
|
|
||||||
"llm_model_id": llm_model_id,
|
|
||||||
"language": self.language,
|
|
||||||
"snapshot_dir": snapshot_dir,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[Emotion] 异步情绪提取任务已提交 - "
|
|
||||||
f"task_id = {result.id}, "
|
|
||||||
f"statement_count = {len(emotion_statements)}, "
|
|
||||||
f"snapshot_dir = {snapshot_dir}, "
|
|
||||||
f"source=async"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[Emotion] 提交情绪提取任务失败(不影响主流程): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
|
||||||
# Step 3.6: 异步元数据提取
|
|
||||||
# fire-and-forget 提交 Celery 任务,不阻塞主流程
|
|
||||||
# ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _extract_metadata(self, result: ExtractionResult) -> None:
|
|
||||||
"""提交异步元数据提取 Celery 任务。
|
|
||||||
|
|
||||||
从去重后的用户实体 description 中提取结构化元数据,
|
|
||||||
异步回写到 Neo4j ExtractedEntity 节点。失败不影响主流程。
|
|
||||||
"""
|
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.metadata_extractor import (
|
|
||||||
collect_user_entities_for_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
user_entities = collect_user_entities_for_metadata(result.entity_nodes)
|
|
||||||
|
|
||||||
if not user_entities:
|
|
||||||
return
|
|
||||||
|
|
||||||
llm_model_id = (
|
|
||||||
str(self.memory_config.llm_model_id)
|
|
||||||
if self.memory_config.llm_model_id
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if not llm_model_id:
|
|
||||||
logger.warning("[Metadata] 无法提交元数据提取任务:llm_model_id 为空")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 快照目录
|
|
||||||
recorder = getattr(self, "_recorder", None)
|
|
||||||
snapshot_dir = (
|
|
||||||
recorder.snapshot_dir
|
|
||||||
if recorder is not None and recorder.enabled
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
|
|
||||||
task_result = celery_app.send_task(
|
|
||||||
"app.tasks.extract_metadata_batch",
|
|
||||||
kwargs={
|
|
||||||
"user_entities": user_entities,
|
|
||||||
"llm_model_id": llm_model_id,
|
|
||||||
"language": self.language,
|
|
||||||
"snapshot_dir": snapshot_dir,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[Metadata] 异步元数据提取任务已提交 - "
|
|
||||||
f"task_id = {task_result.id}, "
|
|
||||||
f"entity_count = {len(user_entities)}, "
|
|
||||||
f"snapshot_dir = {snapshot_dir}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[Metadata] 提交元数据提取任务失败(不影响主流程): {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────────
|
# ──────────────────────────────────────────────
|
||||||
# Step 5: 摘要
|
# Step 5: 摘要
|
||||||
# (+ entity_description)+ meta_data部分在此提取
|
# (+ entity_description)+ meta_data部分在此提取
|
||||||
|
|||||||
@@ -183,14 +183,8 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
|||||||
|
|
||||||
# 时间范围合并
|
# 时间范围合并
|
||||||
try:
|
try:
|
||||||
# 统一使用 created_at / expired_at
|
|
||||||
if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at:
|
if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at:
|
||||||
canonical.created_at = ent.created_at
|
canonical.created_at = ent.created_at
|
||||||
if getattr(ent, "expired_at", None) and getattr(canonical, "expired_at", None):
|
|
||||||
if canonical.expired_at is None:
|
|
||||||
canonical.expired_at = ent.expired_at
|
|
||||||
elif ent.expired_at and ent.expired_at > canonical.expired_at:
|
|
||||||
canonical.expired_at = ent.expired_at
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
user_id=row.get("user_id") or "",
|
user_id=row.get("user_id") or "",
|
||||||
apply_id=row.get("apply_id") or "",
|
apply_id=row.get("apply_id") or "",
|
||||||
created_at=_parse_dt(row.get("created_at")),
|
created_at=_parse_dt(row.get("created_at")),
|
||||||
expired_at=_parse_dt(row.get("expired_at")) if row.get("expired_at") else None,
|
|
||||||
entity_idx=int(row.get("entity_idx") or 0),
|
entity_idx=int(row.get("entity_idx") or 0),
|
||||||
statement_id=row.get("statement_id") or "",
|
statement_id=row.get("statement_id") or "",
|
||||||
entity_type=row.get("entity_type") or "",
|
entity_type=row.get("entity_type") or "",
|
||||||
|
|||||||
@@ -1089,7 +1089,6 @@ class ExtractionOrchestrator:
|
|||||||
content=dialog_data.context.content if dialog_data.context else "",
|
content=dialog_data.context.content if dialog_data.context else "",
|
||||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
metadata=dialog_data.metadata,
|
metadata=dialog_data.metadata,
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||||
)
|
)
|
||||||
@@ -1109,7 +1108,6 @@ class ExtractionOrchestrator:
|
|||||||
chunk_embedding=chunk.chunk_embedding,
|
chunk_embedding=chunk.chunk_embedding,
|
||||||
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
|
sequence_number=chunk_idx, # 添加必需的 sequence_number 字段
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
metadata=chunk.metadata,
|
metadata=chunk.metadata,
|
||||||
)
|
)
|
||||||
chunk_nodes.append(chunk_node)
|
chunk_nodes.append(chunk_node)
|
||||||
@@ -1175,7 +1173,6 @@ class ExtractionOrchestrator:
|
|||||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
|
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
|
||||||
'temporal_validity') and statement.temporal_validity else None,
|
'temporal_validity') and statement.temporal_validity else None,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||||
# Emotion fields
|
# Emotion fields
|
||||||
emotion_type=getattr(statement, 'emotion_type', None),
|
emotion_type=getattr(statement, 'emotion_type', None),
|
||||||
@@ -1232,7 +1229,6 @@ class ExtractionOrchestrator:
|
|||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||||
)
|
)
|
||||||
entity_nodes.append(entity_node)
|
entity_nodes.append(entity_node)
|
||||||
@@ -1269,7 +1265,6 @@ class ExtractionOrchestrator:
|
|||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
valid_at=_tv.valid_at if _tv else None,
|
valid_at=_tv.valid_at if _tv else None,
|
||||||
invalid_at=_tv.invalid_at if _tv else None,
|
invalid_at=_tv.invalid_at if _tv else None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -912,6 +912,7 @@ class NewExtractionOrchestrator:
|
|||||||
has_emotional_state=stmt_out.has_emotional_state,
|
has_emotional_state=stmt_out.has_emotional_state,
|
||||||
triplet_extraction_info=triplet_info,
|
triplet_extraction_info=triplet_info,
|
||||||
statement_embedding=stmt_embedding,
|
statement_embedding=stmt_embedding,
|
||||||
|
dialog_at=getattr(chunk, "dialog_at", None),
|
||||||
**emotion_kwargs,
|
**emotion_kwargs,
|
||||||
)
|
)
|
||||||
new_statements.append(stmt)
|
new_statements.append(stmt)
|
||||||
|
|||||||
@@ -215,7 +215,6 @@ async def _process_chunk_summary(
|
|||||||
apply_id=dialog.end_user_id,
|
apply_id=dialog.end_user_id,
|
||||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
expired_at=datetime(9999, 12, 31),
|
|
||||||
dialog_id=dialog.id,
|
dialog_id=dialog.id,
|
||||||
chunk_ids=[chunk.id],
|
chunk_ids=[chunk.id],
|
||||||
content=summary_text,
|
content=summary_text,
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ class StatementExtractor:
|
|||||||
chunk_id=chunk.id,
|
chunk_id=chunk.id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
speaker=chunk_speaker,
|
speaker=chunk_speaker,
|
||||||
|
dialog_at=getattr(chunk, "dialog_at", None),
|
||||||
has_unsolved_reference=getattr(extracted_stmt, "has_unsolved_reference", False),
|
has_unsolved_reference=getattr(extracted_stmt, "has_unsolved_reference", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -135,7 +135,6 @@ async def build_graph_nodes_and_edges(
|
|||||||
content=dialog_data.context.content if dialog_data.context else "",
|
content=dialog_data.context.content if dialog_data.context else "",
|
||||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None,
|
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, "dialog_embedding") else None,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
metadata=dialog_data.metadata,
|
metadata=dialog_data.metadata,
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||||
)
|
)
|
||||||
@@ -154,7 +153,6 @@ async def build_graph_nodes_and_edges(
|
|||||||
chunk_embedding=chunk.chunk_embedding,
|
chunk_embedding=chunk.chunk_embedding,
|
||||||
sequence_number=chunk_idx,
|
sequence_number=chunk_idx,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
metadata=chunk.metadata,
|
metadata=chunk.metadata,
|
||||||
)
|
)
|
||||||
chunk_nodes.append(chunk_node)
|
chunk_nodes.append(chunk_node)
|
||||||
@@ -227,7 +225,7 @@ async def build_graph_nodes_and_edges(
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
dialog_at=getattr(statement, "dialog_at", None),
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||||
emotion_type=getattr(statement, "emotion_type", None),
|
emotion_type=getattr(statement, "emotion_type", None),
|
||||||
emotion_intensity=getattr(statement, "emotion_intensity", None),
|
emotion_intensity=getattr(statement, "emotion_intensity", None),
|
||||||
@@ -280,7 +278,6 @@ async def build_graph_nodes_and_edges(
|
|||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id,
|
run_id=dialog_data.run_id,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, "config_id") else None,
|
||||||
)
|
)
|
||||||
entity_nodes.append(entity_node)
|
entity_nodes.append(entity_node)
|
||||||
@@ -320,7 +317,6 @@ async def build_graph_nodes_and_edges(
|
|||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id,
|
run_id=dialog_data.run_id,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
valid_at=_tv.valid_at if _tv else None,
|
valid_at=_tv.valid_at if _tv else None,
|
||||||
invalid_at=_tv.invalid_at if _tv else None,
|
invalid_at=_tv.invalid_at if _tv else None,
|
||||||
)
|
)
|
||||||
@@ -382,7 +378,6 @@ async def build_graph_nodes_and_edges(
|
|||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id,
|
run_id=dialog_data.run_id,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
pair_id=pair_id,
|
pair_id=pair_id,
|
||||||
dialog_id=dialog_data.id,
|
dialog_id=dialog_data.id,
|
||||||
text=record["original_text"],
|
text=record["original_text"],
|
||||||
@@ -408,7 +403,6 @@ async def build_graph_nodes_and_edges(
|
|||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id,
|
run_id=dialog_data.run_id,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
|
||||||
pair_id=pair_id,
|
pair_id=pair_id,
|
||||||
dialog_id=dialog_data.id,
|
dialog_id=dialog_data.id,
|
||||||
text=record["pruned_text"],
|
text=record["pruned_text"],
|
||||||
|
|||||||
@@ -483,7 +483,7 @@ class ReflectionEngine:
|
|||||||
result_data['memory_verifies'] = memory_verifies
|
result_data['memory_verifies'] = memory_verifies
|
||||||
result_data['quality_assessments'] = quality_assessments
|
result_data['quality_assessments'] = quality_assessments
|
||||||
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
REMOVE_KEYS = {"created_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||||
# Clean conflict_data, and memory_verify and quality_assessment
|
# Clean conflict_data, and memory_verify and quality_assessment
|
||||||
cleaned_conflict_data = []
|
cleaned_conflict_data = []
|
||||||
for item in conflict_data:
|
for item in conflict_data:
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
|||||||
"end_user_id",
|
"end_user_id",
|
||||||
"chunk_id",
|
"chunk_id",
|
||||||
"created_at",
|
"created_at",
|
||||||
"expired_at",
|
|
||||||
"valid_at",
|
"valid_at",
|
||||||
"invalid_at",
|
"invalid_at",
|
||||||
]
|
]
|
||||||
@@ -93,7 +92,6 @@ async def get_data(result):
|
|||||||
rel_filtered['run_id'] = value.get('run_id')
|
rel_filtered['run_id'] = value.get('run_id')
|
||||||
rel_filtered['statement'] = value.get('statement')
|
rel_filtered['statement'] = value.get('statement')
|
||||||
rel_filtered['statement_id'] = value.get('statement_id')
|
rel_filtered['statement_id'] = value.get('statement_id')
|
||||||
rel_filtered['expired_at'] = value.get('expired_at')
|
|
||||||
rel_filtered['created_at'] = value.get('created_at')
|
rel_filtered['created_at'] = value.get('created_at')
|
||||||
filtered_item[key] = value
|
filtered_item[key] = value
|
||||||
elif key == 'entity2' and value is not None:
|
elif key == 'entity2' and value is not None:
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect
|
|||||||
"apply_id": getattr(stmt, 'apply_id', None),
|
"apply_id": getattr(stmt, 'apply_id', None),
|
||||||
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
|
"run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None),
|
||||||
"created_at": getattr(stmt, 'created_at', None),
|
"created_at": getattr(stmt, 'created_at', None),
|
||||||
"expired_at": getattr(stmt, 'expired_at', None),
|
|
||||||
# "created_at": getattr(statement, 'created_at', None),
|
# "created_at": getattr(statement, 'created_at', None),
|
||||||
# "expired_at": None # Set to None or appropriate default
|
# "expired_at": None # Set to None or appropriate default
|
||||||
}
|
}
|
||||||
@@ -87,7 +86,6 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
"end_user_id": s.end_user_id,
|
"end_user_id": s.end_user_id,
|
||||||
"run_id": s.run_id,
|
"run_id": s.run_id,
|
||||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if not edges:
|
if not edges:
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn
|
|||||||
"ref_id": dialogue.ref_id,
|
"ref_id": dialogue.ref_id,
|
||||||
"name": dialogue.name,
|
"name": dialogue.name,
|
||||||
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
|
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
|
||||||
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
|
|
||||||
"content": dialogue.content,
|
"content": dialogue.content,
|
||||||
"dialog_embedding": dialogue.dialog_embedding
|
"dialog_embedding": dialogue.dialog_embedding
|
||||||
})
|
})
|
||||||
@@ -87,7 +86,6 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
"chunk_id": statement.chunk_id,
|
"chunk_id": statement.chunk_id,
|
||||||
# "created_at": statement.created_at.isoformat(),
|
# "created_at": statement.created_at.isoformat(),
|
||||||
"created_at": statement.created_at.isoformat() if statement.created_at else None,
|
"created_at": statement.created_at.isoformat() if statement.created_at else None,
|
||||||
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
|
|
||||||
"stmt_type": statement.stmt_type,
|
"stmt_type": statement.stmt_type,
|
||||||
"temporal_info": statement.temporal_info.value,
|
"temporal_info": statement.temporal_info.value,
|
||||||
"statement": statement.statement,
|
"statement": statement.statement,
|
||||||
@@ -115,7 +113,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
|||||||
"activation_value": statement.activation_value,
|
"activation_value": statement.activation_value,
|
||||||
"access_history": statement.access_history if statement.access_history else [],
|
"access_history": statement.access_history if statement.access_history else [],
|
||||||
"last_access_time": statement.last_access_time,
|
"last_access_time": statement.last_access_time,
|
||||||
"access_count": statement.access_count
|
"access_count": statement.access_count,
|
||||||
|
"dialog_at": statement.dialog_at.isoformat() if statement.dialog_at else None,
|
||||||
}
|
}
|
||||||
flattened_statements.append(flattened_statement)
|
flattened_statements.append(flattened_statement)
|
||||||
|
|
||||||
@@ -159,7 +158,6 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
|||||||
"end_user_id": chunk.end_user_id,
|
"end_user_id": chunk.end_user_id,
|
||||||
"run_id": chunk.run_id,
|
"run_id": chunk.run_id,
|
||||||
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
|
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
|
||||||
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
|
|
||||||
"dialog_id": chunk.dialog_id,
|
"dialog_id": chunk.dialog_id,
|
||||||
"content": chunk.content,
|
"content": chunk.content,
|
||||||
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
|
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
|
||||||
@@ -211,7 +209,6 @@ async def add_memory_summary_nodes(
|
|||||||
"end_user_id": s.end_user_id,
|
"end_user_id": s.end_user_id,
|
||||||
"run_id": s.run_id,
|
"run_id": s.run_id,
|
||||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
|
|
||||||
"dialog_id": s.dialog_id,
|
"dialog_id": s.dialog_id,
|
||||||
"chunk_ids": s.chunk_ids,
|
"chunk_ids": s.chunk_ids,
|
||||||
"content": s.content,
|
"content": s.content,
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ DIALOGUE_NODE_SAVE = """
|
|||||||
n.run_id = dialogue.run_id,
|
n.run_id = dialogue.run_id,
|
||||||
n.ref_id = dialogue.ref_id,
|
n.ref_id = dialogue.ref_id,
|
||||||
n.created_at = dialogue.created_at,
|
n.created_at = dialogue.created_at,
|
||||||
n.expired_at = dialogue.expired_at,
|
|
||||||
n.content = dialogue.content,
|
n.content = dialogue.content,
|
||||||
n.dialog_embedding = dialogue.dialog_embedding
|
n.dialog_embedding = dialogue.dialog_embedding
|
||||||
RETURN n.id AS uuid
|
RETURN n.id AS uuid
|
||||||
@@ -32,7 +31,6 @@ SET s += {
|
|||||||
emotion_keywords: statement.emotion_keywords,
|
emotion_keywords: statement.emotion_keywords,
|
||||||
temporal_info: statement.temporal_info,
|
temporal_info: statement.temporal_info,
|
||||||
created_at: statement.created_at,
|
created_at: statement.created_at,
|
||||||
expired_at: statement.expired_at,
|
|
||||||
valid_at: coalesce(statement.valid_at, ""),
|
valid_at: coalesce(statement.valid_at, ""),
|
||||||
invalid_at: coalesce(statement.invalid_at, ""),
|
invalid_at: coalesce(statement.invalid_at, ""),
|
||||||
statement_embedding: statement.statement_embedding,
|
statement_embedding: statement.statement_embedding,
|
||||||
@@ -41,7 +39,8 @@ SET s += {
|
|||||||
activation_value: statement.activation_value,
|
activation_value: statement.activation_value,
|
||||||
access_history: statement.access_history,
|
access_history: statement.access_history,
|
||||||
last_access_time: statement.last_access_time,
|
last_access_time: statement.last_access_time,
|
||||||
access_count: statement.access_count
|
access_count: statement.access_count,
|
||||||
|
dialog_at: statement.dialog_at
|
||||||
}
|
}
|
||||||
RETURN s.id AS uuid
|
RETURN s.id AS uuid
|
||||||
"""
|
"""
|
||||||
@@ -64,7 +63,6 @@ SET c += {
|
|||||||
end_user_id: chunk.end_user_id,
|
end_user_id: chunk.end_user_id,
|
||||||
run_id: chunk.run_id,
|
run_id: chunk.run_id,
|
||||||
created_at: chunk.created_at,
|
created_at: chunk.created_at,
|
||||||
expired_at: chunk.expired_at,
|
|
||||||
dialog_id: chunk.dialog_id,
|
dialog_id: chunk.dialog_id,
|
||||||
content: chunk.content,
|
content: chunk.content,
|
||||||
speaker: chunk.speaker,
|
speaker: chunk.speaker,
|
||||||
@@ -87,9 +85,6 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
|||||||
e.created_at = CASE
|
e.created_at = CASE
|
||||||
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
|
WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at)
|
||||||
THEN entity.created_at ELSE e.created_at END,
|
THEN entity.created_at ELSE e.created_at END,
|
||||||
e.expired_at = CASE
|
|
||||||
WHEN entity.expired_at IS NOT NULL AND (e.expired_at IS NULL OR entity.expired_at > e.expired_at)
|
|
||||||
THEN entity.expired_at ELSE e.expired_at END,
|
|
||||||
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
|
e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END,
|
||||||
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
|
e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END,
|
||||||
e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END,
|
e.type_description = CASE WHEN entity.type_description IS NOT NULL AND entity.type_description <> '' THEN entity.type_description ELSE coalesce(e.type_description, '') END,
|
||||||
@@ -214,12 +209,61 @@ SET r.predicate = rel.predicate,
|
|||||||
r.valid_at = coalesce(rel.valid_at, ""),
|
r.valid_at = coalesce(rel.valid_at, ""),
|
||||||
r.invalid_at = coalesce(rel.invalid_at, ""),
|
r.invalid_at = coalesce(rel.invalid_at, ""),
|
||||||
r.created_at = rel.created_at,
|
r.created_at = rel.created_at,
|
||||||
r.expired_at = rel.expired_at,
|
|
||||||
r.run_id = rel.run_id,
|
r.run_id = rel.run_id,
|
||||||
r.end_user_id = rel.end_user_id
|
r.end_user_id = rel.end_user_id
|
||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
||||||
|
|
||||||
|
# 保存弱关系实体,设置 e.is_weak = true;不维护 e.relations 聚合字段
|
||||||
|
WEAK_ENTITY_NODE_SAVE = """
|
||||||
|
UNWIND $weak_entities AS entity
|
||||||
|
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
|
||||||
|
SET e += {
|
||||||
|
name: entity.name,
|
||||||
|
end_user_id: entity.end_user_id,
|
||||||
|
run_id: entity.run_id,
|
||||||
|
description: entity.description,
|
||||||
|
chunk_id: entity.chunk_id,
|
||||||
|
dialog_id: entity.dialog_id
|
||||||
|
}
|
||||||
|
// Independent weak flag,仅标记弱关系,不再维护 relations 聚合字段
|
||||||
|
SET e.is_weak = true
|
||||||
|
RETURN e.id AS id
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true,不维护 e.relations 字段
|
||||||
|
SAVE_STRONG_TRIPLE_ENTITIES = """
|
||||||
|
UNWIND $items AS item
|
||||||
|
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
|
||||||
|
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
|
||||||
|
// Independent strong flag
|
||||||
|
SET s.is_strong = true
|
||||||
|
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
|
||||||
|
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
|
||||||
|
// Independent strong flag
|
||||||
|
SET o.is_strong = true
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
DIALOGUE_STATEMENT_EDGE_SAVE = """
|
||||||
|
UNWIND $dialogue_statement_edges AS edge
|
||||||
|
// 支持按 uuid 或 ref_id 连接到 Dialogue,避免因来源 ID 不一致而断链
|
||||||
|
MATCH (dialogue:Dialogue)
|
||||||
|
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
|
||||||
|
MATCH (statement:Statement {id: edge.target})
|
||||||
|
// 仅按端点去重,关系属性可更新
|
||||||
|
MERGE (dialogue)-[e:MENTIONS]->(statement)
|
||||||
|
SET e.uuid = edge.id,
|
||||||
|
e.end_user_id = edge.end_user_id,
|
||||||
|
e.created_at = edge.created_at
|
||||||
|
RETURN e.uuid AS uuid
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
||||||
|
|
||||||
|
|
||||||
CHUNK_STATEMENT_EDGE_SAVE = """
|
CHUNK_STATEMENT_EDGE_SAVE = """
|
||||||
UNWIND $chunk_statement_edges AS edge
|
UNWIND $chunk_statement_edges AS edge
|
||||||
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
||||||
@@ -227,8 +271,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """
|
|||||||
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
|
MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement)
|
||||||
SET e.end_user_id = edge.end_user_id,
|
SET e.end_user_id = edge.end_user_id,
|
||||||
e.run_id = edge.run_id,
|
e.run_id = edge.run_id,
|
||||||
e.created_at = edge.created_at,
|
e.created_at = edge.created_at
|
||||||
e.expired_at = edge.expired_at
|
|
||||||
RETURN e.id AS uuid
|
RETURN e.id AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -243,11 +286,89 @@ MERGE (statement)-[r:REFERENCES_ENTITY]->(entity)
|
|||||||
SET r.end_user_id = rel.end_user_id,
|
SET r.end_user_id = rel.end_user_id,
|
||||||
r.run_id = rel.run_id,
|
r.run_id = rel.run_id,
|
||||||
r.created_at = rel.created_at,
|
r.created_at = rel.created_at,
|
||||||
r.expired_at = rel.expired_at,
|
|
||||||
r.connect_strength = rel.connect_strength
|
r.connect_strength = rel.connect_strength
|
||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ENTITY_EMBEDDING_SEARCH = """
|
||||||
|
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
|
||||||
|
YIELD node AS e, score
|
||||||
|
WHERE e.name_embedding IS NOT NULL
|
||||||
|
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.end_user_id AS end_user_id,
|
||||||
|
e.entity_type AS entity_type,
|
||||||
|
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||||
|
e.last_access_time AS last_access_time,
|
||||||
|
COALESCE(e.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
# Embedding-based search: cosine similarity on Statement.statement_embedding
|
||||||
|
STATEMENT_EMBEDDING_SEARCH = """
|
||||||
|
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
|
||||||
|
YIELD node AS s, score
|
||||||
|
WHERE s.statement_embedding IS NOT NULL
|
||||||
|
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||||
|
RETURN s.id AS id,
|
||||||
|
s.statement AS statement,
|
||||||
|
s.end_user_id AS end_user_id,
|
||||||
|
s.chunk_id AS chunk_id,
|
||||||
|
s.created_at AS created_at,
|
||||||
|
s.valid_at AS valid_at,
|
||||||
|
s.invalid_at AS invalid_at,
|
||||||
|
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||||
|
s.last_access_time AS last_access_time,
|
||||||
|
COALESCE(s.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
|
||||||
|
CHUNK_EMBEDDING_SEARCH = """
|
||||||
|
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
|
||||||
|
YIELD node AS c, score
|
||||||
|
WHERE c.chunk_embedding IS NOT NULL
|
||||||
|
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
|
RETURN c.id AS chunk_id,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.content AS content,
|
||||||
|
c.dialog_id AS dialog_id,
|
||||||
|
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||||
|
c.last_access_time AS last_access_time,
|
||||||
|
COALESCE(c.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||||
|
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
|
||||||
|
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||||
|
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||||
|
RETURN s.id AS id,
|
||||||
|
s.statement AS statement,
|
||||||
|
s.end_user_id AS end_user_id,
|
||||||
|
s.chunk_id AS chunk_id,
|
||||||
|
s.created_at AS created_at,
|
||||||
|
s.valid_at AS valid_at,
|
||||||
|
s.invalid_at AS invalid_at,
|
||||||
|
c.id AS chunk_id_from_rel,
|
||||||
|
collect(DISTINCT e.id) AS entity_ids,
|
||||||
|
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||||
|
s.last_access_time AS last_access_time,
|
||||||
|
COALESCE(s.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
# 查询实体名称包含指定字符串的实体
|
# 查询实体名称包含指定字符串的实体
|
||||||
SEARCH_ENTITIES_BY_NAME = """
|
SEARCH_ENTITIES_BY_NAME = """
|
||||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||||
@@ -259,7 +380,6 @@ RETURN e.id AS id,
|
|||||||
e.end_user_id AS end_user_id,
|
e.end_user_id AS end_user_id,
|
||||||
e.entity_type AS entity_type,
|
e.entity_type AS entity_type,
|
||||||
e.created_at AS created_at,
|
e.created_at AS created_at,
|
||||||
e.expired_at AS expired_at,
|
|
||||||
e.entity_idx AS entity_idx,
|
e.entity_idx AS entity_idx,
|
||||||
e.statement_id AS statement_id,
|
e.statement_id AS statement_id,
|
||||||
e.description AS description,
|
e.description AS description,
|
||||||
@@ -279,6 +399,72 @@ ORDER BY score DESC
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
||||||
|
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||||
|
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
|
WITH e, score
|
||||||
|
With collect({entity: e, score: score}) AS fulltextResults
|
||||||
|
|
||||||
|
OPTIONAL MATCH (ae:ExtractedEntity)
|
||||||
|
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
||||||
|
AND ae.aliases IS NOT NULL
|
||||||
|
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
|
||||||
|
WITH fulltextResults, collect(ae) AS aliasEntities
|
||||||
|
|
||||||
|
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
||||||
|
CASE
|
||||||
|
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
|
||||||
|
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
|
||||||
|
ELSE 0.8
|
||||||
|
END
|
||||||
|
}]) AS row
|
||||||
|
WITH row.entity AS e, row.score AS score
|
||||||
|
WITH DISTINCT e, MAX(score) AS score
|
||||||
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.end_user_id AS end_user_id,
|
||||||
|
e.entity_type AS entity_type,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.entity_idx AS entity_idx,
|
||||||
|
e.statement_id AS statement_id,
|
||||||
|
e.description AS description,
|
||||||
|
e.aliases AS aliases,
|
||||||
|
e.name_embedding AS name_embedding,
|
||||||
|
e.connect_strength AS connect_strength,
|
||||||
|
collect(DISTINCT s.id) AS statement_ids,
|
||||||
|
collect(DISTINCT c.id) AS chunk_ids,
|
||||||
|
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||||
|
e.last_access_time AS last_access_time,
|
||||||
|
COALESCE(e.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
SEARCH_CHUNKS_BY_CONTENT = """
|
||||||
|
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
|
||||||
|
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
|
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||||
|
RETURN c.id AS chunk_id,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.content AS content,
|
||||||
|
c.dialog_id AS dialog_id,
|
||||||
|
c.sequence_number AS sequence_number,
|
||||||
|
collect(DISTINCT s.id) AS statement_ids,
|
||||||
|
collect(DISTINCT e.id) AS entity_ids,
|
||||||
|
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||||
|
c.last_access_time AS last_access_time,
|
||||||
|
COALESCE(c.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
||||||
|
|
||||||
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
||||||
@@ -332,8 +518,7 @@ WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
|
|||||||
RETURN d.id AS dialog_id,
|
RETURN d.id AS dialog_id,
|
||||||
d.end_user_id AS end_user_id,
|
d.end_user_id AS end_user_id,
|
||||||
d.content AS content,
|
d.content AS content,
|
||||||
d.created_at AS created_at,
|
d.created_at AS created_at
|
||||||
d.expired_at AS expired_at
|
|
||||||
ORDER BY d.created_at DESC
|
ORDER BY d.created_at DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
@@ -347,7 +532,6 @@ RETURN c.id AS chunk_id,
|
|||||||
c.content AS content,
|
c.content AS content,
|
||||||
c.dialog_id AS dialog_id,
|
c.dialog_id AS dialog_id,
|
||||||
c.created_at AS created_at,
|
c.created_at AS created_at,
|
||||||
c.expired_at AS expired_at,
|
|
||||||
c.sequence_number AS sequence_number
|
c.sequence_number AS sequence_number
|
||||||
ORDER BY c.created_at DESC
|
ORDER BY c.created_at DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
@@ -560,7 +744,6 @@ SET m += {
|
|||||||
end_user_id: summary.end_user_id,
|
end_user_id: summary.end_user_id,
|
||||||
run_id: summary.run_id,
|
run_id: summary.run_id,
|
||||||
created_at: summary.created_at,
|
created_at: summary.created_at,
|
||||||
expired_at: summary.expired_at,
|
|
||||||
dialog_id: summary.dialog_id,
|
dialog_id: summary.dialog_id,
|
||||||
chunk_ids: summary.chunk_ids,
|
chunk_ids: summary.chunk_ids,
|
||||||
content: summary.content,
|
content: summary.content,
|
||||||
@@ -584,8 +767,7 @@ MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id})
|
|||||||
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
|
MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s)
|
||||||
SET r.end_user_id = e.end_user_id,
|
SET r.end_user_id = e.end_user_id,
|
||||||
r.run_id = e.run_id,
|
r.run_id = e.run_id,
|
||||||
r.created_at = e.created_at,
|
r.created_at = e.created_at
|
||||||
r.expired_at = e.expired_at
|
|
||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -614,8 +796,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
|||||||
user_id: rel.user_id,
|
user_id: rel.user_id,
|
||||||
apply_id: rel.apply_id,
|
apply_id: rel.apply_id,
|
||||||
run_id: rel.run_id,
|
run_id: rel.run_id,
|
||||||
created_at: rel.created_at,
|
created_at: rel.created_at
|
||||||
expired_at: rel.expired_at
|
|
||||||
}]->(target)
|
}]->(target)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -636,8 +817,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END |
|
|||||||
user_id: rel.user_id,
|
user_id: rel.user_id,
|
||||||
apply_id: rel.apply_id,
|
apply_id: rel.apply_id,
|
||||||
run_id: rel.run_id,
|
run_id: rel.run_id,
|
||||||
created_at: rel.created_at,
|
created_at: rel.created_at
|
||||||
expired_at: rel.expired_at
|
|
||||||
}]->(canonical)
|
}]->(canonical)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -678,7 +858,6 @@ neo4j_query_part = """
|
|||||||
m.description as description,
|
m.description as description,
|
||||||
m.statement_id as statement_id,
|
m.statement_id as statement_id,
|
||||||
m.created_at as created_at,
|
m.created_at as created_at,
|
||||||
m.expired_at as expired_at,
|
|
||||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||||
elementId(rel) as rel_id,
|
elementId(rel) as rel_id,
|
||||||
rel.predicate as predicate,
|
rel.predicate as predicate,
|
||||||
@@ -698,7 +877,6 @@ neo4j_query_all = """
|
|||||||
m.description as description,
|
m.description as description,
|
||||||
m.statement_id as statement_id,
|
m.statement_id as statement_id,
|
||||||
m.created_at as created_at,
|
m.created_at as created_at,
|
||||||
m.expired_at as expired_at,
|
|
||||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||||
elementId(rel) as rel_id,
|
elementId(rel) as rel_id,
|
||||||
rel.predicate as predicate,
|
rel.predicate as predicate,
|
||||||
@@ -1513,8 +1691,7 @@ SET o += {
|
|||||||
dialog_id: orig.dialog_id,
|
dialog_id: orig.dialog_id,
|
||||||
pair_id: orig.pair_id,
|
pair_id: orig.pair_id,
|
||||||
text: orig.text,
|
text: orig.text,
|
||||||
created_at: orig.created_at,
|
created_at: orig.created_at
|
||||||
expired_at: orig.expired_at
|
|
||||||
}
|
}
|
||||||
RETURN o.id AS uuid
|
RETURN o.id AS uuid
|
||||||
"""
|
"""
|
||||||
@@ -1530,8 +1707,7 @@ SET pr += {
|
|||||||
text: p.text,
|
text: p.text,
|
||||||
memory_type: p.memory_type,
|
memory_type: p.memory_type,
|
||||||
text_embedding: p.text_embedding,
|
text_embedding: p.text_embedding,
|
||||||
created_at: p.created_at,
|
created_at: p.created_at
|
||||||
expired_at: p.expired_at
|
|
||||||
}
|
}
|
||||||
RETURN pr.id AS uuid
|
RETURN pr.id AS uuid
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -49,8 +49,6 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]):
|
|||||||
# 处理datetime字段
|
# 处理datetime字段
|
||||||
if isinstance(n.get('created_at'), str):
|
if isinstance(n.get('created_at'), str):
|
||||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
|
||||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
|
||||||
|
|
||||||
return DialogueNode(**n)
|
return DialogueNode(**n)
|
||||||
|
|
||||||
|
|||||||
@@ -48,8 +48,6 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
|||||||
# 处理datetime字段
|
# 处理datetime字段
|
||||||
if isinstance(n.get('created_at'), str):
|
if isinstance(n.get('created_at'), str):
|
||||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||||
if n.get('expired_at') and isinstance(n.get('expired_at'), str):
|
|
||||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
|
||||||
|
|
||||||
# 确保aliases字段存在且为列表
|
# 确保aliases字段存在且为列表
|
||||||
if 'aliases' not in n or n['aliases'] is None:
|
if 'aliases' not in n or n['aliases'] is None:
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ async def save_entities_and_relationships(
|
|||||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||||
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||||
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
'run_id': edge.run_id,
|
'run_id': edge.run_id,
|
||||||
'end_user_id': edge.end_user_id,
|
'end_user_id': edge.end_user_id,
|
||||||
}
|
}
|
||||||
@@ -115,7 +114,6 @@ async def save_statement_chunk_edges(
|
|||||||
"end_user_id": edge.end_user_id,
|
"end_user_id": edge.end_user_id,
|
||||||
"run_id": edge.run_id,
|
"run_id": edge.run_id,
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -145,7 +143,6 @@ async def save_statement_entity_edges(
|
|||||||
"run_id": edge.run_id,
|
"run_id": edge.run_id,
|
||||||
"connect_strength": edge.connect_strength,
|
"connect_strength": edge.connect_strength,
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
}
|
}
|
||||||
all_se_edges.append(edge_data)
|
all_se_edges.append(edge_data)
|
||||||
|
|
||||||
@@ -313,7 +310,6 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||||
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||||
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
'run_id': edge.run_id,
|
'run_id': edge.run_id,
|
||||||
'end_user_id': edge.end_user_id,
|
'end_user_id': edge.end_user_id,
|
||||||
})
|
})
|
||||||
@@ -332,7 +328,6 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
"source": edge.source,
|
"source": edge.source,
|
||||||
"target": edge.target,
|
"target": edge.target,
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
"run_id": edge.run_id,
|
"run_id": edge.run_id,
|
||||||
"end_user_id": edge.end_user_id,
|
"end_user_id": edge.end_user_id,
|
||||||
})
|
})
|
||||||
@@ -350,7 +345,6 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
"source": edge.source,
|
"source": edge.source,
|
||||||
"target": edge.target,
|
"target": edge.target,
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
"run_id": edge.run_id,
|
"run_id": edge.run_id,
|
||||||
"end_user_id": edge.end_user_id,
|
"end_user_id": edge.end_user_id,
|
||||||
"connect_strength": getattr(edge, "connect_strength", "strong"),
|
"connect_strength": getattr(edge, "connect_strength", "strong"),
|
||||||
|
|||||||
@@ -232,8 +232,6 @@ async def neo4j_data(solved_data):
|
|||||||
updata_entity = {}
|
updata_entity = {}
|
||||||
ori_edge = {}
|
ori_edge = {}
|
||||||
updata_edge = {}
|
updata_edge = {}
|
||||||
ori_expired_at={}
|
|
||||||
updat_expired_at={}
|
|
||||||
for i in solved_data:
|
for i in solved_data:
|
||||||
databasets = i['data']
|
databasets = i['data']
|
||||||
for key, values in databasets.items():
|
for key, values in databasets.items():
|
||||||
@@ -247,12 +245,9 @@ async def neo4j_data(solved_data):
|
|||||||
key = 'name'
|
key = 'name'
|
||||||
ori_entity[key] = values[0]
|
ori_entity[key] = values[0]
|
||||||
updata_entity[key] = values[1]
|
updata_entity[key] = values[1]
|
||||||
ori_expired_at[key] = values[0]
|
|
||||||
if key == 'statement':
|
if key == 'statement':
|
||||||
ori_edge[key] = values[0]
|
ori_edge[key] = values[0]
|
||||||
updata_edge[key] = values[1]
|
updata_edge[key] = values[1]
|
||||||
if key=='expired_at':
|
|
||||||
updat_expired_at[key] = values[1]
|
|
||||||
|
|
||||||
elif key == 'id':
|
elif key == 'id':
|
||||||
ori_edge[key] = values
|
ori_edge[key] = values
|
||||||
@@ -260,8 +255,6 @@ async def neo4j_data(solved_data):
|
|||||||
|
|
||||||
ori_entity[key] = values
|
ori_entity[key] = values
|
||||||
updata_entity[key] = values
|
updata_entity[key] = values
|
||||||
|
|
||||||
ori_expired_at[key] = values
|
|
||||||
elif key == 'rel_id':
|
elif key == 'rel_id':
|
||||||
key='id'
|
key='id'
|
||||||
ori_edge[key] = values
|
ori_edge[key] = values
|
||||||
@@ -270,18 +263,12 @@ async def neo4j_data(solved_data):
|
|||||||
ori_entity[key] = values
|
ori_entity[key] = values
|
||||||
updata_entity[key] = values
|
updata_entity[key] = values
|
||||||
|
|
||||||
ori_expired_at[key] = values
|
|
||||||
|
|
||||||
|
|
||||||
print(ori_entity)
|
print(ori_entity)
|
||||||
print(updata_entity)
|
print(updata_entity)
|
||||||
print(100*'-')
|
print(100*'-')
|
||||||
print(ori_edge)
|
print(ori_edge)
|
||||||
print(updata_edge)
|
print(updata_edge)
|
||||||
expired_at_ = updat_expired_at.get('expired_at', None)
|
|
||||||
if expired_at_ is not None:
|
|
||||||
await update_neo4j_data(ori_expired_at, updat_expired_at)
|
|
||||||
success_count += 1
|
|
||||||
if ori_entity != updata_entity:
|
if ori_entity != updata_entity:
|
||||||
await update_neo4j_data(ori_entity, updata_entity)
|
await update_neo4j_data(ori_entity, updata_entity)
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|||||||
@@ -50,12 +50,12 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
|||||||
# 处理datetime字段
|
# 处理datetime字段
|
||||||
if isinstance(n.get('created_at'), str):
|
if isinstance(n.get('created_at'), str):
|
||||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||||
if n.get('expired_at') and isinstance(n['expired_at'], str):
|
|
||||||
n['expired_at'] = datetime.fromisoformat(n['expired_at'])
|
|
||||||
if n.get('valid_at') and isinstance(n['valid_at'], str):
|
if n.get('valid_at') and isinstance(n['valid_at'], str):
|
||||||
n['valid_at'] = datetime.fromisoformat(n['valid_at'])
|
n['valid_at'] = datetime.fromisoformat(n['valid_at'])
|
||||||
if n.get('invalid_at') and isinstance(n['invalid_at'], str):
|
if n.get('invalid_at') and isinstance(n['invalid_at'], str):
|
||||||
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
|
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
|
||||||
|
if n.get('dialog_at') and isinstance(n['dialog_at'], str):
|
||||||
|
n['dialog_at'] = datetime.fromisoformat(n['dialog_at'])
|
||||||
|
|
||||||
# 处理temporal_info字段
|
# 处理temporal_info字段
|
||||||
if isinstance(n.get('temporal_info'), str):
|
if isinstance(n.get('temporal_info'), str):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -21,6 +22,10 @@ class MessageItem(BaseModel):
|
|||||||
"""单条消息结构"""
|
"""单条消息结构"""
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
dialog_at: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="该条消息发生的绝对时间(ISO 8601 格式),不传则使用服务端当前时间",
|
||||||
|
)
|
||||||
files: Optional[list[dict]] = None
|
files: Optional[list[dict]] = None
|
||||||
file_content: Optional[list[Any]] = None
|
file_content: Optional[list[Any]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -457,7 +457,9 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
if use_new_pipeline:
|
if use_new_pipeline:
|
||||||
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
|
service = MemoryService(memory_config=memory_config, end_user_id=end_user_id)
|
||||||
result = await service.write(messages=messages_dict, language=language, ref_id='')
|
result = await service.write(
|
||||||
|
messages=messages_dict, language=language, ref_id='',
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[NewPipeline] 完成: status={result.status}, "
|
f"[NewPipeline] 完成: status={result.status}, "
|
||||||
f"elapsed={result.elapsed_seconds:.2f}s, "
|
f"elapsed={result.elapsed_seconds:.2f}s, "
|
||||||
|
|||||||
180
api/app/tasks.py
180
api/app/tasks.py
@@ -1564,6 +1564,186 @@ def extract_emotion_batch_task(
|
|||||||
_shutdown_loop_gracefully(loop)
|
_shutdown_loop_gracefully(loop)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
bind=True,
|
||||||
|
name="app.tasks.post_store_dedup_and_alias_merge",
|
||||||
|
max_retries=1,
|
||||||
|
default_retry_delay=30,
|
||||||
|
)
|
||||||
|
def post_store_dedup_and_alias_merge_task(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
entity_ids: List[str],
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Celery task: 写入后异步执行 Neo4j 别名归并 + 第二层去重。
|
||||||
|
|
||||||
|
在主写入流水线将第一层去重结果写入 Neo4j 之后执行:
|
||||||
|
1. Neo4j 别名归并:将 "别名属于" 边的 source.name 合并到 target.aliases
|
||||||
|
2. Neo4j 边重定向:将指向别名节点的边重定向到目标节点
|
||||||
|
3. 第二层去重:与 Neo4j 中已有的同组实体做联合去重
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 终端用户 ID
|
||||||
|
entity_ids: 本轮写入的实体 ID 列表(用于第二层去重的候选检索)
|
||||||
|
llm_model_id: LLM 模型 UUID(用于第二层去重的 LLM 兜底判定)
|
||||||
|
"""
|
||||||
|
task_id = self.request.id
|
||||||
|
logger.info(
|
||||||
|
f"[PostStore] 开始异步别名归并+第二层去重: "
|
||||||
|
f"end_user_id={end_user_id}, entity_count={len(entity_ids)}, "
|
||||||
|
f"task_id={task_id}"
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
async def _run() -> Dict[str, Any]:
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.repositories.neo4j.cypher_queries import (
|
||||||
|
MERGE_ALIAS_BELONGS_TO,
|
||||||
|
REDIRECT_ALIAS_EDGES,
|
||||||
|
)
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
result_info: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── 1. Neo4j 别名归并 ──
|
||||||
|
try:
|
||||||
|
records = await connector.execute_query(
|
||||||
|
MERGE_ALIAS_BELONGS_TO,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
merged_count = len(records) if records else 0
|
||||||
|
result_info["alias_merged"] = merged_count
|
||||||
|
logger.info(f"[PostStore] Neo4j 别名归并完成,影响 {merged_count} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PostStore] Neo4j 别名归并失败: {e}")
|
||||||
|
result_info["alias_merge_error"] = str(e)
|
||||||
|
|
||||||
|
# ── 2. Neo4j 边重定向 ──
|
||||||
|
try:
|
||||||
|
redirect_records = await connector.execute_query(
|
||||||
|
REDIRECT_ALIAS_EDGES,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
redirect_count = len(redirect_records) if redirect_records else 0
|
||||||
|
result_info["edges_redirected"] = redirect_count
|
||||||
|
logger.info(f"[PostStore] Neo4j 边重定向完成,影响 {redirect_count} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PostStore] Neo4j 边重定向失败: {e}")
|
||||||
|
result_info["redirect_error"] = str(e)
|
||||||
|
|
||||||
|
# ── 3. 第二层去重(与 Neo4j 已有实体联合去重) ──
|
||||||
|
try:
|
||||||
|
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||||
|
second_layer_dedup_and_merge_with_neo4j,
|
||||||
|
)
|
||||||
|
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||||
|
clean_cross_role_aliases,
|
||||||
|
)
|
||||||
|
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
|
||||||
|
|
||||||
|
# 从 Neo4j 加载本轮写入的实体(第一层去重后的结果)
|
||||||
|
load_query = """
|
||||||
|
UNWIND $entity_ids AS eid
|
||||||
|
MATCH (e:ExtractedEntity {id: eid})
|
||||||
|
RETURN e {.*} AS entity
|
||||||
|
"""
|
||||||
|
entity_records = await connector.execute_query(
|
||||||
|
load_query, entity_ids=entity_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if entity_records:
|
||||||
|
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||||
|
_row_to_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_entities = []
|
||||||
|
for rec in entity_records:
|
||||||
|
try:
|
||||||
|
entity_data = rec.get("entity") or rec
|
||||||
|
current_entities.append(_row_to_entity(entity_data))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if current_entities:
|
||||||
|
# 构建 LLM client(如果有 llm_model_id)
|
||||||
|
llm_client = None
|
||||||
|
if llm_model_id:
|
||||||
|
try:
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
from app.db import get_db_context
|
||||||
|
with get_db_context() as db:
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
llm_client = factory.get_llm_client(llm_model_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PostStore] 构建 LLM client 失败,跳过 LLM 兜底: {e}")
|
||||||
|
|
||||||
|
fused_entities, _, _ = await second_layer_dedup_and_merge_with_neo4j(
|
||||||
|
connector=connector,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
entity_nodes=current_entities,
|
||||||
|
statement_entity_edges=[],
|
||||||
|
entity_entity_edges=[],
|
||||||
|
llm_client=llm_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 清洗跨角色别名污染
|
||||||
|
clean_cross_role_aliases(fused_entities)
|
||||||
|
|
||||||
|
# 将融合后的实体回写 Neo4j
|
||||||
|
if fused_entities:
|
||||||
|
entity_data = [e.model_dump() for e in fused_entities]
|
||||||
|
await connector.execute_query(
|
||||||
|
EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result_info["layer2_input"] = len(current_entities)
|
||||||
|
result_info["layer2_output"] = len(fused_entities)
|
||||||
|
logger.info(
|
||||||
|
f"[PostStore] 第二层去重完成: "
|
||||||
|
f"{len(current_entities)} → {len(fused_entities)} 个实体"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result_info["layer2_skipped"] = "no entities loaded"
|
||||||
|
else:
|
||||||
|
result_info["layer2_skipped"] = "no entity records found"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PostStore] 第二层去重失败(不影响主流程): {e}", exc_info=True)
|
||||||
|
result_info["layer2_error"] = str(e)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
return result_info
|
||||||
|
|
||||||
|
loop = None
|
||||||
|
try:
|
||||||
|
loop = set_asyncio_event_loop()
|
||||||
|
result = loop.run_until_complete(_run())
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info(
|
||||||
|
f"[PostStore] 任务完成: {result}, 耗时={elapsed:.2f}s, task_id={task_id}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "SUCCESS",
|
||||||
|
**result,
|
||||||
|
"elapsed_time": elapsed,
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.error(
|
||||||
|
f"[PostStore] 任务失败: {e}, 耗时={elapsed:.2f}s",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise self.retry(exc=e)
|
||||||
|
finally:
|
||||||
|
if loop:
|
||||||
|
_shutdown_loop_gracefully(loop)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
@celery_app.task(
|
||||||
bind=True,
|
bind=True,
|
||||||
name="app.tasks.extract_metadata_batch",
|
name="app.tasks.extract_metadata_batch",
|
||||||
|
|||||||
@@ -156,5 +156,5 @@ testpaths = ["tests"]
|
|||||||
python_files = ["test_*.py"]
|
python_files = ["test_*.py"]
|
||||||
python_classes = ["Test*"]
|
python_classes = ["Test*"]
|
||||||
python_functions = ["test_*"]
|
python_functions = ["test_*"]
|
||||||
# 使用 anyio 作为异步测试后端
|
# 使用 asyncio 作为异步测试后端
|
||||||
anyio_backends = ["asyncio"]
|
asyncio_mode = "auto"
|
||||||
|
|||||||
Reference in New Issue
Block a user