From 2ff81ba101b422e600982dcc96fc673bc4684223 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 23 Mar 2026 16:33:25 +0800 Subject: [PATCH] feat(memory): support perception-aware memory writing in workflow and Neo4j nodes --- .../core/memory/agent/utils/write_tools.py | 65 ++- api/app/core/memory/models/graph_models.py | 18 + .../deduplication/two_stage_dedup.py | 24 +- .../extraction_orchestrator.py | 465 +++++++++--------- .../knowledge_extraction/memory_summary.py | 1 - api/app/core/workflow/engine/variable_pool.py | 4 +- api/app/core/workflow/nodes/base_node.py | 8 +- api/app/core/workflow/nodes/llm/node.py | 14 +- api/app/core/workflow/nodes/memory/node.py | 1 + api/app/models/memory_config_model.py | 3 + .../repositories/memory_config_repository.py | 57 +-- api/app/repositories/neo4j/add_nodes.py | 111 ++++- api/app/repositories/neo4j/cypher_queries.py | 33 ++ api/app/schemas/memory_config_schema.py | 6 + api/app/services/app_chat_service.py | 4 +- api/app/services/draft_run_service.py | 4 +- api/app/services/memory_agent_service.py | 85 ++-- api/app/services/memory_api_service.py | 85 ++-- api/app/services/memory_config_service.py | 194 +++++--- api/app/services/memory_perceptual_service.py | 120 ++++- api/app/services/multimodal_service.py | 31 +- api/app/tasks.py | 6 +- 22 files changed, 820 insertions(+), 519 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b62eb50a..147a0316 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -5,6 +5,7 @@ This module provides the main write function for executing the knowledge extract pipeline. Only MemoryConfig is needed - clients are constructed internally. """ import asyncio +import uuid import time from datetime import datetime @@ -13,28 +14,31 @@ from dotenv import load_dotenv from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \ + memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context +from app.models import MemoryPerceptualModel from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges -from app.repositories.neo4j.add_nodes import add_memory_summary_nodes +from app.repositories.neo4j.add_nodes import add_memory_summary_nodes, add_perceptual_nodes, \ + add_perceptual_dialogue_edges from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig - load_dotenv() logger = get_agent_logger(__name__) async def write( - end_user_id: str, - memory_config: MemoryConfig, - messages: list, - ref_id: str = "wyl20251027", - language: str = "zh", + end_user_id: str, + memory_config: MemoryConfig, + messages: list, + file_content: list[MemoryPerceptualModel], + ref_id: str = "", + language: str = "zh", ) -> None: """ Execute the complete knowledge extraction pipeline. @@ -43,9 +47,12 @@ async def write( end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] - ref_id: Reference ID, defaults to "wyl20251027" + file_content: mutilmodal message list + ref_id: Reference ID, defaults to "" language: 语言类型 ("zh" 中文, "en" 英文),默认中文 """ + if not ref_id: + ref_id = uuid.uuid4().hex # Extract config values embedding_model_id = str(memory_config.embedding_model_id) chunker_strategy = memory_config.chunker_strategy @@ -99,14 +106,14 @@ async def write( if memory_config.scene_id: try: from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene - + with get_db_context() as db: ontology_types = load_ontology_types_for_scene( scene_id=memory_config.scene_id, workspace_id=memory_config.workspace_id, db=db ) - + if ontology_types: logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" @@ -173,7 +180,8 @@ async def write( schedule_clustering_after_write( all_entity_nodes, llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + embedding_model_id=str( + memory_config.embedding_model_id) if memory_config.embedding_model_id else None, ) break else: @@ -208,9 +216,8 @@ async def write( summaries = await memory_summary_generation( chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language ) - + ms_connector = Neo4jConnector() try: - ms_connector = Neo4jConnector() await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector) finally: @@ -223,6 +230,34 @@ async def write( finally: log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file) + # Step 5: Save perceptual memory to Neo4j + step_start = time.time() + if file_content: + try: + pc_connector = Neo4jConnector() + try: + created_ids = await add_perceptual_nodes( + perceptuals=file_content, + connector=pc_connector, + embedder_client=embedder_client, + ) + # 如果有 ref_id,建立感知记忆与对话的关联 + if ref_id and created_ids: + await add_perceptual_dialogue_edges( + perceptuals=file_content, + dialog_id=ref_id, + connector=pc_connector, + ) + logger.info(f"Successfully saved {len(created_ids or [])} perceptual memory nodes to Neo4j") + finally: + try: + await pc_connector.close() + except Exception: + pass + except Exception as e: + logger.error(f"Perceptual memory Neo4j save failed: {e}", exc_info=True) + log_time("Perceptual Memory (Neo4j)", time.time() - step_start, log_file) + # Log total pipeline time total_time = time.time() - pipeline_start log_time("TOTAL PIPELINE TIME", total_time, log_file) @@ -251,4 +286,4 @@ async def write( logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file + logger.info(f"Total execution time: {total_time:.2f} seconds") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 5a2d8c2e..fb251f1f 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -553,3 +553,21 @@ class MemorySummaryNode(Node): ge=0, description="Total number of times this node has been accessed (reset to 1 on creation)" ) + + +class MutlimodalNode(Node): + """Node representing a multimodal message in the knowledge graph. + + Attributes: + dialog_id: ID of the parent dialog + message_id: ID of the message + metadata: Additional message metadata + embedding: Optional embedding vector for the message + """ + dialog_id: str = Field(..., description="ID of the parent dialog") + message_id: str = Field(..., description="ID of the message") + summary: str = Field(..., description="The text content of the message") + file_type: str = Field(..., description="Type of the message (e.g., 'text', 'image', 'audio', 'video')") + file_path: List[str] = Field(..., description="List of file paths for multimodal content") + metadata: dict = Field(default_factory=dict, description="Additional message metadata") + embedding: Optional[List[float]] = Field(None, description="Embedding vector for the message") diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index f28b8a5f..4b9c5718 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def dedup_layers_and_merge_and_return( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - pipeline_config: ExtractionPipelineConfig, - connector: Optional[Neo4jConnector] = None, - llm_client = None, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + pipeline_config: ExtractionPipelineConfig, + connector: Optional[Neo4jConnector] = None, + llm_client=None, ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return( List[StatementChunkEdge], List[StatementEntityEdge], List[EntityEntityEdge], - dict, # 新增:返回去重详情 + dict ]: """ 执行两层实体去重与融合: diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 00d06f72..6e94a84f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -31,11 +31,10 @@ from app.core.memory.models.graph_models import ( ExtractedEntityNode, StatementChunkEdge, StatementEntityEdge, - StatementNode, + StatementNode ) from app.core.memory.models.message_models import DialogData from app.core.memory.models.ontology_extraction_models import OntologyTypeList -from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, ) @@ -46,7 +45,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb embedding_generation, generate_entity_embeddings_from_triplets, ) - # 导入各个提取模块 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( StatementExtractor, @@ -90,16 +88,16 @@ class ExtractionOrchestrator: """ def __init__( - self, - llm_client: LLMClient, - embedder_client: OpenAIEmbedderClient, - connector: Neo4jConnector, - config: Optional[ExtractionPipelineConfig] = None, - progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, - embedding_id: Optional[str] = None, - ontology_types: Optional[OntologyTypeList] = None, - enable_general_types: bool = True, - language: str = "zh", + self, + llm_client: LLMClient, + embedder_client: OpenAIEmbedderClient, + connector: Neo4jConnector, + config: Optional[ExtractionPipelineConfig] = None, + progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None, + embedding_id: Optional[str] = None, + ontology_types: Optional[OntologyTypeList] = None, + enable_general_types: bool = True, + language: str = "zh", ): """ 初始化流水线编排器 @@ -123,7 +121,7 @@ class ExtractionOrchestrator: self.progress_callback = progress_callback # 保存进度回调函数 self.embedding_id = embedding_id # 保存嵌入模型ID self.language = language # 保存语言配置 - + # 处理本体类型配置 # 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并 # 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合 @@ -146,7 +144,7 @@ class ExtractionOrchestrator: self.ontology_types = ontology_types if not enable_general_types and ontology_types: logger.info("enable_general_types=False,仅使用场景类型") - + # 保存去重消歧的详细记录(内存中的数据结构) self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录 self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录 @@ -157,19 +155,25 @@ class ExtractionOrchestrator: llm_client=llm_client, config=self.config.statement_extraction, ) - self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language) + self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types, + language=language) self.temporal_extractor = TemporalExtractor(llm_client=llm_client) logger.info("ExtractionOrchestrator 初始化完成") async def run( - self, - dialog_data_list: List[DialogData], - is_pilot_run: bool = False, - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialog_data_list: List[DialogData], + is_pilot_run: bool = False, + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + dict ]: """ 运行完整的知识提取流水线(优化版:并行执行) @@ -202,13 +206,12 @@ class ExtractionOrchestrator: # 步骤 1: 陈述句提取 logger.info("步骤 1/6: 陈述句提取(全局分块级并行)") dialog_data_list = await self._extract_statements(dialog_data_list) - + # 收集陈述句内容和统计数量 all_statements_list = [] for dialog in dialog_data_list: for chunk in dialog.chunks: all_statements_list.extend(chunk.statements) - len(all_statements_list) # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") @@ -220,7 +223,7 @@ class ExtractionOrchestrator: chunk_embedding_maps, dialog_embeddings, ) = await self._parallel_extract_and_embed(dialog_data_list) - + # 收集实体和三元组内容,并统计数量 all_entities_list = [] all_triplets_list = [] @@ -229,10 +232,6 @@ class ExtractionOrchestrator: if triplet_info: all_entities_list.extend(triplet_info.entities) all_triplets_list.extend(triplet_info.triplets) - - len(all_entities_list) - len(all_triplets_list) - sum(len(temporal_map) for temporal_map in temporal_maps) # 步骤 3: 生成实体嵌入(依赖三元组提取结果) logger.info("步骤 3/6: 生成实体嵌入") @@ -252,9 +251,9 @@ class ExtractionOrchestrator: # 步骤 5: 创建节点和边 logger.info("步骤 5/6: 创建节点和边") - + # 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送 - + ( dialogue_nodes, chunk_nodes, @@ -273,9 +272,9 @@ class ExtractionOrchestrator: logger.info("步骤 6/6: 去重和消歧(试运行模式:仅第一层去重)") else: logger.info("步骤 6/6: 两阶段去重和消歧") - + # 注意:deduplication 消息已在创建节点和边完成后立即发送 - + result = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -287,8 +286,6 @@ class ExtractionOrchestrator: dialog_data_list, ) - - logger.info(f"知识提取流水线运行完成({mode_str})") return result @@ -297,7 +294,7 @@ class ExtractionOrchestrator: raise async def _extract_statements( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[DialogData]: """ 从对话中提取陈述句(流式输出版本:边提取边发送进度) @@ -313,7 +310,7 @@ class ExtractionOrchestrator: # 收集所有分块及其元数据 all_chunks = [] chunk_metadata = [] # (dialog_idx, chunk_idx) - + for d_idx, dialog in enumerate(dialog_data_list): dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None for c_idx, chunk in enumerate(dialog.chunks): @@ -321,7 +318,7 @@ class ExtractionOrchestrator: chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") - + # 用于跟踪已完成的分块数量 completed_chunks = 0 total_chunks = len(all_chunks) @@ -332,7 +329,7 @@ class ExtractionOrchestrator: chunk, end_user_id, dialogue_content = chunk_data try: statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content) - + # 流式输出:每提取完一个分块的陈述句,立即发送进度 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 completed_chunks += 1 @@ -347,11 +344,11 @@ class ExtractionOrchestrator: "statement_index_in_chunk": idx + 1 } await self.progress_callback( - "knowledge_extraction_result", - f"陈述句提取中 ({completed_chunks}/{total_chunks})", + "knowledge_extraction_result", + f"陈述句提取中 ({completed_chunks}/{total_chunks})", stmt_result ) - + return statements except Exception as e: logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}") @@ -381,7 +378,7 @@ class ExtractionOrchestrator: # 保存陈述句到文件(试运行和正式模式都需要) self.statement_extractor.save_statements(all_statements) - + logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句") # 试运行模式下,所有分块提取完成后发送完成事件 @@ -395,7 +392,7 @@ class ExtractionOrchestrator: return dialog_data_list async def _extract_triplets( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取三元组(流式输出版本:边提取边发送进度) @@ -411,7 +408,7 @@ class ExtractionOrchestrator: # 收集所有陈述句及其元数据 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, chunk_content) - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -419,7 +416,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组") - + # 用于跟踪已完成的陈述句数量 completed_statements = 0 len(all_statements) @@ -430,11 +427,11 @@ class ExtractionOrchestrator: statement, chunk_content = stmt_data try: triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content) - + # 注意:不再发送三元组提取的流式输出 # 三元组提取在后台执行,但不向前端发送详细信息 completed_statements += 1 - + return triplet_info except Exception as e: logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}") @@ -450,7 +447,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 triplet_maps = [{} for _ in dialog_data_list] all_responses = [] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -478,7 +475,7 @@ class ExtractionOrchestrator: return triplet_maps async def _extract_temporal( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取时间信息(流式输出版本:边提取边发送进度) @@ -502,13 +499,13 @@ class ExtractionOrchestrator: temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None) temporal_maps.append(temporal_map) return temporal_maps - + logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)") # 收集所有需要提取时间的陈述句 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id, ref_dates) - + for d_idx, dialog in enumerate(dialog_data_list): # 获取参考日期 ref_dates = {} @@ -517,11 +514,11 @@ class ExtractionOrchestrator: ref_dates['conversation_date'] = dialog.metadata['conversation_date'] if 'publication_date' in dialog.metadata: ref_dates['publication_date'] = dialog.metadata['publication_date'] - + if not ref_dates: from datetime import datetime ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")} - + for chunk in dialog.chunks: for statement in chunk.statements: # 跳过 ATEMPORAL 类型的陈述句 @@ -531,7 +528,7 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取") - + # 用于跟踪已完成的时间提取数量 completed_temporal = 0 len(all_statements) @@ -542,11 +539,11 @@ class ExtractionOrchestrator: statement, ref_dates = stmt_data try: temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) - + # 注意:不再发送时间提取的流式输出 # 时间提取在后台执行,但不向前端发送详细信息 completed_temporal += 1 - + return temporal_range except Exception as e: logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}") @@ -559,7 +556,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 temporal_maps = [{} for _ in dialog_data_list] - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -585,7 +582,7 @@ class ExtractionOrchestrator: return temporal_maps async def _extract_emotions( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) @@ -601,36 +598,36 @@ class ExtractionOrchestrator: # 收集所有陈述句及其配置 all_statements = [] statement_metadata = [] # (dialog_idx, statement_id) - + # 获取第一个对话的config_id来加载配置 config_id = None if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): config_id = dialog_data_list[0].config_id - + # 加载MemoryConfig memory_config = None if config_id: try: from app.db import SessionLocal from app.repositories.memory_config_repository import MemoryConfigRepository - + db = SessionLocal() try: memory_config = MemoryConfigRepository.get_by_id(db, config_id) finally: db.close() - + if memory_config and not memory_config.emotion_enabled: logger.info("情绪提取已在配置中禁用,跳过情绪提取") return [{} for _ in dialog_data_list] - + except Exception as e: logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取") return [{} for _ in dialog_data_list] else: logger.info("未找到config_id,跳过情绪提取") return [{} for _ in dialog_data_list] - + # 如果配置未启用情绪提取,直接返回空映射 if not memory_config or not memory_config.emotion_enabled: logger.info("情绪提取未启用,跳过") @@ -639,7 +636,7 @@ class ExtractionOrchestrator: # 收集所有陈述句(只收集 speaker 为 "user" 的) total_statements = 0 filtered_statements = 0 - + for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: @@ -655,12 +652,12 @@ class ExtractionOrchestrator: # 初始化情绪提取服务 # 如果 emotion_model_id 为空,回退到工作空间默认 LLM from app.services.emotion_extraction_service import EmotionExtractionService - + emotion_model_id = memory_config.emotion_model_id if not emotion_model_id and memory_config.workspace_id: from app.repositories.workspace_repository import get_workspace_models_configs from app.db import SessionLocal - + db = SessionLocal() try: workspace_models = get_workspace_models_configs(db, memory_config.workspace_id) @@ -669,7 +666,7 @@ class ExtractionOrchestrator: logger.info(f"emotion_model_id 为空,使用工作空间默认 LLM: {emotion_model_id}") finally: db.close() - + emotion_service = EmotionExtractionService( llm_id=emotion_model_id if emotion_model_id else None ) @@ -689,7 +686,7 @@ class ExtractionOrchestrator: # 将结果组织成对话级别的映射 emotion_maps = [{} for _ in dialog_data_list] successful_extractions = 0 - + for i, result in enumerate(results): d_idx, stmt_id = statement_metadata[i] if isinstance(result, Exception): @@ -706,7 +703,7 @@ class ExtractionOrchestrator: return emotion_maps async def _parallel_extract_and_embed( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[Dict[str, Any]], List[Dict[str, Any]], @@ -757,7 +754,7 @@ class ExtractionOrchestrator: triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - + if isinstance(results[3], Exception): logger.error(f"基础嵌入生成失败: {results[3]}") statement_embedding_maps = [{} for _ in dialog_data_list] @@ -777,7 +774,7 @@ class ExtractionOrchestrator: ) async def _generate_basic_embeddings( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: """ 生成基础嵌入向量(陈述句、分块、对话) @@ -810,7 +807,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") raise ValueError("embedding_id is required but was not provided") - + # 只生成陈述句、分块和对话的嵌入(不包括实体) statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation( dialog_data_list, self.embedding_id @@ -836,7 +833,7 @@ class ExtractionOrchestrator: ) async def _generate_entity_embeddings( - self, triplet_maps: List[Dict[str, Any]] + self, triplet_maps: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ 生成实体嵌入向量 @@ -861,7 +858,7 @@ class ExtractionOrchestrator: if not self.embedding_id: logger.error("embedding_id is required but was not provided to ExtractionOrchestrator") return triplet_maps - + # 生成实体嵌入 updated_triplet_maps = await generate_entity_embeddings_from_triplets( triplet_maps, self.embedding_id @@ -874,17 +871,15 @@ class ExtractionOrchestrator: logger.error(f"实体嵌入生成失败: {e}", exc_info=True) return triplet_maps - - async def _assign_extracted_data( - self, - dialog_data_list: List[DialogData], - temporal_maps: List[Dict[str, Any]], - triplet_maps: List[Dict[str, Any]], - emotion_maps: List[Dict[str, Any]], - statement_embedding_maps: List[Dict[str, List[float]]], - chunk_embedding_maps: List[Dict[str, List[float]]], - dialog_embeddings: List[List[float]], + self, + dialog_data_list: List[DialogData], + temporal_maps: List[Dict[str, Any]], + triplet_maps: List[Dict[str, Any]], + emotion_maps: List[Dict[str, Any]], + statement_embedding_maps: List[Dict[str, List[float]]], + chunk_embedding_maps: List[Dict[str, List[float]]], + dialog_embeddings: List[List[float]], ) -> List[DialogData]: """ 将提取的数据赋值到语句 @@ -906,12 +901,12 @@ class ExtractionOrchestrator: # 确保列表长度匹配 expected_length = len(dialog_data_list) if ( - len(temporal_maps) != expected_length - or len(triplet_maps) != expected_length - or len(emotion_maps) != expected_length - or len(statement_embedding_maps) != expected_length - or len(chunk_embedding_maps) != expected_length - or len(dialog_embeddings) != expected_length + len(temporal_maps) != expected_length + or len(triplet_maps) != expected_length + or len(emotion_maps) != expected_length + or len(statement_embedding_maps) != expected_length + or len(chunk_embedding_maps) != expected_length + or len(dialog_embeddings) != expected_length ): logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " @@ -999,7 +994,7 @@ class ExtractionOrchestrator: return dialog_data_list async def _create_nodes_and_edges( - self, dialog_data_list: List[DialogData] + self, dialog_data_list: List[DialogData] ) -> Tuple[ List[DialogueNode], List[ChunkNode], @@ -1007,7 +1002,7 @@ class ExtractionOrchestrator: List[ExtractedEntityNode], List[StatementChunkEdge], List[StatementEntityEdge], - List[EntityEntityEdge], + List[EntityEntityEdge] ]: """ 创建图数据库节点和边 @@ -1021,7 +1016,7 @@ class ExtractionOrchestrator: 包含所有节点和边的元组 """ logger.info("开始创建节点和边") - + # 注意:开始消息已在 run 方法中发送,这里不再重复发送 dialogue_nodes = [] @@ -1034,7 +1029,7 @@ class ExtractionOrchestrator: # 用于去重的集合 entity_id_set = set() - + # 用于跟踪进度 total_dialogs = len(dialog_data_list) processed_dialogs = 0 @@ -1083,15 +1078,19 @@ class ExtractionOrchestrator: name=f"Statement_{statement.id}", # 添加必需的 name 字段 chunk_id=chunk.id, stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 - temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 - connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), + # 添加必需的 temporal_info 字段 + connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 statement_embedding=statement.statement_embedding, - valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, - invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, + valid_at=statement.temporal_validity.valid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, + invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, + 'temporal_validity') and statement.temporal_validity else None, created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, @@ -1120,7 +1119,7 @@ class ExtractionOrchestrator: # 创建实体索引到ID的映射(支持多种索引方式) entity_idx_to_id = {} - + # 创建实体节点 for entity_idx, entity in enumerate(triplet_info.entities): # 映射实体索引到实体ID(使用多个键以提高容错性) @@ -1128,7 +1127,7 @@ class ExtractionOrchestrator: entity_idx_to_id[entity.entity_idx] = entity.id # 2. 使用枚举索引(从0开始) entity_idx_to_id[entity_idx] = entity.id - + if entity.id not in entity_id_set: entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') entity_node = ExtractedEntityNode( @@ -1141,7 +1140,8 @@ class ExtractionOrchestrator: example=getattr(entity, 'example', ''), # 新增:传递示例字段 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 - connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 + connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', + # 添加必需的 connect_strength 字段 aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 @@ -1171,7 +1171,7 @@ class ExtractionOrchestrator: # 将三元组中的整数索引映射到实体ID subject_entity_id = entity_idx_to_id.get(triplet.subject_id) object_entity_id = entity_idx_to_id.get(triplet.object_id) - + # 只有当两个实体ID都存在时才创建边 if subject_entity_id and object_entity_id: entity_entity_edge = EntityEntityEdge( @@ -1186,7 +1186,7 @@ class ExtractionOrchestrator: expired_at=dialog_data.expired_at, ) entity_entity_edges.append(entity_entity_edge) - + # 流式输出:每创建一个关系边,立即发送进度(限制发送数量) if self.progress_callback and len(entity_entity_edges) <= 10: # 获取实体名称 @@ -1202,8 +1202,8 @@ class ExtractionOrchestrator: "dialog_progress": f"{processed_dialogs}/{total_dialogs}" } await self.progress_callback( - "creating_nodes_edges_result", - f"关系创建中 ({processed_dialogs}/{total_dialogs})", + "creating_nodes_edges_result", + f"关系创建中 ({processed_dialogs}/{total_dialogs})", relationship_result ) else: @@ -1211,7 +1211,7 @@ class ExtractionOrchestrator: missing_subject = "subject" if not subject_entity_id else "" missing_object = "object" if not object_entity_id else "" missing_both = " and " if (not subject_entity_id and not object_entity_id) else "" - + logger.debug( f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: " f"subject_id={triplet.subject_id} ({triplet.subject_name}), " @@ -1228,7 +1228,7 @@ class ExtractionOrchestrator: f"陈述句-实体边: {len(statement_entity_edges)}, " f"实体-实体边: {len(entity_entity_edges)}" ) - + # 进度回调:创建节点和边完成,传递结果统计 # 注意:具体的关系创建结果已经在创建过程中实时发送了 if self.progress_callback: @@ -1254,19 +1254,24 @@ class ExtractionOrchestrator: ) async def _run_dedup_and_write_summary( - self, - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], + self, + dialogue_nodes: List[DialogueNode], + chunk_nodes: List[ChunkNode], + statement_nodes: List[StatementNode], + entity_nodes: List[ExtractedEntityNode], + statement_chunk_edges: List[StatementChunkEdge], + statement_entity_edges: List[StatementEntityEdge], + entity_entity_edges: List[EntityEntityEdge], + dialog_data_list: List[DialogData], + ) -> tuple[ + list[DialogueNode], + list[ChunkNode], + list[StatementNode], + list[ExtractedEntityNode], + list[StatementChunkEdge], + list[StatementEntityEdge], + list[EntityEntityEdge], + dict ]: """ 执行两阶段去重并写入汇总 @@ -1288,11 +1293,11 @@ class ExtractionOrchestrator: - 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) """ logger.info("开始两阶段实体去重和消歧") - + # 进度回调:发送去重消歧开始消息 if self.progress_callback: await self.progress_callback("deduplication", "正在去重消歧...") - + logger.info( f"去重前: {len(entity_nodes)} 个实体节点, " f"{len(statement_entity_edges)} 条陈述句-实体边, " @@ -1307,7 +1312,7 @@ class ExtractionOrchestrator: from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( deduplicate_entities_and_edges, ) - + dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges( entity_nodes, statement_entity_edges, @@ -1317,10 +1322,10 @@ class ExtractionOrchestrator: dedup_config=self.config.deduplication, llm_client=self.llm_client, ) - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes) - + result_tuple = ( dialogue_nodes, chunk_nodes, @@ -1330,7 +1335,7 @@ class ExtractionOrchestrator: dedup_statement_entity_edges, dedup_entity_entity_edges, ) - + final_entity_nodes = dedup_entity_nodes final_statement_entity_edges = dedup_statement_entity_edges final_entity_entity_edges = dedup_entity_entity_edges @@ -1361,7 +1366,7 @@ class ExtractionOrchestrator: final_entity_entity_edges, dedup_details, ) = result_tuple - + # 保存去重消歧的详细记录到实例变量 self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) @@ -1375,12 +1380,12 @@ class ExtractionOrchestrator: f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, " f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" ) - + # 流式输出:实时输出去重消歧的具体结果 if self.progress_callback: # 分析实体合并情况(使用内存中的记录) merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes) - + # 逐个输出去重合并的实体示例 for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果 dedup_result = { @@ -1391,10 +1396,10 @@ class ExtractionOrchestrator: "message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并" } await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result) - + # 分析实体消歧情况(使用内存中的记录) disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes) - + # 逐个输出实体消歧的结果 for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果 disamb_result = { @@ -1407,14 +1412,13 @@ class ExtractionOrchestrator: "message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}" } await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result) - + # 进度回调:去重消歧完成,传递去重和消歧的具体效果 await self._send_dedup_progress_callback( len(entity_nodes), len(final_entity_nodes), len(statement_entity_edges), len(final_statement_entity_edges), len(entity_entity_edges), len(final_entity_entity_edges) ) - # 写入提取结果汇总(试运行和正式模式都需要生成) try: @@ -1436,10 +1440,10 @@ class ExtractionOrchestrator: raise def _save_dedup_details( - self, - dedup_details: Dict[str, Any], - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + dedup_details: Dict[str, Any], + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ): """ 保存去重消歧的详细记录到实例变量(基于内存数据结构) @@ -1452,7 +1456,7 @@ class ExtractionOrchestrator: try: # 保存ID重定向映射 self.id_redirect_map = dedup_details.get("id_redirect", {}) - + # 处理精确匹配的合并记录 exact_merge_map = dedup_details.get("exact_merge_map", {}) for key, info in exact_merge_map.items(): @@ -1466,7 +1470,7 @@ class ExtractionOrchestrator: "merged_count": len(merged_ids), "merged_ids": list(merged_ids) }) - + # 处理模糊匹配的合并记录 fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", []) for record in fuzzy_merge_records: @@ -1486,7 +1490,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}") - + # 处理LLM去重的合并记录 llm_decision_records = dedup_details.get("llm_decision_records", []) for record in llm_decision_records: @@ -1505,7 +1509,7 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}") - + # 处理消歧记录 disamb_records = dedup_details.get("disamb_records", []) for record in disamb_records: @@ -1520,14 +1524,14 @@ class ExtractionOrchestrator: entity1_type = match.group(2) match.group(3).strip() entity2_type = match.group(4) - + # 提取置信度和原因 conf_match = re.search(r"conf=([0-9.]+)", str(record)) confidence = conf_match.group(1) if conf_match else "unknown" - + reason_match = re.search(r"reason=([^|]+)", str(record)) reason = reason_match.group(1).strip() if reason_match else "" - + self.dedup_disamb_records.append({ "entity_name": entity1_name, "disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}", @@ -1536,16 +1540,17 @@ class ExtractionOrchestrator: }) except Exception as e: logger.debug(f"解析消歧记录失败: {record}, 错误: {e}") - - logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") - + + logger.info( + f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录") + except Exception as e: logger.error(f"保存去重消歧详情失败: {e}", exc_info=True) async def _analyze_entity_merges( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件) @@ -1566,28 +1571,28 @@ class ExtractionOrchestrator: key=lambda x: x.get("merged_count", 0), reverse=True ) - + merge_info = [] for record in sorted_records: merge_info.append({ "main_entity_name": record.get("entity_name", "未知实体"), "merged_count": record.get("merged_count", 1) }) - + return merge_info - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体合并记录") return [] - + except Exception as e: logger.error(f"分析实体合并情况失败: {e}", exc_info=True) return [] async def _analyze_entity_disambiguation( - self, - original_entities: List[ExtractedEntityNode], - final_entities: List[ExtractedEntityNode] + self, + original_entities: List[ExtractedEntityNode], + final_entities: List[ExtractedEntityNode] ) -> List[Dict[str, Any]]: """ 分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件) @@ -1603,11 +1608,11 @@ class ExtractionOrchestrator: # 直接使用保存的消歧记录 if self.dedup_disamb_records: return self.dedup_disamb_records - + # 如果没有保存的记录,返回空列表 logger.info("未找到实体消歧记录") return [] - + except Exception as e: logger.error(f"分析实体消歧情况失败: {e}", exc_info=True) return [] @@ -1624,7 +1629,7 @@ class ExtractionOrchestrator: """ type_mapping = { "Person": "人物实体节点", - "Organization": "组织实体节点", + "Organization": "组织实体节点", "ORG": "组织实体节点", "Location": "地点实体节点", "LOC": "地点实体节点", @@ -1645,9 +1650,9 @@ class ExtractionOrchestrator: return type_mapping.get(entity_type, f"{entity_type}实体节点") async def _output_relationship_creation_results( - self, - entity_entity_edges: List[EntityEntityEdge], - entity_nodes: List[ExtractedEntityNode] + self, + entity_entity_edges: List[EntityEntityEdge], + entity_nodes: List[ExtractedEntityNode] ): """ 输出关系创建结果 @@ -1659,13 +1664,13 @@ class ExtractionOrchestrator: try: # 创建实体ID到名称的映射 entity_id_to_name = {node.id: node.name for node in entity_nodes} - + # 输出关系创建结果 for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系 source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}") target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}") relation_type = edge.relation_type - + relationship_result = { "result_type": "relationship_creation", "relationship_index": i + 1, @@ -1674,20 +1679,20 @@ class ExtractionOrchestrator: "target_entity": target_name, "relationship_text": f"{source_name} -[{relation_type}]-> {target_name}" } - + await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result) - + except Exception as e: logger.error(f"输出关系创建结果失败: {e}", exc_info=True) async def _send_dedup_progress_callback( - self, - original_entities: int, - final_entities: int, - original_stmt_edges: int, - final_stmt_edges: int, - original_ent_edges: int, - final_ent_edges: int, + self, + original_entities: int, + final_entities: int, + original_stmt_edges: int, + final_stmt_edges: int, + original_ent_edges: int, + final_ent_edges: int, ): """ 发送去重消歧完成的进度回调,传递具体的去重和消歧效果 @@ -1703,19 +1708,20 @@ class ExtractionOrchestrator: try: # 解析去重消歧报告文件,获取具体的去重和消歧效果 dedup_details = await self._parse_dedup_report() - + # 计算去重效果统计 entities_reduced = original_entities - final_entities stmt_edges_reduced = original_stmt_edges - final_stmt_edges ent_edges_reduced = original_ent_edges - final_ent_edges - + # 构建进度回调数据 dedup_stats = { "entities": { "original_count": original_entities, "final_count": final_entities, "reduced_count": entities_reduced, - "reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0, + "reduction_rate": round(entities_reduced / original_entities * 100, + 1) if original_entities > 0 else 0, }, "statement_entity_edges": { "original_count": original_stmt_edges, @@ -1734,9 +1740,9 @@ class ExtractionOrchestrator: "total_disambiguations": dedup_details.get("total_disambiguations", 0), } } - + await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats) - + except Exception as e: logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True) # 即使解析失败,也发送基本的统计信息 @@ -1766,12 +1772,12 @@ class ExtractionOrchestrator: disamb_examples = [] total_merges = 0 total_disambiguations = 0 - + # 处理合并记录 for record in self.dedup_merge_records: merge_count = record.get("merged_count", 0) total_merges += merge_count - + dedup_examples.append({ "type": record.get("type", "未知"), "entity_name": record.get("entity_name", "未知实体"), @@ -1779,30 +1785,31 @@ class ExtractionOrchestrator: "merge_count": merge_count, "description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个" }) - + # 处理消歧记录 for record in self.dedup_disamb_records: total_disambiguations += 1 - + # 从消歧类型中提取实体类型信息 disamb_type = record.get("disamb_type", "") entity_name = record.get("entity_name", "未知实体") - + disamb_examples.append({ "entity1_name": entity_name, - "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知", + "entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", + "").strip() if "vs" in disamb_type else "未知", "entity2_name": entity_name, "entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知", "description": f"{entity_name},消歧区分成功" }) - + return { "dedup_examples": dedup_examples[:5], # 只返回前5个示例 "disamb_examples": disamb_examples[:5], # 只返回前5个示例 "total_merges": total_merges, "total_disambiguations": total_disambiguations, } - + except Exception as e: logger.error(f"获取去重报告失败: {e}", exc_info=True) return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0} @@ -1815,9 +1822,9 @@ class ExtractionOrchestrator: async def get_chunked_dialogs( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "group_1", - indices: Optional[List[int]] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "group_1", + indices: Optional[List[int]] = None, ) -> List[DialogData]: """从测试数据生成分块对话 @@ -1831,7 +1838,7 @@ async def get_chunked_dialogs( """ import json import re - + # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") with open(testdata_path, "r", encoding="utf-8") as f: @@ -1845,7 +1852,7 @@ async def get_chunked_dialogs( else: # 默认使用所有数据 selected_data = test_data - + for data in selected_data: # 解析对话上下文 context_text = data["context"] @@ -1861,7 +1868,7 @@ async def get_chunked_dialogs( if m: y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - + dialog_metadata: Dict[str, Any] = {} if conv_date: dialog_metadata["conversation_date"] = conv_date @@ -1890,7 +1897,7 @@ async def get_chunked_dialogs( end_user_id=end_user_id, metadata=dialog_metadata, ) - + # 创建分块器并处理对话 from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, @@ -1913,7 +1920,7 @@ async def get_chunked_dialogs( from app.core.config import settings settings.ensure_memory_output_dir() output_path = settings.get_memory_output_path("chunker_test_output.txt") - + import json with open(output_path, "w", encoding="utf-8") as f: json.dump( @@ -1924,10 +1931,10 @@ async def get_chunked_dialogs( def preprocess_data( - input_path: Optional[str] = None, - output_path: Optional[str] = None, - skip_cleaning: bool = True, - indices: Optional[List[int]] = None + input_path: Optional[str] = None, + output_path: Optional[str] = None, + skip_cleaning: bool = True, + indices: Optional[List[int]] = None ) -> List[DialogData]: """数据预处理 @@ -1946,7 +1953,8 @@ def preprocess_data( ) preprocessor = DataPreprocessor() try: - cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) + cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, + skip_cleaning=skip_cleaning, indices=indices) logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") return cleaned_data except Exception as e: @@ -1955,9 +1963,9 @@ def preprocess_data( async def get_chunked_dialogs_from_preprocessed( - data: List[DialogData], - chunker_strategy: str = "RecursiveChunker", - llm_client: Optional[Any] = None, + data: List[DialogData], + chunker_strategy: str = "RecursiveChunker", + llm_client: Optional[Any] = None, ) -> List[DialogData]: """从预处理后的数据中生成分块 @@ -1972,31 +1980,31 @@ async def get_chunked_dialogs_from_preprocessed( logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===") if not data: raise ValueError("预处理数据为空,无法进行分块") - + all_chunked_dialogs: List[DialogData] = [] from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( DialogueChunker, ) - + for dialog_data in data: chunker = DialogueChunker(chunker_strategy, llm_client=llm_client) chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = chunks all_chunked_dialogs.append(dialog_data) - + return all_chunked_dialogs async def get_chunked_dialogs_with_preprocessing( - chunker_strategy: str = "RecursiveChunker", - end_user_id: str = "default", - user_id: str = "default", - apply_id: str = "default", - indices: Optional[List[int]] = None, - input_data_path: Optional[str] = None, - llm_client: Optional[Any] = None, - skip_cleaning: bool = True, - pruning_config: Optional[Dict] = None, + chunker_strategy: str = "RecursiveChunker", + end_user_id: str = "default", + user_id: str = "default", + apply_id: str = "default", + indices: Optional[List[int]] = None, + input_data_path: Optional[str] = None, + llm_client: Optional[Any] = None, + skip_cleaning: bool = True, + pruning_config: Optional[Dict] = None, ) -> List[DialogData]: """包含数据预处理步骤的完整分块流程 @@ -2020,7 +2028,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path = os.path.join( os.path.dirname(__file__), "../../data", "testdata.json" ) - + # 步骤1: 数据预处理(包含索引筛选) from app.core.config import settings settings.ensure_memory_output_dir() @@ -2030,37 +2038,38 @@ async def get_chunked_dialogs_with_preprocessing( skip_cleaning=skip_cleaning, indices=indices, ) - + # 设置 end_user_id for dd in preprocessed_data: dd.end_user_id = end_user_id - + # 步骤2: 语义剪枝 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) from app.core.memory.models.config_models import PruningConfig - + # 构建剪枝配置 if pruning_config: # 使用传入的配置 config = PruningConfig(**pruning_config) - logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + logger.debug( + f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") else: # 使用默认配置(关闭剪枝) config = None logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") - + pruner = SemanticPruner(config=config, llm_client=llm_client) - + # 记录单对话场景下剪枝前的消息数量 single_dialog_original_msgs = None if len(preprocessed_data) == 1 and preprocessed_data[0].context: single_dialog_original_msgs = len(preprocessed_data[0].context.msgs) preprocessed_data = await pruner.prune_dataset(preprocessed_data) - + # 单对话:打印清洗与剪枝信息 if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None: remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0 @@ -2071,7 +2080,7 @@ async def get_chunked_dialogs_with_preprocessing( ) else: logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") - + # 保存剪枝后的数据 try: from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import ( @@ -2084,7 +2093,7 @@ async def get_chunked_dialogs_with_preprocessing( logger.error(f"保存剪枝结果失败:{se}") except Exception as e: logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}") - + # 步骤3: 对话分块 return await get_chunked_dialogs_from_preprocessed( preprocessed_data, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index 443ee36a..5e39ba36 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -188,7 +188,6 @@ async def _process_chunk_summary( response_model=MemorySummaryResponse, ) summary_text = structured.summary.strip() - # Generate title and type for the summary title = None episodic_type = None diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index d4e1b488..60f1257e 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -374,7 +374,9 @@ class VariablePool: self.variables = deepcopy(pool.variables) def is_file_variable(self, selector): - variable_struct = self._get_variable_struct(selector) + variable_struct = self.get_instance(selector, default=None, strict=False) + if variable_struct is None: + return False if isinstance(variable_struct, FileVariable): return True elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 0e3fecee..7f2b8aa6 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -623,7 +623,6 @@ class BaseNode(ABC): async def process_message( api_config: ModelInfo, content: str | dict | FileObject, - end_user_id: str, enable_file=False ) -> list | str | None: provider = api_config.provider @@ -642,8 +641,8 @@ class BaseNode(ABC): return content elif isinstance(content, FileObject): - if content.content_cache.get(provider): - return content.content_cache[provider] + if content.content_cache.get(f"{provider}_{ModelInfo.is_omni}"): + return content.content_cache[f"{provider}_{ModelInfo.is_omni}"] with get_db_read() as db: multimodel_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( @@ -655,12 +654,11 @@ class BaseNode(ABC): ) file_obj.set_content(content.get_content()) message = await multimodel_service.process_files( - end_user_id, [file_obj], ) content.set_content(file_obj.get_content()) if message: - content.content_cache[provider] = message + content.content_cache[f"{provider}_{ModelInfo.is_omni}"] = message return message return None raise TypeError(f'Unexpect input value type - {type(content)}') diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index b293d1f4..66a0f1ac 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -144,7 +144,6 @@ class LLMNode(BaseNode): f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") messages_config = self.typed_config.messages - if messages_config: # 使用 LangChain 消息格式 messages = [] @@ -153,7 +152,6 @@ class LLMNode(BaseNode): content_template = msg_config.content content_template = self._render_context(content_template, variable_pool) content = self._render_template(content_template, variable_pool) - user_id = self.get_variable("sys.user_id", variable_pool) # 根据角色创建对应的消息对象 if role == "system": messages.append({ @@ -161,32 +159,31 @@ class LLMNode(BaseNode): "content": await self.process_message( model_info, content, - user_id, self.typed_config.vision, ) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) + "content": await self.process_message(model_info, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file.value, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -200,7 +197,7 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(model_info, file, user_id, self.typed_config.vision) + content = await self.process_message(model_info, file, self.typed_config.vision) if content: file_content.extend(content) history_message.append( @@ -210,7 +207,6 @@ class LLMNode(BaseNode): message["content"] = await self.process_message( model_info, message["content"], - user_id, self.typed_config.vision ) history_message.append(message) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 82363056..cbdad0fa 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -116,6 +116,7 @@ class MemoryWriteNode(BaseNode): write_message_task.delay( end_user_id=end_user_id, message=messages, + file_messages=multimodal_memories, config_id=str(self.typed_config.config_id), storage_type=state["memory_storage_type"], user_rag_memory_id=state["user_rag_memory_id"] diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index 1095a386..616f7f3a 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -30,6 +30,9 @@ class MemoryConfig(Base): llm_id = Column(String, nullable=True, comment="LLM模型配置ID") embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") + vision_id = Column(String, nullable=True, comment="视觉模型配置ID") + audio_id = Column(String, nullable=True, comment="语音模型配置ID") + video_id = Column(String, nullable=True, comment="视频模型配置ID") # 记忆萃取引擎配置 enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 5c2f81a7..6fb41914 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -9,21 +9,22 @@ Classes: """ import uuid -from uuid import UUID from typing import Dict, List, Optional, Tuple +from uuid import UUID + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session + from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger from app.models.memory_config_model import MemoryConfig +from app.models.workspace_model import Workspace from app.schemas.memory_storage_schema import ( - ConfigKey, ConfigParamsCreate, ConfigUpdate, ConfigUpdateExtracted, ConfigUpdateForget, ) -from sqlalchemy import desc, select -from sqlalchemy.orm import Session - from app.utils.config_utils import resolve_config_id # 获取数据库专用日志器 @@ -157,7 +158,7 @@ class MemoryConfigRepository: return memory_config_obj @staticmethod - def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: + def query_reflection_config_by_id(db: Session, config_id: uuid.UUID | int | str) -> MemoryConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: @@ -491,7 +492,10 @@ class MemoryConfigRepository: raise @staticmethod - def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: + def get_config_with_workspace( + db: Session, + config_id: uuid.UUID | int | str + ) -> Optional[tuple[MemoryConfig, Workspace]]: """Get memory config and its associated workspace information Args: @@ -506,8 +510,6 @@ class MemoryConfigRepository: """ import time - from app.models.workspace_model import Workspace - start_time = time.time() config_id = resolve_config_id(config_id, db) @@ -594,7 +596,7 @@ class MemoryConfigRepository: db_logger.debug( f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") - return (config, workspace) + return config, workspace except ValueError: # Re-raise known business exceptions @@ -630,7 +632,7 @@ class MemoryConfigRepository: List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称) """ from app.models.ontology_scene import OntologyScene - + db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: @@ -694,7 +696,7 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 默认配置对象,不存在则返回None """ db_logger.debug(f"查询工作空间默认配置: workspace_id={workspace_id}") - + try: # 优先查找显式标记为默认的配置 stmt = ( @@ -706,13 +708,13 @@ class MemoryConfigRepository: ) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"找到默认配置: config_id={config.config_id}") return config - + # 回退:获取最早创建的活跃配置 stmt = ( select(MemoryConfig) @@ -723,25 +725,25 @@ class MemoryConfigRepository: .order_by(MemoryConfig.created_at.asc()) .limit(1) ) - + config = db.scalars(stmt).first() - + if config: db_logger.debug(f"使用最早创建的配置作为默认: config_id={config.config_id}") else: db_logger.warning(f"工作空间没有活跃的记忆配置: workspace_id={workspace_id}") - + return config - + except Exception as e: db_logger.error(f"查询工作空间默认配置失败: workspace_id={workspace_id} - {str(e)}") raise @staticmethod def get_with_fallback( - db: Session, - config_id: Optional[uuid.UUID], - workspace_id: uuid.UUID + db: Session, + config_id: Optional[uuid.UUID], + workspace_id: uuid.UUID ) -> Optional[MemoryConfig]: """获取记忆配置,支持回退到工作空间默认配置 @@ -756,19 +758,18 @@ class MemoryConfigRepository: Optional[MemoryConfig]: 配置对象,如果都不存在则返回None """ db_logger.debug(f"查询配置(支持回退): config_id={config_id}, workspace_id={workspace_id}") - + if not config_id: db_logger.debug("config_id 为空,使用工作空间默认配置") return MemoryConfigRepository.get_workspace_default(db, workspace_id) - + config = db.get(MemoryConfig, config_id) - + if config: return config - + db_logger.warning( f"配置不存在,回退到工作空间默认配置: missing_config_id={config_id}, workspace_id={workspace_id}" ) - - return MemoryConfigRepository.get_workspace_default(db, workspace_id) + return MemoryConfigRepository.get_workspace_default(db, workspace_id) diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 42c178b3..3a017089 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -1,7 +1,8 @@ from typing import List, Optional -from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode +from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ + MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -12,6 +13,7 @@ async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): print(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result + async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add dialogue nodes to Neo4j database. @@ -127,6 +129,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC print(f"Error creating statement nodes: {e}") return None + async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add chunk nodes to Neo4j in batch. @@ -179,8 +182,8 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> return None - -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: +async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[ + List[str]]: """Add memory summary nodes to Neo4j in batch. Args: @@ -211,7 +214,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector "summary_embedding": s.summary_embedding if s.summary_embedding else None, "config_id": s.config_id, # 添加 config_id }) - + result = await connector.execute_query( MEMORY_SUMMARY_NODE_SAVE, summaries=flattened @@ -224,3 +227,103 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector return None +async def add_perceptual_nodes( + perceptuals: list, + connector: Neo4jConnector, + embedder_client=None, +) -> Optional[List[str]]: + """Add perceptual memory nodes to Neo4j in batch. + + Args: + perceptuals: List of MemoryPerceptualModel objects from PostgreSQL + connector: Neo4j connector instance + embedder_client: Optional embedder client for generating summary embeddings + + Returns: + List of created node UUIDs or None if failed + """ + if not perceptuals: + print("No perceptual nodes to add") + return [] + + try: + flattened = [] + for p in perceptuals: + meta = p.meta_data or {} + content_meta = meta.get("content", {}) + + # 生成 summary embedding(如果有 embedder_client) + summary_embedding = None + if embedder_client and p.summary: + try: + summary_embedding = (await embedder_client.response([p.summary]))[0] + except Exception as emb_err: + print(f"Failed to embed perceptual summary: {emb_err}") + + flattened.append({ + "id": str(p.id), + "end_user_id": str(p.end_user_id), + "perceptual_type": p.perceptual_type, + "file_path": p.file_path or "", + "file_name": p.file_name or "", + "file_ext": p.file_ext or "", + "summary": p.summary or "", + "keywords": content_meta.get("keywords", []), + "topic": content_meta.get("topic", ""), + "domain": content_meta.get("domain", ""), + "created_at": p.created_time.isoformat() if p.created_time else None, + "summary_embedding": summary_embedding, + }) + + result = await connector.execute_query( + PERCEPTUAL_NODE_SAVE, + perceptuals=flattened, + ) + created_uuids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j") + return created_uuids + + except Exception as e: + print(f"Failed to save Perceptual nodes to Neo4j: {e}") + return None + + +async def add_perceptual_dialogue_edges( + perceptuals: list, + dialog_id: str, + connector: Neo4jConnector, +) -> Optional[List[str]]: + """Add edges between Perceptual nodes and Dialogue nodes. + + Args: + perceptuals: List of MemoryPerceptualModel objects + dialog_id: The dialogue ID (or ref_id) to link to + connector: Neo4j connector instance + + Returns: + List of created edge element IDs or None if failed + """ + if not perceptuals or not dialog_id: + return [] + + try: + edges = [] + for p in perceptuals: + edges.append({ + "perceptual_id": str(p.id), + "dialog_id": dialog_id, + "end_user_id": str(p.end_user_id), + "created_at": p.created_time.isoformat() if p.created_time else None, + }) + + result = await connector.execute_query( + PERCEPTUAL_DIALOGUE_EDGE_SAVE, + edges=edges, + ) + created_ids = [record.get("uuid") for record in result] + print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j") + return created_ids + + except Exception as e: + print(f"Failed to save Perceptual-Dialogue edges: {e}") + return None diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 0ac7dcb1..49dbe2a5 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1323,3 +1323,36 @@ RETURN s.statement AS statement, ORDER BY COALESCE(s.activation_value, 0) DESC LIMIT $limit """ + +# 感知记忆节点保存 +PERCEPTUAL_NODE_SAVE = """ +UNWIND $perceptuals AS p +MERGE (n:Perceptual {id: p.id}) +SET n += { + id: p.id, + end_user_id: p.end_user_id, + perceptual_type: p.perceptual_type, + file_path: p.file_path, + file_name: p.file_name, + file_ext: p.file_ext, + summary: p.summary, + keywords: p.keywords, + topic: p.topic, + domain: p.domain, + created_at: p.created_at, + summary_embedding: p.summary_embedding +} +RETURN n.id AS uuid +""" + +# 感知记忆与对话的关联边 +PERCEPTUAL_DIALOGUE_EDGE_SAVE = """ +UNWIND $edges AS edge +MATCH (p:Perceptual {id: edge.perceptual_id, end_user_id: edge.end_user_id}) +MATCH (d:Dialogue {end_user_id: edge.end_user_id}) +WHERE d.id = edge.dialog_id OR d.ref_id = edge.dialog_id +MERGE (d)-[r:HAS_PERCEPTUAL]->(p) +SET r.end_user_id = edge.end_user_id, + r.created_at = edge.created_at +RETURN elementId(r) AS uuid +""" diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 8d7490fe..e186e54b 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -387,6 +387,12 @@ class MemoryConfig: rerank_model_id: Optional[UUID] = None rerank_model_name: Optional[str] = None + video_model_id: Optional[UUID] = None + video_model_name: Optional[str] = None + vision_model_id: Optional[UUID] = None + vision_model_name: Optional[str] = None + audio_model_id: Optional[UUID] = None + audio_model_name: Optional[str] = None llm_params: Dict[str, Any] = field(default_factory=dict) embedding_params: Dict[str, Any] = field(default_factory=dict) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 604514b4..98f93408 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -141,7 +141,7 @@ class AppChatService: model_type=ModelType.LLM ) multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 调用 Agent(支持多模态) @@ -339,7 +339,7 @@ class AppChatService: model_type=ModelType.LLM ) multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") # 流式调用 Agent(支持多模态),同时并行启动 TTS diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index ba41d323..f7331851 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -600,7 +600,7 @@ class AgentRunService: ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -836,7 +836,7 @@ class AgentRunService: ) provider = api_key_config.get("provider", "openai") multimodal_service = MultimodalService(self.db, model_info) - processed_files = await multimodal_service.process_files(user_id, files) + processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 514cb12f..875f02bb 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -19,32 +19,35 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from uuid import UUID import redis -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session +from app.cache import InterestMemoryCache from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.messages_tools import ( merge_multiple_search_results, reorder_output_results, ) from app.core.memory.agent.utils.type_classifier import status_typle +from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas import FileInput from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from app.services.memory_perceptual_service import MemoryPerceptualService try: from app.core.memory.utils.log.audit_logger import audit_logger @@ -271,6 +274,7 @@ class MemoryAgentService: self, end_user_id: str, messages: list[dict], + file_messages: list[dict], config_id: Optional[uuid.UUID] | int, db: Session, storage_type: str, @@ -283,6 +287,7 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) messages: Message to write + files: Files to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -342,48 +347,52 @@ class MemoryAgentService: raise ValueError(error_msg) + perceptual_serivce = MemoryPerceptualService(db) + file_content = [] + for message in file_messages: + for file in message["files"]: + file_object = await perceptual_serivce.generate_perceptual_memory( + end_user_id=end_user_id, + memory_config=memory_config, + file=FileInput(**file) + ) + file_content.append(file_object) + + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) try: if storage_type == "rag": # For RAG storage, convert messages to single string - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) await write_rag(end_user_id, message_text, user_rag_memory_id) return "success" else: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # Convert structured messages to LangChain messages - langchain_messages = [] - for msg in messages: - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - print(100 * '-') - print(langchain_messages) - print(100 * '-') - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config, - "language": language - } - - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - # Convert messages back to string for logging - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, - contents) + await write_neo4j( + end_user_id=end_user_id, + messages=messages, + file_content=file_content, + memory_config=memory_config, + ref_id='', + language=language + ) + for lang in ["zh", "en"]: + deleted = await InterestMemoryCache.delete_interest_distribution( + end_user_id, lang + ) + if deleted: + logger.info( + f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") + return self.writer_messages_deal( + "success", + start_time, + end_user_id, + config_id, + message_text, + { + "status": "success", + "data": messages, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name + } + ) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 01bc6267..9a0fb8ed 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -28,7 +28,7 @@ class MemoryAPIService: 2. Maps end_user_id to end_user_id for memory operations 3. Delegates to MemoryAgentService for actual memory read/write operations """ - + def __init__(self, db: Session): """Initialize MemoryAPIService. @@ -36,11 +36,11 @@ class MemoryAPIService: db: SQLAlchemy database session """ self.db = db - + def validate_end_user( - self, - end_user_id: str, - workspace_id: uuid.UUID + self, + end_user_id: str, + workspace_id: uuid.UUID ) -> EndUser: """Validate that end_user exists and belongs to the workspace. @@ -56,7 +56,7 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace """ logger.info(f"Validating end_user: {end_user_id} for workspace: {workspace_id}") - + # Query end_user by ID try: end_user_uuid = uuid.UUID(end_user_id) @@ -66,7 +66,7 @@ class MemoryAPIService: message=f"Invalid end_user_id format: {end_user_id}", code=BizCode.INVALID_PARAMETER ) - + end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first() if not end_user: @@ -75,13 +75,13 @@ class MemoryAPIService: resource_type="EndUser", resource_id=end_user_id ) - + # Verify end_user belongs to the workspace via App relationship app = self.db.query(App).filter( App.id == end_user.app_id, App.is_active.is_(True) ).first() - + if not app: logger.warning(f"App not found for end_user: {end_user_id}") # raise ResourceNotFoundException( @@ -99,7 +99,7 @@ class MemoryAPIService: # message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}", # code=BizCode.FORBIDDEN # ) - + logger.info(f"End user {end_user_id} validated successfully") return end_user @@ -125,13 +125,14 @@ class MemoryAPIService: logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") async def write_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - config_id: str, - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + config_id: str, + storage_type: str = "neo4j", + files: Optional[list]=None, + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Write memory with validation. @@ -153,14 +154,16 @@ class MemoryAPIService: ResourceNotFoundException: If end_user not found BusinessException: If end_user not in authorized workspace or write fails """ + if files is None: + files = list() logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - + try: # Delegate to MemoryAgentService # Convert string message to list[dict] format expected by MemoryAgentService @@ -171,11 +174,12 @@ class MemoryAPIService: config_id=config_id, db=self.db, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id or "" + user_rag_memory_id=user_rag_memory_id or "", + files=files ) - + logger.info(f"Memory write successful for end_user: {end_user_id}") - + # result may be a string "success" or a dict with a "status" key # Preserve the full dict so callers don't silently lose extra fields # (e.g. error codes, metadata) returned by MemoryAgentService. @@ -189,7 +193,7 @@ class MemoryAPIService: "status": result if isinstance(result, str) else "success", "end_user_id": end_user_id, } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -204,16 +208,16 @@ class MemoryAPIService: message=f"Memory write failed: {str(e)}", code=BizCode.MEMORY_WRITE_FAILED ) - + async def read_memory( - self, - workspace_id: uuid.UUID, - end_user_id: str, - message: str, - search_switch: str = "0", - config_id: str = "", - storage_type: str = "neo4j", - user_rag_memory_id: Optional[str] = None, + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + search_switch: str = "0", + config_id: str = "", + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """Read memory with validation. @@ -237,14 +241,13 @@ class MemoryAPIService: BusinessException: If end_user not in authorized workspace or read fails """ logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}") - + # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - + # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) - try: # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( @@ -257,15 +260,15 @@ class MemoryAPIService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id or "" ) - + logger.info(f"Memory read successful for end_user: {end_user_id}") - + return { "answer": result.get("answer", ""), "intermediate_outputs": result.get("intermediate_outputs", []), "end_user_id": end_user_id } - + except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") raise BusinessException( @@ -282,8 +285,8 @@ class MemoryAPIService: ) def list_memory_configs( - self, - workspace_id: uuid.UUID, + self, + workspace_id: uuid.UUID, ) -> Dict[str, Any]: """List all memory configs for a workspace. diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index a3751c07..1a4af531 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -37,7 +37,7 @@ def _validate_config_id(config_id, db: Session = None): """Validate configuration ID format (supports both UUID and integer).""" if isinstance(config_id, uuid.UUID): return config_id - + if config_id is None: raise InvalidConfigError( "Configuration ID cannot be None", @@ -60,18 +60,18 @@ def _validate_config_id(config_id, db: Session = None): if result: logger.info(f"Found config_id {result.config_id} for user_id {config_id}") return result.config_id - + return config_id if isinstance(config_id, str): config_id_stripped = config_id.strip() - + # Try parsing as UUID first try: return uuid.UUID(config_id_stripped) except ValueError: pass - + # Fall back to integer parsing try: parsed_id = int(config_id_stripped) @@ -81,17 +81,17 @@ def _validate_config_id(config_id, db: Session = None): field_name="config_id", invalid_value=config_id, ) - + # 如果提供了数据库会话,尝试通过 user_id 查询 config_id if db is not None: # 查询 user_id 匹配的记录 stmt = select(MemoryConfigModel).where(MemoryConfigModel.user_id == str(parsed_id)) result = db.execute(stmt).scalars().first() - + if result: logger.info(f"Found config_id {result.config_id} for user_id {parsed_id}") return result.config_id - + return parsed_id except ValueError: raise InvalidConfigError( @@ -154,10 +154,10 @@ class MemoryConfigService: self.db = db def load_memory_config( - self, - config_id: Optional[UUID] = None, - workspace_id: Optional[UUID] = None, - service_name: str = "MemoryConfigService", + self, + config_id: Optional[UUID] = None, + workspace_id: Optional[UUID] = None, + service_name: str = "MemoryConfigService", ) -> MemoryConfig: """ Load memory configuration from database with optional fallback. @@ -194,14 +194,14 @@ class MemoryConfigService: try: # Use get_config_with_fallback if workspace_id is provided memory_config = None + validated_config_id = None if workspace_id: - validated_config_id = None if config_id: try: validated_config_id = _validate_config_id(config_id, self.db) except Exception: validated_config_id = None - + memory_config = self.get_config_with_fallback( memory_config_id=validated_config_id, workspace_id=workspace_id @@ -210,7 +210,7 @@ class MemoryConfigService: validated_config_id = _validate_config_id(config_id, self.db) from app.models.memory_config_model import MemoryConfig as MemoryConfigModel memory_config = self.db.get(MemoryConfigModel, validated_config_id) - + if not memory_config: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -233,7 +233,7 @@ class MemoryConfigService: result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id) db_query_time = time.time() - db_query_start logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") - + if not result: raise ConfigurationError( f"Workspace not found for config {memory_config.config_id}" @@ -243,10 +243,10 @@ class MemoryConfigService: # Helper function to validate model with workspace fallback def _validate_model_with_fallback( - model_id: str, - model_type: str, - workspace_default: str, - required: bool = False + model_id: str, + model_type: str, + workspace_default: str, + required: bool = False ) -> tuple: """Validate model ID, falling back to workspace default if invalid. @@ -275,7 +275,7 @@ class MemoryConfigService: logger.warning( f"{model_type} model validation failed, trying workspace default: {e}" ) - + # Fallback to workspace default if workspace_default: try: @@ -297,7 +297,7 @@ class MemoryConfigService: logger.error(f"Workspace default {model_type} model also invalid: {e}") if required: raise - + if required: raise InvalidConfigError( f"{model_type.title()} model is required but not configured", @@ -306,7 +306,7 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id ) - + return None, None # Step 2: Validate embedding model with workspace fallback @@ -343,6 +343,35 @@ class MemoryConfigService: if memory_config.rerank_id or workspace.rerank: logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") + vision_uuid, vision_name = validate_and_resolve_model_id( + memory_config.vision_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + audio_uuid, audio_name = validate_and_resolve_model_id( + memory_config.audio_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) + + video_uuid, video_name = validate_and_resolve_model_id( + memory_config.video_id, + "llm", + self.db, + workspace.tenant_id, + required=False, + config_id=validated_config_id, + workspace_id=workspace.id, + ) # Create immutable MemoryConfig object config = MemoryConfig( config_id=memory_config.config_id, @@ -356,6 +385,12 @@ class MemoryConfigService: embedding_model_name=embedding_name, rerank_model_id=rerank_uuid, rerank_model_name=rerank_name, + video_model_id=video_uuid, + video_model_name=video_name, + vision_model_id=vision_uuid, + vision_model_name=vision_name, + audio_model_id=audio_uuid, + audio_model_name=audio_name, storage_type=workspace.storage_type or "neo4j", chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker", reflexion_enabled=memory_config.enable_self_reflexion or False, @@ -364,24 +399,31 @@ class MemoryConfigService: reflexion_baseline=memory_config.baseline or "Time", loaded_at=datetime.now(), # Pipeline config: Deduplication - enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, - enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, + enable_llm_dedup_blockwise=bool( + memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False, + enable_llm_disambiguation=bool( + memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False, deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True, t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8, t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8, t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8, # Pipeline config: Statement extraction - statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, - include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, - max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000, + statement_granularity=int( + memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2, + include_dialogue_context=bool( + memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False, + max_dialogue_context_chars=int( + memory_config.max_context) if memory_config.max_context is not None else 1000, # Pipeline config: Forgetting engine lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5, lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5, offset=float(memory_config.offset) if memory_config.offset is not None else 0.0, # Pipeline config: Pruning - pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, + pruning_enabled=bool( + memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False, pruning_scene=memory_config.pruning_scene or "education", - pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, + pruning_threshold=float( + memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, # Ontology scene association scene_id=memory_config.scene_id, ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id), @@ -448,9 +490,9 @@ class MemoryConfigService: if not config: logger.warning(f"Model ID {model_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -481,9 +523,9 @@ class MemoryConfigService: if not config: logger.warning(f"Embedding model ID {embedding_id} not found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在") - + api_config: ModelApiKey = config.api_keys[0] - + return { "model_name": api_config.model_name, "provider": api_config.provider, @@ -571,25 +613,25 @@ class MemoryConfigService: """ from app.core.memory.models.ontology_extraction_models import OntologyTypeList from app.repositories.ontology_class_repository import OntologyClassRepository - + if not memory_config.scene_id: logger.debug("No scene_id configured, skipping ontology type fetch") return None - + try: ontology_repo = OntologyClassRepository(self.db) ontology_classes = ontology_repo.get_classes_by_scene(memory_config.scene_id) - + if not ontology_classes: logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}") return None - + ontology_types = OntologyTypeList.from_db_models(ontology_classes) logger.info( f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" ) return ontology_types - + except Exception as e: logger.warning( f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}", @@ -598,8 +640,8 @@ class MemoryConfigService: return None def get_workspace_default_config( - self, - workspace_id: UUID + self, + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get workspace default memory config. @@ -613,19 +655,19 @@ class MemoryConfigService: Optional[MemoryConfigModel]: Default config or None if no configs exist """ config = MemoryConfigRepository.get_workspace_default(self.db, workspace_id) - + if not config: logger.warning( "No active memory config found for workspace fallback", extra={"workspace_id": str(workspace_id)} ) - + return config def get_config_with_fallback( - self, - memory_config_id: Optional[UUID], - workspace_id: UUID + self, + memory_config_id: Optional[UUID], + workspace_id: UUID ) -> Optional["MemoryConfigModel"]: """Get memory config with fallback to workspace default. @@ -644,13 +686,13 @@ class MemoryConfigService: "No memory config ID provided, using workspace default", extra={"workspace_id": str(workspace_id)} ) - + config = MemoryConfigRepository.get_with_fallback( self.db, memory_config_id, workspace_id ) - + if not config and memory_config_id: logger.warning( "Memory config not found, falling back to workspace default", @@ -659,13 +701,13 @@ class MemoryConfigService: "workspace_id": str(workspace_id) } ) - + return config def delete_config( - self, - config_id: UUID | int, - force: bool = False + self, + config_id: UUID | int, + force: bool = False ) -> dict: """Delete memory config with protection against in-use configs. @@ -687,7 +729,7 @@ class MemoryConfigService: from app.core.exceptions import ResourceNotFoundException from app.models.memory_config_model import MemoryConfig as MemoryConfigModel from app.repositories.end_user_repository import EndUserRepository - + # 处理旧格式 int 类型的 config_id if isinstance(config_id, int): logger.warning( @@ -699,11 +741,11 @@ class MemoryConfigService: "message": "旧格式配置ID不支持删除操作,请使用新版配置", "legacy_int_id": config_id } - + config = self.db.get(MemoryConfigModel, config_id) if not config: raise ResourceNotFoundException("MemoryConfig", str(config_id)) - + # Check if this is the default config - default configs cannot be deleted if config.is_default: logger.warning( @@ -715,11 +757,11 @@ class MemoryConfigService: "message": "默认配置不允许删除", "is_default": True } - + # Use repository to count connected end users end_user_repo = EndUserRepository(self.db) connected_count = end_user_repo.count_by_memory_config_id(config_id) - + if connected_count > 0 and not force: logger.warning( "Attempted to delete memory config with connected end users", @@ -728,18 +770,18 @@ class MemoryConfigService: "connected_count": connected_count } ) - + return { "status": "warning", "message": f"无法删除记忆配置:{connected_count} 个终端用户正在使用此配置", "connected_count": connected_count, "force_required": True } - + # Force delete: use repository to clear end user references first if connected_count > 0 and force: cleared_count = end_user_repo.clear_memory_config_id(config_id) - + logger.warning( "Force deleting memory config, clearing end user references", extra={ @@ -747,11 +789,11 @@ class MemoryConfigService: "cleared_end_users": cleared_count } ) - + try: self.db.delete(config) self.db.commit() - + logger.info( "Memory config deleted", extra={ @@ -760,16 +802,16 @@ class MemoryConfigService: "affected_users": connected_count } ) - + return { "status": "success", "message": "记忆配置删除成功", "affected_users": connected_count } - + except IntegrityError as e: self.db.rollback() - + # Handle foreign key violation gracefully error_str = str(e.orig) if e.orig else str(e) if "ForeignKeyViolation" in error_str or "foreign key constraint" in error_str.lower(): @@ -785,7 +827,7 @@ class MemoryConfigService: "message": "无法删除记忆配置:仍有终端用户引用此配置,请使用 force=true 强制删除", "force_required": True } - + # Re-raise other integrity errors logger.error( "Delete failed due to integrity error", @@ -800,9 +842,9 @@ class MemoryConfigService: # ==================== 记忆配置提取方法 ==================== def extract_memory_config_id( - self, - app_type: str, - config: dict + self, + app_type: str, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从发布配置中提取 memory_config_id(根据应用类型分发) @@ -828,8 +870,8 @@ class MemoryConfigService: return None, False def _extract_memory_config_id_from_agent( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Agent 应用配置中提取 memory_config_id @@ -888,8 +930,8 @@ class MemoryConfigService: return None, False def _extract_memory_config_id_from_workflow( - self, - config: dict + self, + config: dict ) -> tuple[Optional[uuid.UUID], bool]: """从 Workflow 应用配置中提取 memory_config_id @@ -905,14 +947,14 @@ class MemoryConfigService: - is_legacy_int: 是否检测到旧格式 int 数据 """ nodes = config.get("nodes", []) - + for node in nodes: node_type = node.get("type", "") - + # 检查是否为记忆节点 (support both formats: memory-read/memory-write and MemoryRead/MemoryWrite) if node_type.lower() in ["memoryread", "memorywrite", "memory-read", "memory-write"]: config_id = node.get("config", {}).get("config_id") - + if config_id: try: # 处理字符串、UUID 和 int(旧数据兼容)三种情况 @@ -937,6 +979,6 @@ class MemoryConfigService: f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, " f"node_type={node_type}, error={str(e)}" ) - + logger.debug("工作流配置中未找到记忆节点") return None, False diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 8a7c86e2..d6c1de87 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -12,11 +12,12 @@ from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig -from app.models import FileMetadata +from app.models import FileMetadata, ModelApiKey, ModelType from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.prompt_optimizer_model import RoleType from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository -from app.schemas import FileType +from app.schemas import FileType, FileInput +from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_perceptual_schema import ( PerceptualQuerySchema, PerceptualTimelineResponse, @@ -24,6 +25,8 @@ from app.schemas.memory_perceptual_schema import ( AudioModal, Content, VideoModal, TextModal ) from app.schemas.model_schema import ModelInfo +from app.services.model_service import ModelApiKeyService +from app.services.multimodal_service import MultimodalService business_logger = get_business_logger() @@ -195,21 +198,58 @@ class MemoryPerceptualService: business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) + def _get_mutlimodal_client( + self, + file_type: FileType, + config: MemoryConfig + ) -> tuple[RedBearLLM | None, ModelApiKey | None]: + model_config = None + if file_type == FileType.AUDIO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.audio_model_id + ) + elif file_type == FileType.VIDEO: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.video_model_id + ) + elif file_type == FileType.DOCUMENT: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.llm_model_id + ) + elif file_type == FileType.IMAGE: + model_config = ModelApiKeyService.get_available_api_key( + self.db, + config.vision_model_id + ) + llm = None + if model_config: + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_config.model_name, + provider=model_config.provider, + api_key=model_config.api_key, + base_url=model_config.api_base, + is_omni=model_config.is_omni + ) + ) + return llm, model_config + async def generate_perceptual_memory( self, end_user_id: str, - model_config: ModelInfo, - file_type: str, - file_url: str, - file_message: dict, + memory_config: MemoryConfig, + file: FileInput ): - memories = self.repository.get_by_url(file_url) + memories = self.repository.get_by_url(file.url) if memories: - business_logger.info(f"Perceptual memory already exists: {file_url}") + business_logger.info(f"Perceptual memory already exists: {file.url}") if end_user_id not in [memory.end_user_id for memory in memories]: business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") memory_cache = memories[0] - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), perceptual_type=PerceptualType(memory_cache.perceptual_type), file_path=memory_cache.file_path, @@ -219,20 +259,31 @@ class MemoryPerceptualService: meta_data=memory_cache.meta_data ) self.db.commit() - - return - llm = RedBearLLM(RedBearModelConfig( + return memory + else: + for memory in memories: + if memory.end_user_id == uuid.UUID(end_user_id): + return memory + llm, model_config = self._get_mutlimodal_client(file.type, memory_config) + multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, provider=model_config.provider, api_key=model_config.api_key, - base_url=model_config.api_base, - is_omni=model_config.is_omni - ), type=model_config.model_type) + api_base=model_config.api_base, + is_omni=model_config.is_omni, + capability=model_config.capability, + model_type=ModelType.LLM + )) + file_message = await multimodel_service.process_files( + files=[file] + ) + if file_message: + file_message = file_message[0] try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() - rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') + rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) messages = [ @@ -242,8 +293,22 @@ class MemoryPerceptualService: ]} ] result = await llm.ainvoke(messages) - content = json_repair.repair_json(result.content, return_objects=True) - path = urlparse(file_url).path + content = result.content + final_output = "" + if isinstance(content, list): + for msg in content: + if isinstance(msg, dict): + final_output += msg.get("text", "") + elif isinstance(msg, str): + final_output += msg + elif isinstance(content, dict): + final_output += content.get("text", "") + elif isinstance(content, str): + final_output = content + else: + raise ValueError(f"Unexcept Model Output Type: {result.content}") + content = json_repair.repair_json(final_output, return_objects=True) + path = urlparse(file.url).path filename = os.path.basename(path) filename = unquote(filename) file_ext = os.path.splitext(filename)[1] @@ -260,13 +325,13 @@ class MemoryPerceptualService: except ValueError: business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: - if file_type == FileType.AUDIO: + if file.type == FileType.AUDIO: file_ext = ".mp3" - elif file_type == FileType.VIDEO: + elif file.type == FileType.VIDEO: file_ext = ".mp4" - elif file_type == FileType.DOCUMENT: + elif file.type == FileType.DOCUMENT: file_ext = ".txt" - elif file_type == FileType.IMAGE: + elif file.type == FileType.IMAGE: file_ext = ".jpg" filename += file_ext file_content = { @@ -274,11 +339,11 @@ class MemoryPerceptualService: "topic": content.get("topic"), "domain": content.get("domain") } - if file_type in [FileType.IMAGE, FileType.VIDEO]: + if file.type in [FileType.IMAGE, FileType.VIDEO]: file_modalities = { "scene": content.get("scene", []) } - elif file_type in [FileType.DOCUMENT]: + elif file.type in [FileType.DOCUMENT]: file_modalities = { "section_count": content.get("section_count", 0), "title": content.get("title", ""), @@ -288,10 +353,10 @@ class MemoryPerceptualService: file_modalities = { "speaker_count": content.get("speaker_count", 0) } - self.repository.create_perceptual_memory( + memory = self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType.trans_from_file_type(file_type), - file_path=file_url, + perceptual_type=PerceptualType.trans_from_file_type(file.type), + file_path=file.url, file_name=filename, file_ext=file_ext, summary=content.get('summary', ""), @@ -301,3 +366,4 @@ class MemoryPerceptualService: } ) self.db.commit() + return memory diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index f0c7cee2..eb8df242 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -9,14 +9,12 @@ - OpenAI: 支持 URL 和 base64 格式 """ import base64 +import csv import io -import uuid +import json from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional -import csv -import json - import PyPDF2 import httpx import magic @@ -33,7 +31,6 @@ from app.models.file_metadata_model import FileMetadata from app.schemas.app_schema import FileInput, FileType, TransferMethod from app.schemas.model_schema import ModelInfo from app.services.audio_transcription_service import AudioTranscriptionService -from app.tasks import write_perceptual_memory logger = get_business_logger() @@ -342,15 +339,12 @@ class MultimodalService: async def process_files( self, - end_user_id: uuid.UUID | str, files: Optional[List[FileInput]], - ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 Args: - end_user_id: 用户ID files: 文件输入列表 Returns: @@ -358,8 +352,6 @@ class MultimodalService: """ if not files: return [] - if isinstance(end_user_id, uuid.UUID): - end_user_id = str(end_user_id) # 获取对应的策略 # dashscope 的 omni 模型使用 OpenAI 兼容格式 @@ -380,23 +372,15 @@ class MultimodalService: if file.type == FileType.IMAGE and "vision" in self.capability: is_support, content = await self._process_image(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.DOCUMENT: is_support, content = await self._process_document(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.AUDIO and "audio" in self.capability: is_support, content = await self._process_audio(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.VIDEO and "video" in self.capability: is_support, content = await self._process_video(file, strategy) result.append(content) - if is_support: - self.write_perceptual_memory(end_user_id, file.type, file.url, content) else: logger.warning(f"不支持的文件类型: {file.type}") except Exception as e: @@ -418,17 +402,6 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - def write_perceptual_memory( - self, - end_user_id: str, - file_type: str, - file_url: str, - file_message: dict - ): - """写入感知记忆""" - if end_user_id and self.api_config: - write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message) - async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理图片文件 diff --git a/api/app/tasks.py b/api/app/tasks.py index c37e564e..8afb2194 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1080,12 +1080,14 @@ def write_message_task( config_id: str | int, storage_type: str, user_rag_memory_id: str, + file_messages: list[dict] | None, language: str = "zh" ) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write + file_messages: Files to write config_id: Configuration ID (can be UUID string, integer, or config_id_old) storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID @@ -1097,6 +1099,8 @@ def write_message_task( Raises: Exception on failure """ + if file_messages is None: + file_messages = [] logger.info( f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " @@ -1142,7 +1146,7 @@ def write_message_task( f"[CELERY WRITE] Executing MemoryAgentService.write_memory " f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, + result = await service.write_memory(end_user_id, message, file_messages, actual_config_id, db, storage_type, user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result