feat(memory): support perception-aware memory writing in workflow and Neo4j nodes

This commit is contained in:
Eternity
2026-03-23 16:33:25 +08:00
parent 31085ed678
commit 2ff81ba101
22 changed files with 820 additions and 519 deletions

View File

@@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract
pipeline. Only MemoryConfig is needed - clients are constructed internally. pipeline. Only MemoryConfig is needed - clients are constructed internally.
""" """
import asyncio import asyncio
import uuid
import time import time
from datetime import datetime from datetime import datetime
@@ -13,28 +14,31 @@ from dotenv import load_dotenv
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
from app.models import MemoryPerceptualModel
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \
add_perceptual_dialogue_edges
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
load_dotenv() load_dotenv()
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
async def write( async def write(
end_user_id: str, end_user_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list, messages: list,
ref_id: str = "wyl20251027", file_content: list[MemoryPerceptualModel],
language: str = "zh", ref_id: str = "",
language: str = "zh",
) -> None: ) -> None:
""" """
Execute the complete knowledge extraction pipeline. Execute the complete knowledge extraction pipeline.
@@ -43,9 +47,12 @@ async def write(
end_user_id: Group identifier end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027" file_content: mutilmodal message list
ref_id: Reference ID, defaults to ""
language: 语言类型 ("zh" 中文, "en" 英文),默认中文 language: 语言类型 ("zh" 中文, "en" 英文),默认中文
""" """
if not ref_id:
ref_id = uuid.uuid4().hex
# Extract config values # Extract config values
embedding_model_id = str(memory_config.embedding_model_id) embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy chunker_strategy = memory_config.chunker_strategy
@@ -173,7 +180,8 @@ async def write(
schedule_clustering_after_write( schedule_clustering_after_write(
all_entity_nodes, all_entity_nodes,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, embedding_model_id=str(
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
) )
break break
else: else:
@@ -208,9 +216,8 @@ async def write(
summaries = await memory_summary_generation( summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
) )
ms_connector = Neo4jConnector()
try: try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector)
finally: finally:
@@ -223,6 +230,34 @@ async def write(
finally: finally:
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file) log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
# Step 5: Save perceptual memory to Neo4j
step_start = time.time()
if file_content:
try:
pc_connector = Neo4jConnector()
try:
created_ids = await add_perceptual_nodes(
perceptuals=file_content,
connector=pc_connector,
embedder_client=embedder_client,
)
# 如果有 ref_id建立感知记忆与对话的关联
if ref_id and created_ids:
await add_perceptual_dialogue_edges(
perceptuals=file_content,
dialog_id=ref_id,
connector=pc_connector,
)
logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j")
finally:
try:
await pc_connector.close()
except Exception:
pass
except Exception as e:
logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True)
log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file)
# Log total pipeline time # Log total pipeline time
total_time = time.time() - pipeline_start total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file) log_time("TOTAL PIPELINE TIME", total_time, log_file)

View File

@@ -553,3 +553,21 @@ class MemorySummaryNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)" description="Total number of times this node has been accessed (reset to 1 on creation)"
) )
class MutlimodalNode(Node):
"""Node representing a multimodal message in the knowledge graph.
Attributes:
dialog_id: ID of the parent dialog
message_id: ID of the message
metadata: Additional message metadata
embedding: Optional embedding vector for the message
"""
dialog_id: str = Field(..., description="ID of the parent dialog")
message_id: str = Field(..., description="ID of the message")
summary: str = Field(..., description="The text content of the message")
file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')")
file_path: List[str] = Field(..., description="List of file paths for multimodal content")
metadata: dict = Field(default_factory=dict, description="Additional message metadata")
embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message")

View File

@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return( async def dedup_layers_and_merge_and_return(
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
pipeline_config: ExtractionPipelineConfig, pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None, connector: Optional[Neo4jConnector] = None,
llm_client = None, llm_client=None,
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
List[StatementChunkEdge], List[StatementChunkEdge],
List[StatementEntityEdge], List[StatementEntityEdge],
List[EntityEntityEdge], List[EntityEntityEdge],
dict, # 新增:返回去重详情 dict
]: ]:
""" """
执行两层实体去重与融合: 执行两层实体去重与融合:

View File

@@ -31,11 +31,10 @@ from app.core.memory.models.graph_models import (
ExtractedEntityNode, ExtractedEntityNode,
StatementChunkEdge, StatementChunkEdge,
StatementEntityEdge, StatementEntityEdge,
StatementNode, StatementNode
) )
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.variate_config import ( from app.core.memory.models.variate_config import (
ExtractionPipelineConfig, ExtractionPipelineConfig,
) )
@@ -46,7 +45,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
embedding_generation, embedding_generation,
generate_entity_embeddings_from_triplets, generate_entity_embeddings_from_triplets,
) )
# 导入各个提取模块 # 导入各个提取模块
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
StatementExtractor, StatementExtractor,
@@ -90,16 +88,16 @@ class ExtractionOrchestrator:
""" """
def __init__( def __init__(
self, self,
llm_client: LLMClient, llm_client: LLMClient,
embedder_client: OpenAIEmbedderClient, embedder_client: OpenAIEmbedderClient,
connector: Neo4jConnector, connector: Neo4jConnector,
config: Optional[ExtractionPipelineConfig] = None, config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
embedding_id: Optional[str] = None, embedding_id: Optional[str] = None,
ontology_types: Optional[OntologyTypeList] = None, ontology_types: Optional[OntologyTypeList] = None,
enable_general_types: bool = True, enable_general_types: bool = True,
language: str = "zh", language: str = "zh",
): ):
""" """
初始化流水线编排器 初始化流水线编排器
@@ -157,19 +155,25 @@ class ExtractionOrchestrator:
llm_client=llm_client, llm_client=llm_client,
config=self.config.statement_extraction, config=self.config.statement_extraction,
) )
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language) self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
language=language)
self.temporal_extractor = TemporalExtractor(llm_client=llm_client) self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
logger.info("ExtractionOrchestrator 初始化完成") logger.info("ExtractionOrchestrator 初始化完成")
async def run( async def run(
self, self,
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
is_pilot_run: bool = False, is_pilot_run: bool = False,
) -> Tuple[ ) -> tuple[
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], list[DialogueNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[ChunkNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[StatementNode],
list[ExtractedEntityNode],
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
dict
]: ]:
""" """
运行完整的知识提取流水线(优化版:并行执行) 运行完整的知识提取流水线(优化版:并行执行)
@@ -208,7 +212,6 @@ class ExtractionOrchestrator:
for dialog in dialog_data_list: for dialog in dialog_data_list:
for chunk in dialog.chunks: for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements) all_statements_list.extend(chunk.statements)
len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
@@ -230,10 +233,6 @@ class ExtractionOrchestrator:
all_entities_list.extend(triplet_info.entities) all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets) all_triplets_list.extend(triplet_info.triplets)
len(all_entities_list)
len(all_triplets_list)
sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果) # 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入") logger.info("步骤 3/6: 生成实体嵌入")
triplet_maps = await self._generate_entity_embeddings(triplet_maps) triplet_maps = await self._generate_entity_embeddings(triplet_maps)
@@ -287,8 +286,6 @@ class ExtractionOrchestrator:
dialog_data_list, dialog_data_list,
) )
logger.info(f"知识提取流水线运行完成({mode_str}") logger.info(f"知识提取流水线运行完成({mode_str}")
return result return result
@@ -297,7 +294,7 @@ class ExtractionOrchestrator:
raise raise
async def _extract_statements( async def _extract_statements(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[DialogData]: ) -> List[DialogData]:
""" """
从对话中提取陈述句(流式输出版本:边提取边发送进度) 从对话中提取陈述句(流式输出版本:边提取边发送进度)
@@ -395,7 +392,7 @@ class ExtractionOrchestrator:
return dialog_data_list return dialog_data_list
async def _extract_triplets( async def _extract_triplets(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取三元组(流式输出版本:边提取边发送进度) 从对话中提取三元组(流式输出版本:边提取边发送进度)
@@ -478,7 +475,7 @@ class ExtractionOrchestrator:
return triplet_maps return triplet_maps
async def _extract_temporal( async def _extract_temporal(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取时间信息(流式输出版本:边提取边发送进度) 从对话中提取时间信息(流式输出版本:边提取边发送进度)
@@ -585,7 +582,7 @@ class ExtractionOrchestrator:
return temporal_maps return temporal_maps
async def _extract_emotions( async def _extract_emotions(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
@@ -706,7 +703,7 @@ class ExtractionOrchestrator:
return emotion_maps return emotion_maps
async def _parallel_extract_and_embed( async def _parallel_extract_and_embed(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[ ) -> Tuple[
List[Dict[str, Any]], List[Dict[str, Any]],
List[Dict[str, Any]], List[Dict[str, Any]],
@@ -777,7 +774,7 @@ class ExtractionOrchestrator:
) )
async def _generate_basic_embeddings( async def _generate_basic_embeddings(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
""" """
生成基础嵌入向量(陈述句、分块、对话) 生成基础嵌入向量(陈述句、分块、对话)
@@ -836,7 +833,7 @@ class ExtractionOrchestrator:
) )
async def _generate_entity_embeddings( async def _generate_entity_embeddings(
self, triplet_maps: List[Dict[str, Any]] self, triplet_maps: List[Dict[str, Any]]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
生成实体嵌入向量 生成实体嵌入向量
@@ -874,17 +871,15 @@ class ExtractionOrchestrator:
logger.error(f"实体嵌入生成失败: {e}", exc_info=True) logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
return triplet_maps return triplet_maps
async def _assign_extracted_data( async def _assign_extracted_data(
self, self,
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
temporal_maps: List[Dict[str, Any]], temporal_maps: List[Dict[str, Any]],
triplet_maps: List[Dict[str, Any]], triplet_maps: List[Dict[str, Any]],
emotion_maps: List[Dict[str, Any]], emotion_maps: List[Dict[str, Any]],
statement_embedding_maps: List[Dict[str, List[float]]], statement_embedding_maps: List[Dict[str, List[float]]],
chunk_embedding_maps: List[Dict[str, List[float]]], chunk_embedding_maps: List[Dict[str, List[float]]],
dialog_embeddings: List[List[float]], dialog_embeddings: List[List[float]],
) -> List[DialogData]: ) -> List[DialogData]:
""" """
将提取的数据赋值到语句 将提取的数据赋值到语句
@@ -906,12 +901,12 @@ class ExtractionOrchestrator:
# 确保列表长度匹配 # 确保列表长度匹配
expected_length = len(dialog_data_list) expected_length = len(dialog_data_list)
if ( if (
len(temporal_maps) != expected_length len(temporal_maps) != expected_length
or len(triplet_maps) != expected_length or len(triplet_maps) != expected_length
or len(emotion_maps) != expected_length or len(emotion_maps) != expected_length
or len(statement_embedding_maps) != expected_length or len(statement_embedding_maps) != expected_length
or len(chunk_embedding_maps) != expected_length or len(chunk_embedding_maps) != expected_length
or len(dialog_embeddings) != expected_length or len(dialog_embeddings) != expected_length
): ):
logger.warning( logger.warning(
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
@@ -999,7 +994,7 @@ class ExtractionOrchestrator:
return dialog_data_list return dialog_data_list
async def _create_nodes_and_edges( async def _create_nodes_and_edges(
self, dialog_data_list: List[DialogData] self, dialog_data_list: List[DialogData]
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
@@ -1007,7 +1002,7 @@ class ExtractionOrchestrator:
List[ExtractedEntityNode], List[ExtractedEntityNode],
List[StatementChunkEdge], List[StatementChunkEdge],
List[StatementEntityEdge], List[StatementEntityEdge],
List[EntityEntityEdge], List[EntityEntityEdge]
]: ]:
""" """
创建图数据库节点和边 创建图数据库节点和边
@@ -1083,15 +1078,19 @@ class ExtractionOrchestrator:
name=f"Statement_{statement.id}", # 添加必需的 name 字段 name=f"Statement_{statement.id}", # 添加必需的 name 字段
chunk_id=chunk.id, chunk_id=chunk.id,
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
# 添加必需的 connect_strength 字段
end_user_id=dialog_data.end_user_id, end_user_id=dialog_data.end_user_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement, statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding, statement_embedding=statement.statement_embedding,
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, valid_at=statement.temporal_validity.valid_at if hasattr(statement,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, 'temporal_validity') and statement.temporal_validity else None,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
'temporal_validity') and statement.temporal_validity else None,
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
@@ -1141,7 +1140,8 @@ class ExtractionOrchestrator:
example=getattr(entity, 'example', ''), # 新增:传递示例字段 example=getattr(entity, 'example', ''), # 新增:传递示例字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
# 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None), name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
@@ -1254,19 +1254,24 @@ class ExtractionOrchestrator:
) )
async def _run_dedup_and_write_summary( async def _run_dedup_and_write_summary(
self, self,
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
) -> Tuple[ ) -> tuple[
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], list[DialogueNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[ChunkNode],
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], list[StatementNode],
list[ExtractedEntityNode],
list[StatementChunkEdge],
list[StatementEntityEdge],
list[EntityEntityEdge],
dict
]: ]:
""" """
执行两阶段去重并写入汇总 执行两阶段去重并写入汇总
@@ -1415,7 +1420,6 @@ class ExtractionOrchestrator:
len(entity_entity_edges), len(final_entity_entity_edges) len(entity_entity_edges), len(final_entity_entity_edges)
) )
# 写入提取结果汇总(试运行和正式模式都需要生成) # 写入提取结果汇总(试运行和正式模式都需要生成)
try: try:
from app.core.config import settings from app.core.config import settings
@@ -1436,10 +1440,10 @@ class ExtractionOrchestrator:
raise raise
def _save_dedup_details( def _save_dedup_details(
self, self,
dedup_details: Dict[str, Any], dedup_details: Dict[str, Any],
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
): ):
""" """
保存去重消歧的详细记录到实例变量(基于内存数据结构) 保存去重消歧的详细记录到实例变量(基于内存数据结构)
@@ -1537,15 +1541,16 @@ class ExtractionOrchestrator:
except Exception as e: except Exception as e:
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") logger.info(
f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
except Exception as e: except Exception as e:
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
async def _analyze_entity_merges( async def _analyze_entity_merges(
self, self,
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
@@ -1585,9 +1590,9 @@ class ExtractionOrchestrator:
return [] return []
async def _analyze_entity_disambiguation( async def _analyze_entity_disambiguation(
self, self,
original_entities: List[ExtractedEntityNode], original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode] final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
@@ -1645,9 +1650,9 @@ class ExtractionOrchestrator:
return type_mapping.get(entity_type, f"{entity_type}实体节点") return type_mapping.get(entity_type, f"{entity_type}实体节点")
async def _output_relationship_creation_results( async def _output_relationship_creation_results(
self, self,
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
entity_nodes: List[ExtractedEntityNode] entity_nodes: List[ExtractedEntityNode]
): ):
""" """
输出关系创建结果 输出关系创建结果
@@ -1681,13 +1686,13 @@ class ExtractionOrchestrator:
logger.error(f"输出关系创建结果失败: {e}", exc_info=True) logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
async def _send_dedup_progress_callback( async def _send_dedup_progress_callback(
self, self,
original_entities: int, original_entities: int,
final_entities: int, final_entities: int,
original_stmt_edges: int, original_stmt_edges: int,
final_stmt_edges: int, final_stmt_edges: int,
original_ent_edges: int, original_ent_edges: int,
final_ent_edges: int, final_ent_edges: int,
): ):
""" """
发送去重消歧完成的进度回调,传递具体的去重和消歧效果 发送去重消歧完成的进度回调,传递具体的去重和消歧效果
@@ -1715,7 +1720,8 @@ class ExtractionOrchestrator:
"original_count": original_entities, "original_count": original_entities,
"final_count": final_entities, "final_count": final_entities,
"reduced_count": entities_reduced, "reduced_count": entities_reduced,
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, "reduction_rate": round(entities_reduced / original_entities * 100,
1) if original_entities > 0 else 0,
}, },
"statement_entity_edges": { "statement_entity_edges": {
"original_count": original_stmt_edges, "original_count": original_stmt_edges,
@@ -1790,7 +1796,8 @@ class ExtractionOrchestrator:
disamb_examples.append({ disamb_examples.append({
"entity1_name": entity_name, "entity1_name": entity_name,
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
"").strip() if "vs" in disamb_type else "未知",
"entity2_name": entity_name, "entity2_name": entity_name,
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
"description": f"{entity_name},消歧区分成功" "description": f"{entity_name},消歧区分成功"
@@ -1815,9 +1822,9 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1", end_user_id: str = "group_1",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从测试数据生成分块对话 """从测试数据生成分块对话
@@ -1924,10 +1931,10 @@ async def get_chunked_dialogs(
def preprocess_data( def preprocess_data(
input_path: Optional[str] = None, input_path: Optional[str] = None,
output_path: Optional[str] = None, output_path: Optional[str] = None,
skip_cleaning: bool = True, skip_cleaning: bool = True,
indices: Optional[List[int]] = None indices: Optional[List[int]] = None
) -> List[DialogData]: ) -> List[DialogData]:
"""数据预处理 """数据预处理
@@ -1946,7 +1953,8 @@ def preprocess_data(
) )
preprocessor = DataPreprocessor() preprocessor = DataPreprocessor()
try: try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
skip_cleaning=skip_cleaning, indices=indices)
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
return cleaned_data return cleaned_data
except Exception as e: except Exception as e:
@@ -1955,9 +1963,9 @@ def preprocess_data(
async def get_chunked_dialogs_from_preprocessed( async def get_chunked_dialogs_from_preprocessed(
data: List[DialogData], data: List[DialogData],
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
llm_client: Optional[Any] = None, llm_client: Optional[Any] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从预处理后的数据中生成分块 """从预处理后的数据中生成分块
@@ -1988,15 +1996,15 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing( async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "default", end_user_id: str = "default",
user_id: str = "default", user_id: str = "default",
apply_id: str = "default", apply_id: str = "default",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
input_data_path: Optional[str] = None, input_data_path: Optional[str] = None,
llm_client: Optional[Any] = None, llm_client: Optional[Any] = None,
skip_cleaning: bool = True, skip_cleaning: bool = True,
pruning_config: Optional[Dict] = None, pruning_config: Optional[Dict] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""包含数据预处理步骤的完整分块流程 """包含数据预处理步骤的完整分块流程
@@ -2046,7 +2054,8 @@ async def get_chunked_dialogs_with_preprocessing(
if pruning_config: if pruning_config:
# 使用传入的配置 # 使用传入的配置
config = PruningConfig(**pruning_config) config = PruningConfig(**pruning_config)
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") logger.debug(
f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
else: else:
# 使用默认配置(关闭剪枝) # 使用默认配置(关闭剪枝)
config = None config = None

View File

@@ -188,7 +188,6 @@ async def _process_chunk_summary(
response_model=MemorySummaryResponse, response_model=MemorySummaryResponse,
) )
summary_text = structured.summary.strip() summary_text = structured.summary.strip()
# Generate title and type for the summary # Generate title and type for the summary
title = None title = None
episodic_type = None episodic_type = None

View File

@@ -374,7 +374,9 @@ class VariablePool:
self.variables = deepcopy(pool.variables) self.variables = deepcopy(pool.variables)
def is_file_variable(self, selector): def is_file_variable(self, selector):
variable_struct = self._get_variable_struct(selector) variable_struct = self.get_instance(selector, default=None, strict=False)
if variable_struct is None:
return False
if isinstance(variable_struct, FileVariable): if isinstance(variable_struct, FileVariable):
return True return True
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:

View File

@@ -623,7 +623,6 @@ class BaseNode(ABC):
async def process_message( async def process_message(
api_config: ModelInfo, api_config: ModelInfo,
content: str | dict | FileObject, content: str | dict | FileObject,
end_user_id: str,
enable_file=False enable_file=False
) -> list | str | None: ) -> list | str | None:
provider = api_config.provider provider = api_config.provider
@@ -642,8 +641,8 @@ class BaseNode(ABC):
return content return content
elif isinstance(content, FileObject): elif isinstance(content, FileObject):
if content.content_cache.get(provider): if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"):
return content.content_cache[provider] return content.content_cache[f"{provider}_{ModelInfo.is_omni}"]
with get_db_read() as db: with get_db_read() as db:
multimodel_service = MultimodalService(db, api_config=api_config) multimodel_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput( file_obj = FileInput(
@@ -655,12 +654,11 @@ class BaseNode(ABC):
) )
file_obj.set_content(content.get_content()) file_obj.set_content(content.get_content())
message = await multimodel_service.process_files( message = await multimodel_service.process_files(
end_user_id,
[file_obj], [file_obj],
) )
content.set_content(file_obj.get_content()) content.set_content(file_obj.get_content())
if message: if message:
content.content_cache[provider] = message content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message
return message return message
return None return None
raise TypeError(f'Unexpect input value type - {type(content)}') raise TypeError(f'Unexpect input value type - {type(content)}')

View File

@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages messages_config = self.typed_config.messages
if messages_config: if messages_config:
# 使用 LangChain 消息格式 # 使用 LangChain 消息格式
messages = [] messages = []
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
content_template = msg_config.content content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool) content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool) content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append({ messages.append({
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
"content": await self.process_message( "content": await self.process_message(
model_info, model_info,
content, content,
user_id,
self.typed_config.vision, self.typed_config.vision,
) )
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision) "content": await self.process_message(model_info, content, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) content = await self.process_message(model_info, file.value, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list): if isinstance(message["content"], list):
file_content = [] file_content = []
for file in message["content"]: for file in message["content"]:
content = await self.process_message(model_info, file, user_id, self.typed_config.vision) content = await self.process_message(model_info, file, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
history_message.append( history_message.append(
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
message["content"] = await self.process_message( message["content"] = await self.process_message(
model_info, model_info,
message["content"], message["content"],
user_id,
self.typed_config.vision self.typed_config.vision
) )
history_message.append(message) history_message.append(message)

View File

@@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode):
write_message_task.delay( write_message_task.delay(
end_user_id=end_user_id, end_user_id=end_user_id,
message=messages, message=messages,
file_messages=multimodal_memories,
config_id=str(self.typed_config.config_id), config_id=str(self.typed_config.config_id),
storage_type=state["memory_storage_type"], storage_type=state["memory_storage_type"],
user_rag_memory_id=state["user_rag_memory_id"] user_rag_memory_id=state["user_rag_memory_id"]

View File

@@ -30,6 +30,9 @@ class MemoryConfig(Base):
llm_id = Column(String, nullable=True, comment="LLM模型配置ID") llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
video_id = Column(String, nullable=True, comment="视频模型配置ID")
# 记忆萃取引擎配置 # 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")

View File

@@ -9,21 +9,22 @@ Classes:
""" """
import uuid import uuid
from uuid import UUID
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from uuid import UUID
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_config_logger, get_db_logger from app.core.logging_config import get_config_logger, get_db_logger
from app.models.memory_config_model import MemoryConfig from app.models.memory_config_model import MemoryConfig
from app.models.workspace_model import Workspace
from app.schemas.memory_storage_schema import ( from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate, ConfigParamsCreate,
ConfigUpdate, ConfigUpdate,
ConfigUpdateExtracted, ConfigUpdateExtracted,
ConfigUpdateForget, ConfigUpdateForget,
) )
from sqlalchemy import desc, select
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
# 获取数据库专用日志器 # 获取数据库专用日志器
@@ -157,7 +158,7 @@ class MemoryConfigRepository:
return memory_config_obj return memory_config_obj
@staticmethod @staticmethod
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数) """构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args: Args:
@@ -491,7 +492,10 @@ class MemoryConfigRepository:
raise raise
@staticmethod @staticmethod
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: def get_config_with_workspace(
db: Session,
config_id: uuid.UUID | int | str
) -> Optional[tuple[MemoryConfig, Workspace]]:
"""Get memory config and its associated workspace information """Get memory config and its associated workspace information
Args: Args:
@@ -506,8 +510,6 @@ class MemoryConfigRepository:
""" """
import time import time
from app.models.workspace_model import Workspace
start_time = time.time() start_time = time.time()
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -594,7 +596,7 @@ class MemoryConfigRepository:
db_logger.debug( db_logger.debug(
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace) return config, workspace
except ValueError: except ValueError:
# Re-raise known business exceptions # Re-raise known business exceptions
@@ -739,9 +741,9 @@ class MemoryConfigRepository:
@staticmethod @staticmethod
def get_with_fallback( def get_with_fallback(
db: Session, db: Session,
config_id: Optional[uuid.UUID], config_id: Optional[uuid.UUID],
workspace_id: uuid.UUID workspace_id: uuid.UUID
) -> Optional[MemoryConfig]: ) -> Optional[MemoryConfig]:
"""获取记忆配置,支持回退到工作空间默认配置 """获取记忆配置,支持回退到工作空间默认配置
@@ -771,4 +773,3 @@ class MemoryConfigRepository:
) )
return MemoryConfigRepository.get_workspace_default(db, workspace_id) return MemoryConfigRepository.get_workspace_default(db, workspace_id)

View File

@@ -1,7 +1,8 @@
from typing import List, Optional from typing import List, Optional
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
print(f"All end_user_id: {end_user_id} node and edge deleted successfully") print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database. """Add dialogue nodes to Neo4j database.
@@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
print(f"Error creating statement nodes: {e}") print(f"Error creating statement nodes: {e}")
return None return None
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch. """Add chunk nodes to Neo4j in batch.
@@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
return None return None
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: List[str]]:
"""Add memory summary nodes to Neo4j in batch. """Add memory summary nodes to Neo4j in batch.
Args: Args:
@@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
return None return None
async def add_perceptual_nodes(
perceptuals: list,
connector: Neo4jConnector,
embedder_client=None,
) -> Optional[List[str]]:
"""Add perceptual memory nodes to Neo4j in batch.
Args:
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
connector: Neo4j connector instance
embedder_client: Optional embedder client for generating summary embeddings
Returns:
List of created node UUIDs or None if failed
"""
if not perceptuals:
print("No perceptual nodes to add")
return []
try:
flattened = []
for p in perceptuals:
meta = p.meta_data or {}
content_meta = meta.get("content", {})
# 生成 summary embedding如果有 embedder_client
summary_embedding = None
if embedder_client and p.summary:
try:
summary_embedding = (await embedder_client.response([p.summary]))[0]
except Exception as emb_err:
print(f"Failed to embed perceptual summary: {emb_err}")
flattened.append({
"id": str(p.id),
"end_user_id": str(p.end_user_id),
"perceptual_type": p.perceptual_type,
"file_path": p.file_path or "",
"file_name": p.file_name or "",
"file_ext": p.file_ext or "",
"summary": p.summary or "",
"keywords": content_meta.get("keywords", []),
"topic": content_meta.get("topic", ""),
"domain": content_meta.get("domain", ""),
"created_at": p.created_time.isoformat() if p.created_time else None,
"summary_embedding": summary_embedding,
})
result = await connector.execute_query(
PERCEPTUAL_NODE_SAVE,
perceptuals=flattened,
)
created_uuids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
return created_uuids
except Exception as e:
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
return None
async def add_perceptual_dialogue_edges(
perceptuals: list,
dialog_id: str,
connector: Neo4jConnector,
) -> Optional[List[str]]:
"""Add edges between Perceptual nodes and Dialogue nodes.
Args:
perceptuals: List of MemoryPerceptualModel objects
dialog_id: The dialogue ID (or ref_id) to link to
connector: Neo4j connector instance
Returns:
List of created edge element IDs or None if failed
"""
if not perceptuals or not dialog_id:
return []
try:
edges = []
for p in perceptuals:
edges.append({
"perceptual_id": str(p.id),
"dialog_id": dialog_id,
"end_user_id": str(p.end_user_id),
"created_at": p.created_time.isoformat() if p.created_time else None,
})
result = await connector.execute_query(
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
edges=edges,
)
created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
return created_ids
except Exception as e:
print(f"Failed to save Perceptual-Dialogue edges: {e}")
return None

View File

@@ -1323,3 +1323,36 @@ RETURN s.statement AS statement,
ORDER BY COALESCE(s.activation_value, 0) DESC ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit LIMIT $limit
""" """
# 感知记忆节点保存
PERCEPTUAL_NODE_SAVE = """
UNWIND $perceptuals AS p
MERGE (n:Perceptual {id: p.id})
SET n += {
id: p.id,
end_user_id: p.end_user_id,
perceptual_type: p.perceptual_type,
file_path: p.file_path,
file_name: p.file_name,
file_ext: p.file_ext,
summary: p.summary,
keywords: p.keywords,
topic: p.topic,
domain: p.domain,
created_at: p.created_at,
summary_embedding: p.summary_embedding
}
RETURN n.id AS uuid
"""
# 感知记忆与对话的关联边
PERCEPTUAL_DIALOGUE_EDGE_SAVE = """
UNWIND $edges AS edge
MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id})
MATCH (d:Dialogue {end_user_id: edge.end_user_id})
WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id
MERGE (d)-[r:HAS_PERCEPTUAL]->(p)
SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at
RETURN elementId(r) AS uuid
"""

View File

@@ -387,6 +387,12 @@ class MemoryConfig:
rerank_model_id: Optional[UUID] = None rerank_model_id: Optional[UUID] = None
rerank_model_name: Optional[str] = None rerank_model_name: Optional[str] = None
video_model_id: Optional[UUID] = None
video_model_name: Optional[str] = None
vision_model_id: Optional[UUID] = None
vision_model_name: Optional[str] = None
audio_model_id: Optional[UUID] = None
audio_model_name: Optional[str] = None
llm_params: Dict[str, Any] = field(default_factory=dict) llm_params: Dict[str, Any] = field(default_factory=dict)
embedding_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict)

View File

@@ -141,7 +141,7 @@ class AppChatService:
model_type=ModelType.LLM model_type=ModelType.LLM
) )
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态 # 调用 Agent支持多模态
@@ -339,7 +339,7 @@ class AppChatService:
model_type=ModelType.LLM model_type=ModelType.LLM
) )
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件") logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态同时并行启动 TTS # 流式调用 Agent支持多模态同时并行启动 TTS

View File

@@ -600,7 +600,7 @@ class AgentRunService:
) )
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索
@@ -836,7 +836,7 @@ class AgentRunService:
) )
provider = api_key_config.get("provider", "openai") provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info) multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files) processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}") logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索 # 7. 知识库检索

View File

@@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional
from uuid import UUID from uuid import UUID
import redis import redis
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.cache import InterestMemoryCache
from app.core.config import settings from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import ( from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results, merge_multiple_search_results,
reorder_output_results, reorder_output_results,
) )
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas import FileInput
from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
write_rag, write_rag,
) )
from app.services.memory_perceptual_service import MemoryPerceptualService
try: try:
from app.core.memory.utils.log.audit_logger import audit_logger from app.core.memory.utils.log.audit_logger import audit_logger
@@ -271,6 +274,7 @@ class MemoryAgentService:
self, self,
end_user_id: str, end_user_id: str,
messages: list[dict], messages: list[dict],
file_messages: list[dict],
config_id: Optional[uuid.UUID] | int, config_id: Optional[uuid.UUID] | int,
db: Session, db: Session,
storage_type: str, storage_type: str,
@@ -283,6 +287,7 @@ class MemoryAgentService:
Args: Args:
end_user_id: Group identifier (also used as end_user_id) end_user_id: Group identifier (also used as end_user_id)
messages: Message to write messages: Message to write
files: Files to write
config_id: Configuration ID from database config_id: Configuration ID from database
db: SQLAlchemy database session db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag) storage_type: Storage type (neo4j or rag)
@@ -342,48 +347,52 @@ class MemoryAgentService:
raise ValueError(error_msg) raise ValueError(error_msg)
perceptual_serivce = MemoryPerceptualService(db)
file_content = []
for message in file_messages:
for file in message["files"]:
file_object = await perceptual_serivce.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file)
)
file_content.append(file_object)
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try: try:
if storage_type == "rag": if storage_type == "rag":
# For RAG storage, convert messages to single string # For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
await write_rag(end_user_id, message_text, user_rag_memory_id) await write_rag(end_user_id, message_text, user_rag_memory_id)
return "success" return "success"
else: else:
async with make_write_graph() as graph: await write_neo4j(
config = {"configurable": {"thread_id": end_user_id}} end_user_id=end_user_id,
# Convert structured messages to LangChain messages messages=messages,
langchain_messages = [] file_content=file_content,
for msg in messages: memory_config=memory_config,
if msg['role'] == 'user': ref_id='',
langchain_messages.append(HumanMessage(content=msg['content'])) language=language
elif msg['role'] == 'assistant': )
langchain_messages.append(AIMessage(content=msg['content'])) for lang in ["zh", "en"]:
print(100 * '-') deleted = await InterestMemoryCache.delete_interest_distribution(
print(langchain_messages) end_user_id, lang
print(100 * '-') )
# 初始状态 - 包含所有必要字段 if deleted:
initial_state = { logger.info(
"messages": langchain_messages, f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
"end_user_id": end_user_id, return self.writer_messages_deal(
"memory_config": memory_config, "success",
"language": language start_time,
} end_user_id,
config_id,
# 获取节点更新信息 message_text,
async for update_event in graph.astream( {
initial_state, "status": "success",
stream_mode="updates", "data": messages,
config=config "config_id": memory_config.config_id,
): "config_name": memory_config.config_name
for node_name, node_data in update_event.items(): }
if 'save_neo4j' == node_name: )
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"

View File

@@ -38,9 +38,9 @@ class MemoryAPIService:
self.db = db self.db = db
def validate_end_user( def validate_end_user(
self, self,
end_user_id: str, end_user_id: str,
workspace_id: uuid.UUID workspace_id: uuid.UUID
) -> EndUser: ) -> EndUser:
"""Validate that end_user exists and belongs to the workspace. """Validate that end_user exists and belongs to the workspace.
@@ -125,13 +125,14 @@ class MemoryAPIService:
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
async def write_memory( async def write_memory(
self, self,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
end_user_id: str, end_user_id: str,
message: str, message: str,
config_id: str, config_id: str,
storage_type: str = "neo4j", storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None, files: Optional[list]=None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Write memory with validation. """Write memory with validation.
@@ -153,6 +154,8 @@ class MemoryAPIService:
ResourceNotFoundException: If end_user not found ResourceNotFoundException: If end_user not found
BusinessException: If end_user not in authorized workspace or write fails BusinessException: If end_user not in authorized workspace or write fails
""" """
if files is None:
files = list()
logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}")
# Validate end_user exists and belongs to workspace # Validate end_user exists and belongs to workspace
@@ -171,7 +174,8 @@ class MemoryAPIService:
config_id=config_id, config_id=config_id,
db=self.db, db=self.db,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id or "" user_rag_memory_id=user_rag_memory_id or "",
files=files
) )
logger.info(f"Memory write successful for end_user: {end_user_id}") logger.info(f"Memory write successful for end_user: {end_user_id}")
@@ -206,14 +210,14 @@ class MemoryAPIService:
) )
async def read_memory( async def read_memory(
self, self,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
end_user_id: str, end_user_id: str,
message: str, message: str,
search_switch: str = "0", search_switch: str = "0",
config_id: str = "", config_id: str = "",
storage_type: str = "neo4j", storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Read memory with validation. """Read memory with validation.
@@ -244,7 +248,6 @@ class MemoryAPIService:
# Update end user's memory_config_id # Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id) self._update_end_user_config(end_user_id, config_id)
try: try:
# Delegate to MemoryAgentService # Delegate to MemoryAgentService
result = await MemoryAgentService().read_memory( result = await MemoryAgentService().read_memory(
@@ -282,8 +285,8 @@ class MemoryAPIService:
) )
def list_memory_configs( def list_memory_configs(
self, self,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""List all memory configs for a workspace. """List all memory configs for a workspace.

View File

@@ -154,10 +154,10 @@ class MemoryConfigService:
self.db = db self.db = db
def load_memory_config( def load_memory_config(
self, self,
config_id: Optional[UUID] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService", service_name: str = "MemoryConfigService",
) -> MemoryConfig: ) -> MemoryConfig:
""" """
Load memory configuration from database with optional fallback. Load memory configuration from database with optional fallback.
@@ -194,8 +194,8 @@ class MemoryConfigService:
try: try:
# Use get_config_with_fallback if workspace_id is provided # Use get_config_with_fallback if workspace_id is provided
memory_config = None memory_config = None
validated_config_id = None
if workspace_id: if workspace_id:
validated_config_id = None
if config_id: if config_id:
try: try:
validated_config_id = _validate_config_id(config_id, self.db) validated_config_id = _validate_config_id(config_id, self.db)
@@ -243,10 +243,10 @@ class MemoryConfigService:
# Helper function to validate model with workspace fallback # Helper function to validate model with workspace fallback
def _validate_model_with_fallback( def _validate_model_with_fallback(
model_id: str, model_id: str,
model_type: str, model_type: str,
workspace_default: str, workspace_default: str,
required: bool = False required: bool = False
) -> tuple: ) -> tuple:
"""Validate model ID, falling back to workspace default if invalid. """Validate model ID, falling back to workspace default if invalid.
@@ -343,6 +343,35 @@ class MemoryConfigService:
if memory_config.rerank_id or workspace.rerank: if memory_config.rerank_id or workspace.rerank:
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
vision_uuid, vision_name = validate_and_resolve_model_id(
memory_config.vision_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
audio_uuid, audio_name = validate_and_resolve_model_id(
memory_config.audio_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
video_uuid, video_name = validate_and_resolve_model_id(
memory_config.video_id,
"llm",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object # Create immutable MemoryConfig object
config = MemoryConfig( config = MemoryConfig(
config_id=memory_config.config_id, config_id=memory_config.config_id,
@@ -356,6 +385,12 @@ class MemoryConfigService:
embedding_model_name=embedding_name, embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid, rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name, rerank_model_name=rerank_name,
video_model_id=video_uuid,
video_model_name=video_name,
vision_model_id=vision_uuid,
vision_model_name=vision_name,
audio_model_id=audio_uuid,
audio_model_name=audio_name,
storage_type=workspace.storage_type or "neo4j", storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker", chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False, reflexion_enabled=memory_config.enable_self_reflexion or False,
@@ -364,24 +399,31 @@ class MemoryConfigService:
reflexion_baseline=memory_config.baseline or "Time", reflexion_baseline=memory_config.baseline or "Time",
loaded_at=datetime.now(), loaded_at=datetime.now(),
# Pipeline config: Deduplication # Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, enable_llm_dedup_blockwise=bool(
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(
memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True, deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8, t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8, t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8, t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction # Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, statement_granularity=int(
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000, include_dialogue_context=bool(
memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(
memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine # Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5, lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5, lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0, offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning # Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, pruning_enabled=bool(
memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education", pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, pruning_threshold=float(
memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association # Ontology scene association
scene_id=memory_config.scene_id, scene_id=memory_config.scene_id,
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id), ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
@@ -598,8 +640,8 @@ class MemoryConfigService:
return None return None
def get_workspace_default_config( def get_workspace_default_config(
self, self,
workspace_id: UUID workspace_id: UUID
) -> Optional["MemoryConfigModel"]: ) -> Optional["MemoryConfigModel"]:
"""Get workspace default memory config. """Get workspace default memory config.
@@ -623,9 +665,9 @@ class MemoryConfigService:
return config return config
def get_config_with_fallback( def get_config_with_fallback(
self, self,
memory_config_id: Optional[UUID], memory_config_id: Optional[UUID],
workspace_id: UUID workspace_id: UUID
) -> Optional["MemoryConfigModel"]: ) -> Optional["MemoryConfigModel"]:
"""Get memory config with fallback to workspace default. """Get memory config with fallback to workspace default.
@@ -663,9 +705,9 @@ class MemoryConfigService:
return config return config
def delete_config( def delete_config(
self, self,
config_id: UUID | int, config_id: UUID | int,
force: bool = False force: bool = False
) -> dict: ) -> dict:
"""Delete memory config with protection against in-use configs. """Delete memory config with protection against in-use configs.
@@ -800,9 +842,9 @@ class MemoryConfigService:
# ==================== 记忆配置提取方法 ==================== # ==================== 记忆配置提取方法 ====================
def extract_memory_config_id( def extract_memory_config_id(
self, self,
app_type: str, app_type: str,
config: dict config: dict
) -> tuple[Optional[uuid.UUID], bool]: ) -> tuple[Optional[uuid.UUID], bool]:
"""从发布配置中提取 memory_config_id根据应用类型分发 """从发布配置中提取 memory_config_id根据应用类型分发
@@ -828,8 +870,8 @@ class MemoryConfigService:
return None, False return None, False
def _extract_memory_config_id_from_agent( def _extract_memory_config_id_from_agent(
self, self,
config: dict config: dict
) -> tuple[Optional[uuid.UUID], bool]: ) -> tuple[Optional[uuid.UUID], bool]:
"""从 Agent 应用配置中提取 memory_config_id """从 Agent 应用配置中提取 memory_config_id
@@ -888,8 +930,8 @@ class MemoryConfigService:
return None, False return None, False
def _extract_memory_config_id_from_workflow( def _extract_memory_config_id_from_workflow(
self, self,
config: dict config: dict
) -> tuple[Optional[uuid.UUID], bool]: ) -> tuple[Optional[uuid.UUID], bool]:
"""从 Workflow 应用配置中提取 memory_config_id """从 Workflow 应用配置中提取 memory_config_id

View File

@@ -12,11 +12,12 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import FileMetadata from app.models import FileMetadata, ModelApiKey, ModelType
from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas import FileType from app.schemas import FileType, FileInput
from app.schemas.memory_config_schema import MemoryConfig
from app.schemas.memory_perceptual_schema import ( from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema, PerceptualQuerySchema,
PerceptualTimelineResponse, PerceptualTimelineResponse,
@@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import (
AudioModal, Content, VideoModal, TextModal AudioModal, Content, VideoModal, TextModal
) )
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService
business_logger = get_business_logger() business_logger = get_business_logger()
@@ -195,21 +198,58 @@ class MemoryPerceptualService:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
def _get_mutlimodal_client(
self,
file_type: FileType,
config: MemoryConfig
) -> tuple[RedBearLLM | None, ModelApiKey | None]:
model_config = None
if file_type == FileType.AUDIO:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.audio_model_id
)
elif file_type == FileType.VIDEO:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.video_model_id
)
elif file_type == FileType.DOCUMENT:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.llm_model_id
)
elif file_type == FileType.IMAGE:
model_config = ModelApiKeyService.get_available_api_key(
self.db,
config.vision_model_id
)
llm = None
if model_config:
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
)
)
return llm, model_config
async def generate_perceptual_memory( async def generate_perceptual_memory(
self, self,
end_user_id: str, end_user_id: str,
model_config: ModelInfo, memory_config: MemoryConfig,
file_type: str, file: FileInput
file_url: str,
file_message: dict,
): ):
memories = self.repository.get_by_url(file_url) memories = self.repository.get_by_url(file.url)
if memories: if memories:
business_logger.info(f"Perceptual memory already exists: {file_url}") business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]: if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0] memory_cache = memories[0]
self.repository.create_perceptual_memory( memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id), end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type), perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path, file_path=memory_cache.file_path,
@@ -219,20 +259,31 @@ class MemoryPerceptualService:
meta_data=memory_cache.meta_data meta_data=memory_cache.meta_data
) )
self.db.commit() self.db.commit()
return memory
return else:
llm = RedBearLLM(RedBearModelConfig( for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name, model_name=model_config.model_name,
provider=model_config.provider, provider=model_config.provider,
api_key=model_config.api_key, api_key=model_config.api_key,
base_url=model_config.api_base, api_base=model_config.api_base,
is_omni=model_config.is_omni is_omni=model_config.is_omni,
), type=model_config.model_type) capability=model_config.capability,
model_type=ModelType.LLM
))
file_message = await multimodel_service.process_files(
files=[file]
)
if file_message:
file_message = file_message[0]
try: try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read() opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
except FileNotFoundError: except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [ messages = [
@@ -242,8 +293,22 @@ class MemoryPerceptualService:
]} ]}
] ]
result = await llm.ainvoke(messages) result = await llm.ainvoke(messages)
content = json_repair.repair_json(result.content, return_objects=True) content = result.content
path = urlparse(file_url).path final_output = ""
if isinstance(content, list):
for msg in content:
if isinstance(msg, dict):
final_output += msg.get("text", "")
elif isinstance(msg, str):
final_output += msg
elif isinstance(content, dict):
final_output += content.get("text", "")
elif isinstance(content, str):
final_output = content
else:
raise ValueError(f"Unexcept Model Output Type: {result.content}")
content = json_repair.repair_json(final_output, return_objects=True)
path = urlparse(file.url).path
filename = os.path.basename(path) filename = os.path.basename(path)
filename = unquote(filename) filename = unquote(filename)
file_ext = os.path.splitext(filename)[1] file_ext = os.path.splitext(filename)[1]
@@ -260,13 +325,13 @@ class MemoryPerceptualService:
except ValueError: except ValueError:
business_logger.debug(f"Remote file, file_id={filename}") business_logger.debug(f"Remote file, file_id={filename}")
if not file_ext: if not file_ext:
if file_type == FileType.AUDIO: if file.type == FileType.AUDIO:
file_ext = ".mp3" file_ext = ".mp3"
elif file_type == FileType.VIDEO: elif file.type == FileType.VIDEO:
file_ext = ".mp4" file_ext = ".mp4"
elif file_type == FileType.DOCUMENT: elif file.type == FileType.DOCUMENT:
file_ext = ".txt" file_ext = ".txt"
elif file_type == FileType.IMAGE: elif file.type == FileType.IMAGE:
file_ext = ".jpg" file_ext = ".jpg"
filename += file_ext filename += file_ext
file_content = { file_content = {
@@ -274,11 +339,11 @@ class MemoryPerceptualService:
"topic": content.get("topic"), "topic": content.get("topic"),
"domain": content.get("domain") "domain": content.get("domain")
} }
if file_type in [FileType.IMAGE, FileType.VIDEO]: if file.type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = { file_modalities = {
"scene": content.get("scene", []) "scene": content.get("scene", [])
} }
elif file_type in [FileType.DOCUMENT]: elif file.type in [FileType.DOCUMENT]:
file_modalities = { file_modalities = {
"section_count": content.get("section_count", 0), "section_count": content.get("section_count", 0),
"title": content.get("title", ""), "title": content.get("title", ""),
@@ -288,10 +353,10 @@ class MemoryPerceptualService:
file_modalities = { file_modalities = {
"speaker_count": content.get("speaker_count", 0) "speaker_count": content.get("speaker_count", 0)
} }
self.repository.create_perceptual_memory( memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id), end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType.trans_from_file_type(file_type), perceptual_type=PerceptualType.trans_from_file_type(file.type),
file_path=file_url, file_path=file.url,
file_name=filename, file_name=filename,
file_ext=file_ext, file_ext=file_ext,
summary=content.get('summary', ""), summary=content.get('summary', ""),
@@ -301,3 +366,4 @@ class MemoryPerceptualService:
} }
) )
self.db.commit() self.db.commit()
return memory

View File

@@ -9,14 +9,12 @@
- OpenAI: 支持 URL 和 base64 格式 - OpenAI: 支持 URL 和 base64 格式
""" """
import base64 import base64
import csv
import io import io
import uuid import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import csv
import json
import PyPDF2 import PyPDF2
import httpx import httpx
import magic import magic
@@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService from app.services.audio_transcription_service import AudioTranscriptionService
from app.tasks import write_perceptual_memory
logger = get_business_logger() logger = get_business_logger()
@@ -342,15 +339,12 @@ class MultimodalService:
async def process_files( async def process_files(
self, self,
end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]], files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
处理文件列表,返回 LLM 可用的格式 处理文件列表,返回 LLM 可用的格式
Args: Args:
end_user_id: 用户ID
files: 文件输入列表 files: 文件输入列表
Returns: Returns:
@@ -358,8 +352,6 @@ class MultimodalService:
""" """
if not files: if not files:
return [] return []
if isinstance(end_user_id, uuid.UUID):
end_user_id = str(end_user_id)
# 获取对应的策略 # 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式 # dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -380,23 +372,15 @@ class MultimodalService:
if file.type == FileType.IMAGE and "vision" in self.capability: if file.type == FileType.IMAGE and "vision" in self.capability:
is_support, content = await self._process_image(file, strategy) is_support, content = await self._process_image(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT: elif file.type == FileType.DOCUMENT:
is_support, content = await self._process_document(file, strategy) is_support, content = await self._process_document(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.AUDIO and "audio" in self.capability: elif file.type == FileType.AUDIO and "audio" in self.capability:
is_support, content = await self._process_audio(file, strategy) is_support, content = await self._process_audio(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.VIDEO and "video" in self.capability: elif file.type == FileType.VIDEO and "video" in self.capability:
is_support, content = await self._process_video(file, strategy) is_support, content = await self._process_video(file, strategy)
result.append(content) result.append(content)
if is_support:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else: else:
logger.warning(f"不支持的文件类型: {file.type}") logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e: except Exception as e:
@@ -418,17 +402,6 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}") logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result return result
def write_perceptual_memory(
self,
end_user_id: str,
file_type: str,
file_url: str,
file_message: dict
):
"""写入感知记忆"""
if end_user_id and self.api_config:
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
""" """
处理图片文件 处理图片文件

View File

@@ -1080,12 +1080,14 @@ def write_message_task(
config_id: str | int, config_id: str | int,
storage_type: str, storage_type: str,
user_rag_memory_id: str, user_rag_memory_id: str,
file_messages: list[dict] | None,
language: str = "zh" language: str = "zh"
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService. """Celery task to process a write message via MemoryAgentService.
Args: Args:
end_user_id: Group ID for the memory agent (also used as end_user_id) end_user_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write message: Message to write
file_messages: Files to write
config_id: Configuration ID (can be UUID string, integer, or config_id_old) config_id: Configuration ID (can be UUID string, integer, or config_id_old)
storage_type: Storage type (neo4j or rag) storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID user_rag_memory_id: User RAG memory ID
@@ -1097,6 +1099,8 @@ def write_message_task(
Raises: Raises:
Exception on failure Exception on failure
""" """
if file_messages is None:
file_messages = []
logger.info( logger.info(
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
@@ -1142,7 +1146,7 @@ def write_message_task(
f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService() service = MemoryAgentService()
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type,
user_rag_memory_id, language) user_rag_memory_id, language)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}") logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result return result