546 lines
20 KiB
Python
546 lines
20 KiB
Python
"""
|
||
Memory Storage Service
|
||
|
||
Handles business logic for memory storage operations.
|
||
"""
|
||
|
||
from typing import Dict, List, Optional, Any
|
||
import os
|
||
import json
|
||
from sqlalchemy.orm import Session
|
||
|
||
from dotenv import load_dotenv
|
||
|
||
from app.models.user_model import User
|
||
from app.models.end_user_model import EndUser
|
||
from app.core.logging_config import get_logger
|
||
from app.schemas.memory_storage_schema import (
|
||
ConfigFilter,
|
||
ConfigPilotRun,
|
||
ConfigParamsCreate,
|
||
ConfigParamsDelete,
|
||
ConfigUpdate,
|
||
ConfigUpdateExtracted,
|
||
ConfigUpdateForget,
|
||
ConfigKey,
|
||
)
|
||
from app.repositories.data_config_repository import DataConfigRepository
|
||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||
from app.repositories.data_config_repository import DataConfigRepository
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# Load environment variables for Neo4j connector
|
||
load_dotenv()
|
||
_neo4j_connector = Neo4jConnector()
|
||
|
||
|
||
class MemoryStorageService:
|
||
"""Service for memory storage operations"""
|
||
|
||
def __init__(self):
|
||
logger.info("MemoryStorageService initialized")
|
||
|
||
async def get_storage_info(self) -> dict:
|
||
"""
|
||
Example wrapper method - retrieves storage information
|
||
|
||
Args:
|
||
|
||
Returns:
|
||
Storage information dictionary
|
||
"""
|
||
logger.info("Getting storage info ")
|
||
|
||
# Empty wrapper - implement your logic here
|
||
result = {
|
||
"status": "active",
|
||
"message": "This is an example wrapper"
|
||
}
|
||
|
||
return result
|
||
|
||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||
"""Service layer for config params CRUD.
|
||
|
||
使用 SQLAlchemy ORM 进行数据库操作。
|
||
"""
|
||
|
||
def __init__(self, db: Session) -> None:
|
||
"""初始化服务
|
||
|
||
Args:
|
||
db: SQLAlchemy 数据库会话
|
||
"""
|
||
self.db = db
|
||
|
||
@staticmethod
|
||
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
|
||
from datetime import datetime
|
||
|
||
for item in data_list:
|
||
for field in ['created_at', 'updated_at']:
|
||
if field in item and item[field] is not None:
|
||
value = item[field]
|
||
dt = None
|
||
|
||
# 如果是 datetime 对象,直接使用
|
||
if isinstance(value, datetime):
|
||
dt = value
|
||
# 如果是字符串,先解析
|
||
elif isinstance(value, str):
|
||
try:
|
||
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||
except Exception:
|
||
pass # 保持原值
|
||
|
||
# 转换为 YYYYMMDDHHmmss 格式
|
||
if dt:
|
||
item[field] = dt.strftime('%Y%m%d%H%M%S')
|
||
|
||
return data_list
|
||
|
||
# --- Create ---
|
||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||
configs = self._get_workspace_configs(params.workspace_id)
|
||
if configs is None:
|
||
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
|
||
|
||
# 只在未指定时填充(允许手动覆盖)
|
||
if not params.llm_id:
|
||
params.llm_id = configs.get('llm')
|
||
if not params.embedding_id:
|
||
params.embedding_id = configs.get('embedding')
|
||
if not params.rerank_id:
|
||
params.rerank_id = configs.get('rerank')
|
||
|
||
config = DataConfigRepository.create(self.db, params)
|
||
self.db.commit()
|
||
return {"affected": 1, "config_id": config.config_id}
|
||
|
||
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
|
||
"""获取工作空间模型配置(内部方法,便于测试)"""
|
||
from app.db import SessionLocal
|
||
from app.repositories.workspace_repository import get_workspace_models_configs
|
||
|
||
db_session = SessionLocal()
|
||
try:
|
||
return get_workspace_models_configs(db_session, workspace_id)
|
||
finally:
|
||
db_session.close()
|
||
|
||
# --- Delete ---
|
||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||
success = DataConfigRepository.delete(self.db, key.config_id)
|
||
if not success:
|
||
raise ValueError("未找到配置")
|
||
return {"affected": 1}
|
||
|
||
# --- Update ---
|
||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||
config = DataConfigRepository.update(self.db, update)
|
||
if not config:
|
||
raise ValueError("未找到配置")
|
||
return {"affected": 1}
|
||
|
||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||
config = DataConfigRepository.update_extracted(self.db, update)
|
||
if not config:
|
||
raise ValueError("未找到配置")
|
||
return {"affected": 1}
|
||
|
||
# --- Forget config params ---
|
||
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
|
||
config = DataConfigRepository.update_forget(self.db, update)
|
||
if not config:
|
||
raise ValueError("未找到配置")
|
||
return {"affected": 1}
|
||
|
||
# --- Read ---
|
||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||
result = DataConfigRepository.get_extracted_config(self.db, key.config_id)
|
||
if not result:
|
||
raise ValueError("未找到配置")
|
||
return result
|
||
|
||
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数
|
||
result = DataConfigRepository.get_forget_config(self.db, key.config_id)
|
||
if not result:
|
||
raise ValueError("未找到配置")
|
||
return result
|
||
|
||
# --- Read All ---
|
||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||
configs = DataConfigRepository.get_all(self.db, workspace_id)
|
||
|
||
# 将 ORM 对象转换为字典列表
|
||
data_list = []
|
||
for config in configs:
|
||
config_dict = {
|
||
"config_id": config.config_id,
|
||
"config_name": config.config_name,
|
||
"config_desc": config.config_desc,
|
||
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
||
"group_id": config.group_id,
|
||
"user_id": config.user_id,
|
||
"apply_id": config.apply_id,
|
||
"llm_id": config.llm_id,
|
||
"embedding_id": config.embedding_id,
|
||
"rerank_id": config.rerank_id,
|
||
"llm": config.llm,
|
||
"enable_llm_dedup_blockwise": config.enable_llm_dedup_blockwise,
|
||
"enable_llm_disambiguation": config.enable_llm_disambiguation,
|
||
"deep_retrieval": config.deep_retrieval,
|
||
"t_type_strict": config.t_type_strict,
|
||
"t_name_strict": config.t_name_strict,
|
||
"t_overall": config.t_overall,
|
||
"state": config.state,
|
||
"chunker_strategy": config.chunker_strategy,
|
||
"pruning_enabled": config.pruning_enabled,
|
||
"pruning_scene": config.pruning_scene,
|
||
"pruning_threshold": config.pruning_threshold,
|
||
"enable_self_reflexion": config.enable_self_reflexion,
|
||
"iteration_period": config.iteration_period,
|
||
"reflexion_range": config.reflexion_range,
|
||
"baseline": config.baseline,
|
||
"statement_granularity": config.statement_granularity,
|
||
"include_dialogue_context": config.include_dialogue_context,
|
||
"max_context": config.max_context,
|
||
"lambda_time": config.lambda_time,
|
||
"lambda_mem": config.lambda_mem,
|
||
"offset": config.offset,
|
||
"created_at": config.created_at,
|
||
"updated_at": config.updated_at,
|
||
}
|
||
data_list.append(config_dict)
|
||
|
||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||
return self._convert_timestamps_to_format(data_list)
|
||
|
||
|
||
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
|
||
"""
|
||
选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。
|
||
支持 dialogue_text 参数用于试运行模式。
|
||
"""
|
||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||
|
||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||
cid: Optional[str] = payload_cid if payload_cid else None
|
||
|
||
if not cid and os.path.isfile(dbrun_path):
|
||
try:
|
||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||
dbrun = json.load(f)
|
||
if isinstance(dbrun, dict):
|
||
sel = dbrun.get("selections", {})
|
||
if isinstance(sel, dict):
|
||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||
cid = fallback_cid or None
|
||
except Exception:
|
||
cid = None
|
||
|
||
if not cid:
|
||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||
|
||
# 验证 dialogue_text 必须提供
|
||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||
if not dialogue_text:
|
||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||
|
||
# 应用内存覆写并刷新常量(在导入主管线前)
|
||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||
|
||
ok_override = reload_configuration_from_database(cid)
|
||
if not ok_override:
|
||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||
|
||
# 导入并 await 主管线(使用当前 ASGI 事件循环)
|
||
from app.core.memory.main import main as pipeline_main
|
||
from app.core.memory.utils.self_reflexion_utils import reflexion
|
||
|
||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
||
logger.info("[PILOT_RUN] pipeline_main completed")
|
||
|
||
# 调用自我反思
|
||
# data = [
|
||
# {
|
||
# "data": {
|
||
# "id": "1",
|
||
# "statement": "张明现在在谷歌工作。",
|
||
# "group_id": "1",
|
||
# "chunk_id": "10",
|
||
# "created_at": "2023-01-01",
|
||
# "expired_at": "2023-01-02",
|
||
# "valid_at": "2023-01-01",
|
||
# "invalid_at": "2023-01-02",
|
||
# "entity_ids": []
|
||
# },
|
||
# "conflict": True,
|
||
# "conflict_memory": {
|
||
# "id": "1",
|
||
# "statement": "张明现在在清华大学当讲师。",
|
||
# "group_id": "1",
|
||
# "chunk_id": "1",
|
||
# "created_at": "2019-12-01T19:15:05.213210",
|
||
# "expired_at": None,
|
||
# "valid_at": None,
|
||
# "invalid_at": None,
|
||
# "entity_ids": []
|
||
# }
|
||
# }
|
||
# ]
|
||
from app.core.memory.utils.config.get_example_data import get_example_data
|
||
data = get_example_data()
|
||
reflexion_result = await reflexion(data)
|
||
|
||
# 读取输出,使用全局配置路径
|
||
from app.core.config import settings
|
||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||
if not os.path.isfile(result_path):
|
||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||
|
||
with open(result_path, "r", encoding="utf-8") as rf:
|
||
extracted_result = json.load(rf)
|
||
|
||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
||
return {
|
||
"config_id": cid,
|
||
"time_log": os.path.join(project_root, "time.log"),
|
||
"extracted_result": extracted_result,
|
||
}
|
||
|
||
|
||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||
load_dotenv()
|
||
_neo4j_connector = Neo4jConnector()
|
||
|
||
|
||
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_DIALOGUE,
|
||
group_id=end_user_id,
|
||
)
|
||
data = {"search_for": "dialogue", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_CHUNK,
|
||
group_id=end_user_id,
|
||
)
|
||
data = {"search_for": "chunk", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_STATEMENT,
|
||
group_id=end_user_id,
|
||
)
|
||
data = {"search_for": "statement", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_ENTITY,
|
||
group_id=end_user_id,
|
||
)
|
||
data = {"search_for": "entity", "num": result[0]["num"]}
|
||
return data
|
||
|
||
|
||
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_ALL,
|
||
group_id=end_user_id,
|
||
)
|
||
|
||
# 检查结果是否为空或长度不足
|
||
if not result or len(result) < 4:
|
||
data = {
|
||
"total": 0,
|
||
"counts": {
|
||
"dialogue": 0,
|
||
"chunk": 0,
|
||
"statement": 0,
|
||
"entity": 0,
|
||
},
|
||
}
|
||
return data
|
||
|
||
data = {
|
||
"total": result[-1]["Count"],
|
||
"counts": {
|
||
"dialogue": result[0]["Count"],
|
||
"chunk": result[1]["Count"],
|
||
"statement": result[2]["Count"],
|
||
"entity": result[3]["Count"],
|
||
},
|
||
}
|
||
return data
|
||
|
||
|
||
async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""统一知识库类型分布接口。
|
||
|
||
聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。
|
||
"""
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_ALL,
|
||
group_id=end_user_id,
|
||
)
|
||
|
||
# 检查结果是否为空或长度不足
|
||
if not result or len(result) < 4:
|
||
data = {
|
||
"total": 0,
|
||
"distribution": [
|
||
{"type": "dialogue", "count": 0},
|
||
{"type": "chunk", "count": 0},
|
||
{"type": "statement", "count": 0},
|
||
{"type": "entity", "count": 0},
|
||
]
|
||
}
|
||
return data
|
||
|
||
total = result[-1]["Count"]
|
||
distribution = [
|
||
{"type": "dialogue", "count": result[0]["Count"]},
|
||
{"type": "chunk", "count": result[1]["Count"]},
|
||
{"type": "statement", "count": result[2]["Count"]},
|
||
{"type": "entity", "count": result[3]["Count"]},
|
||
]
|
||
|
||
data = {"total": total, "distribution": distribution}
|
||
return data
|
||
|
||
|
||
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_DETIALS,
|
||
group_id=end_user_id,
|
||
)
|
||
return result
|
||
|
||
|
||
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_EDGES,
|
||
group_id=end_user_id,
|
||
)
|
||
return result
|
||
|
||
|
||
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""搜索所有实体之间的关系网络(group 维度)。"""
|
||
result = await _neo4j_connector.execute_query(
|
||
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
|
||
group_id=end_user_id,
|
||
)
|
||
# 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”)
|
||
for item in result:
|
||
source_fact = item["sourceNode"]["fact_summary"]
|
||
target_fact = item["targetNode"]["fact_summary"]
|
||
# 截取前三条“来源”
|
||
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
|
||
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
|
||
# 与现有返回风格保持一致,携带搜索类型、数量与详情
|
||
data = {
|
||
"search_for": "entity_graph",
|
||
"num": len(result),
|
||
"detials": result,
|
||
}
|
||
return data
|
||
|
||
|
||
async def analytics_hot_memory_tags(
|
||
db: Session,
|
||
current_user: User,
|
||
limit: int = 10
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取热门记忆标签,按数量排序并返回前N个
|
||
"""
|
||
workspace_id = current_user.current_workspace_id
|
||
# 获取更多标签供LLM筛选(获取limit*4个标签)
|
||
raw_limit = limit * 4
|
||
from app.services.memory_dashboard_service import get_workspace_end_users
|
||
end_users = get_workspace_end_users(db, workspace_id, current_user)
|
||
|
||
tags = []
|
||
for end_user in end_users:
|
||
tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit)
|
||
if tag:
|
||
# 将每个用户的标签列表展平到总列表中
|
||
tags.extend(tag)
|
||
|
||
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
|
||
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)
|
||
|
||
# 只返回前limit个
|
||
top_tags = sorted_tags[:limit]
|
||
|
||
return [{"name": t, "frequency": f} for t, f in top_tags]
|
||
|
||
|
||
async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
insight = MemoryInsight(end_user_id)
|
||
report = await insight.generate_insight_report()
|
||
await insight.close()
|
||
data = {"report": report}
|
||
return data
|
||
|
||
|
||
async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
||
stats, _msg = get_recent_activity_stats()
|
||
total = (
|
||
stats.get("chunk_count", 0)
|
||
+ stats.get("statements_count", 0)
|
||
+ stats.get("triplet_entities_count", 0)
|
||
+ stats.get("triplet_relations_count", 0)
|
||
+ stats.get("temporal_count", 0)
|
||
)
|
||
# 精简:仅提供“最新一次活动多久前”
|
||
latest_relative = None
|
||
try:
|
||
info = stats.get("log_path", "")
|
||
idx = info.rfind("最新:")
|
||
if idx != -1:
|
||
latest_path = info[idx + 3 :].strip()
|
||
if latest_path and os.path.exists(latest_path):
|
||
import time
|
||
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
|
||
m = int(diff // 60)
|
||
if m < 1:
|
||
latest_relative = "刚刚"
|
||
elif m < 60:
|
||
latest_relative = f"{m}分钟前"
|
||
else:
|
||
h = int(m // 60)
|
||
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
|
||
except Exception:
|
||
pass
|
||
|
||
data = {"total": total, "stats": stats, "latest_relative": latest_relative}
|
||
return data
|
||
|
||
|
||
async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
summary = await generate_user_summary(end_user_id)
|
||
data = {"summary": summary}
|
||
return data |