From f38c065f944b0328e0e6dfd7bb4dccb7b036fce4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Thu, 18 Dec 2025 09:56:35 +0000 Subject: [PATCH] Merge #13 into develop from fix/stream-output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 'fix/stream-output' * fix/stream-output: (17 commits squashed) - [fix]Fix the issue where the streaming output effect is not obvious. - [fix]Fix the issue where the streaming output effect is not obvious. - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix] - [fix]Skip time extraction - [fix] - [fix]Skip time extraction - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix]Remove human-induced delays - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Remove human-induced delays - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output Signed-off-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/13 --- .../extraction_orchestrator.py | 239 ++++++++++-------- 1 file changed, 138 insertions(+), 101 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 7eec1189..e00bcf0a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -179,8 +179,21 @@ class ExtractionOrchestrator: all_statements_list.extend(chunk.statements) total_statements = len(all_statements_list) - # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成 - logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成") + # 🔥 陈述句提取完成后,立即发送知识抽取完成消息 + if self.progress_callback: + extraction_stats = { + "statements_count": total_statements, + "entities_count": 0, # 暂时为0,后续会更新 + "triplets_count": 0, # 暂时为0,后续会更新 + "temporal_ranges_count": 0, # 暂时为0,后续会更新 + } + await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) + + # 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段 + await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") + + # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行) + logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)") ( triplet_maps, temporal_maps, @@ -206,72 +219,6 @@ class ExtractionOrchestrator: logger.info("步骤 3/6: 生成实体嵌入") triplet_maps = await self._generate_entity_embeddings(triplet_maps) - # 进度回调:按三个阶段分别输出知识抽取结果 - if self.progress_callback: - # 第一阶段:陈述句提取结果 - for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句 - stmt_result = { - "extraction_type": "statement", - "statement_index": i + 1, - "statement": stmt.statement, - "statement_id": stmt.id - } - await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result) - - # 第二阶段:三元组提取结果 - for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组 - triplet_result = { - "extraction_type": "triplet", - "triplet_index": i + 1, - "subject": triplet.subject_name, - "predicate": triplet.predicate, - "object": triplet.object_name - } - await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result) - - # 第三阶段:时间提取结果 - if total_temporal > 0: - # 收集时间信息 - temporal_results = [] - for dialog in dialog_data_list: - for chunk in dialog.chunks: - for statement in chunk.statements: - if hasattr(statement, 'temporal_validity') and statement.temporal_validity: - temporal_results.append({ - "statement_id": statement.id, - "statement": statement.statement, - "valid_at": statement.temporal_validity.valid_at, - "invalid_at": statement.temporal_validity.invalid_at - }) - - # 输出时间提取结果 - for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果 - time_result = { - "extraction_type": "temporal", - "temporal_index": i + 1, - "statement": temporal_result["statement"], - "valid_at": temporal_result["valid_at"], - "invalid_at": temporal_result["invalid_at"] - } - await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result) - else: - # 如果没有时间信息,也发送一个时间提取完成的消息 - time_result = { - "extraction_type": "temporal", - "temporal_index": 0, - "message": "未发现时间信息" - } - await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result) - - # 进度回调:知识抽取完成,传递知识抽取的统计信息 - extraction_stats = { - "statements_count": total_statements, - "entities_count": total_entities, - "triplets_count": total_triplets, - "temporal_ranges_count": total_temporal, - } - await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats) - # 步骤 4: 将提取的数据赋值到语句 logger.info("步骤 4/6: 数据赋值") dialog_data_list = await self._assign_extracted_data( @@ -285,6 +232,9 @@ class ExtractionOrchestrator: # 步骤 5: 创建节点和边 logger.info("步骤 5/6: 创建节点和边") + + # 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送 + ( dialogue_nodes, chunk_nodes, @@ -304,6 +254,8 @@ class ExtractionOrchestrator: else: logger.info("步骤 6/6: 两阶段去重和消歧") + # 注意:deduplication 消息已在创建节点和边完成后立即发送 + result = await self._run_dedup_and_write_summary( dialogue_nodes, chunk_nodes, @@ -328,7 +280,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[DialogData]: """ - 从对话中提取陈述句(优化版:全局分块级并行) + 从对话中提取陈述句(流式输出版本:边提取边发送进度) Args: dialog_data_list: 对话数据列表 @@ -336,7 +288,7 @@ class ExtractionOrchestrator: Returns: 更新后的对话数据列表(包含提取的陈述句) """ - logger.info("开始陈述句提取(全局分块级并行)") + logger.info("开始陈述句提取(全局分块级并行 + 流式输出)") # 收集所有分块及其元数据 all_chunks = [] @@ -349,17 +301,44 @@ class ExtractionOrchestrator: chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") + + # 用于跟踪已完成的分块数量 + completed_chunks = 0 + total_chunks = len(all_chunks) # 全局并行处理所有分块 - async def extract_for_chunk(chunk_data): + async def extract_for_chunk(chunk_data, chunk_index): + nonlocal completed_chunks chunk, group_id, dialogue_content = chunk_data try: - return await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) + statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) + + # 流式输出:每提取完一个分块的陈述句,立即发送进度 + # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 + completed_chunks += 1 + if self.progress_callback and statements and self.is_pilot_run: + # 发送前3个陈述句作为示例 + for idx, stmt in enumerate(statements[:3]): + stmt_result = { + "extraction_type": "statement", + "statement": stmt.statement, + "statement_id": stmt.id, + "chunk_progress": f"{completed_chunks}/{total_chunks}", + "statement_index_in_chunk": idx + 1 + } + await self.progress_callback( + "knowledge_extraction_result", + f"陈述句提取中 ({completed_chunks}/{total_chunks})", + stmt_result + ) + + return statements except Exception as e: logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}") + completed_chunks += 1 return [] - tasks = [extract_for_chunk(chunk_data) for chunk_data in all_chunks] + tasks = [extract_for_chunk(chunk_data, i) for i, chunk_data in enumerate(all_chunks)] results = await asyncio.gather(*tasks, return_exceptions=True) # 将结果分配回对话 @@ -391,7 +370,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ - 从对话中提取三元组(优化版:全局陈述句级并行) + 从对话中提取三元组(流式输出版本:边提取边发送进度) Args: dialog_data_list: 对话数据列表 @@ -399,7 +378,7 @@ class ExtractionOrchestrator: Returns: 三元组映射列表,每个对话对应一个字典 """ - logger.info("开始三元组提取(全局陈述句级并行)") + logger.info("开始三元组提取(全局陈述句级并行 + 流式输出)") # 收集所有陈述句及其元数据 all_statements = [] @@ -412,18 +391,30 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组") + + # 用于跟踪已完成的陈述句数量 + completed_statements = 0 + total_statements = len(all_statements) # 全局并行处理所有陈述句 - async def extract_for_statement(stmt_data): + async def extract_for_statement(stmt_data, stmt_index): + nonlocal completed_statements statement, chunk_content = stmt_data try: - return await self.triplet_extractor._extract_triplets(statement, chunk_content) + triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content) + + # 注意:不再发送三元组提取的流式输出 + # 三元组提取在后台执行,但不向前端发送详细信息 + completed_statements += 1 + + return triplet_info except Exception as e: logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}") + completed_statements += 1 from app.core.memory.models.triplet_models import TripletExtractionResponse return TripletExtractionResponse(triplets=[], entities=[]) - tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] + tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)] results = await asyncio.gather(*tasks, return_exceptions=True) # 将结果组织成对话级别的映射 @@ -458,7 +449,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ - 从对话中提取时间信息(优化版:全局陈述句级并行) + 从对话中提取时间信息(流式输出版本:边提取边发送进度) Args: dialog_data_list: 对话数据列表 @@ -466,7 +457,21 @@ class ExtractionOrchestrator: Returns: 时间信息映射列表,每个对话对应一个字典 """ - logger.info("开始时间信息提取(全局陈述句级并行)") + # 试运行模式:跳过时间提取以节省时间 + if self.is_pilot_run: + logger.info("试运行模式:跳过时间信息提取(节省约 10-15 秒)") + # 为所有陈述句返回空的时间范围 + from app.core.memory.models.message_models import TemporalValidityRange + temporal_maps = [] + for dialog in dialog_data_list: + temporal_map = {} + for chunk in dialog.chunks: + for statement in chunk.statements: + temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None) + temporal_maps.append(temporal_map) + return temporal_maps + + logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)") # 收集所有需要提取时间的陈述句 all_statements = [] @@ -494,18 +499,30 @@ class ExtractionOrchestrator: statement_metadata.append((d_idx, statement.id)) logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取") + + # 用于跟踪已完成的时间提取数量 + completed_temporal = 0 + total_temporal_statements = len(all_statements) # 全局并行处理所有陈述句 - async def extract_for_statement(stmt_data): + async def extract_for_statement(stmt_data, stmt_index): + nonlocal completed_temporal statement, ref_dates = stmt_data try: - return await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) + temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) + + # 注意:不再发送时间提取的流式输出 + # 时间提取在后台执行,但不向前端发送详细信息 + completed_temporal += 1 + + return temporal_range except Exception as e: logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}") + completed_temporal += 1 from app.core.memory.models.message_models import TemporalValidityRange return TemporalValidityRange(valid_at=None, invalid_at=None) - tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] + tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)] results = await asyncio.gather(*tasks, return_exceptions=True) # 将结果组织成对话级别的映射 @@ -832,9 +849,7 @@ class ExtractionOrchestrator: """ logger.info("开始创建节点和边") - # 进度回调:正在创建节点和边 - if self.progress_callback: - await self.progress_callback("creating_nodes_edges", "正在创建节点和边...") + # 注意:开始消息已在 run 方法中发送,这里不再重复发送 dialogue_nodes = [] chunk_nodes = [] @@ -846,8 +861,13 @@ class ExtractionOrchestrator: # 用于去重的集合 entity_id_set = set() + + # 用于跟踪进度 + total_dialogs = len(dialog_data_list) + processed_dialogs = 0 for dialog_data in dialog_data_list: + processed_dialogs += 1 # 创建对话节点 dialogue_node = DialogueNode( id=dialog_data.id, @@ -994,6 +1014,26 @@ class ExtractionOrchestrator: expired_at=dialog_data.expired_at, ) entity_entity_edges.append(entity_entity_edge) + + # 流式输出:每创建一个关系边,立即发送进度(限制发送数量) + if self.progress_callback and len(entity_entity_edges) <= 10: + # 获取实体名称 + source_name = triplet.subject_name + target_name = triplet.object_name + relationship_result = { + "result_type": "relationship_creation", + "relationship_index": len(entity_entity_edges), + "source_entity": source_name, + "relation_type": triplet.predicate, + "target_entity": target_name, + "relationship_text": f"{source_name} -[{triplet.predicate}]-> {target_name}", + "dialog_progress": f"{processed_dialogs}/{total_dialogs}" + } + await self.progress_callback( + "creating_nodes_edges_result", + f"关系创建中 ({processed_dialogs}/{total_dialogs})", + relationship_result + ) else: logger.warning( f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, " @@ -1008,12 +1048,9 @@ class ExtractionOrchestrator: f"实体-实体边: {len(entity_entity_edges)}" ) - # 进度回调:只输出关系创建结果 + # 进度回调:创建节点和边完成,传递结果统计 + # 注意:具体的关系创建结果已经在创建过程中实时发送了 if self.progress_callback: - # 输出关系创建结果 - await self._output_relationship_creation_results(entity_entity_edges, entity_nodes) - - # 进度回调:创建节点和边完成,传递结果统计 nodes_edges_stats = { "dialogue_nodes_count": len(dialogue_nodes), "chunk_nodes_count": len(chunk_nodes), @@ -1071,7 +1108,7 @@ class ExtractionOrchestrator: """ logger.info("开始两阶段实体去重和消歧") - # 进度回调:正在去重消歧 + # 进度回调:发送去重消歧开始消息 if self.progress_callback: await self.progress_callback("deduplication", "正在去重消歧...") @@ -1154,25 +1191,26 @@ class ExtractionOrchestrator: f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" ) - # 进度回调:输出去重消歧的具体结果 + # 流式输出:实时输出去重消歧的具体结果 if self.progress_callback: - # 分析实体合并情况 + # 分析实体合并情况(使用内存中的记录) merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes) - # 输出去重合并的实体示例 + # 逐个输出去重合并的实体示例 for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果 dedup_result = { "result_type": "entity_merge", "merged_entity_name": merge_detail["main_entity_name"], "merged_count": merge_detail["merged_count"], + "merge_progress": f"{i + 1}/{min(len(merge_info), 5)}", "message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并" } - await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result) + await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result) - # 分析实体消歧情况 + # 分析实体消歧情况(使用内存中的记录) disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes) - # 输出实体消歧的结果 + # 逐个输出实体消歧的结果 for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果 disamb_result = { "result_type": "entity_disambiguation", @@ -1180,11 +1218,10 @@ class ExtractionOrchestrator: "disambiguation_type": disamb_detail["disamb_type"], "confidence": disamb_detail.get("confidence", "unknown"), "reason": disamb_detail.get("reason", ""), + "disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}", "message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}" } - await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result) - - + await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result) # 进度回调:去重消歧完成,传递去重和消歧的具体效果 await self._send_dedup_progress_callback(