842 lines
32 KiB
Python
842 lines
32 KiB
Python
"""
|
||
Memory Storage Service
|
||
|
||
Handles business logic for memory storage operations.
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||
|
||
from dotenv import load_dotenv
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.logging_config import get_config_logger, get_logger
|
||
from app.core.memory.analytics.hot_memory_tags import (
|
||
get_raw_tags_from_db,
|
||
filter_tags_with_llm,
|
||
)
|
||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||
from app.models.user_model import User
|
||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||
from app.schemas.memory_config_schema import ConfigurationError
|
||
from app.schemas.memory_storage_schema import (
|
||
ConfigKey,
|
||
ConfigParamsCreate,
|
||
ConfigParamsDelete,
|
||
ConfigPilotRun,
|
||
ConfigUpdate,
|
||
ConfigUpdateExtracted,
|
||
)
|
||
from app.services.memory_config_service import MemoryConfigService
|
||
from app.utils.sse_utils import format_sse_message
|
||
|
||
logger = get_logger(__name__)
|
||
config_logger = get_config_logger()
|
||
|
||
# Load environment variables for Neo4j connector
|
||
load_dotenv()
|
||
_neo4j_connector = Neo4jConnector()
|
||
|
||
|
||
class MemoryStorageService:
|
||
"""Service for memory storage operations"""
|
||
|
||
def __init__(self):
|
||
logger.info("MemoryStorageService initialized")
|
||
|
||
async def get_storage_info(self) -> dict:
|
||
"""
|
||
Example wrapper method - retrieves storage information
|
||
|
||
Args:
|
||
|
||
Returns:
|
||
Storage information dictionary
|
||
"""
|
||
logger.info("Getting storage info ")
|
||
|
||
# Empty wrapper - implement your logic here
|
||
result = {
|
||
"status": "active",
|
||
"message": "This is an example wrapper"
|
||
}
|
||
|
||
return result
|
||
|
||
|
||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||
"""Service layer for config params CRUD.
|
||
|
||
使用 SQLAlchemy ORM 进行数据库操作。
|
||
"""
|
||
|
||
def __init__(self, db: Session) -> None:
|
||
"""初始化服务
|
||
|
||
Args:
|
||
db: SQLAlchemy 数据库会话
|
||
"""
|
||
self.db = db
|
||
|
||
@staticmethod
|
||
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
|
||
|
||
for item in data_list:
|
||
for field in ['created_at', 'updated_at']:
|
||
if field in item and item[field] is not None:
|
||
value = item[field]
|
||
dt = None
|
||
|
||
# 处理不同类型的时间值
|
||
if hasattr(value, 'to_native'):
|
||
# Neo4j DateTime 对象
|
||
dt = value.to_native()
|
||
elif isinstance(value, datetime):
|
||
# Python datetime 对象
|
||
dt = value
|
||
elif isinstance(value, str):
|
||
# 字符串格式
|
||
try:
|
||
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||
except Exception:
|
||
pass # 保持原值
|
||
|
||
# 转换为 YYYYMMDDHHmmss 格式
|
||
if dt:
|
||
item[field] = dt.strftime('%Y%m%d%H%M%S')
|
||
|
||
return data_list
|
||
|
||
# --- Create ---
|
||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||
# 业务层检查同一工作空间下是否已存在同名配置
|
||
if params.workspace_id and params.config_name:
|
||
from app.models.memory_config_model import MemoryConfig
|
||
existing = (
|
||
self.db.query(MemoryConfig)
|
||
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
|
||
.first()
|
||
)
|
||
if existing:
|
||
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
|
||
|
||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||
configs = self._get_workspace_configs(params.workspace_id)
|
||
if configs is None:
|
||
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
|
||
|
||
# 只在未指定时填充(允许手动覆盖)
|
||
if not params.llm_id:
|
||
params.llm_id = configs.get('llm')
|
||
if not params.embedding_id:
|
||
params.embedding_id = configs.get('embedding')
|
||
if not params.rerank_id:
|
||
params.rerank_id = configs.get('rerank')
|
||
|
||
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
||
if not params.reflection_model_id:
|
||
params.reflection_model_id = params.llm_id
|
||
if not params.emotion_model_id:
|
||
params.emotion_model_id = params.llm_id
|
||
|
||
# 根据关联的本体场景推导 pruning_scene(语义剪枝场景与本体工程场景保持一致)
|
||
if params.scene_id and not getattr(params, 'pruning_scene', None):
|
||
params.pruning_scene = self._resolve_pruning_scene_from_scene_id(params.scene_id)
|
||
|
||
config = MemoryConfigRepository.create(self.db, params)
|
||
self.db.commit()
|
||
return {"affected": 1, "config_id": config.config_id}
|
||
|
||
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
|
||
"""获取工作空间模型配置(内部方法,便于测试)"""
|
||
from app.db import SessionLocal
|
||
from app.repositories.workspace_repository import get_workspace_models_configs
|
||
|
||
db_session = SessionLocal()
|
||
try:
|
||
return get_workspace_models_configs(db_session, workspace_id)
|
||
finally:
|
||
db_session.close()
|
||
|
||
def _resolve_pruning_scene_from_scene_id(self, scene_id) -> Optional[str]:
|
||
"""根据本体场景ID获取对应的 scene_name,作为语义剪枝场景值
|
||
|
||
Args:
|
||
scene_id: 本体场景UUID
|
||
|
||
Returns:
|
||
scene_name 字符串,查询失败时返回 None
|
||
"""
|
||
try:
|
||
from app.models.ontology_scene import OntologyScene
|
||
scene = self.db.query(OntologyScene).filter_by(scene_id=scene_id).first()
|
||
return scene.scene_name if scene else None
|
||
except Exception as e:
|
||
logger.warning(f"_resolve_pruning_scene_from_scene_id failed for scene_id={scene_id}: {e}", exc_info=True)
|
||
return None
|
||
|
||
# --- Delete ---
|
||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
||
if not success:
|
||
raise ValueError("未找到配置")
|
||
return {"affected": 1}
|
||
|
||
# --- Update ---
|
||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||
config = MemoryConfigRepository.update(self.db, update)
|
||
if not config:
|
||
raise ValueError("未找到配置")
|
||
return {"affected": 1}
|
||
|
||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||
config = MemoryConfigRepository.update_extracted(self.db, update)
|
||
if not config:
|
||
raise ValueError("未找到配置")
|
||
return {"affected": 1}
|
||
|
||
# --- Forget config params ---
|
||
# 遗忘引擎配置方法已迁移到 memory_forget_service.py
|
||
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
|
||
|
||
# --- Read ---
|
||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
|
||
if not result:
|
||
raise ValueError("未找到配置")
|
||
return result
|
||
|
||
# --- Read All ---
|
||
def get_all(self, workspace_id=None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||
|
||
# 检查并修正 pruning_scene 与 scene_name 不一致的记录
|
||
needs_commit = False
|
||
for config, scene_name in results:
|
||
if scene_name and config.pruning_scene != scene_name:
|
||
logger.info(
|
||
f"修正 pruning_scene: config_id={config.config_id} "
|
||
f"'{config.pruning_scene}' -> '{scene_name}'"
|
||
)
|
||
config.pruning_scene = scene_name
|
||
needs_commit = True
|
||
if needs_commit:
|
||
self.db.commit()
|
||
|
||
# 将 ORM 对象转换为字典列表
|
||
data_list = []
|
||
for config, scene_name in results:
|
||
# 安全地转换 user_id 为 int
|
||
config_id_old = None
|
||
if config.config_id_old:
|
||
try:
|
||
config_id_old = int(config.config_id_old)
|
||
except (ValueError, TypeError):
|
||
config_id_old = None
|
||
|
||
config_dict = {
|
||
"config_id": str(config.config_id),
|
||
"config_name": config.config_name,
|
||
"config_desc": config.config_desc,
|
||
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
||
"end_user_id": config.end_user_id,
|
||
"config_id_old": config_id_old,
|
||
"apply_id": config.apply_id,
|
||
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||
"scene_name": scene_name, # 新增:场景名称
|
||
"is_system_default": config.is_default, # 是否为系统默认配置
|
||
"llm_id": config.llm_id,
|
||
"embedding_id": config.embedding_id,
|
||
"rerank_id": config.rerank_id,
|
||
"enable_llm_dedup_blockwise": config.enable_llm_dedup_blockwise,
|
||
"enable_llm_disambiguation": config.enable_llm_disambiguation,
|
||
"deep_retrieval": config.deep_retrieval,
|
||
"t_type_strict": config.t_type_strict,
|
||
"t_name_strict": config.t_name_strict,
|
||
"t_overall": config.t_overall,
|
||
"state": config.state,
|
||
"chunker_strategy": config.chunker_strategy,
|
||
"pruning_enabled": config.pruning_enabled,
|
||
"pruning_scene": config.pruning_scene,
|
||
"pruning_threshold": config.pruning_threshold,
|
||
"enable_self_reflexion": config.enable_self_reflexion,
|
||
"iteration_period": config.iteration_period,
|
||
"reflexion_range": config.reflexion_range,
|
||
"baseline": config.baseline,
|
||
"statement_granularity": config.statement_granularity,
|
||
"include_dialogue_context": config.include_dialogue_context,
|
||
"max_context": config.max_context,
|
||
"lambda_time": config.lambda_time,
|
||
"lambda_mem": config.lambda_mem,
|
||
"offset": config.offset,
|
||
"created_at": config.created_at,
|
||
"updated_at": config.updated_at,
|
||
}
|
||
data_list.append(config_dict)
|
||
|
||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||
return self._convert_timestamps_to_format(data_list)
|
||
|
||
async def pilot_run_stream(self, payload: ConfigPilotRun, language: str = "zh") -> AsyncGenerator[str, None]:
|
||
"""
|
||
流式执行试运行,产生 SSE 格式的进度事件
|
||
|
||
Args:
|
||
payload: 试运行配置和对话文本
|
||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||
|
||
Yields:
|
||
SSE 格式的字符串,包含以下事件类型:
|
||
- 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等)
|
||
- result: 最终结果
|
||
- error: 错误信息
|
||
- done: 完成标记
|
||
|
||
Raises:
|
||
ValueError: 当配置无效或参数缺失时
|
||
RuntimeError: 当管线执行失败时
|
||
"""
|
||
from pathlib import Path
|
||
project_root = str(Path(__file__).resolve().parents[2])
|
||
|
||
try:
|
||
# 发出初始进度事件
|
||
yield format_sse_message("starting", {
|
||
"message": "开始试运行...",
|
||
"time": int(time.time() * 1000)
|
||
})
|
||
|
||
# 步骤 1: 配置加载和验证(数据库优先)
|
||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||
cid: Optional[str] = payload_cid if payload_cid else None
|
||
|
||
if not cid:
|
||
raise ValueError("未提供 payload.config_id,禁止启动试运行")
|
||
|
||
# Load configuration from database only using centralized manager
|
||
try:
|
||
config_service = MemoryConfigService(self.db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=str(cid),
|
||
service_name="MemoryStorageService.pilot_run_stream"
|
||
)
|
||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||
except ConfigurationError as e:
|
||
raise RuntimeError(f"Configuration loading failed: {e}")
|
||
|
||
# 根据是否关联本体场景选择使用的文本
|
||
# 如果配置关联了本体场景(scene_id 不为空),使用 custom_text(如果提供)
|
||
# 否则使用 dialogue_text
|
||
if memory_config.scene_id:
|
||
# 关联了本体场景,优先使用 custom_text
|
||
if hasattr(payload, 'custom_text') and payload.custom_text:
|
||
dialogue_text = payload.custom_text.strip()
|
||
logger.info(
|
||
f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
|
||
else:
|
||
# 如果没有提供 custom_text,回退到 dialogue_text
|
||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||
logger.info(
|
||
f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
|
||
else:
|
||
# 没有关联本体场景,使用 dialogue_text
|
||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||
logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}")
|
||
|
||
# 验证最终使用的文本不为空
|
||
if not dialogue_text:
|
||
raise ValueError("试运行模式必须提供有效的文本内容(dialogue_text 或 custom_text)")
|
||
|
||
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
|
||
|
||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||
# 使用队列在回调和生成器之间传递进度事件
|
||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||
|
||
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||
"""
|
||
进度回调函数,将进度事件放入队列
|
||
|
||
Args:
|
||
stage: 阶段标识
|
||
message: 进度消息
|
||
data: 可选的结果数据(用于传递节点执行结果)
|
||
"""
|
||
await progress_queue.put((stage, message, data))
|
||
|
||
# 步骤 3: 在后台任务中执行管线
|
||
async def run_pipeline():
|
||
"""在后台执行管线并捕获异常"""
|
||
try:
|
||
from app.services.pilot_run_service import run_pilot_extraction
|
||
|
||
logger.info(
|
||
f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||
await run_pilot_extraction(
|
||
memory_config=memory_config,
|
||
dialogue_text=dialogue_text,
|
||
db=self.db,
|
||
progress_callback=progress_callback,
|
||
language=language,
|
||
)
|
||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||
|
||
# 标记管线完成
|
||
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
|
||
except Exception as e:
|
||
# 将异常放入队列
|
||
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
|
||
|
||
# 启动后台任务
|
||
pipeline_task = asyncio.create_task(run_pipeline())
|
||
|
||
# 步骤 4: 从队列中读取进度事件并发出
|
||
while True:
|
||
try:
|
||
# 等待进度事件,设置超时以检测客户端断开
|
||
stage, message, data = await asyncio.wait_for(
|
||
progress_queue.get(),
|
||
timeout=0.5
|
||
)
|
||
|
||
# 检查特殊标记
|
||
if stage == "__PIPELINE_COMPLETE__":
|
||
break
|
||
elif stage == "__PIPELINE_ERROR__":
|
||
raise RuntimeError(message)
|
||
|
||
# 构建进度事件数据
|
||
progress_data = {
|
||
"message": message,
|
||
"time": int(time.time() * 1000)
|
||
}
|
||
|
||
# 如果有结果数据,添加到事件中
|
||
if data:
|
||
progress_data["data"] = data
|
||
|
||
# 发出进度事件,使用 stage 作为事件类型
|
||
yield format_sse_message(stage, progress_data)
|
||
|
||
except TimeoutError:
|
||
# 超时,继续等待(这允许检测客户端断开)
|
||
continue
|
||
|
||
# 等待管线任务完成
|
||
await pipeline_task
|
||
|
||
# 步骤 5: 读取提取结果
|
||
from app.core.config import settings
|
||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||
if not os.path.isfile(result_path):
|
||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||
|
||
with open(result_path, "r", encoding="utf-8") as rf:
|
||
extracted_result = json.load(rf)
|
||
|
||
# 步骤 6: 计算本体覆盖率并合并到结果中
|
||
result_data = {
|
||
"config_id": cid,
|
||
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||
"extracted_result": extracted_result,
|
||
}
|
||
try:
|
||
ontology_coverage = await self._compute_ontology_coverage(
|
||
extracted_result=extracted_result,
|
||
memory_config=memory_config,
|
||
)
|
||
if ontology_coverage:
|
||
result_data["ontology_coverage"] = ontology_coverage
|
||
except Exception as cov_err:
|
||
logger.warning(f"[PILOT_RUN_STREAM] Ontology coverage computation failed: {cov_err}", exc_info=True)
|
||
|
||
yield format_sse_message("result", result_data)
|
||
|
||
# 步骤 7: 发出完成事件
|
||
yield format_sse_message("done", {
|
||
"message": "试运行完成",
|
||
"time": int(time.time() * 1000)
|
||
})
|
||
|
||
except asyncio.CancelledError:
|
||
# 客户端断开连接
|
||
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
|
||
raise
|
||
except Exception as e:
|
||
# 发出错误事件
|
||
logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True)
|
||
yield format_sse_message("error", {
|
||
"code": 5000,
|
||
"message": "试运行失败",
|
||
"error": str(e),
|
||
"time": int(time.time() * 1000)
|
||
})
|
||
|
||
async def _compute_ontology_coverage(
|
||
self,
|
||
extracted_result: Dict[str, Any],
|
||
memory_config,
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""根据提取结果中的实体类型,与场景/通用本体类型做互斥分类统计。
|
||
|
||
分类规则(互斥):场景类型优先 > 通用类型 > 未匹配
|
||
确保: 场景实体数 + 通用实体数 + 未匹配数 = 总实体数
|
||
|
||
Returns:
|
||
包含三部分统计的字典,或 None(无实体数据时)
|
||
"""
|
||
core_entities = extracted_result.get("core_entities", [])
|
||
if not core_entities:
|
||
return None
|
||
|
||
# 1. 加载场景本体类型集合
|
||
scene_ontology_types: set = set()
|
||
try:
|
||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||
|
||
if memory_config.scene_id:
|
||
class_repo = OntologyClassRepository(self.db)
|
||
ontology_classes = class_repo.get_classes_by_scene(memory_config.scene_id)
|
||
scene_ontology_types = {oc.class_name for oc in ontology_classes}
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load scene ontology types: {e}")
|
||
|
||
# 2. 加载通用本体类型集合
|
||
general_ontology_types: set = set()
|
||
try:
|
||
from app.core.memory.ontology_services.ontology_type_loader import (
|
||
get_general_ontology_registry,
|
||
is_general_ontology_enabled,
|
||
)
|
||
|
||
if is_general_ontology_enabled():
|
||
registry = get_general_ontology_registry()
|
||
if registry:
|
||
general_ontology_types = set(registry.types.keys())
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load general ontology types: {e}")
|
||
|
||
# 3. 互斥分类:场景优先 > 通用 > 未匹配
|
||
scene_distribution: list = []
|
||
general_distribution: list = []
|
||
unmatched_distribution: list = []
|
||
scene_total = 0
|
||
general_total = 0
|
||
unmatched_total = 0
|
||
|
||
for item in core_entities:
|
||
entity_type = item.get("type", "")
|
||
count = item.get("count", 0)
|
||
|
||
if entity_type in scene_ontology_types:
|
||
scene_distribution.append({"type": entity_type, "count": count})
|
||
scene_total += count
|
||
elif entity_type in general_ontology_types:
|
||
general_distribution.append({"type": entity_type, "count": count})
|
||
general_total += count
|
||
else:
|
||
unmatched_distribution.append({"type": entity_type, "count": count})
|
||
unmatched_total += count
|
||
|
||
# 按数量降序排列
|
||
scene_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||
general_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||
unmatched_distribution.sort(key=lambda x: x["count"], reverse=True)
|
||
|
||
total_entities = scene_total + general_total + unmatched_total
|
||
|
||
return {
|
||
"scene_type_distribution": {
|
||
"type_count": len(scene_distribution),
|
||
"entity_total": scene_total,
|
||
"types": scene_distribution,
|
||
},
|
||
"general_type_distribution": {
|
||
"type_count": len(general_distribution),
|
||
"entity_total": general_total,
|
||
"types": general_distribution,
|
||
},
|
||
"unmatched": {
|
||
"type_count": len(unmatched_distribution),
|
||
"entity_total": unmatched_total,
|
||
"types": unmatched_distribution,
|
||
},
|
||
"total_entities": total_entities,
|
||
"time": int(time.time() * 1000),
|
||
}
|
||
|
||
|
||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||
|
||
|
||
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_DIALOGUE,
|
||
end_user_id=end_user_id,
|
||
)
|
||
data = {"search_for": "dialogue", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_CHUNK,
|
||
end_user_id=end_user_id,
|
||
)
|
||
data = {"search_for": "chunk", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_STATEMENT,
|
||
end_user_id=end_user_id,
|
||
)
|
||
data = {"search_for": "statement", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_ENTITY,
|
||
end_user_id=end_user_id,
|
||
)
|
||
data = {"search_for": "entity", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""统一知识库类型分布接口。
|
||
|
||
聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。
|
||
"""
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_ALL,
|
||
end_user_id=end_user_id,
|
||
)
|
||
|
||
# 检查结果是否为空或长度不足
|
||
if not result or len(result) < 4:
|
||
data = {
|
||
"total": 0,
|
||
"distribution": [
|
||
{"type": "dialogue", "count": 0},
|
||
{"type": "chunk", "count": 0},
|
||
{"type": "statement", "count": 0},
|
||
{"type": "entity", "count": 0},
|
||
]
|
||
}
|
||
return data
|
||
|
||
total = result[-1]["Count"]
|
||
distribution = [
|
||
{"type": "dialogue", "count": result[0]["Count"]},
|
||
{"type": "chunk", "count": result[1]["Count"]},
|
||
{"type": "statement", "count": result[2]["Count"]},
|
||
{"type": "entity", "count": result[3]["Count"]},
|
||
]
|
||
|
||
data = {"total": total, "distribution": distribution}
|
||
return data
|
||
|
||
|
||
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_DETIALS,
|
||
end_user_id=end_user_id,
|
||
)
|
||
return result
|
||
|
||
|
||
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_EDGES,
|
||
end_user_id=end_user_id,
|
||
)
|
||
return result
|
||
|
||
|
||
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
|
||
"""批量查询多个用户的记忆数量(简化版本,只返回total)
|
||
|
||
Args:
|
||
end_user_ids: 用户ID列表
|
||
|
||
Returns:
|
||
Dict[str, int]: 以user_id为key的记忆数量字典
|
||
格式: {"user_id": total_count}
|
||
"""
|
||
if not end_user_ids:
|
||
return {}
|
||
|
||
result = await _neo4j_connector.execute_query(
|
||
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||
end_user_ids=end_user_ids,
|
||
)
|
||
|
||
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
|
||
data = {}
|
||
for row in result:
|
||
data[row["user_id"]] = row["total"]
|
||
|
||
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
|
||
for user_id in end_user_ids:
|
||
if user_id not in data:
|
||
data[user_id] = 0
|
||
|
||
return data
|
||
|
||
|
||
async def analytics_hot_memory_tags(
|
||
db: Session,
|
||
current_user: User,
|
||
limit: int = 10
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取热门记忆标签,按数量排序并返回前N个
|
||
|
||
优化策略:
|
||
1. 先从所有用户收集原始标签(不调用LLM)
|
||
2. 聚合并合并相同标签的频率
|
||
3. 排序后取前N个
|
||
4. 只调用一次LLM进行筛选
|
||
"""
|
||
workspace_id = current_user.current_workspace_id
|
||
# 获取更多标签供LLM筛选(获取limit*4个标签)
|
||
raw_limit = limit * 4
|
||
from app.services.memory_dashboard_service import get_workspace_end_users
|
||
# 使用 asyncio.to_thread 避免阻塞事件循环
|
||
end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user)
|
||
|
||
if not end_users:
|
||
return []
|
||
|
||
# 步骤1: 收集所有用户的原始标签(不调用LLM)
|
||
connector = Neo4jConnector()
|
||
try:
|
||
all_raw_tags = []
|
||
for end_user in end_users:
|
||
raw_tags = await get_raw_tags_from_db(
|
||
connector,
|
||
str(end_user.id),
|
||
limit=raw_limit,
|
||
by_user=False
|
||
)
|
||
if raw_tags:
|
||
all_raw_tags.extend(raw_tags)
|
||
|
||
if not all_raw_tags:
|
||
return []
|
||
|
||
# 步骤2: 聚合相同标签的频率
|
||
tag_frequency_map = {}
|
||
for tag_name, frequency in all_raw_tags:
|
||
if tag_name in tag_frequency_map:
|
||
tag_frequency_map[tag_name] += frequency
|
||
else:
|
||
tag_frequency_map[tag_name] = frequency
|
||
|
||
# 步骤3: 按频率降序排序,取前raw_limit个
|
||
sorted_tags = sorted(
|
||
tag_frequency_map.items(),
|
||
key=lambda x: x[1],
|
||
reverse=True
|
||
)[:raw_limit]
|
||
|
||
if not sorted_tags:
|
||
return []
|
||
|
||
# 步骤4: 只调用一次LLM进行筛选
|
||
tag_names = [tag for tag, _ in sorted_tags]
|
||
|
||
# 使用第一个用户的end_user_id来获取LLM配置
|
||
# 因为同一工作空间下的用户应该使用相同的配置
|
||
first_end_user_id = str(end_users[0].id)
|
||
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)
|
||
|
||
# 步骤5: 根据LLM筛选结果构建最终列表(保留频率)
|
||
final_tags = []
|
||
for tag, freq in sorted_tags:
|
||
if tag in filtered_tag_names:
|
||
final_tags.append((tag, freq))
|
||
|
||
# 步骤6: 只返回前limit个
|
||
top_tags = final_tags[:limit]
|
||
|
||
return [{"name": t, "frequency": f} for t, f in top_tags]
|
||
|
||
finally:
|
||
await connector.close()
|
||
|
||
|
||
async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""获取最近记忆提取活动统计。
|
||
|
||
优先从 Redis 缓存读取(按 workspace_id),缓存不存在时降级到日志文件解析。
|
||
|
||
Args:
|
||
workspace_id: 工作空间ID,用于从 Redis 读取对应缓存
|
||
|
||
Returns:
|
||
包含 total、stats、latest_relative、source 的统计字典
|
||
"""
|
||
stats = None
|
||
source = "log"
|
||
|
||
# 优先从 Redis 读取
|
||
if workspace_id:
|
||
try:
|
||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||
cached = await ActivityStatsCache.get_activity_stats(workspace_id)
|
||
if cached:
|
||
stats = cached.get("stats", {})
|
||
source = "redis"
|
||
logger.info(f"[ANALYTICS] 从 Redis 读取活动统计: workspace_id={workspace_id}")
|
||
except Exception as e:
|
||
logger.warning(f"[ANALYTICS] 读取 Redis 活动统计失败,降级到日志: {e}")
|
||
|
||
# 降级:从日志文件解析
|
||
if stats is None:
|
||
stats, _msg = get_recent_activity_stats()
|
||
source = "log"
|
||
|
||
total = (
|
||
stats.get("chunk_count", 0)
|
||
+ stats.get("statements_count", 0)
|
||
+ stats.get("triplet_entities_count", 0)
|
||
+ stats.get("triplet_relations_count", 0)
|
||
+ stats.get("temporal_count", 0)
|
||
)
|
||
|
||
# 计算"最新一次活动多久前"(仅日志来源时有效)
|
||
latest_relative = None
|
||
if source == "log":
|
||
try:
|
||
info = stats.get("log_path", "")
|
||
idx = info.rfind("最新:")
|
||
if idx != -1:
|
||
latest_path = info[idx + 3:].strip()
|
||
if latest_path and os.path.exists(latest_path):
|
||
import time
|
||
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
|
||
m = int(diff // 60)
|
||
if m < 1:
|
||
latest_relative = "刚刚"
|
||
elif m < 60:
|
||
latest_relative = "一会前"
|
||
else:
|
||
latest_relative = "较早前"
|
||
except Exception:
|
||
pass
|
||
|
||
data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source}
|
||
return data
|