Merge #13 into develop from fix/stream-output
'fix/stream-output' * fix/stream-output: (17 commits squashed) - [fix]Fix the issue where the streaming output effect is not obvious. - [fix]Fix the issue where the streaming output effect is not obvious. - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix] - [fix]Skip time extraction - [fix] - [fix]Skip time extraction - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output - [fix]Remove human-induced delays - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Fix the issue where the streaming output effect is not obvious. - [fix] - [fix]Skip time extraction - [fix]Remove human-induced delays - Merge branch 'fix/stream-output' of codeup.aliyun.com:redbearai/python/redbear-mem-open into fix/stream-output Signed-off-by: 乐力齐 <accounts_690c7b0af9007d7e338af636@mail.teambition.com> Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/13
This commit is contained in:
@@ -179,8 +179,21 @@ class ExtractionOrchestrator:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
total_statements = len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
|
||||
# 🔥 陈述句提取完成后,立即发送知识抽取完成消息
|
||||
if self.progress_callback:
|
||||
extraction_stats = {
|
||||
"statements_count": total_statements,
|
||||
"entities_count": 0, # 暂时为0,后续会更新
|
||||
"triplets_count": 0, # 暂时为0,后续会更新
|
||||
"temporal_ranges_count": 0, # 暂时为0,后续会更新
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||
|
||||
# 🔥 立即发送下一阶段的开始消息,让前端知道进入了创建节点和边阶段
|
||||
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成(后台静默执行)
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成(后台静默执行)")
|
||||
(
|
||||
triplet_maps,
|
||||
temporal_maps,
|
||||
@@ -206,72 +219,6 @@ class ExtractionOrchestrator:
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||
|
||||
# 进度回调:按三个阶段分别输出知识抽取结果
|
||||
if self.progress_callback:
|
||||
# 第一阶段:陈述句提取结果
|
||||
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement_index": i + 1,
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
|
||||
|
||||
# 第二阶段:三元组提取结果
|
||||
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
|
||||
triplet_result = {
|
||||
"extraction_type": "triplet",
|
||||
"triplet_index": i + 1,
|
||||
"subject": triplet.subject_name,
|
||||
"predicate": triplet.predicate,
|
||||
"object": triplet.object_name
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
|
||||
|
||||
# 第三阶段:时间提取结果
|
||||
if total_temporal > 0:
|
||||
# 收集时间信息
|
||||
temporal_results = []
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
temporal_results.append({
|
||||
"statement_id": statement.id,
|
||||
"statement": statement.statement,
|
||||
"valid_at": statement.temporal_validity.valid_at,
|
||||
"invalid_at": statement.temporal_validity.invalid_at
|
||||
})
|
||||
|
||||
# 输出时间提取结果
|
||||
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": i + 1,
|
||||
"statement": temporal_result["statement"],
|
||||
"valid_at": temporal_result["valid_at"],
|
||||
"invalid_at": temporal_result["invalid_at"]
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
else:
|
||||
# 如果没有时间信息,也发送一个时间提取完成的消息
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": 0,
|
||||
"message": "未发现时间信息"
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
|
||||
# 进度回调:知识抽取完成,传递知识抽取的统计信息
|
||||
extraction_stats = {
|
||||
"statements_count": total_statements,
|
||||
"entities_count": total_entities,
|
||||
"triplets_count": total_triplets,
|
||||
"temporal_ranges_count": total_temporal,
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||
|
||||
# 步骤 4: 将提取的数据赋值到语句
|
||||
logger.info("步骤 4/6: 数据赋值")
|
||||
dialog_data_list = await self._assign_extracted_data(
|
||||
@@ -285,6 +232,9 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 步骤 5: 创建节点和边
|
||||
logger.info("步骤 5/6: 创建节点和边")
|
||||
|
||||
# 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送
|
||||
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -304,6 +254,8 @@ class ExtractionOrchestrator:
|
||||
else:
|
||||
logger.info("步骤 6/6: 两阶段去重和消歧")
|
||||
|
||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||
|
||||
result = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -328,7 +280,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[DialogData]:
|
||||
"""
|
||||
从对话中提取陈述句(优化版:全局分块级并行)
|
||||
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -336,7 +288,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
更新后的对话数据列表(包含提取的陈述句)
|
||||
"""
|
||||
logger.info("开始陈述句提取(全局分块级并行)")
|
||||
logger.info("开始陈述句提取(全局分块级并行 + 流式输出)")
|
||||
|
||||
# 收集所有分块及其元数据
|
||||
all_chunks = []
|
||||
@@ -349,17 +301,44 @@ class ExtractionOrchestrator:
|
||||
chunk_metadata.append((d_idx, c_idx))
|
||||
|
||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||
|
||||
# 用于跟踪已完成的分块数量
|
||||
completed_chunks = 0
|
||||
total_chunks = len(all_chunks)
|
||||
|
||||
# 全局并行处理所有分块
|
||||
async def extract_for_chunk(chunk_data):
|
||||
async def extract_for_chunk(chunk_data, chunk_index):
|
||||
nonlocal completed_chunks
|
||||
chunk, group_id, dialogue_content = chunk_data
|
||||
try:
|
||||
return await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
|
||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||
completed_chunks += 1
|
||||
if self.progress_callback and statements and self.is_pilot_run:
|
||||
# 发送前3个陈述句作为示例
|
||||
for idx, stmt in enumerate(statements[:3]):
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id,
|
||||
"chunk_progress": f"{completed_chunks}/{total_chunks}",
|
||||
"statement_index_in_chunk": idx + 1
|
||||
}
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_result",
|
||||
f"陈述句提取中 ({completed_chunks}/{total_chunks})",
|
||||
stmt_result
|
||||
)
|
||||
|
||||
return statements
|
||||
except Exception as e:
|
||||
logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}")
|
||||
completed_chunks += 1
|
||||
return []
|
||||
|
||||
tasks = [extract_for_chunk(chunk_data) for chunk_data in all_chunks]
|
||||
tasks = [extract_for_chunk(chunk_data, i) for i, chunk_data in enumerate(all_chunks)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果分配回对话
|
||||
@@ -391,7 +370,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取三元组(优化版:全局陈述句级并行)
|
||||
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -399,7 +378,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
三元组映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始三元组提取(全局陈述句级并行)")
|
||||
logger.info("开始三元组提取(全局陈述句级并行 + 流式输出)")
|
||||
|
||||
# 收集所有陈述句及其元数据
|
||||
all_statements = []
|
||||
@@ -412,18 +391,30 @@ class ExtractionOrchestrator:
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组")
|
||||
|
||||
# 用于跟踪已完成的陈述句数量
|
||||
completed_statements = 0
|
||||
total_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
nonlocal completed_statements
|
||||
statement, chunk_content = stmt_data
|
||||
try:
|
||||
return await self.triplet_extractor._extract_triplets(statement, chunk_content)
|
||||
triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content)
|
||||
|
||||
# 注意:不再发送三元组提取的流式输出
|
||||
# 三元组提取在后台执行,但不向前端发送详细信息
|
||||
completed_statements += 1
|
||||
|
||||
return triplet_info
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
|
||||
completed_statements += 1
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
@@ -458,7 +449,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取时间信息(优化版:全局陈述句级并行)
|
||||
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -466,7 +457,21 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
时间信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始时间信息提取(全局陈述句级并行)")
|
||||
# 试运行模式:跳过时间提取以节省时间
|
||||
if self.is_pilot_run:
|
||||
logger.info("试运行模式:跳过时间信息提取(节省约 10-15 秒)")
|
||||
# 为所有陈述句返回空的时间范围
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
temporal_maps = []
|
||||
for dialog in dialog_data_list:
|
||||
temporal_map = {}
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
temporal_maps.append(temporal_map)
|
||||
return temporal_maps
|
||||
|
||||
logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)")
|
||||
|
||||
# 收集所有需要提取时间的陈述句
|
||||
all_statements = []
|
||||
@@ -494,18 +499,30 @@ class ExtractionOrchestrator:
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取")
|
||||
|
||||
# 用于跟踪已完成的时间提取数量
|
||||
completed_temporal = 0
|
||||
total_temporal_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
nonlocal completed_temporal
|
||||
statement, ref_dates = stmt_data
|
||||
try:
|
||||
return await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
|
||||
temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
|
||||
|
||||
# 注意:不再发送时间提取的流式输出
|
||||
# 时间提取在后台执行,但不向前端发送详细信息
|
||||
completed_temporal += 1
|
||||
|
||||
return temporal_range
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}")
|
||||
completed_temporal += 1
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
return TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
@@ -832,9 +849,7 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
logger.info("开始创建节点和边")
|
||||
|
||||
# 进度回调:正在创建节点和边
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||
# 注意:开始消息已在 run 方法中发送,这里不再重复发送
|
||||
|
||||
dialogue_nodes = []
|
||||
chunk_nodes = []
|
||||
@@ -846,8 +861,13 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 用于去重的集合
|
||||
entity_id_set = set()
|
||||
|
||||
# 用于跟踪进度
|
||||
total_dialogs = len(dialog_data_list)
|
||||
processed_dialogs = 0
|
||||
|
||||
for dialog_data in dialog_data_list:
|
||||
processed_dialogs += 1
|
||||
# 创建对话节点
|
||||
dialogue_node = DialogueNode(
|
||||
id=dialog_data.id,
|
||||
@@ -994,6 +1014,26 @@ class ExtractionOrchestrator:
|
||||
expired_at=dialog_data.expired_at,
|
||||
)
|
||||
entity_entity_edges.append(entity_entity_edge)
|
||||
|
||||
# 流式输出:每创建一个关系边,立即发送进度(限制发送数量)
|
||||
if self.progress_callback and len(entity_entity_edges) <= 10:
|
||||
# 获取实体名称
|
||||
source_name = triplet.subject_name
|
||||
target_name = triplet.object_name
|
||||
relationship_result = {
|
||||
"result_type": "relationship_creation",
|
||||
"relationship_index": len(entity_entity_edges),
|
||||
"source_entity": source_name,
|
||||
"relation_type": triplet.predicate,
|
||||
"target_entity": target_name,
|
||||
"relationship_text": f"{source_name} -[{triplet.predicate}]-> {target_name}",
|
||||
"dialog_progress": f"{processed_dialogs}/{total_dialogs}"
|
||||
}
|
||||
await self.progress_callback(
|
||||
"creating_nodes_edges_result",
|
||||
f"关系创建中 ({processed_dialogs}/{total_dialogs})",
|
||||
relationship_result
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
|
||||
@@ -1008,12 +1048,9 @@ class ExtractionOrchestrator:
|
||||
f"实体-实体边: {len(entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:只输出关系创建结果
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
# 注意:具体的关系创建结果已经在创建过程中实时发送了
|
||||
if self.progress_callback:
|
||||
# 输出关系创建结果
|
||||
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
|
||||
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
nodes_edges_stats = {
|
||||
"dialogue_nodes_count": len(dialogue_nodes),
|
||||
"chunk_nodes_count": len(chunk_nodes),
|
||||
@@ -1071,7 +1108,7 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
logger.info("开始两阶段实体去重和消歧")
|
||||
|
||||
# 进度回调:正在去重消歧
|
||||
# 进度回调:发送去重消歧开始消息
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("deduplication", "正在去重消歧...")
|
||||
|
||||
@@ -1154,25 +1191,26 @@ class ExtractionOrchestrator:
|
||||
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:输出去重消歧的具体结果
|
||||
# 流式输出:实时输出去重消歧的具体结果
|
||||
if self.progress_callback:
|
||||
# 分析实体合并情况
|
||||
# 分析实体合并情况(使用内存中的记录)
|
||||
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出去重合并的实体示例
|
||||
# 逐个输出去重合并的实体示例
|
||||
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
|
||||
dedup_result = {
|
||||
"result_type": "entity_merge",
|
||||
"merged_entity_name": merge_detail["main_entity_name"],
|
||||
"merged_count": merge_detail["merged_count"],
|
||||
"merge_progress": f"{i + 1}/{min(len(merge_info), 5)}",
|
||||
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result)
|
||||
|
||||
# 分析实体消歧情况
|
||||
# 分析实体消歧情况(使用内存中的记录)
|
||||
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出实体消歧的结果
|
||||
# 逐个输出实体消歧的结果
|
||||
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
|
||||
disamb_result = {
|
||||
"result_type": "entity_disambiguation",
|
||||
@@ -1180,11 +1218,10 @@ class ExtractionOrchestrator:
|
||||
"disambiguation_type": disamb_detail["disamb_type"],
|
||||
"confidence": disamb_detail.get("confidence", "unknown"),
|
||||
"reason": disamb_detail.get("reason", ""),
|
||||
"disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}",
|
||||
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
|
||||
|
||||
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result)
|
||||
|
||||
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
|
||||
await self._send_dedup_progress_callback(
|
||||
|
||||
Reference in New Issue
Block a user