From 902dd18bc829f7ce0c55189d0134da35d2748992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Sat, 20 Dec 2025 07:02:46 +0000 Subject: [PATCH] Merge #21 into develop from feature/emotion-engine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feature/情绪引擎 * feature/emotion-engine: (7 commits squashed) - [feature]Emotion Engine Development - [feature]Emotion Engine Development - Merge branch 'feature/emotion-engine' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/emotion-engine - [fix]1.Fix the front-end files;2.Cache Management Deletion;3.Delete "check_code.py" - [fix]1.Fix the front-end files;2.Cache Management Deletion;3.Delete "check_code.py" - Merge branch 'feature/emotion-engine' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/emotion-engine - [fix]fix vite.config.ts Signed-off-by: 乐力齐 Commented-by: aliyun6762716068 Commented-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/21 --- .../extraction_orchestrator.py | 173 +++++++++++++++--- 1 file changed, 144 insertions(+), 29 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index e00bcf0a..91529aa9 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -35,7 +35,6 @@ from app.core.memory.models.graph_models import ( from app.core.memory.utils.data.ontology import TemporalInfo from app.core.memory.models.variate_config import ( ExtractionPipelineConfig, - StatementExtractionConfig, ) from app.core.memory.llm_tools.openai_client import LLMClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient @@ -53,7 +52,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.tem ) from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( embedding_generation, - embedding_generation_all, generate_entity_embeddings_from_triplets, ) from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( @@ -179,24 +177,12 @@ class ExtractionOrchestrator: all_statements_list.extend(chunk.statements) total_statements = len(all_statements_list) - # 🔥 陈述句提取完成后,立即发送知识抽取完成消息 - if self.progress_callback: - extraction_stats = { - "statements_count": total_statements, - "entities_count": 0, # 暂时为0,后续会更新 - "triplets_count": 0, # 暂时为0,后续会更新 - "temporal_ranges_count": 0, # 暂时为0,后续会更新 - } - await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) - - # 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段 - await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") - - # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行) - logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)") + # 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 + logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成") ( triplet_maps, temporal_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -225,6 +211,7 @@ class ExtractionOrchestrator: dialog_data_list, temporal_maps, triplet_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -552,9 +539,108 @@ class ExtractionOrchestrator: return temporal_maps + async def _extract_emotions( + self, dialog_data_list: List[DialogData] + ) -> List[Dict[str, Any]]: + """ + 从对话中提取情绪信息(优化版:全局陈述句级并行) + + Args: + dialog_data_list: 对话数据列表 + + Returns: + 情绪信息映射列表,每个对话对应一个字典 + """ + logger.info("开始情绪信息提取(全局陈述句级并行)") + + # 收集所有陈述句及其配置 + all_statements = [] + statement_metadata = [] # (dialog_idx, statement_id) + + # 获取第一个对话的config_id来加载配置 + config_id = None + if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): + config_id = dialog_data_list[0].config_id + + # 加载DataConfig + data_config = None + if config_id: + try: + from app.db import SessionLocal + from app.repositories.data_config_repository import DataConfigRepository + + db = SessionLocal() + try: + data_config = DataConfigRepository.get_by_id(db, config_id) + finally: + db.close() + + if data_config and not data_config.emotion_enabled: + logger.info("情绪提取已在配置中禁用,跳过情绪提取") + return [{} for _ in dialog_data_list] + + except Exception as e: + logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取") + return [{} for _ in dialog_data_list] + else: + logger.info("未找到config_id,跳过情绪提取") + return [{} for _ in dialog_data_list] + + # 如果配置未启用情绪提取,直接返回空映射 + if not data_config or not data_config.emotion_enabled: + logger.info("情绪提取未启用,跳过") + return [{} for _ in dialog_data_list] + + # 收集所有陈述句 + for d_idx, dialog in enumerate(dialog_data_list): + for chunk in dialog.chunks: + for statement in chunk.statements: + all_statements.append((statement, data_config)) + statement_metadata.append((d_idx, statement.id)) + + logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") + + # 初始化情绪提取服务 + from app.services.emotion_extraction_service import EmotionExtractionService + emotion_service = EmotionExtractionService( + llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None + ) + + # 全局并行处理所有陈述句 + async def extract_for_statement(stmt_data): + statement, config = stmt_data + try: + return await emotion_service.extract_emotion(statement.statement, config) + except Exception as e: + logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}") + return None + + tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 将结果组织成对话级别的映射 + emotion_maps = [{} for _ in dialog_data_list] + successful_extractions = 0 + + for i, result in enumerate(results): + d_idx, stmt_id = statement_metadata[i] + if isinstance(result, Exception): + logger.error(f"陈述句处理异常: {result}") + emotion_maps[d_idx][stmt_id] = None + else: + emotion_maps[d_idx][stmt_id] = result + if result is not None: + successful_extractions += 1 + + # 统计提取结果 + logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪") + + return emotion_maps + async def _parallel_extract_and_embed( self, dialog_data_list: List[DialogData] ) -> Tuple[ + List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, List[float]]], @@ -562,35 +648,39 @@ class ExtractionOrchestrator: List[List[float]], ]: """ - 并行执行三元组提取、时间信息提取和基础嵌入生成 + 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成 - 这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: + 这四个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: - 三元组提取:从陈述句中提取实体和关系 - 时间信息提取:从陈述句中提取时间范围 + - 情绪提取:从陈述句中提取情绪信息 - 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组) Args: dialog_data_list: 对话数据列表 Returns: - 五个列表的元组: + 六个列表的元组: - 三元组映射列表 - 时间信息映射列表 + - 情绪映射列表 - 陈述句嵌入映射列表 - 分块嵌入映射列表 - 对话嵌入列表 """ - logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成") + logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成") - # 创建三个并行任务 + # 创建四个并行任务 triplet_task = self._extract_triplets(dialog_data_list) temporal_task = self._extract_temporal(dialog_data_list) + emotion_task = self._extract_emotions(dialog_data_list) embedding_task = self._generate_basic_embeddings(dialog_data_list) # 并行执行 results = await asyncio.gather( triplet_task, temporal_task, + emotion_task, embedding_task, return_exceptions=True ) @@ -598,19 +688,21 @@ class ExtractionOrchestrator: # 解包结果 triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] + emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list] - if isinstance(results[2], Exception): - logger.error(f"基础嵌入生成失败: {results[2]}") + if isinstance(results[3], Exception): + logger.error(f"基础嵌入生成失败: {results[3]}") statement_embedding_maps = [{} for _ in dialog_data_list] chunk_embedding_maps = [{} for _ in dialog_data_list] dialog_embeddings = [[] for _ in dialog_data_list] else: - statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2] + statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3] logger.info("并行任务执行完成") return ( triplet_maps, temporal_maps, + emotion_maps, statement_embedding_maps, chunk_embedding_maps, dialog_embeddings, @@ -727,6 +819,7 @@ class ExtractionOrchestrator: dialog_data_list: List[DialogData], temporal_maps: List[Dict[str, Any]], triplet_maps: List[Dict[str, Any]], + emotion_maps: List[Dict[str, Any]], statement_embedding_maps: List[Dict[str, List[float]]], chunk_embedding_maps: List[Dict[str, List[float]]], dialog_embeddings: List[List[float]], @@ -738,6 +831,7 @@ class ExtractionOrchestrator: dialog_data_list: 对话数据列表 temporal_maps: 时间信息映射列表 triplet_maps: 三元组映射列表 + emotion_maps: 情绪信息映射列表 statement_embedding_maps: 陈述句嵌入映射列表 chunk_embedding_maps: 分块嵌入映射列表 dialog_embeddings: 对话嵌入列表 @@ -752,6 +846,7 @@ class ExtractionOrchestrator: if ( len(temporal_maps) != expected_length or len(triplet_maps) != expected_length + or len(emotion_maps) != expected_length or len(statement_embedding_maps) != expected_length or len(chunk_embedding_maps) != expected_length or len(dialog_embeddings) != expected_length @@ -759,6 +854,7 @@ class ExtractionOrchestrator: logger.warning( f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, " + f"情绪映射: {len(emotion_maps)}, " f"陈述句嵌入: {len(statement_embedding_maps)}, " f"分块嵌入: {len(chunk_embedding_maps)}, " f"对话嵌入: {len(dialog_embeddings)}" @@ -767,6 +863,7 @@ class ExtractionOrchestrator: total_statements = 0 assigned_temporal = 0 assigned_triplets = 0 + assigned_emotions = 0 assigned_statement_embeddings = 0 assigned_chunk_embeddings = 0 assigned_dialog_embeddings = 0 @@ -774,12 +871,13 @@ class ExtractionOrchestrator: # 处理每个对话 for i, dialog_data in enumerate(dialog_data_list): # 检查是否有缺失的数据 - if i >= len(temporal_maps) or i >= len(triplet_maps): + if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps): logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值") continue temporal_map = temporal_maps[i] triplet_map = triplet_maps[i] + emotion_map = emotion_maps[i] statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {} chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {} dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else [] @@ -810,6 +908,18 @@ class ExtractionOrchestrator: statement.triplet_extraction_info = triplet_map[statement.id] assigned_triplets += 1 + # 赋值情绪信息 + if statement.id in emotion_map: + emotion_data = emotion_map[statement.id] + if emotion_data is not None: + # 将EmotionExtraction对象的字段赋值到Statement + statement.emotion_type = emotion_data.emotion_type + statement.emotion_intensity = emotion_data.emotion_intensity + statement.emotion_keywords = emotion_data.emotion_keywords + statement.emotion_subject = emotion_data.emotion_subject + statement.emotion_target = emotion_data.emotion_target + assigned_emotions += 1 + # 赋值陈述句嵌入 if statement.id in statement_embedding_map: statement.statement_embedding = statement_embedding_map[statement.id] @@ -818,6 +928,7 @@ class ExtractionOrchestrator: logger.info( f"数据赋值完成 - 总陈述句: {total_statements}, " f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, " + f"情绪信息: {assigned_emotions}, " f"陈述句嵌入: {assigned_statement_embeddings}, " f"分块嵌入: {assigned_chunk_embeddings}, " f"对话嵌入: {assigned_dialog_embeddings}" @@ -927,6 +1038,12 @@ class ExtractionOrchestrator: created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, + # Emotion fields + emotion_type=getattr(statement, 'emotion_type', None), + emotion_intensity=getattr(statement, 'emotion_intensity', None), + emotion_keywords=getattr(statement, 'emotion_keywords', None), + emotion_subject=getattr(statement, 'emotion_subject', None), + emotion_target=getattr(statement, 'emotion_target', None), ) statement_nodes.append(statement_node) @@ -1333,7 +1450,7 @@ class ExtractionOrchestrator: if match: entity1_name = match.group(1).strip() entity1_type = match.group(2) - entity2_name = match.group(3).strip() + match.group(3).strip() entity2_type = match.group(4) # 提取置信度和原因 @@ -1646,7 +1763,6 @@ async def get_chunked_dialogs( """ import json import re - import os # 加载测试数据 testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") @@ -1822,7 +1938,6 @@ async def get_chunked_dialogs_with_preprocessing( Returns: 带 chunks 的 DialogData 列表 """ - import os print("\n=== 完整数据处理流程(包含预处理)===") if input_data_path is None: