feat(memory): support perception-aware memory writing in workflow and Neo4j nodes
This commit is contained in:
@@ -9,21 +9,22 @@ Classes:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_config_logger, get_db_logger
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取数据库专用日志器
|
||||
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
|
||||
return memory_config_obj
|
||||
|
||||
@staticmethod
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -491,7 +492,10 @@ class MemoryConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
||||
def get_config_with_workspace(
|
||||
db: Session,
|
||||
config_id: uuid.UUID | int | str
|
||||
) -> Optional[tuple[MemoryConfig, Workspace]]:
|
||||
"""Get memory config and its associated workspace information
|
||||
|
||||
Args:
|
||||
@@ -506,8 +510,6 @@ class MemoryConfigRepository:
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
start_time = time.time()
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
@@ -594,7 +596,7 @@ class MemoryConfigRepository:
|
||||
|
||||
db_logger.debug(
|
||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||
return (config, workspace)
|
||||
return config, workspace
|
||||
|
||||
except ValueError:
|
||||
# Re-raise known business exceptions
|
||||
@@ -630,7 +632,7 @@ class MemoryConfigRepository:
|
||||
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
||||
"""
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
@@ -694,7 +696,7 @@ class MemoryConfigRepository:
|
||||
Optional[MemoryConfig]: 默认配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 优先查找显式标记为默认的配置
|
||||
stmt = (
|
||||
@@ -706,13 +708,13 @@ class MemoryConfigRepository:
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
config = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"找到默认配置: config_id={config.config_id}")
|
||||
return config
|
||||
|
||||
|
||||
# 回退:获取最早创建的活跃配置
|
||||
stmt = (
|
||||
select(MemoryConfig)
|
||||
@@ -723,25 +725,25 @@ class MemoryConfigRepository:
|
||||
.order_by(MemoryConfig.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
config = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}")
|
||||
else:
|
||||
db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
return config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_with_fallback(
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
config_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID
|
||||
) -> Optional[MemoryConfig]:
|
||||
"""获取记忆配置,支持回退到工作空间默认配置
|
||||
|
||||
@@ -756,19 +758,18 @@ class MemoryConfigRepository:
|
||||
Optional[MemoryConfig]: 配置对象,如果都不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}")
|
||||
|
||||
|
||||
if not config_id:
|
||||
db_logger.debug("config_id 为空,使用工作空间默认配置")
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
|
||||
config = db.get(MemoryConfig, config_id)
|
||||
|
||||
|
||||
if config:
|
||||
return config
|
||||
|
||||
|
||||
db_logger.warning(
|
||||
f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
return MemoryConfigRepository.get_workspace_default(db, workspace_id)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
|
||||
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
|
||||
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
@@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
|
||||
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
|
||||
return result
|
||||
|
||||
|
||||
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add dialogue nodes to Neo4j database.
|
||||
|
||||
@@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
print(f"Error creating statement nodes: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
"""Add chunk nodes to Neo4j in batch.
|
||||
|
||||
@@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
return None
|
||||
|
||||
|
||||
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]:
|
||||
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
|
||||
List[str]]:
|
||||
"""Add memory summary nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
@@ -211,7 +214,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
||||
"config_id": s.config_id, # 添加 config_id
|
||||
})
|
||||
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_NODE_SAVE,
|
||||
summaries=flattened
|
||||
@@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
return None
|
||||
|
||||
|
||||
async def add_perceptual_nodes(
|
||||
perceptuals: list,
|
||||
connector: Neo4jConnector,
|
||||
embedder_client=None,
|
||||
) -> Optional[List[str]]:
|
||||
"""Add perceptual memory nodes to Neo4j in batch.
|
||||
|
||||
Args:
|
||||
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
|
||||
connector: Neo4j connector instance
|
||||
embedder_client: Optional embedder client for generating summary embeddings
|
||||
|
||||
Returns:
|
||||
List of created node UUIDs or None if failed
|
||||
"""
|
||||
if not perceptuals:
|
||||
print("No perceptual nodes to add")
|
||||
return []
|
||||
|
||||
try:
|
||||
flattened = []
|
||||
for p in perceptuals:
|
||||
meta = p.meta_data or {}
|
||||
content_meta = meta.get("content", {})
|
||||
|
||||
# 生成 summary embedding(如果有 embedder_client)
|
||||
summary_embedding = None
|
||||
if embedder_client and p.summary:
|
||||
try:
|
||||
summary_embedding = (await embedder_client.response([p.summary]))[0]
|
||||
except Exception as emb_err:
|
||||
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||
|
||||
flattened.append({
|
||||
"id": str(p.id),
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"perceptual_type": p.perceptual_type,
|
||||
"file_path": p.file_path or "",
|
||||
"file_name": p.file_name or "",
|
||||
"file_ext": p.file_ext or "",
|
||||
"summary": p.summary or "",
|
||||
"keywords": content_meta.get("keywords", []),
|
||||
"topic": content_meta.get("topic", ""),
|
||||
"domain": content_meta.get("domain", ""),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
"summary_embedding": summary_embedding,
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
PERCEPTUAL_NODE_SAVE,
|
||||
perceptuals=flattened,
|
||||
)
|
||||
created_uuids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
|
||||
return created_uuids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def add_perceptual_dialogue_edges(
|
||||
perceptuals: list,
|
||||
dialog_id: str,
|
||||
connector: Neo4jConnector,
|
||||
) -> Optional[List[str]]:
|
||||
"""Add edges between Perceptual nodes and Dialogue nodes.
|
||||
|
||||
Args:
|
||||
perceptuals: List of MemoryPerceptualModel objects
|
||||
dialog_id: The dialogue ID (or ref_id) to link to
|
||||
connector: Neo4j connector instance
|
||||
|
||||
Returns:
|
||||
List of created edge element IDs or None if failed
|
||||
"""
|
||||
if not perceptuals or not dialog_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
edges = []
|
||||
for p in perceptuals:
|
||||
edges.append({
|
||||
"perceptual_id": str(p.id),
|
||||
"dialog_id": dialog_id,
|
||||
"end_user_id": str(p.end_user_id),
|
||||
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||
})
|
||||
|
||||
result = await connector.execute_query(
|
||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
|
||||
edges=edges,
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
|
||||
return created_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to save Perceptual-Dialogue edges: {e}")
|
||||
return None
|
||||
|
||||
@@ -1323,3 +1323,36 @@ RETURN s.statement AS statement,
|
||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
# 感知记忆节点保存
|
||||
PERCEPTUAL_NODE_SAVE = """
|
||||
UNWIND $perceptuals AS p
|
||||
MERGE (n:Perceptual {id: p.id})
|
||||
SET n += {
|
||||
id: p.id,
|
||||
end_user_id: p.end_user_id,
|
||||
perceptual_type: p.perceptual_type,
|
||||
file_path: p.file_path,
|
||||
file_name: p.file_name,
|
||||
file_ext: p.file_ext,
|
||||
summary: p.summary,
|
||||
keywords: p.keywords,
|
||||
topic: p.topic,
|
||||
domain: p.domain,
|
||||
created_at: p.created_at,
|
||||
summary_embedding: p.summary_embedding
|
||||
}
|
||||
RETURN n.id AS uuid
|
||||
"""
|
||||
|
||||
# 感知记忆与对话的关联边
|
||||
PERCEPTUAL_DIALOGUE_EDGE_SAVE = """
|
||||
UNWIND $edges AS edge
|
||||
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
|
||||
MATCH (d:Dialogue {end_user_id: edge.end_user_id})
|
||||
WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id
|
||||
MERGE (d)-[r:HAS_PERCEPTUAL]->(p)
|
||||
SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user