[fix]Address the shortcomings of intelligent pruning

This commit is contained in:
lanceyq
2026-02-27 16:09:22 +08:00
parent f7d92be5ea
commit 5253cf3899

View File

@@ -5,14 +5,17 @@
- 对话级一次性抽取判定相关性
- 仅对"不相关对话"的消息按比例删除
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
"""
import asyncio
import os
import hashlib
import json
import re
from collections import OrderedDict
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Dict, Tuple, Set
from pydantic import BaseModel, Field
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
@@ -36,6 +39,23 @@ class DialogExtractionResponse(BaseModel):
keywords: List[str] = Field(default_factory=list)
class MessageImportanceResponse(BaseModel):
"""消息重要性批量判断的结构化返回用于LLM语义判断
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
- reasons: 可选的判断理由
"""
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
class QAPair(BaseModel):
"""问答对模型,用于识别和保护对话中的问答结构。"""
question_idx: int = Field(..., description="问题消息的索引")
answer_idx: int = Field(..., description="答案消息的索引")
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
class SemanticPruner:
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
@@ -43,109 +63,353 @@ class SemanticPruner:
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
"""
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
cfg_dict = get_pruning_config() if config is None else config.model_dump()
self.config = PruningConfig.model_validate(cfg_dict)
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
# 如果没有提供config使用默认配置
if config is None:
# 使用默认的剪枝配置
config = PruningConfig(
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
pruning_scene="education",
pruning_threshold=0.5
)
self.config = config
self.llm_client = llm_client
self.language = language # 保存语言配置
self.max_concurrent = max_concurrent # 新增:最大并发数
# Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
self._cache_max_size = 1000 # 缓存大小限制
# 运行日志:收集关键终端输出,便于写入 JSON
self.run_logs: List[str] = []
# 采用顺序处理,移除并发配置以简化与稳定执行
# 扩展的填充词库(包含表情符号和网络用语)
self._extended_fillers = [
# 基础寒暄
"你好", "您好", "在吗", "在的", "在呢", "", "嗯嗯", "", "哦哦",
"好的", "", "", "可以", "不可以", "谢谢", "多谢", "感谢",
"拜拜", "再见", "88", "", "回见",
# 口头禅
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
"", "", "", "", "", "", "嗯哼",
# 确认词
"是的", "", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
# 标点和符号
"。。。", "...", "???", "", "!!!", "",
# 表情符号(文本形式)
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
# 网络用语
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
"emmm", "emm", "em", "mmp", "wtf", "omg",
]
def _is_important_message(self, message: ConversationMessage) -> bool:
"""基于启发式规则识别重要信息消息,优先保留。
- 含日期/时间如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
- 关键词:"时间""日期""编号""订单""流水""金额""""""电话""手机号""邮箱""地址"
改进版:增强了规则覆盖范围,包括:
- 含日期/时间如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段
- 关键词:"时间""日期""编号""订单""流水""金额""""""电话""手机号""邮箱""地址"
- 新增:问句识别、决策性语句、承诺性语句
"""
import re
text = message.msg.strip()
if not text:
return False
patterns = [
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
r"\b\d{1,2}:\d{2}\b",
# 原有模式
r"\d{4}-\d{1,2}-\d{1,2}", # 修复:移除 \b 边界,因为中文前后没有单词边界
r"\d{1,2}:\d{2}", # 修复:移除 \b
r"\d{4}\d{1,2}月\d{1,2}日",
r"上午|下午|AM|PM",
r"订单号|工单|申请号|编号|ID|账号|账户",
r"电话|手机号|微信|QQ|邮箱",
r"地址|地点",
r"金额|费用|价格|¥|¥|\d+元",
r"时间|日期|有效期|截止",
r"上午|下午|AM|PM|今天|明天|后天|昨天|前天|本周|下周|上周|本月|下月|上月",
r"订单号|工单|申请号|编号|ID|账号|账户|流水号|单号",
r"电话|手机号|微信|QQ|邮箱|联系方式",
r"地址|地点|位置|门牌号",
r"金额|费用|价格|¥|¥|\d+元|人民币|美元|欧元",
r"时间|日期|有效期|截止|期限|到期",
# 新增模式
r"什么|为什么|怎么|如何|哪里|哪个|谁|多少|几点|何时", # 问句关键词
r"必须|一定|务必|需要|要求|规定|应该", # 决策性语句
r"承诺|保证|确保|负责|同意|答应", # 承诺性语句
r"\d{11}", # 11位手机号
r"\d{3,4}-\d{7,8}", # 固定电话
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", # 邮箱
]
for p in patterns:
if re.search(p, text, flags=re.IGNORECASE):
return True
# 检查是否为问句(以问号结尾或包含疑问词)
if text.endswith("") or text.endswith("?"):
return True
return False
def _importance_score(self, message: ConversationMessage) -> int:
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
简单启发:匹配到的类别越多、越关键分值越高。
改进版更细致的评分体系0-10分
"""
import re
text = message.msg.strip()
score = 0
weights = [
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
(r"\b\d{1,2}:\d{2}\b", 2),
# 高优先级4-5分
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
(r"金额|费用|价格|¥|¥|\d+元", 5),
(r"\d{11}", 4), # 手机号
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
# 中优先级2-3分
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 修复:移除 \b
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
(r"电话|手机号|微信|QQ|邮箱", 3),
(r"地址|地点", 2),
(r"金额|费用|价格|¥|¥|\d+元", 4),
(r"时间|日期|有效期|截止", 2),
(r"电话|手机号|微信|QQ|联系方式", 3),
(r"地址|地点|位置", 2),
(r"时间|日期|有效期|截止|明天|后天|下周|下月", 2), # 新增时间相关词
# 低优先级1分
(r"\d{1,2}:\d{2}", 1), # 修复:移除 \b
(r"上午|下午|AM|PM", 1),
]
for p, w in weights:
if re.search(p, text, flags=re.IGNORECASE):
score += w
return score
# 问句加分
if text.endswith("") or text.endswith("?"):
score += 2
# 长度加分(较长的消息通常包含更多信息)
if len(text) > 50:
score += 1
if len(text) > 100:
score += 1
return min(score, 10) # 最高10分
def _is_filler_message(self, message: ConversationMessage) -> bool:
"""检测典型寒暄/口头禅/确认类短消息用于跳过LLM分类以加速。
改进版:扩展了填充词库,支持表情符号和网络用语
满足以下之一视为填充消息:
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体
- 在扩展填充词库中
- 纯表情符号
"""
import re
t = message.msg.strip()
if not t:
return True
# 常见填充语
fillers = [
"你好", "您好", "在吗", "", "嗯嗯", "", "好的", "", "", "可以", "不可以", "谢谢",
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", ""
]
if t in fillers:
# 检查是否在扩展填充词库中
if t in self._extended_fillers:
return True
# 检查是否为纯表情符号(方括号包裹)
if re.fullmatch(r"(\[[^\]]+\])+", t):
return True
# 检查是否为纯emojiUnicode表情
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # 表情符号
"\U0001F300-\U0001F5FF" # 符号和象形文字
"\U0001F680-\U0001F6FF" # 交通和地图符号
"\U0001F1E0-\U0001F1FF" # 旗帜
"\U00002702-\U000027B0"
"\U000024C2-\U0001F251"
"]+", flags=re.UNICODE
)
if emoji_pattern.fullmatch(t):
return True
# 长度与字符类型判断
if len(t) <= 8:
# 非数字、无关键实体的短文本
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
# 主要是标点或简单确认词
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
return True
return False
async def _batch_evaluate_importance_with_llm(
self,
messages: List[ConversationMessage],
context: str = ""
) -> Dict[int, int]:
"""使用LLM批量评估消息的重要性语义层面
Args:
messages: 消息列表
context: 对话上下文(可选)
Returns:
消息索引到重要性分数(0-10)的映射
"""
if not self.llm_client or not messages:
return {}
# 构建批量评估的提示词
msg_list = []
for idx, msg in enumerate(messages):
msg_list.append(f"{idx}. {msg.msg}")
msg_text = "\n".join(msg_list)
prompt = f"""请评估以下消息的重要性给每条消息打分0-10分
- 0-2分无意义的寒暄、口头禅、纯表情
- 3-5分一般性对话有一定信息量但不关键
- 6-8分包含重要信息时间、地点、人物、事件等
- 9-10分关键决策、承诺、重要数据
对话上下文:
{context if context else ""}
待评估的消息:
{msg_text}
请以JSON格式返回格式为
{{
"importance_scores": {{
"0": 分数,
"1": 分数,
...
}}
}}
"""
try:
messages_for_llm = [
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
{"role": "user", "content": prompt}
]
response = await self.llm_client.response_structured(
messages_for_llm,
MessageImportanceResponse
)
# 转换字符串键为整数键
return {int(k): v for k, v in response.importance_scores.items()}
except Exception as e:
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
return {}
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
"""识别对话中的问答对,用于保护问答结构的完整性。
Args:
messages: 消息列表
Returns:
问答对列表
"""
qa_pairs = []
for i in range(len(messages) - 1):
current_msg = messages[i].msg.strip()
next_msg = messages[i + 1].msg.strip()
# 简单规则:如果当前消息是问句,下一条消息可能是答案
is_question = (
current_msg.endswith("") or
current_msg.endswith("?") or
any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时", ""])
)
if is_question and next_msg:
# 检查下一条消息是否像答案(不是另一个问句)
is_answer = not (next_msg.endswith("") or next_msg.endswith("?"))
if is_answer:
qa_pairs.append(QAPair(
question_idx=i,
answer_idx=i + 1,
confidence=0.8 # 基于规则的置信度
))
return qa_pairs
def _get_protected_indices(
self,
messages: List[ConversationMessage],
qa_pairs: List[QAPair],
window_size: int = 2
) -> Set[int]:
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
Args:
messages: 消息列表
qa_pairs: 问答对列表
window_size: 上下文窗口大小(前后各保留几条消息)
Returns:
需要保护的消息索引集合
"""
protected = set()
for qa_pair in qa_pairs:
# 保护问答对本身
protected.add(qa_pair.question_idx)
protected.add(qa_pair.answer_idx)
# 保护上下文窗口
for offset in range(-window_size, window_size + 1):
q_idx = qa_pair.question_idx + offset
a_idx = qa_pair.answer_idx + offset
if 0 <= q_idx < len(messages):
protected.add(q_idx)
if 0 <= a_idx < len(messages):
protected.add(a_idx)
return protected
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
- 仅使用 LLM 结构化输出;
改进版:
- LRU缓存管理
- 重试机制
- 降级策略
"""
# 缓存命中则直接返回(场景+内容作为键)
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
# LRU缓存如果命中移到末尾最近使用
if cache_key in self._dialog_extract_cache:
self._dialog_extract_cache.move_to_end(cache_key)
return self._dialog_extract_cache[cache_key]
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
# LRU缓存大小限制超过限制时删除最旧的条目
if len(self._dialog_extract_cache) >= self._cache_max_size:
# 删除最旧的条目OrderedDict的第一个
oldest_key = next(iter(self._dialog_extract_cache))
del self._dialog_extract_cache[oldest_key]
self._log(f"[剪枝-缓存] LRU缓存已满删除最旧条目")
rendered = self.template.render(
pruning_scene=self.config.pruning_scene,
dialog_text=dialog_text,
language=self.language
)
log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene,
"language": self.language
})
log_prompt_rendering("pruning-extract", rendered)
# 强制使用 LLM;移除正则回退
# 强制使用 LLM
if not self.llm_client:
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
@@ -153,12 +417,32 @@ class SemanticPruner:
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
{"role": "user", "content": rendered},
]
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
# 重试机制
max_retries = 3
for attempt in range(max_retries):
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
if attempt < max_retries - 1:
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
continue
else:
# 降级策略:标记为相关,避免误删
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
fallback_response = DialogExtractionResponse(
is_related=True,
times=[],
ids=[],
amounts=[],
contacts=[],
addresses=[],
keywords=[]
)
return fallback_response
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
"""判断消息是否包含任意抽取到的重要片段。"""
@@ -248,12 +532,15 @@ class SemanticPruner:
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
"""数据集层面:全局消息级剪枝,保留所有对话。
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
- 保证每段对话至少保留1条消息不会删除整段对话。
改进版:
- 并发处理对话级相关性判断
- 问答对识别和保护
- 优化删除策略,保持上下文连贯性
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
- 保证每段对话至少保留1条消息不会删除整段对话
"""
# 如果剪枝功能关闭,直接返回原始数据集
# 如果剪枝功能关闭,直接返回原始数据集
if not self.config.pruning_switch:
return dialogs
@@ -264,29 +551,36 @@ class SemanticPruner:
proportion = 0.9
if proportion < 0.0:
proportion = 0.0
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
self._log(
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
)
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
evaluated_dialogs = []
for idx, dd in enumerate(dialogs):
try:
ex = await self._extract_dialog_important(dd.content)
evaluated_dialogs.append({
"dialog": dd,
"is_related": bool(ex.is_related),
"index": idx,
"extraction": ex
})
except Exception:
evaluated_dialogs.append({
"dialog": dd,
"is_related": True,
"index": idx,
"extraction": None
})
# 并发处理对话级相关性分类
semaphore = asyncio.Semaphore(self.max_concurrent)
async def classify_dialog(idx: int, dd: DialogData):
async with semaphore:
try:
ex = await self._extract_dialog_important(dd.content)
return {
"dialog": dd,
"is_related": bool(ex.is_related),
"index": idx,
"extraction": ex
}
except Exception as e:
self._log(f"[剪枝-并发] 对话 {idx} 分类失败: {str(e)[:100]}")
return {
"dialog": dd,
"is_related": True, # 失败时标记为相关,避免误删
"index": idx,
"extraction": None
}
# 并发执行所有对话的分类
tasks = [classify_dialog(idx, dd) for idx, dd in enumerate(dialogs)]
evaluated_dialogs = await asyncio.gather(*tasks)
# 统计相关 / 不相关对话
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
@@ -300,7 +594,6 @@ class SemanticPruner:
inds = [i["index"] + 1 for i in items]
if len(inds) <= cap:
return inds
# 超过上限时只打印前cap个并标注总数
return inds[:cap] + ["...", f"{len(inds)}"]
rel_inds = _fmt_indices(related_dialogs)
@@ -309,59 +602,83 @@ class SemanticPruner:
result: List[DialogData] = []
if not_related_dialogs:
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM
# 为每个不相关对话进行分析
per_dialog_info = {}
total_unrelated = 0
total_capacity = 0
for d in not_related_dialogs:
dd = d["dialog"]
extraction = d.get("extraction")
if extraction is None:
extraction = await self._extract_dialog_important(dd.content)
# 合并所有重要标记
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
msgs = dd.context.msgs
# 分类消息
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
# 识别问答对
qa_pairs = self._identify_qa_pairs(msgs)
protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=1)
# 分类消息(考虑问答对保护)
imp_unrel_msgs = []
unimp_unrel_msgs = []
for idx, m in enumerate(msgs):
# 问答对中的消息自动标记为重要
if idx in protected_indices:
imp_unrel_msgs.append((idx, m))
elif self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
imp_unrel_msgs.append((idx, m))
elif not self._is_filler_message(m):
unimp_unrel_msgs.append((idx, m))
# 填充消息不加入任何列表,优先删除
# 重要消息按重要性排序
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
imp_sorted = sorted(imp_unrel_msgs, key=lambda x: self._importance_score(x[1]))
imp_sorted_ids = [id(m) for _, m in imp_sorted]
info = {
"dialog": dd,
"total_msgs": len(msgs),
"unrelated_count": len(msgs),
"imp_ids_sorted": imp_sorted_ids,
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
"unimp_ids": [id(m) for _, m in unimp_unrel_msgs],
"protected_indices": protected_indices,
"qa_pairs_count": len(qa_pairs),
}
per_dialog_info[d["index"]] = info
total_unrelated += info["unrelated_count"]
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
# 全局删除配额计算
global_delete = int(total_unrelated * proportion)
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
global_delete = 1
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例)且至少保留1条消息
# 每段的最大可删容量
capacities = []
for d in not_related_dialogs:
idx = d["index"]
info = per_dialog_info[idx]
# 统计重要数量
imp_count = len(info["imp_ids_sorted"])
unimp_count = len(info["unimp_ids"])
imp_cap = int(imp_count * proportion)
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
capacities.append(cap)
total_capacity = sum(capacities)
if global_delete > total_capacity:
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
self._log(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}。将按最大可删执行。")
global_delete = total_capacity
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
# 配额分配
alloc = []
for i, d in enumerate(not_related_dialogs):
idx = d["index"]
info = per_dialog_info[idx]
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
alloc.append(min(share, capacities[i]))
allocated = sum(alloc)
rem = global_delete - allocated
turn = 0
@@ -378,34 +695,40 @@ class SemanticPruner:
break
turn += 1
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
# 应用删除
total_deleted_confirm = 0
for d in evaluated_dialogs:
dd = d["dialog"]
msgs = dd.context.msgs
original = len(msgs)
if d["is_related"]:
result.append(dd)
continue
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
if idx_in_unrel is None:
result.append(dd)
continue
quota = alloc[idx_in_unrel]
info = per_dialog_info[d["index"]]
# 计算本对话重要最多可删数量
# 计算删除ID
imp_count = len(info["imp_ids_sorted"])
imp_del_cap = int(imp_count * proportion)
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
del_unimp = min(quota, len(unimp_delete_ids))
rem_quota = quota - del_unimp
# 再从重要里选低分优先的删除ID不超过 imp_del_cap
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
deleted_here = 0
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept = []
for m in msgs:
mid = id(m)
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
@@ -417,26 +740,30 @@ class SemanticPruner:
deleted_here += 1
continue
kept.append(m)
if not kept and msgs:
kept = [msgs[0]]
dd.context.msgs = kept
total_deleted_confirm += deleted_here
qa_info = f",问答对={info['qa_pairs_count']}" if info['qa_pairs_count'] > 0 else ""
self._log(
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}{qa_info}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
self._log(f"[剪枝-数据集] 全局消息级剪枝完成,总删除 {total_deleted_confirm} 条(保护问答对和上下文)。")
else:
# 全部相关:不执行剪枝
result = [d["dialog"] for d in evaluated_dialogs]
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
# 保存日志
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
payload = self._parse_logs_to_structured(sanitized_logs)
with open(log_output_path, "w", encoding="utf-8") as f:
@@ -448,6 +775,7 @@ class SemanticPruner:
if not result:
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs
return result
def _log(self, msg: str) -> None: