[add]New semantic pruning effect display for streaming output
This commit is contained in:
@@ -101,14 +101,101 @@ async def run_pilot_extraction(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
|
||||||
|
|
||||||
|
# ========== 步骤 2.1: 语义剪枝 ==========
|
||||||
|
pruned_dialogs = [dialog]
|
||||||
|
deleted_messages = [] # 记录被删除的消息
|
||||||
|
|
||||||
|
if memory_config.pruning_enabled:
|
||||||
|
try:
|
||||||
|
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||||
|
SemanticPruner,
|
||||||
|
)
|
||||||
|
from app.core.memory.models.config_models import PruningConfig
|
||||||
|
|
||||||
|
# 构建剪枝配置
|
||||||
|
pruning_config_dict = {
|
||||||
|
"pruning_switch": memory_config.pruning_enabled,
|
||||||
|
"pruning_scene": memory_config.pruning_scene,
|
||||||
|
"pruning_threshold": memory_config.pruning_threshold,
|
||||||
|
"llm_model_id": str(memory_config.llm_model_id),
|
||||||
|
}
|
||||||
|
config = PruningConfig(**pruning_config_dict)
|
||||||
|
|
||||||
|
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||||
|
|
||||||
|
# 记录剪枝前的消息(用于对比)
|
||||||
|
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
|
||||||
|
original_msg_count = len(original_messages)
|
||||||
|
|
||||||
|
# 执行剪枝
|
||||||
|
pruner = SemanticPruner(config=config, llm_client=llm_client)
|
||||||
|
pruned_dialogs = await pruner.prune_dataset([dialog])
|
||||||
|
|
||||||
|
# 计算剪枝结果并找出被删除的消息
|
||||||
|
if pruned_dialogs and pruned_dialogs[0].context:
|
||||||
|
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
|
||||||
|
remaining_msg_count = len(remaining_messages)
|
||||||
|
deleted_msg_count = original_msg_count - remaining_msg_count
|
||||||
|
|
||||||
|
# 找出被删除的消息(通过内容对比)
|
||||||
|
remaining_contents = {msg["content"] for msg in remaining_messages}
|
||||||
|
deleted_messages = [
|
||||||
|
{"index": idx, "role": msg["role"], "content": msg["content"]}
|
||||||
|
for idx, msg in enumerate(original_messages)
|
||||||
|
if msg["content"] not in remaining_contents
|
||||||
|
]
|
||||||
|
|
||||||
|
pruning_result = {
|
||||||
|
"enabled": True,
|
||||||
|
"scene": config.pruning_scene,
|
||||||
|
"threshold": config.pruning_threshold,
|
||||||
|
"original_count": original_msg_count,
|
||||||
|
"remaining_count": remaining_msg_count,
|
||||||
|
"deleted_count": deleted_msg_count,
|
||||||
|
"deleted_messages": deleted_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
|
||||||
|
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback("text_preprocessing_pruning", "语义剪枝完成", pruning_result)
|
||||||
|
else:
|
||||||
|
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
|
||||||
|
pruned_dialogs = [dialog]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
|
||||||
|
pruned_dialogs = [dialog]
|
||||||
|
if progress_callback:
|
||||||
|
error_result = {
|
||||||
|
"enabled": True,
|
||||||
|
"error": str(e),
|
||||||
|
"fallback": "使用原始对话"
|
||||||
|
}
|
||||||
|
await progress_callback("text_preprocessing_pruning", "语义剪枝失败", error_result)
|
||||||
|
else:
|
||||||
|
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
|
||||||
|
if progress_callback:
|
||||||
|
pruning_result = {
|
||||||
|
"enabled": False,
|
||||||
|
"message": "语义剪枝已关闭"
|
||||||
|
}
|
||||||
|
await progress_callback("text_preprocessing_pruning", "语义剪枝已关闭", pruning_result)
|
||||||
|
|
||||||
|
# ========== 步骤 2.2: 语义分块 ==========
|
||||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||||
data=[dialog],
|
data=pruned_dialogs,
|
||||||
chunker_strategy=memory_config.chunker_strategy,
|
chunker_strategy=memory_config.chunker_strategy,
|
||||||
llm_client=llm_client,
|
llm_client=llm_client,
|
||||||
)
|
)
|
||||||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
|
||||||
|
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
|
||||||
|
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
|
||||||
|
|
||||||
# 进度回调:输出每个分块的结果
|
# 进度回调:输出每个分块的结果
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
@@ -121,14 +208,14 @@ async def run_pilot_extraction(
|
|||||||
"dialog_id": dlg.id,
|
"dialog_id": dlg.id,
|
||||||
"chunker_strategy": memory_config.chunker_strategy,
|
"chunker_strategy": memory_config.chunker_strategy,
|
||||||
}
|
}
|
||||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
await progress_callback("text_preprocessing_chunking", f"分块 {i + 1} 处理完成", chunk_result)
|
||||||
|
|
||||||
preprocessing_summary = {
|
preprocessing_summary = {
|
||||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||||||
"total_dialogs": len(chunked_dialogs),
|
"total_dialogs": len(chunked_dialogs),
|
||||||
"chunker_strategy": memory_config.chunker_strategy,
|
"chunker_strategy": memory_config.chunker_strategy,
|
||||||
}
|
}
|
||||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
|
||||||
|
|
||||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user