Merge pull request #425 from SuanmoSuanyangTechnology/fix/2.6-bug
Fix/2.6 bug
This commit is contained in:
@@ -82,7 +82,7 @@ celery_app.conf.update(
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -115,16 +115,11 @@ beat_schedule_config = {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
"write-all-workspaces-memory": {
|
||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
#如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -633,12 +633,11 @@ async def get_knowledge_type_stats_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | Memory。
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- Memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -469,6 +470,8 @@ async def get_chunk_insight(
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -503,6 +506,15 @@ async def dashboard_data(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 如果没有提供时间范围,默认使用最近30天
|
||||
if start_date is None or end_date is None:
|
||||
from datetime import datetime, timedelta
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
end_date = int(end_dt.timestamp() * 1000)
|
||||
start_date = int(start_dt.timestamp() * 1000)
|
||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -563,17 +575,22 @@ async def dashboard_data(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
@@ -602,10 +619,23 @@ async def dashboard_data(
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -201,7 +201,6 @@ class Settings:
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
|
||||
@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks
|
||||
@@ -57,6 +57,61 @@ async def get_chunked_dialogs(
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 语义剪枝步骤(在分块之前)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 加载剪枝配置
|
||||
pruning_config = None
|
||||
if config_id:
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="semantic_pruning"
|
||||
)
|
||||
|
||||
if memory_config:
|
||||
pruning_config = PruningConfig(
|
||||
pruning_switch=memory_config.pruning_enabled,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
# 获取LLM客户端用于剪枝
|
||||
if pruning_config.pruning_switch:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
||||
original_msg_count = len(dialog_data.context.msgs)
|
||||
|
||||
# 使用 prune_dataset 而不是 prune_dialog
|
||||
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
|
||||
pruned_dialogs = await pruner.prune_dataset([dialog_data])
|
||||
|
||||
if pruned_dialogs:
|
||||
dialog_data = pruned_dialogs[0]
|
||||
remaining_msg_count = len(dialog_data.context.msgs)
|
||||
deleted_count = original_msg_count - remaining_msg_count
|
||||
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
||||
else:
|
||||
logger.warning("[剪枝] prune_dataset 返回空列表")
|
||||
else:
|
||||
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
|
||||
@@ -139,14 +139,14 @@ async def get_raw_tags_from_db(
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
@@ -161,8 +161,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
# 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文)
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
@@ -177,7 +178,8 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
# 4. 限制返回的标签数量
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
@@ -5,20 +5,27 @@
|
||||
- 对话级一次性抽取判定相关性
|
||||
- 仅对"不相关对话"的消息按比例删除
|
||||
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
|
||||
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Tuple, Set
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pruning_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
|
||||
SceneConfigRegistry,
|
||||
ScenePatterns
|
||||
)
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
@@ -36,6 +43,23 @@ class DialogExtractionResponse(BaseModel):
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
"""消息重要性批量判断的结构化返回(用于LLM语义判断)。
|
||||
|
||||
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
|
||||
- reasons: 可选的判断理由
|
||||
"""
|
||||
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
|
||||
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
|
||||
|
||||
|
||||
class QAPair(BaseModel):
|
||||
"""问答对模型,用于识别和保护对话中的问答结构。"""
|
||||
question_idx: int = Field(..., description="问题消息的索引")
|
||||
answer_idx: int = Field(..., description="答案消息的索引")
|
||||
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
|
||||
|
||||
|
||||
class SemanticPruner:
|
||||
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
|
||||
|
||||
@@ -43,109 +67,374 @@ class SemanticPruner:
|
||||
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
|
||||
cfg_dict = get_pruning_config() if config is None else config.model_dump()
|
||||
self.config = PruningConfig.model_validate(cfg_dict)
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
|
||||
# 如果没有提供config,使用默认配置
|
||||
if config is None:
|
||||
# 使用默认的剪枝配置
|
||||
config = PruningConfig(
|
||||
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
|
||||
pruning_scene="education",
|
||||
pruning_threshold=0.5
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.llm_client = llm_client
|
||||
self.language = language # 保存语言配置
|
||||
self.max_concurrent = max_concurrent # 新增:最大并发数
|
||||
|
||||
# 详细日志配置:限制逐条消息日志的数量
|
||||
self._detailed_prune_logging = True # 是否启用详细日志
|
||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||
|
||||
# 加载场景特定配置
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||
self.config.pruning_scene,
|
||||
fallback_to_generic=True
|
||||
)
|
||||
|
||||
# 检查场景是否有专门支持
|
||||
is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
if is_supported:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)")
|
||||
self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}")
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
|
||||
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
|
||||
|
||||
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
|
||||
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
|
||||
self._cache_max_size = 1000 # 缓存大小限制
|
||||
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
# 采用顺序处理,移除并发配置以简化与稳定执行
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
- 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
|
||||
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
|
||||
- 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。
|
||||
改进版:使用场景特定的模式进行识别
|
||||
- 根据 pruning_scene 动态加载对应的识别规则
|
||||
- 支持教育、在线服务、外呼三个场景的特定模式
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
patterns = [
|
||||
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
|
||||
r"\b\d{1,2}:\d{2}\b",
|
||||
r"\d{4}年\d{1,2}月\d{1,2}日",
|
||||
r"上午|下午|AM|PM",
|
||||
r"订单号|工单|申请号|编号|ID|账号|账户",
|
||||
r"电话|手机号|微信|QQ|邮箱",
|
||||
r"地址|地点",
|
||||
r"金额|费用|价格|¥|¥|\d+元",
|
||||
r"时间|日期|有效期|截止",
|
||||
]
|
||||
for p in patterns:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
|
||||
# 使用场景特定的模式
|
||||
all_patterns = (
|
||||
self.scene_config.high_priority_patterns +
|
||||
self.scene_config.medium_priority_patterns +
|
||||
self.scene_config.low_priority_patterns
|
||||
)
|
||||
|
||||
for pattern, _ in all_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
return True
|
||||
|
||||
# 检查是否包含问句关键词
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
return True
|
||||
|
||||
# 检查是否包含决策性关键词
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
简单启发:匹配到的类别越多、越关键分值越高。
|
||||
改进版:使用场景特定的权重体系(0-10分)
|
||||
- 根据场景动态调整不同信息类型的权重
|
||||
- 高优先级模式:4-6分
|
||||
- 中优先级模式:2-3分
|
||||
- 低优先级模式:1分
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
weights = [
|
||||
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
|
||||
(r"\b\d{1,2}:\d{2}\b", 2),
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
|
||||
(r"电话|手机号|微信|QQ|邮箱", 3),
|
||||
(r"地址|地点", 2),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 4),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
]
|
||||
for p, w in weights:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
score += w
|
||||
return score
|
||||
|
||||
# 使用场景特定的权重
|
||||
for pattern, weight in self.scene_config.high_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.medium_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.low_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
# 问句加分
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
score += 2
|
||||
|
||||
# 包含问句关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
score += 1
|
||||
|
||||
# 包含决策性关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
score += 2
|
||||
|
||||
# 长度加分(较长的消息通常包含更多信息)
|
||||
if len(text) > 50:
|
||||
score += 1
|
||||
if len(text) > 100:
|
||||
score += 1
|
||||
|
||||
return min(score, 10) # 最高10分
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。
|
||||
"""检测典型寒暄/口头禅/确认类短消息。
|
||||
|
||||
改进版:更严格的填充消息判断,避免误删场景相关内容
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
|
||||
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
|
||||
- 纯标点或空白
|
||||
- 在场景特定填充词库中(精确匹配)
|
||||
- 纯表情符号
|
||||
- 常见寒暄(精确匹配短语)
|
||||
|
||||
注意:不再使用长度判断,避免误删短但重要的消息
|
||||
"""
|
||||
import re
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
# 常见填充语
|
||||
fillers = [
|
||||
"你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢",
|
||||
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??"
|
||||
]
|
||||
if t in fillers:
|
||||
|
||||
# 检查是否在场景特定填充词库中(精确匹配)
|
||||
if t in self.scene_config.filler_phrases:
|
||||
return True
|
||||
# 长度与字符类型判断
|
||||
if len(t) <= 8:
|
||||
# 非数字、无关键实体的短文本
|
||||
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
|
||||
# 主要是标点或简单确认词
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
|
||||
return True
|
||||
|
||||
# 常见寒暄和问候(精确匹配,避免误删)
|
||||
common_greetings = {
|
||||
"在吗", "在不在", "在呢", "在的",
|
||||
"你好", "您好", "hello", "hi",
|
||||
"拜拜", "再见", "拜", "88", "bye",
|
||||
"好的", "好", "行", "可以", "嗯", "哦", "啊",
|
||||
"是的", "对", "对的", "没错", "是啊",
|
||||
"哈哈", "呵呵", "嘿嘿", "嗯嗯"
|
||||
}
|
||||
if t in common_greetings:
|
||||
return True
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
|
||||
# 检查是否为纯emoji(Unicode表情)
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
||||
"\U00002702-\U000027B0"
|
||||
"\U000024C2-\U0001F251"
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
if emoji_pattern.fullmatch(t):
|
||||
return True
|
||||
|
||||
# 纯标点符号
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _batch_evaluate_importance_with_llm(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
context: str = ""
|
||||
) -> Dict[int, int]:
|
||||
"""使用LLM批量评估消息的重要性(语义层面)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
context: 对话上下文(可选)
|
||||
|
||||
Returns:
|
||||
消息索引到重要性分数(0-10)的映射
|
||||
"""
|
||||
if not self.llm_client or not messages:
|
||||
return {}
|
||||
|
||||
# 构建批量评估的提示词
|
||||
msg_list = []
|
||||
for idx, msg in enumerate(messages):
|
||||
msg_list.append(f"{idx}. {msg.msg}")
|
||||
|
||||
msg_text = "\n".join(msg_list)
|
||||
|
||||
prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分):
|
||||
- 0-2分:无意义的寒暄、口头禅、纯表情
|
||||
- 3-5分:一般性对话,有一定信息量但不关键
|
||||
- 6-8分:包含重要信息(时间、地点、人物、事件等)
|
||||
- 9-10分:关键决策、承诺、重要数据
|
||||
|
||||
对话上下文:
|
||||
{context if context else "无"}
|
||||
|
||||
待评估的消息:
|
||||
{msg_text}
|
||||
|
||||
请以JSON格式返回,格式为:
|
||||
{{
|
||||
"importance_scores": {{
|
||||
"0": 分数,
|
||||
"1": 分数,
|
||||
...
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
messages_for_llm = [
|
||||
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages_for_llm,
|
||||
MessageImportanceResponse
|
||||
)
|
||||
|
||||
# 转换字符串键为整数键
|
||||
return {int(k): v for k, v in response.importance_scores.items()}
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
|
||||
return {}
|
||||
|
||||
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
|
||||
"""识别对话中的问答对,用于保护问答结构的完整性。
|
||||
|
||||
改进版:使用场景特定的问句关键词,并排除寒暄类问句
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
问答对列表
|
||||
"""
|
||||
qa_pairs = []
|
||||
|
||||
# 寒暄类问句,不应该被保护(这些不是真正的问答)
|
||||
greeting_questions = {
|
||||
"在吗", "在不在", "你好吗", "怎么样", "好吗",
|
||||
"有空吗", "忙吗", "睡了吗", "起床了吗"
|
||||
}
|
||||
|
||||
for i in range(len(messages) - 1):
|
||||
current_msg = messages[i].msg.strip()
|
||||
next_msg = messages[i + 1].msg.strip()
|
||||
|
||||
# 排除寒暄类问句
|
||||
if current_msg in greeting_questions:
|
||||
continue
|
||||
|
||||
# 使用场景特定的问句关键词,但要求更严格
|
||||
is_question = False
|
||||
|
||||
# 1. 以问号结尾
|
||||
if current_msg.endswith("?") or current_msg.endswith("?"):
|
||||
is_question = True
|
||||
# 2. 包含实质性问句关键词(排除"吗"这种太宽泛的)
|
||||
elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时"]):
|
||||
is_question = True
|
||||
|
||||
if is_question and next_msg:
|
||||
# 检查下一条消息是否像答案(不是另一个问句,也不是寒暄)
|
||||
is_answer = not (next_msg.endswith("?") or next_msg.endswith("?"))
|
||||
|
||||
# 排除寒暄类回复
|
||||
greeting_answers = {"你好", "您好", "在呢", "在的", "嗯", "哦", "好的"}
|
||||
if next_msg in greeting_answers:
|
||||
is_answer = False
|
||||
|
||||
if is_answer:
|
||||
qa_pairs.append(QAPair(
|
||||
question_idx=i,
|
||||
answer_idx=i + 1,
|
||||
confidence=0.8 # 基于规则的置信度
|
||||
))
|
||||
|
||||
return qa_pairs
|
||||
|
||||
def _get_protected_indices(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
qa_pairs: List[QAPair],
|
||||
window_size: int = 2
|
||||
) -> Set[int]:
|
||||
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
qa_pairs: 问答对列表
|
||||
window_size: 上下文窗口大小(前后各保留几条消息)
|
||||
|
||||
Returns:
|
||||
需要保护的消息索引集合
|
||||
"""
|
||||
protected = set()
|
||||
|
||||
for qa_pair in qa_pairs:
|
||||
# 保护问答对本身
|
||||
protected.add(qa_pair.question_idx)
|
||||
protected.add(qa_pair.answer_idx)
|
||||
|
||||
# 保护上下文窗口
|
||||
for offset in range(-window_size, window_size + 1):
|
||||
q_idx = qa_pair.question_idx + offset
|
||||
a_idx = qa_pair.answer_idx + offset
|
||||
|
||||
if 0 <= q_idx < len(messages):
|
||||
protected.add(q_idx)
|
||||
if 0 <= a_idx < len(messages):
|
||||
protected.add(a_idx)
|
||||
|
||||
return protected
|
||||
|
||||
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
|
||||
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
|
||||
|
||||
- 仅使用 LLM 结构化输出;
|
||||
改进版:
|
||||
- LRU缓存管理
|
||||
- 重试机制
|
||||
- 降级策略
|
||||
"""
|
||||
# 缓存命中则直接返回(场景+内容作为键)
|
||||
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
|
||||
|
||||
# LRU缓存:如果命中,移到末尾(最近使用)
|
||||
if cache_key in self._dialog_extract_cache:
|
||||
self._dialog_extract_cache.move_to_end(cache_key)
|
||||
return self._dialog_extract_cache[cache_key]
|
||||
|
||||
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
|
||||
# LRU缓存大小限制:超过限制时删除最旧的条目
|
||||
if len(self._dialog_extract_cache) >= self._cache_max_size:
|
||||
# 删除最旧的条目(OrderedDict的第一个)
|
||||
oldest_key = next(iter(self._dialog_extract_cache))
|
||||
del self._dialog_extract_cache[oldest_key]
|
||||
self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目")
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"language": self.language
|
||||
})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
|
||||
# 强制使用 LLM;移除正则回退
|
||||
# 强制使用 LLM
|
||||
if not self.llm_client:
|
||||
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
|
||||
|
||||
@@ -153,12 +442,32 @@ class SemanticPruner:
|
||||
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
|
||||
{"role": "user", "content": rendered},
|
||||
]
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
|
||||
|
||||
# 重试机制
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
|
||||
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
|
||||
continue
|
||||
else:
|
||||
# 降级策略:标记为相关,避免误删
|
||||
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
|
||||
fallback_response = DialogExtractionResponse(
|
||||
is_related=True,
|
||||
times=[],
|
||||
ids=[],
|
||||
amounts=[],
|
||||
contacts=[],
|
||||
addresses=[],
|
||||
keywords=[]
|
||||
)
|
||||
return fallback_response
|
||||
|
||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||
@@ -248,12 +557,14 @@ class SemanticPruner:
|
||||
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
|
||||
"""数据集层面:全局消息级剪枝,保留所有对话。
|
||||
|
||||
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
|
||||
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话。
|
||||
改进版:
|
||||
- 消息级独立判断,每条消息根据场景规则独立评估
|
||||
- 问答对保护已注释(暂不启用,留作观察)
|
||||
- 优化删除策略:填充消息 → 不重要消息 → 低分重要消息
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话
|
||||
"""
|
||||
# 如果剪枝功能关闭,直接返回原始数据集。
|
||||
# 如果剪枝功能关闭,直接返回原始数据集
|
||||
if not self.config.pruning_switch:
|
||||
return dialogs
|
||||
|
||||
@@ -264,179 +575,140 @@ class SemanticPruner:
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
|
||||
|
||||
self._log(
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
|
||||
)
|
||||
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
|
||||
evaluated_dialogs = []
|
||||
for idx, dd in enumerate(dialogs):
|
||||
try:
|
||||
ex = await self._extract_dialog_important(dd.content)
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": bool(ex.is_related),
|
||||
"index": idx,
|
||||
"extraction": ex
|
||||
})
|
||||
except Exception:
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": True,
|
||||
"index": idx,
|
||||
"extraction": None
|
||||
})
|
||||
|
||||
# 统计相关 / 不相关对话
|
||||
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
|
||||
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
|
||||
self._log(
|
||||
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
|
||||
)
|
||||
|
||||
# 简洁打印第几段对话相关/不相关(索引基于1)
|
||||
def _fmt_indices(items, cap: int = 10):
|
||||
inds = [i["index"] + 1 for i in items]
|
||||
if len(inds) <= cap:
|
||||
return inds
|
||||
# 超过上限时只打印前cap个,并标注总数
|
||||
return inds[:cap] + ["...", f"共{len(inds)}个"]
|
||||
|
||||
rel_inds = _fmt_indices(related_dialogs)
|
||||
nrel_inds = _fmt_indices(not_related_dialogs)
|
||||
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段")
|
||||
|
||||
|
||||
result: List[DialogData] = []
|
||||
if not_related_dialogs:
|
||||
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM)
|
||||
per_dialog_info = {}
|
||||
total_unrelated = 0
|
||||
total_capacity = 0
|
||||
for d in not_related_dialogs:
|
||||
dd = d["dialog"]
|
||||
extraction = d.get("extraction")
|
||||
if extraction is None:
|
||||
extraction = await self._extract_dialog_important(dd.content)
|
||||
# 合并所有重要标记
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
msgs = dd.context.msgs
|
||||
# 分类消息
|
||||
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
|
||||
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
|
||||
# 重要消息按重要性排序
|
||||
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
|
||||
info = {
|
||||
"dialog": dd,
|
||||
"total_msgs": len(msgs),
|
||||
"unrelated_count": len(msgs),
|
||||
"imp_ids_sorted": imp_sorted_ids,
|
||||
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
|
||||
}
|
||||
per_dialog_info[d["index"]] = info
|
||||
total_unrelated += info["unrelated_count"]
|
||||
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
|
||||
global_delete = int(total_unrelated * proportion)
|
||||
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
|
||||
global_delete = 1
|
||||
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息
|
||||
capacities = []
|
||||
for d in not_related_dialogs:
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
# 统计重要数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
unimp_count = len(info["unimp_ids"])
|
||||
imp_cap = int(imp_count * proportion)
|
||||
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
|
||||
capacities.append(cap)
|
||||
total_capacity = sum(capacities)
|
||||
if global_delete > total_capacity:
|
||||
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
|
||||
global_delete = total_capacity
|
||||
|
||||
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
|
||||
alloc = []
|
||||
for i, d in enumerate(not_related_dialogs):
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
|
||||
alloc.append(min(share, capacities[i]))
|
||||
allocated = sum(alloc)
|
||||
rem = global_delete - allocated
|
||||
turn = 0
|
||||
while rem > 0 and turn < 100000:
|
||||
progressed = False
|
||||
for i in range(len(not_related_dialogs)):
|
||||
if rem <= 0:
|
||||
break
|
||||
if alloc[i] < capacities[i]:
|
||||
alloc[i] += 1
|
||||
rem -= 1
|
||||
progressed = True
|
||||
if not progressed:
|
||||
break
|
||||
turn += 1
|
||||
|
||||
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
|
||||
total_deleted_confirm = 0
|
||||
for d in evaluated_dialogs:
|
||||
dd = d["dialog"]
|
||||
msgs = dd.context.msgs
|
||||
original = len(msgs)
|
||||
if d["is_related"]:
|
||||
result.append(dd)
|
||||
continue
|
||||
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
|
||||
if idx_in_unrel is None:
|
||||
result.append(dd)
|
||||
continue
|
||||
quota = alloc[idx_in_unrel]
|
||||
info = per_dialog_info[d["index"]]
|
||||
# 计算本对话重要最多可删数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
imp_del_cap = int(imp_count * proportion)
|
||||
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
|
||||
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
|
||||
del_unimp = min(quota, len(unimp_delete_ids))
|
||||
rem_quota = quota - del_unimp
|
||||
# 再从重要里选低分优先的删除ID(不超过 imp_del_cap)
|
||||
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
|
||||
deleted_here = 0
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept = []
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
|
||||
actual_imp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
kept.append(m)
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
dd.context.msgs = kept
|
||||
total_deleted_confirm += deleted_here
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
|
||||
)
|
||||
result.append(dd)
|
||||
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
|
||||
else:
|
||||
# 全部相关:不执行剪枝
|
||||
result = [d["dialog"] for d in evaluated_dialogs]
|
||||
total_original_msgs = 0
|
||||
total_deleted_msgs = 0
|
||||
|
||||
for d_idx, dd in enumerate(dialogs):
|
||||
msgs = dd.context.msgs
|
||||
original_count = len(msgs)
|
||||
total_original_msgs += original_count
|
||||
|
||||
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
|
||||
# qa_pairs = self._identify_qa_pairs(msgs)
|
||||
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
|
||||
# ========================================================
|
||||
|
||||
# 消息级分类:每条消息独立判断
|
||||
important_msgs = [] # 重要消息(保留)
|
||||
unimportant_msgs = [] # 不重要消息(可删除)
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
|
||||
# 判断是否需要详细日志(仅对前N条消息记录)
|
||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
||||
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# ========== 问答对保护判断(已注释) ==========
|
||||
# if idx in protected_indices:
|
||||
# important_msgs.append((idx, m))
|
||||
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
|
||||
# ==========================================
|
||||
|
||||
# 填充消息(寒暄、表情等)
|
||||
if self._is_filler_message(m):
|
||||
filler_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
||||
# 重要信息(学号、成绩、时间、金额等)
|
||||
elif self._is_important_message(m):
|
||||
important_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
|
||||
# 其他消息
|
||||
else:
|
||||
unimportant_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
|
||||
|
||||
# 计算删除配额
|
||||
delete_target = int(original_count * proportion)
|
||||
if proportion > 0 and original_count > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
|
||||
# 确保至少保留1条消息
|
||||
max_deletable = max(0, original_count - 1)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
|
||||
# 删除策略:优先删除填充消息,再删除不重要消息
|
||||
to_delete_indices = set()
|
||||
deleted_details = [] # 记录删除的消息详情
|
||||
|
||||
# 第一步:删除填充消息
|
||||
filler_to_delete = min(len(filler_msgs), delete_target)
|
||||
for i in range(filler_to_delete):
|
||||
idx, msg = filler_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,删除不重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0:
|
||||
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
|
||||
for i in range(unimp_to_delete):
|
||||
idx, msg = unimportant_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
|
||||
|
||||
# 第三步:如果还需要删除,按重要性分数删除重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0 and important_msgs:
|
||||
# 按重要性分数排序(分数低的优先删除)
|
||||
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
|
||||
imp_to_delete = min(len(imp_sorted), remaining_quota)
|
||||
for i in range(imp_to_delete):
|
||||
idx, msg = imp_sorted[i]
|
||||
to_delete_indices.add(idx)
|
||||
score = self._importance_score(msg)
|
||||
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
|
||||
|
||||
# 执行删除
|
||||
kept_msgs = []
|
||||
for idx, m in enumerate(msgs):
|
||||
if idx not in to_delete_indices:
|
||||
kept_msgs.append(m)
|
||||
|
||||
# 确保至少保留1条
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
dd.context.msgs = kept_msgs
|
||||
deleted_count = original_count - len(kept_msgs)
|
||||
total_deleted_msgs += deleted_count
|
||||
|
||||
# 输出删除详情
|
||||
if deleted_details:
|
||||
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
|
||||
for detail in deleted_details:
|
||||
self._log(f" {detail}")
|
||||
|
||||
# ========== 问答对统计(已注释) ==========
|
||||
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
|
||||
# ========================================
|
||||
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
|
||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
result.append(dd)
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
|
||||
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
|
||||
# 保存日志
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
|
||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||
@@ -448,6 +720,7 @@ class SemanticPruner:
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
return result
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
场景特定配置 - 为不同场景提供定制化的剪枝规则
|
||||
|
||||
功能:
|
||||
- 场景特定的重要信息识别模式
|
||||
- 场景特定的重要性评分权重
|
||||
- 场景特定的填充词库
|
||||
- 场景特定的问答对识别规则
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenePatterns:
|
||||
"""场景特定的识别模式"""
|
||||
|
||||
# 重要信息的正则模式(优先级从高到低)
|
||||
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
|
||||
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
|
||||
# 填充词库(无意义对话)
|
||||
filler_phrases: Set[str] = field(default_factory=set)
|
||||
|
||||
# 问句关键词(用于识别问答对)
|
||||
question_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
# 决策性/承诺性关键词
|
||||
decision_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class SceneConfigRegistry:
|
||||
"""场景配置注册表 - 管理所有场景的特定配置"""
|
||||
|
||||
# 基础通用模式(所有场景共享)
|
||||
BASE_HIGH_PRIORITY = [
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
||||
(r"\d{11}", 4), # 手机号
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
||||
]
|
||||
|
||||
BASE_MEDIUM_PRIORITY = [
|
||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
||||
(r"地址|地点|位置", 2),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
|
||||
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
|
||||
(r"今年|去年|明年", 3),
|
||||
]
|
||||
|
||||
BASE_LOW_PRIORITY = [
|
||||
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
|
||||
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
|
||||
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
|
||||
(r"AM|PM|am|pm", 1),
|
||||
]
|
||||
|
||||
BASE_FILLERS = {
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
"拜拜", "再见", "88", "拜", "回见",
|
||||
# 口头禅
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号
|
||||
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
|
||||
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
|
||||
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
|
||||
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
|
||||
# 网络用语
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
}
|
||||
|
||||
BASE_QUESTION_KEYWORDS = {
|
||||
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"
|
||||
}
|
||||
|
||||
BASE_DECISION_KEYWORDS = {
|
||||
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
|
||||
"承诺", "保证", "确保", "负责", "同意", "答应"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_education_config(cls) -> ScenePatterns:
|
||||
"""教育场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 成绩相关(最高优先级)
|
||||
(r"成绩|分数|得分|满分|及格|不及格", 6),
|
||||
(r"GPA|绩点|学分|平均分", 6),
|
||||
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
|
||||
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
|
||||
|
||||
# 学籍信息
|
||||
(r"学号|学生证|教师工号|工号", 5),
|
||||
(r"班级|年级|专业|院系", 4),
|
||||
|
||||
# 课程相关
|
||||
(r"课程|科目|学科|必修|选修", 4),
|
||||
(r"教材|课本|教科书|参考书", 4),
|
||||
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
|
||||
|
||||
# 学科内容(新增)
|
||||
(r"微积分|导数|积分|函数|极限|微分", 4),
|
||||
(r"代数|几何|三角|概率|统计", 4),
|
||||
(r"物理|化学|生物|历史|地理", 4),
|
||||
(r"英语|语文|数学|政治|哲学", 4),
|
||||
(r"定义|定理|公式|概念|原理|法则", 3),
|
||||
(r"例题|解题|证明|推导|计算", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 教学活动
|
||||
(r"作业|练习|习题|题目", 3),
|
||||
(r"考试|测验|测试|考核|期中|期末", 3),
|
||||
(r"上课|下课|课堂|讲课", 2),
|
||||
(r"提问|回答|发言|讨论", 2),
|
||||
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
|
||||
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
|
||||
|
||||
# 时间安排
|
||||
(r"课表|课程表|时间表", 3),
|
||||
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"老师|教师|同学|学生", 1),
|
||||
(r"教室|实验室|图书馆", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"为啥", "咋", "咋办", "怎样", "如何做",
|
||||
"能不能", "可不可以", "行不行", "对不对", "是不是",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"必考", "重点", "考点", "难点", "关键",
|
||||
"记住", "背诵", "掌握", "理解", "复习",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_online_service_config(cls) -> ScenePatterns:
|
||||
"""在线服务场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 工单相关(最高优先级)
|
||||
(r"工单号|工单编号|ticket|TK\d+", 6),
|
||||
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
|
||||
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
|
||||
|
||||
# 产品信息
|
||||
(r"产品型号|型号|SKU|产品编号", 5),
|
||||
(r"序列号|SN|设备号", 5),
|
||||
(r"版本号|软件版本|固件版本", 4),
|
||||
|
||||
# 问题描述
|
||||
(r"故障|错误|异常|bug|问题", 4),
|
||||
(r"错误代码|故障代码|error code", 5),
|
||||
(r"无法|不能|失败|报错", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 服务相关
|
||||
(r"退款|退货|换货|补发", 4),
|
||||
(r"发票|收据|凭证", 3),
|
||||
(r"物流|快递|运单号", 3),
|
||||
(r"保修|质保|售后", 3),
|
||||
|
||||
# 时效相关
|
||||
(r"SLA|响应时间|处理时长", 4),
|
||||
(r"超时|延迟|等待", 2),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"客服|工程师|技术支持", 1),
|
||||
(r"用户|客户|会员", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 在线服务特有填充词
|
||||
"您好", "请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "再见", "欢迎下次咨询",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"能否", "可否", "是否", "有没有", "能不能",
|
||||
"怎么办", "如何处理", "怎么解决",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"立即处理", "马上解决", "尽快", "优先",
|
||||
"升级", "转接", "派单", "跟进",
|
||||
"补偿", "赔偿", "退款", "换货",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outbound_config(cls) -> ScenePatterns:
|
||||
"""外呼场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 意向相关(最高优先级)
|
||||
(r"意向|意愿|兴趣|感兴趣", 6),
|
||||
(r"A类|B类|C类|D类|高意向|低意向", 6),
|
||||
(r"成交|签约|下单|购买|确认", 6),
|
||||
|
||||
# 联系信息(外呼场景中更重要)
|
||||
(r"预约|约定|安排|确定时间", 5),
|
||||
(r"下次联系|回访|跟进", 5),
|
||||
(r"方便|有空|可以|时间", 4),
|
||||
|
||||
# 通话状态
|
||||
(r"接通|未接通|占线|关机|停机", 4),
|
||||
(r"通话时长|通话时间", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 客户信息
|
||||
(r"姓名|称呼|先生|女士", 3),
|
||||
(r"公司|单位|职位|职务", 3),
|
||||
(r"需求|要求|期望", 3),
|
||||
|
||||
# 跟进状态
|
||||
(r"跟进状态|进展|进度", 3),
|
||||
(r"已联系|待联系|联系中", 2),
|
||||
(r"拒绝|不感兴趣|考虑|再说", 3),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"销售|客户经理|业务员", 1),
|
||||
(r"产品|服务|方案", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 外呼场景特有填充词
|
||||
"您好", "喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"好的", "行", "可以", "没问题", "那就这样",
|
||||
"再联系", "回头聊", "有需要再说",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"有没有", "需不需要", "要不要", "考虑不考虑",
|
||||
"了解吗", "知道吗", "听说过吗",
|
||||
"方便吗", "有空吗", "在吗",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"确定", "决定", "选择", "购买", "下单",
|
||||
"预约", "安排", "约定", "确认",
|
||||
"跟进", "回访", "联系", "沟通",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
|
||||
"""根据场景名称获取配置
|
||||
|
||||
Args:
|
||||
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
|
||||
fallback_to_generic: 如果场景不存在,是否降级到通用配置
|
||||
|
||||
Returns:
|
||||
对应场景的配置,如果场景不存在:
|
||||
- fallback_to_generic=True: 返回通用配置(仅基础规则)
|
||||
- fallback_to_generic=False: 抛出异常
|
||||
"""
|
||||
scene_map = {
|
||||
'education': cls.get_education_config,
|
||||
'online_service': cls.get_online_service_config,
|
||||
'outbound': cls.get_outbound_config,
|
||||
}
|
||||
|
||||
if scene in scene_map:
|
||||
return scene_map[scene]()
|
||||
|
||||
if fallback_to_generic:
|
||||
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
|
||||
return cls.get_generic_config()
|
||||
else:
|
||||
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
|
||||
|
||||
@classmethod
|
||||
def get_generic_config(cls) -> ScenePatterns:
|
||||
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
|
||||
|
||||
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
|
||||
"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY,
|
||||
filler_phrases=cls.BASE_FILLERS,
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS,
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_scenes(cls) -> List[str]:
|
||||
"""获取所有预定义场景的列表"""
|
||||
return ['education', 'online_service', 'outbound']
|
||||
|
||||
@classmethod
|
||||
def is_scene_supported(cls, scene: str) -> bool:
|
||||
"""检查场景是否有专门的配置支持
|
||||
|
||||
Args:
|
||||
scene: 场景名称
|
||||
|
||||
Returns:
|
||||
True: 有专门配置
|
||||
False: 将使用通用配置
|
||||
"""
|
||||
return scene in cls.get_all_scenes()
|
||||
@@ -1932,17 +1932,17 @@ def preprocess_data(
|
||||
Returns:
|
||||
经过清洗转换后的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 数据预处理 ===")
|
||||
logger.debug("=== 数据预处理 ===")
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
|
||||
DataPreprocessor,
|
||||
)
|
||||
preprocessor = DataPreprocessor()
|
||||
try:
|
||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
||||
print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||
return cleaned_data
|
||||
except Exception as e:
|
||||
print(f"数据预处理过程中出现错误: {e}")
|
||||
logger.error(f"数据预处理过程中出现错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
|
||||
logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
|
||||
if not data:
|
||||
raise ValueError("预处理数据为空,无法进行分块")
|
||||
|
||||
@@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
input_data_path: Optional[str] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
skip_cleaning: bool = True,
|
||||
pruning_config: Optional[Dict] = None,
|
||||
) -> List[DialogData]:
|
||||
"""包含数据预处理步骤的完整分块流程
|
||||
|
||||
@@ -2000,11 +2001,12 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
input_data_path: 输入数据路径
|
||||
llm_client: LLM 客户端
|
||||
skip_cleaning: 是否跳过数据清洗步骤(默认False)
|
||||
pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold
|
||||
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 完整数据处理流程(包含预处理)===")
|
||||
logger.debug("=== 完整数据处理流程(包含预处理)===")
|
||||
|
||||
if input_data_path is None:
|
||||
input_data_path = os.path.join(
|
||||
@@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||
SemanticPruner,
|
||||
)
|
||||
pruner = SemanticPruner(llm_client=llm_client)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
|
||||
# 构建剪枝配置
|
||||
if pruning_config:
|
||||
# 使用传入的配置
|
||||
config = PruningConfig(**pruning_config)
|
||||
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||
else:
|
||||
# 使用默认配置(关闭剪枝)
|
||||
config = None
|
||||
logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)")
|
||||
|
||||
pruner = SemanticPruner(config=config, llm_client=llm_client)
|
||||
|
||||
# 记录单对话场景下剪枝前的消息数量
|
||||
single_dialog_original_msgs = None
|
||||
@@ -2043,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None:
|
||||
remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0
|
||||
deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs)
|
||||
print(
|
||||
logger.debug(
|
||||
f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs},"
|
||||
f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。"
|
||||
)
|
||||
else:
|
||||
print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
|
||||
logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
|
||||
|
||||
# 保存剪枝后的数据
|
||||
try:
|
||||
@@ -2059,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
dp = DataPreprocessor(output_file_path=pruned_output_path)
|
||||
dp.save_data(preprocessed_data, output_path=pruned_output_path)
|
||||
except Exception as se:
|
||||
print(f"保存剪枝结果失败:{se}")
|
||||
logger.error(f"保存剪枝结果失败:{se}")
|
||||
except Exception as e:
|
||||
print(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
|
||||
logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
|
||||
|
||||
# 步骤3: 对话分块
|
||||
return await get_chunked_dialogs_from_preprocessed(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Any
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
@@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class ChunkerStrategy(Enum):
|
||||
"""Supported chunking strategies."""
|
||||
RECURSIVE = "RecursiveChunker"
|
||||
SEMANTIC = "SemanticChunker"
|
||||
LATE = "LateChunker"
|
||||
NEURAL = "NeuralChunker"
|
||||
LLM = "LLMChunker"
|
||||
|
||||
@classmethod
|
||||
def get_valid_strategies(cls) -> List[str]:
|
||||
"""Get list of valid strategy names."""
|
||||
return [strategy.value for strategy in cls]
|
||||
|
||||
|
||||
class DialogueChunker:
|
||||
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
|
||||
|
||||
@@ -17,23 +33,51 @@ class DialogueChunker:
|
||||
of different chunking strategies to dialogue data.
|
||||
"""
|
||||
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None):
|
||||
"""Initialize the DialogueChunker with a specific chunking strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker
|
||||
llm_client: LLM client instance (required for LLMChunker strategy)
|
||||
|
||||
Raises:
|
||||
ValueError: If chunker_strategy is invalid or required parameters are missing
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
# Validate strategy
|
||||
valid_strategies = ChunkerStrategy.get_valid_strategies()
|
||||
if chunker_strategy not in valid_strategies:
|
||||
raise ValueError(
|
||||
f"Invalid chunker_strategy: '{chunker_strategy}'. "
|
||||
f"Must be one of {valid_strategies}"
|
||||
)
|
||||
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
self.chunker_strategy = chunker_strategy
|
||||
logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}")
|
||||
|
||||
try:
|
||||
# Load and validate configuration
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
if not chunker_config_dict:
|
||||
raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}")
|
||||
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
|
||||
# Initialize chunker client
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
if not llm_client:
|
||||
raise ValueError("llm_client is required for LLMChunker strategy")
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
|
||||
logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
|
||||
async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]:
|
||||
"""Process a dialogue by generating chunks and adding them to the DialogData object.
|
||||
|
||||
Args:
|
||||
@@ -43,54 +87,125 @@ class DialogueChunker:
|
||||
A list of Chunk objects
|
||||
|
||||
Raises:
|
||||
ValueError: If chunking fails or returns empty chunks
|
||||
ValueError: If dialogue is invalid or chunking fails
|
||||
Exception: If chunking process encounters an error
|
||||
"""
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
if not chunks or len(chunks) == 0:
|
||||
# Validate input
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
if not dialogue.context or not dialogue.context.msgs:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
f"Dialogue {dialogue.ref_id} has no messages to chunk. "
|
||||
f"Context: {dialogue.context is not None}, "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages "
|
||||
f"using strategy: {self.chunker_strategy}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate chunks
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
return chunks
|
||||
# Validate results
|
||||
if not chunks or len(chunks) == 0:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs)}, "
|
||||
f"Content length: {len(dialogue.content) if dialogue.content else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
)
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
|
||||
logger.info(
|
||||
f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. "
|
||||
f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
except ValueError:
|
||||
# Re-raise validation errors
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def save_chunking_results(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
dialogue: DialogData,
|
||||
output_path: Optional[str] = None,
|
||||
preview_length: int = 100
|
||||
) -> str:
|
||||
"""Save the chunking results to a file and return the output path.
|
||||
|
||||
Args:
|
||||
dialogue: The processed DialogData object with chunks
|
||||
output_path: Optional path to save the output
|
||||
chunks: List of Chunk objects to save
|
||||
dialogue: The DialogData object that was processed
|
||||
output_path: Optional path to save the output (defaults to current directory)
|
||||
preview_length: Maximum length of content preview (default: 100)
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
|
||||
Raises:
|
||||
ValueError: If chunks or dialogue is invalid
|
||||
IOError: If file writing fails
|
||||
"""
|
||||
if not output_path:
|
||||
output_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt"
|
||||
)
|
||||
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs)} messages",
|
||||
f"Total characters: {len(dialogue.content)}",
|
||||
f"Generated {len(dialogue.chunks)} chunks:"
|
||||
]
|
||||
# Validate input
|
||||
if not chunks:
|
||||
raise ValueError("chunks list cannot be empty")
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {chunk.content}...")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
# Generate default output path if not provided
|
||||
if not output_path:
|
||||
output_dir = Path(__file__).parent.parent.parent
|
||||
output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
||||
|
||||
logger.info(f"Saving chunking results to: {output_path}")
|
||||
|
||||
try:
|
||||
# Prepare output content
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages",
|
||||
f"Total characters: {len(dialogue.content) if dialogue.content else 0}",
|
||||
f"Generated {len(chunks)} chunks:",
|
||||
""
|
||||
]
|
||||
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
content_preview = chunk.content[:preview_length] if chunk.content else ""
|
||||
if len(chunk.content) > preview_length:
|
||||
content_preview += "..."
|
||||
|
||||
output_lines.append(f" Chunk {i}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {content_preview}")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
output_lines.append("")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
# Write to file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
|
||||
logger.info(f"Chunking results saved to: {output_path}")
|
||||
return output_path
|
||||
logger.info(f"Successfully saved chunking results to: {output_path}")
|
||||
return output_path
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -400,7 +400,8 @@ async def render_user_summary_prompt(
|
||||
user_id: str,
|
||||
entities: str,
|
||||
statements: str,
|
||||
language: str = "zh"
|
||||
language: str = "zh",
|
||||
user_display_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Renders the user summary prompt using the user_summary.jinja2 template.
|
||||
@@ -410,16 +411,22 @@ async def render_user_summary_prompt(
|
||||
entities: Core entities with frequency information
|
||||
statements: Representative statement samples
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user")
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
# 如果没有提供 user_display_name,使用默认值
|
||||
if user_display_name is None:
|
||||
user_display_name = "该用户" if language == "zh" else "the user"
|
||||
|
||||
template = prompt_env.get_template("user_summary.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
user_id=user_id,
|
||||
entities=entities,
|
||||
statements=statements,
|
||||
language=language
|
||||
language=language,
|
||||
user_display_name=user_display_name
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -429,7 +436,8 @@ async def render_user_summary_prompt(
|
||||
'user_id': user_id,
|
||||
'entities_len': len(entities),
|
||||
'statements_len': len(statements),
|
||||
'language': language
|
||||
'language': language,
|
||||
'user_display_name': user_display_name
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
{% if user_id %}
|
||||
- User ID: {{ user_id }}
|
||||
{% if user_display_name %}
|
||||
- User Display Name: {{ user_display_name }}
|
||||
{% endif %}
|
||||
{% if entities %}
|
||||
- Core Entities & Frequency: {{ entities }}
|
||||
@@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
3. Avoid excessive adjectives and empty phrases
|
||||
4. Strictly follow the output format specified below
|
||||
|
||||
{% if language == "zh" %}
|
||||
**【严格人称规定】**
|
||||
- 在描述用户时,必须使用"{{ user_display_name }}"作为人称
|
||||
- 绝对禁止使用用户ID(如 {{ user_id }})来称呼用户
|
||||
- 绝对禁止在摘要中出现任何形式的UUID或ID字符串
|
||||
- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA)
|
||||
{% else %}
|
||||
**【STRICT PRONOUN RULES】**
|
||||
- When describing the user, you MUST use "{{ user_display_name }}" as the reference
|
||||
- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user
|
||||
- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary
|
||||
- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they)
|
||||
{% endif %}
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
{% if language == "zh" %}
|
||||
@@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
|
||||
{% if language == "zh" %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: 张三
|
||||
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
|
||||
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
|
||||
|
||||
Example Output:
|
||||
【基本介绍】
|
||||
我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
张三是一名充满热情的高级产品经理,在深圳工作。在过去的5年里,张三专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
|
||||
【性格特点】
|
||||
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
|
||||
@@ -121,13 +135,13 @@ Example Output:
|
||||
"让每一个产品决策都充满温度。"
|
||||
{% else %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: John
|
||||
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
|
||||
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
|
||||
|
||||
Example Output:
|
||||
【Basic Introduction】
|
||||
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
|
||||
【Personality Traits】
|
||||
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.
|
||||
|
||||
@@ -816,11 +816,10 @@ class MemoryAgentService:
|
||||
"""
|
||||
统计知识库类型分布,包含:
|
||||
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
||||
2. Neo4j 中的 Memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
|
||||
3. total: 所有类型的总和
|
||||
2. total: 所有类型的总和
|
||||
|
||||
参数:
|
||||
- end_user_id: 用户组ID(可选,未提供时 Memory 统计为 0)
|
||||
- end_user_id: 用户组ID(可选,保留参数以保持接口兼容性)
|
||||
- only_active: 是否仅统计有效记录
|
||||
- current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0)
|
||||
- db: 数据库会话
|
||||
@@ -831,7 +830,6 @@ class MemoryAgentService:
|
||||
"Web": count,
|
||||
"Third-party": count,
|
||||
"Folder": count,
|
||||
"Memory": chunk_count,
|
||||
"total": sum_of_all
|
||||
}
|
||||
"""
|
||||
@@ -878,51 +876,8 @@ class MemoryAgentService:
|
||||
logger.error(f"知识库类型统计失败: {e}")
|
||||
raise Exception(f"知识库类型统计失败: {e}")
|
||||
|
||||
# 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数)
|
||||
try:
|
||||
if current_workspace_id:
|
||||
# 获取当前空间下的所有宿主
|
||||
from app.repositories import app_repository, end_user_repository
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
# 查询应用并转换为 Pydantic 模型
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
||||
apps = [AppSchema.model_validate(h) for h in apps_orm]
|
||||
app_ids = [app.id for app in apps]
|
||||
|
||||
# 获取所有宿主
|
||||
end_users = []
|
||||
for app_id in app_ids:
|
||||
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
|
||||
end_users.extend(h for h in end_user_orm_list)
|
||||
|
||||
# 统计所有宿主的 Chunk 总数
|
||||
total_chunks = 0
|
||||
for end_user in end_users:
|
||||
end_user_id_str = str(end_user.id)
|
||||
memory_query = """
|
||||
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
|
||||
"""
|
||||
neo4j_result = await _neo4j_connector.execute_query(
|
||||
memory_query,
|
||||
end_user_id=end_user_id_str,
|
||||
)
|
||||
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
|
||||
total_chunks += chunk_count
|
||||
logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}")
|
||||
|
||||
result["Memory"] = total_chunks
|
||||
logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}")
|
||||
else:
|
||||
# 没有 workspace_id 时,返回 0
|
||||
result["Memory"] = 0
|
||||
logger.info("未提供 workspace_id,memory 统计为 0")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)
|
||||
# 如果 Neo4j 查询失败,memory 设为 0
|
||||
result["Memory"] = 0
|
||||
# 2. 统计 Neo4j 中的 memory 总量已移除
|
||||
# memory 字段不再返回
|
||||
|
||||
# 3. 计算知识库类型总和(不包括 memory)
|
||||
result["total"] = (
|
||||
|
||||
@@ -101,34 +101,141 @@ async def run_pilot_extraction(
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
|
||||
|
||||
# ========== 步骤 2.1: 语义剪枝 ==========
|
||||
pruned_dialogs = [dialog]
|
||||
deleted_messages = [] # 记录被删除的消息
|
||||
pruning_stats = None # 保存剪枝统计信息,用于最终汇总
|
||||
|
||||
if memory_config.pruning_enabled:
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||
SemanticPruner,
|
||||
)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
|
||||
# 构建剪枝配置
|
||||
pruning_config_dict = {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
"llm_model_id": str(memory_config.llm_model_id),
|
||||
}
|
||||
config = PruningConfig(**pruning_config_dict)
|
||||
|
||||
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||
|
||||
# 记录剪枝前的消息(用于对比)
|
||||
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
|
||||
original_msg_count = len(original_messages)
|
||||
|
||||
# 执行剪枝
|
||||
pruner = SemanticPruner(config=config, llm_client=llm_client)
|
||||
pruned_dialogs = await pruner.prune_dataset([dialog])
|
||||
|
||||
# 计算剪枝结果并找出被删除的消息
|
||||
if pruned_dialogs and pruned_dialogs[0].context:
|
||||
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
|
||||
remaining_msg_count = len(remaining_messages)
|
||||
deleted_msg_count = original_msg_count - remaining_msg_count
|
||||
|
||||
# 找出被删除的消息(基于索引精确匹配)
|
||||
# 为剩余消息创建带索引的列表,用于精确追踪
|
||||
remaining_with_index = []
|
||||
remaining_idx = 0
|
||||
for orig_idx, orig_msg in enumerate(original_messages):
|
||||
if remaining_idx < len(remaining_messages) and \
|
||||
orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \
|
||||
orig_msg["content"] == remaining_messages[remaining_idx]["content"]:
|
||||
remaining_with_index.append(orig_idx)
|
||||
remaining_idx += 1
|
||||
|
||||
# 找出未在保留列表中的消息索引
|
||||
deleted_messages = [
|
||||
{"index": idx, "role": msg["role"], "content": msg["content"]}
|
||||
for idx, msg in enumerate(original_messages)
|
||||
if idx not in remaining_with_index
|
||||
]
|
||||
|
||||
# 保存剪枝统计信息(用于最终汇总,只保留deleted_count)
|
||||
pruning_stats = {
|
||||
"enabled": True,
|
||||
"scene": config.pruning_scene,
|
||||
"threshold": config.pruning_threshold,
|
||||
"deleted_count": deleted_msg_count,
|
||||
}
|
||||
|
||||
# 输出剪枝结果(显示删除的消息详情)
|
||||
pruning_result = {
|
||||
"type": "pruning",
|
||||
"deleted_messages": deleted_messages,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
|
||||
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result)
|
||||
else:
|
||||
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
|
||||
pruned_dialogs = [dialog]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
|
||||
pruned_dialogs = [dialog]
|
||||
if progress_callback:
|
||||
error_result = {
|
||||
"type": "pruning",
|
||||
"error": str(e),
|
||||
"fallback": "使用原始对话"
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result)
|
||||
else:
|
||||
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
|
||||
pruning_stats = {
|
||||
"enabled": False,
|
||||
}
|
||||
|
||||
# ========== 步骤 2.2: 语义分块 ==========
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
data=pruned_dialogs,
|
||||
chunker_strategy=memory_config.chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
||||
|
||||
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
|
||||
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dlg in chunked_dialogs:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
if hasattr(dlg, 'chunks') and dlg.chunks:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"type": "chunking",
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 构建预处理完成总结(包含剪枝统计)
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
# 添加剪枝统计信息
|
||||
if pruning_stats:
|
||||
preprocessing_summary["pruning"] = pruning_stats
|
||||
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
@@ -1163,11 +1163,32 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||
from app.core.language_utils import validate_language
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.db import get_db
|
||||
import re
|
||||
|
||||
# 验证语言参数
|
||||
language = validate_language(language)
|
||||
|
||||
# 获取用户的 other_name 字段
|
||||
user_display_name = "该用户" if language == "zh" else "the user"
|
||||
if end_user_id:
|
||||
try:
|
||||
# 获取数据库会话并查询用户信息
|
||||
db = next(get_db())
|
||||
try:
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
if end_user and end_user.other_name:
|
||||
user_display_name = end_user.other_name
|
||||
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
||||
else:
|
||||
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
||||
|
||||
# 创建 UserSummaryHelper 实例
|
||||
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
|
||||
|
||||
@@ -1184,7 +1205,8 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
user_id=user_summary_tool.user_id,
|
||||
entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)",
|
||||
statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)",
|
||||
language=language
|
||||
language=language,
|
||||
user_display_name=user_display_name
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
197
api/app/tasks.py
197
api/app/tasks.py
@@ -1304,6 +1304,203 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
"workspace_id": workspace_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
}
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_all_workspaces_memory_task",
|
||||
bind=True,
|
||||
ignore_result=False,
|
||||
max_retries=3,
|
||||
acks_late=True,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"""定时任务:遍历所有工作空间,统计并写入记忆增量
|
||||
|
||||
此任务会:
|
||||
1. 查询所有活跃的工作空间
|
||||
2. 对每个工作空间统计记忆总量
|
||||
3. 将统计结果写入 memory_increments 表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有活跃的工作空间
|
||||
workspaces = db.query(Workspace).filter(
|
||||
Workspace.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
if not workspaces:
|
||||
api_logger.warning("没有找到活跃的工作空间")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到活跃的工作空间",
|
||||
"workspace_count": 0,
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
all_workspace_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
|
||||
try:
|
||||
# 1. 查询当前workspace下的所有app(仅未删除的)
|
||||
apps = db.query(App).filter(
|
||||
App.workspace_id == workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
if not apps:
|
||||
# 如果没有app,总量为0
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
total_num=0
|
||||
)
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "SUCCESS",
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
continue
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
app_ids = [app.id for app in apps]
|
||||
end_users = db.query(EndUser.id).filter(
|
||||
EndUser.app_id.in_(app_ids)
|
||||
).distinct().all()
|
||||
|
||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||
total_num = 0
|
||||
end_user_details = []
|
||||
|
||||
for (end_user_id,) in end_users:
|
||||
try:
|
||||
# 调用 search_all 接口查询该宿主的总量
|
||||
result = await search_all(str(end_user_id))
|
||||
user_total = result.get("total", 0)
|
||||
total_num += user_total
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": user_total
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 4. 写入数据库
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
total_num=total_num
|
||||
)
|
||||
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "SUCCESS",
|
||||
"total_num": total_num,
|
||||
"end_user_count": len(end_users),
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
|
||||
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
})
|
||||
|
||||
total_memory = sum(r.get("total_num", 0) for r in all_workspace_results)
|
||||
success_count = sum(1 for r in all_workspace_results if r.get("status") == "SUCCESS")
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": f"成功处理 {success_count}/{len(workspaces)} 个工作空间,总记忆量: {total_memory}",
|
||||
"workspace_count": len(workspaces),
|
||||
"success_count": success_count,
|
||||
"total_memory": total_memory,
|
||||
"workspace_results": all_workspace_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"workspace_count": 0,
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
|
||||
Reference in New Issue
Block a user