feat(memory): support perception-aware memory writing in workflow and Neo4j nodes
This commit is contained in:
@@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
@@ -13,28 +14,31 @@ from dotenv import load_dotenv
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.models import MemoryPerceptualModel
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \
|
||||
add_perceptual_dialogue_edges
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
file_content: list[MemoryPerceptualModel],
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
@@ -43,9 +47,12 @@ async def write(
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
file_content: mutilmodal message list
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
@@ -99,14 +106,14 @@ async def write(
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
@@ -173,7 +180,8 @@ async def write(
|
||||
schedule_clustering_after_write(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
embedding_model_id=str(
|
||||
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
)
|
||||
break
|
||||
else:
|
||||
@@ -208,9 +216,8 @@ async def write(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
@@ -223,6 +230,34 @@ async def write(
|
||||
finally:
|
||||
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Step 5: Save perceptual memory to Neo4j
|
||||
step_start = time.time()
|
||||
if file_content:
|
||||
try:
|
||||
pc_connector = Neo4jConnector()
|
||||
try:
|
||||
created_ids = await add_perceptual_nodes(
|
||||
perceptuals=file_content,
|
||||
connector=pc_connector,
|
||||
embedder_client=embedder_client,
|
||||
)
|
||||
# 如果有 ref_id,建立感知记忆与对话的关联
|
||||
if ref_id and created_ids:
|
||||
await add_perceptual_dialogue_edges(
|
||||
perceptuals=file_content,
|
||||
dialog_id=ref_id,
|
||||
connector=pc_connector,
|
||||
)
|
||||
logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j")
|
||||
finally:
|
||||
try:
|
||||
await pc_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True)
|
||||
log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file)
|
||||
|
||||
# Log total pipeline time
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||
@@ -251,4 +286,4 @@ async def write(
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -553,3 +553,21 @@ class MemorySummaryNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
|
||||
class MutlimodalNode(Node):
|
||||
"""Node representing a multimodal message in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
message_id: ID of the message
|
||||
metadata: Additional message metadata
|
||||
embedding: Optional embedding vector for the message
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
message_id: str = Field(..., description="ID of the message")
|
||||
summary: str = Field(..., description="The text content of the message")
|
||||
file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')")
|
||||
file_path: List[str] = Field(..., description="List of file paths for multimodal content")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional message metadata")
|
||||
embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message")
|
||||
|
||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client=None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
|
||||
@@ -374,7 +374,9 @@ class VariablePool:
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def is_file_variable(self, selector):
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||
if variable_struct is None:
|
||||
return False
|
||||
if isinstance(variable_struct, FileVariable):
|
||||
return True
|
||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||
|
||||
@@ -623,7 +623,6 @@ class BaseNode(ABC):
|
||||
async def process_message(
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
@@ -642,8 +641,8 @@ class BaseNode(ABC):
|
||||
return content
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
|
||||
return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
@@ -655,12 +654,11 @@ class BaseNode(ABC):
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
end_user_id,
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
|
||||
return message
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
|
||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
|
||||
@@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode):
|
||||
write_message_task.delay(
|
||||
end_user_id=end_user_id,
|
||||
message=messages,
|
||||
file_messages=multimodal_memories,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
|
||||
Reference in New Issue
Block a user