[fix]Address the shortcomings of intelligent pruning
This commit is contained in:
@@ -5,14 +5,17 @@
|
||||
- 对话级一次性抽取判定相关性
|
||||
- 仅对"不相关对话"的消息按比例删除
|
||||
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
|
||||
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Tuple, Set
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||
@@ -36,6 +39,23 @@ class DialogExtractionResponse(BaseModel):
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
"""消息重要性批量判断的结构化返回(用于LLM语义判断)。
|
||||
|
||||
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
|
||||
- reasons: 可选的判断理由
|
||||
"""
|
||||
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
|
||||
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
|
||||
|
||||
|
||||
class QAPair(BaseModel):
|
||||
"""问答对模型,用于识别和保护对话中的问答结构。"""
|
||||
question_idx: int = Field(..., description="问题消息的索引")
|
||||
answer_idx: int = Field(..., description="答案消息的索引")
|
||||
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
|
||||
|
||||
|
||||
class SemanticPruner:
|
||||
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
|
||||
|
||||
@@ -43,109 +63,353 @@ class SemanticPruner:
|
||||
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
|
||||
cfg_dict = get_pruning_config() if config is None else config.model_dump()
|
||||
self.config = PruningConfig.model_validate(cfg_dict)
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
|
||||
# 如果没有提供config,使用默认配置
|
||||
if config is None:
|
||||
# 使用默认的剪枝配置
|
||||
config = PruningConfig(
|
||||
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
|
||||
pruning_scene="education",
|
||||
pruning_threshold=0.5
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.llm_client = llm_client
|
||||
self.language = language # 保存语言配置
|
||||
self.max_concurrent = max_concurrent # 新增:最大并发数
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
|
||||
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
|
||||
|
||||
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
|
||||
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
|
||||
self._cache_max_size = 1000 # 缓存大小限制
|
||||
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
# 采用顺序处理,移除并发配置以简化与稳定执行
|
||||
|
||||
# 扩展的填充词库(包含表情符号和网络用语)
|
||||
self._extended_fillers = [
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
"拜拜", "再见", "88", "拜", "回见",
|
||||
# 口头禅
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号(文本形式)
|
||||
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
|
||||
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
|
||||
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
|
||||
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
|
||||
# 网络用语
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
]
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
- 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
|
||||
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
|
||||
- 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。
|
||||
改进版:增强了规则覆盖范围,包括:
|
||||
- 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)
|
||||
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段
|
||||
- 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"
|
||||
- 新增:问句识别、决策性语句、承诺性语句
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
patterns = [
|
||||
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
|
||||
r"\b\d{1,2}:\d{2}\b",
|
||||
# 原有模式
|
||||
r"\d{4}-\d{1,2}-\d{1,2}", # 修复:移除 \b 边界,因为中文前后没有单词边界
|
||||
r"\d{1,2}:\d{2}", # 修复:移除 \b
|
||||
r"\d{4}年\d{1,2}月\d{1,2}日",
|
||||
r"上午|下午|AM|PM",
|
||||
r"订单号|工单|申请号|编号|ID|账号|账户",
|
||||
r"电话|手机号|微信|QQ|邮箱",
|
||||
r"地址|地点",
|
||||
r"金额|费用|价格|¥|¥|\d+元",
|
||||
r"时间|日期|有效期|截止",
|
||||
r"上午|下午|AM|PM|今天|明天|后天|昨天|前天|本周|下周|上周|本月|下月|上月",
|
||||
r"订单号|工单|申请号|编号|ID|账号|账户|流水号|单号",
|
||||
r"电话|手机号|微信|QQ|邮箱|联系方式",
|
||||
r"地址|地点|位置|门牌号",
|
||||
r"金额|费用|价格|¥|¥|\d+元|人民币|美元|欧元",
|
||||
r"时间|日期|有效期|截止|期限|到期",
|
||||
# 新增模式
|
||||
r"什么|为什么|怎么|如何|哪里|哪个|谁|多少|几点|何时", # 问句关键词
|
||||
r"必须|一定|务必|需要|要求|规定|应该", # 决策性语句
|
||||
r"承诺|保证|确保|负责|同意|答应", # 承诺性语句
|
||||
r"\d{11}", # 11位手机号
|
||||
r"\d{3,4}-\d{7,8}", # 固定电话
|
||||
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", # 邮箱
|
||||
]
|
||||
|
||||
for p in patterns:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
简单启发:匹配到的类别越多、越关键分值越高。
|
||||
改进版:更细致的评分体系(0-10分)
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
|
||||
weights = [
|
||||
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
|
||||
(r"\b\d{1,2}:\d{2}\b", 2),
|
||||
# 高优先级(4-5分)
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
||||
(r"\d{11}", 4), # 手机号
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
||||
|
||||
# 中优先级(2-3分)
|
||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 修复:移除 \b
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
|
||||
(r"电话|手机号|微信|QQ|邮箱", 3),
|
||||
(r"地址|地点", 2),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 4),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
||||
(r"地址|地点|位置", 2),
|
||||
(r"时间|日期|有效期|截止|明天|后天|下周|下月", 2), # 新增时间相关词
|
||||
|
||||
# 低优先级(1分)
|
||||
(r"\d{1,2}:\d{2}", 1), # 修复:移除 \b
|
||||
(r"上午|下午|AM|PM", 1),
|
||||
]
|
||||
|
||||
for p, w in weights:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
score += w
|
||||
return score
|
||||
|
||||
# 问句加分
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
score += 2
|
||||
|
||||
# 长度加分(较长的消息通常包含更多信息)
|
||||
if len(text) > 50:
|
||||
score += 1
|
||||
if len(text) > 100:
|
||||
score += 1
|
||||
|
||||
return min(score, 10) # 最高10分
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。
|
||||
|
||||
改进版:扩展了填充词库,支持表情符号和网络用语
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
|
||||
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
|
||||
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体
|
||||
- 在扩展填充词库中
|
||||
- 纯表情符号
|
||||
"""
|
||||
import re
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
# 常见填充语
|
||||
fillers = [
|
||||
"你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢",
|
||||
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??"
|
||||
]
|
||||
if t in fillers:
|
||||
|
||||
# 检查是否在扩展填充词库中
|
||||
if t in self._extended_fillers:
|
||||
return True
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
|
||||
# 检查是否为纯emoji(Unicode表情)
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
||||
"\U00002702-\U000027B0"
|
||||
"\U000024C2-\U0001F251"
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
if emoji_pattern.fullmatch(t):
|
||||
return True
|
||||
|
||||
# 长度与字符类型判断
|
||||
if len(t) <= 8:
|
||||
# 非数字、无关键实体的短文本
|
||||
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
|
||||
# 主要是标点或简单确认词
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _batch_evaluate_importance_with_llm(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
context: str = ""
|
||||
) -> Dict[int, int]:
|
||||
"""使用LLM批量评估消息的重要性(语义层面)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
context: 对话上下文(可选)
|
||||
|
||||
Returns:
|
||||
消息索引到重要性分数(0-10)的映射
|
||||
"""
|
||||
if not self.llm_client or not messages:
|
||||
return {}
|
||||
|
||||
# 构建批量评估的提示词
|
||||
msg_list = []
|
||||
for idx, msg in enumerate(messages):
|
||||
msg_list.append(f"{idx}. {msg.msg}")
|
||||
|
||||
msg_text = "\n".join(msg_list)
|
||||
|
||||
prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分):
|
||||
- 0-2分:无意义的寒暄、口头禅、纯表情
|
||||
- 3-5分:一般性对话,有一定信息量但不关键
|
||||
- 6-8分:包含重要信息(时间、地点、人物、事件等)
|
||||
- 9-10分:关键决策、承诺、重要数据
|
||||
|
||||
对话上下文:
|
||||
{context if context else "无"}
|
||||
|
||||
待评估的消息:
|
||||
{msg_text}
|
||||
|
||||
请以JSON格式返回,格式为:
|
||||
{{
|
||||
"importance_scores": {{
|
||||
"0": 分数,
|
||||
"1": 分数,
|
||||
...
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
messages_for_llm = [
|
||||
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages_for_llm,
|
||||
MessageImportanceResponse
|
||||
)
|
||||
|
||||
# 转换字符串键为整数键
|
||||
return {int(k): v for k, v in response.importance_scores.items()}
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
|
||||
return {}
|
||||
|
||||
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
|
||||
"""识别对话中的问答对,用于保护问答结构的完整性。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
问答对列表
|
||||
"""
|
||||
qa_pairs = []
|
||||
|
||||
for i in range(len(messages) - 1):
|
||||
current_msg = messages[i].msg.strip()
|
||||
next_msg = messages[i + 1].msg.strip()
|
||||
|
||||
# 简单规则:如果当前消息是问句,下一条消息可能是答案
|
||||
is_question = (
|
||||
current_msg.endswith("?") or
|
||||
current_msg.endswith("?") or
|
||||
any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"])
|
||||
)
|
||||
|
||||
if is_question and next_msg:
|
||||
# 检查下一条消息是否像答案(不是另一个问句)
|
||||
is_answer = not (next_msg.endswith("?") or next_msg.endswith("?"))
|
||||
|
||||
if is_answer:
|
||||
qa_pairs.append(QAPair(
|
||||
question_idx=i,
|
||||
answer_idx=i + 1,
|
||||
confidence=0.8 # 基于规则的置信度
|
||||
))
|
||||
|
||||
return qa_pairs
|
||||
|
||||
def _get_protected_indices(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
qa_pairs: List[QAPair],
|
||||
window_size: int = 2
|
||||
) -> Set[int]:
|
||||
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
qa_pairs: 问答对列表
|
||||
window_size: 上下文窗口大小(前后各保留几条消息)
|
||||
|
||||
Returns:
|
||||
需要保护的消息索引集合
|
||||
"""
|
||||
protected = set()
|
||||
|
||||
for qa_pair in qa_pairs:
|
||||
# 保护问答对本身
|
||||
protected.add(qa_pair.question_idx)
|
||||
protected.add(qa_pair.answer_idx)
|
||||
|
||||
# 保护上下文窗口
|
||||
for offset in range(-window_size, window_size + 1):
|
||||
q_idx = qa_pair.question_idx + offset
|
||||
a_idx = qa_pair.answer_idx + offset
|
||||
|
||||
if 0 <= q_idx < len(messages):
|
||||
protected.add(q_idx)
|
||||
if 0 <= a_idx < len(messages):
|
||||
protected.add(a_idx)
|
||||
|
||||
return protected
|
||||
|
||||
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
|
||||
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
|
||||
|
||||
- 仅使用 LLM 结构化输出;
|
||||
改进版:
|
||||
- LRU缓存管理
|
||||
- 重试机制
|
||||
- 降级策略
|
||||
"""
|
||||
# 缓存命中则直接返回(场景+内容作为键)
|
||||
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
|
||||
|
||||
# LRU缓存:如果命中,移到末尾(最近使用)
|
||||
if cache_key in self._dialog_extract_cache:
|
||||
self._dialog_extract_cache.move_to_end(cache_key)
|
||||
return self._dialog_extract_cache[cache_key]
|
||||
|
||||
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
|
||||
# LRU缓存大小限制:超过限制时删除最旧的条目
|
||||
if len(self._dialog_extract_cache) >= self._cache_max_size:
|
||||
# 删除最旧的条目(OrderedDict的第一个)
|
||||
oldest_key = next(iter(self._dialog_extract_cache))
|
||||
del self._dialog_extract_cache[oldest_key]
|
||||
self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目")
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"language": self.language
|
||||
})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
|
||||
# 强制使用 LLM;移除正则回退
|
||||
# 强制使用 LLM
|
||||
if not self.llm_client:
|
||||
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
|
||||
|
||||
@@ -153,12 +417,32 @@ class SemanticPruner:
|
||||
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
|
||||
{"role": "user", "content": rendered},
|
||||
]
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
|
||||
|
||||
# 重试机制
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
|
||||
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
|
||||
continue
|
||||
else:
|
||||
# 降级策略:标记为相关,避免误删
|
||||
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
|
||||
fallback_response = DialogExtractionResponse(
|
||||
is_related=True,
|
||||
times=[],
|
||||
ids=[],
|
||||
amounts=[],
|
||||
contacts=[],
|
||||
addresses=[],
|
||||
keywords=[]
|
||||
)
|
||||
return fallback_response
|
||||
|
||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||
@@ -248,12 +532,15 @@ class SemanticPruner:
|
||||
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
|
||||
"""数据集层面:全局消息级剪枝,保留所有对话。
|
||||
|
||||
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
|
||||
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话。
|
||||
改进版:
|
||||
- 并发处理对话级相关性判断
|
||||
- 问答对识别和保护
|
||||
- 优化删除策略,保持上下文连贯性
|
||||
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话
|
||||
"""
|
||||
# 如果剪枝功能关闭,直接返回原始数据集。
|
||||
# 如果剪枝功能关闭,直接返回原始数据集
|
||||
if not self.config.pruning_switch:
|
||||
return dialogs
|
||||
|
||||
@@ -264,29 +551,36 @@ class SemanticPruner:
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
|
||||
|
||||
self._log(
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
|
||||
)
|
||||
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
|
||||
evaluated_dialogs = []
|
||||
for idx, dd in enumerate(dialogs):
|
||||
try:
|
||||
ex = await self._extract_dialog_important(dd.content)
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": bool(ex.is_related),
|
||||
"index": idx,
|
||||
"extraction": ex
|
||||
})
|
||||
except Exception:
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": True,
|
||||
"index": idx,
|
||||
"extraction": None
|
||||
})
|
||||
|
||||
# 并发处理对话级相关性分类
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def classify_dialog(idx: int, dd: DialogData):
|
||||
async with semaphore:
|
||||
try:
|
||||
ex = await self._extract_dialog_important(dd.content)
|
||||
return {
|
||||
"dialog": dd,
|
||||
"is_related": bool(ex.is_related),
|
||||
"index": idx,
|
||||
"extraction": ex
|
||||
}
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-并发] 对话 {idx} 分类失败: {str(e)[:100]}")
|
||||
return {
|
||||
"dialog": dd,
|
||||
"is_related": True, # 失败时标记为相关,避免误删
|
||||
"index": idx,
|
||||
"extraction": None
|
||||
}
|
||||
|
||||
# 并发执行所有对话的分类
|
||||
tasks = [classify_dialog(idx, dd) for idx, dd in enumerate(dialogs)]
|
||||
evaluated_dialogs = await asyncio.gather(*tasks)
|
||||
|
||||
# 统计相关 / 不相关对话
|
||||
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
|
||||
@@ -300,7 +594,6 @@ class SemanticPruner:
|
||||
inds = [i["index"] + 1 for i in items]
|
||||
if len(inds) <= cap:
|
||||
return inds
|
||||
# 超过上限时只打印前cap个,并标注总数
|
||||
return inds[:cap] + ["...", f"共{len(inds)}个"]
|
||||
|
||||
rel_inds = _fmt_indices(related_dialogs)
|
||||
@@ -309,59 +602,83 @@ class SemanticPruner:
|
||||
|
||||
result: List[DialogData] = []
|
||||
if not_related_dialogs:
|
||||
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM)
|
||||
# 为每个不相关对话进行分析
|
||||
per_dialog_info = {}
|
||||
total_unrelated = 0
|
||||
total_capacity = 0
|
||||
|
||||
for d in not_related_dialogs:
|
||||
dd = d["dialog"]
|
||||
extraction = d.get("extraction")
|
||||
if extraction is None:
|
||||
extraction = await self._extract_dialog_important(dd.content)
|
||||
|
||||
# 合并所有重要标记
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
msgs = dd.context.msgs
|
||||
# 分类消息
|
||||
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
|
||||
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
|
||||
|
||||
# 识别问答对
|
||||
qa_pairs = self._identify_qa_pairs(msgs)
|
||||
protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=1)
|
||||
|
||||
# 分类消息(考虑问答对保护)
|
||||
imp_unrel_msgs = []
|
||||
unimp_unrel_msgs = []
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
# 问答对中的消息自动标记为重要
|
||||
if idx in protected_indices:
|
||||
imp_unrel_msgs.append((idx, m))
|
||||
elif self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
|
||||
imp_unrel_msgs.append((idx, m))
|
||||
elif not self._is_filler_message(m):
|
||||
unimp_unrel_msgs.append((idx, m))
|
||||
# 填充消息不加入任何列表,优先删除
|
||||
|
||||
# 重要消息按重要性排序
|
||||
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
|
||||
imp_sorted = sorted(imp_unrel_msgs, key=lambda x: self._importance_score(x[1]))
|
||||
imp_sorted_ids = [id(m) for _, m in imp_sorted]
|
||||
|
||||
info = {
|
||||
"dialog": dd,
|
||||
"total_msgs": len(msgs),
|
||||
"unrelated_count": len(msgs),
|
||||
"imp_ids_sorted": imp_sorted_ids,
|
||||
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
|
||||
"unimp_ids": [id(m) for _, m in unimp_unrel_msgs],
|
||||
"protected_indices": protected_indices,
|
||||
"qa_pairs_count": len(qa_pairs),
|
||||
}
|
||||
per_dialog_info[d["index"]] = info
|
||||
total_unrelated += info["unrelated_count"]
|
||||
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
|
||||
|
||||
# 全局删除配额计算
|
||||
global_delete = int(total_unrelated * proportion)
|
||||
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
|
||||
global_delete = 1
|
||||
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息
|
||||
|
||||
# 每段的最大可删容量
|
||||
capacities = []
|
||||
for d in not_related_dialogs:
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
# 统计重要数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
unimp_count = len(info["unimp_ids"])
|
||||
imp_cap = int(imp_count * proportion)
|
||||
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
|
||||
capacities.append(cap)
|
||||
|
||||
total_capacity = sum(capacities)
|
||||
if global_delete > total_capacity:
|
||||
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
|
||||
self._log(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}。将按最大可删执行。")
|
||||
global_delete = total_capacity
|
||||
|
||||
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
|
||||
# 配额分配
|
||||
alloc = []
|
||||
for i, d in enumerate(not_related_dialogs):
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
|
||||
alloc.append(min(share, capacities[i]))
|
||||
|
||||
allocated = sum(alloc)
|
||||
rem = global_delete - allocated
|
||||
turn = 0
|
||||
@@ -378,34 +695,40 @@ class SemanticPruner:
|
||||
break
|
||||
turn += 1
|
||||
|
||||
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
|
||||
# 应用删除
|
||||
total_deleted_confirm = 0
|
||||
for d in evaluated_dialogs:
|
||||
dd = d["dialog"]
|
||||
msgs = dd.context.msgs
|
||||
original = len(msgs)
|
||||
|
||||
if d["is_related"]:
|
||||
result.append(dd)
|
||||
continue
|
||||
|
||||
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
|
||||
if idx_in_unrel is None:
|
||||
result.append(dd)
|
||||
continue
|
||||
|
||||
quota = alloc[idx_in_unrel]
|
||||
info = per_dialog_info[d["index"]]
|
||||
# 计算本对话重要最多可删数量
|
||||
|
||||
# 计算删除ID
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
imp_del_cap = int(imp_count * proportion)
|
||||
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
|
||||
|
||||
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
|
||||
del_unimp = min(quota, len(unimp_delete_ids))
|
||||
rem_quota = quota - del_unimp
|
||||
# 再从重要里选低分优先的删除ID(不超过 imp_del_cap)
|
||||
|
||||
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
|
||||
|
||||
deleted_here = 0
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept = []
|
||||
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
|
||||
@@ -417,26 +740,30 @@ class SemanticPruner:
|
||||
deleted_here += 1
|
||||
continue
|
||||
kept.append(m)
|
||||
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
|
||||
dd.context.msgs = kept
|
||||
total_deleted_confirm += deleted_here
|
||||
|
||||
qa_info = f",问答对={info['qa_pairs_count']}" if info['qa_pairs_count'] > 0 else ""
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
|
||||
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}{qa_info}"
|
||||
)
|
||||
result.append(dd)
|
||||
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
|
||||
|
||||
self._log(f"[剪枝-数据集] 全局消息级剪枝完成,总删除 {total_deleted_confirm} 条(保护问答对和上下文)。")
|
||||
else:
|
||||
# 全部相关:不执行剪枝
|
||||
result = [d["dialog"] for d in evaluated_dialogs]
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
|
||||
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
|
||||
# 保存日志
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
|
||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||
@@ -448,6 +775,7 @@ class SemanticPruner:
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
return result
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user