[MODIFY] Code optimization
This commit is contained in:
@@ -7,9 +7,12 @@ Handles business logic for memory storage operations.
|
||||
from typing import Dict, List, Optional, Any
|
||||
import os
|
||||
import json
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.core.logging_config import get_logger
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
@@ -23,11 +26,10 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# TODO 后续更新
|
||||
# from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
# from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
# from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
# from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -52,7 +54,7 @@ class MemoryStorageService:
|
||||
Returns:
|
||||
Storage information dictionary
|
||||
"""
|
||||
logger.info(f"Getting storage info ")
|
||||
logger.info("Getting storage info ")
|
||||
|
||||
# Empty wrapper - implement your logic here
|
||||
result = {
|
||||
@@ -65,30 +67,28 @@ class MemoryStorageService:
|
||||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"""Service layer for config params CRUD.
|
||||
|
||||
The DB connection is optional; when absent, methods return a failure
|
||||
response containing an SQL preview to aid integration.
|
||||
使用 SQLAlchemy ORM 进行数据库操作。
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn: Optional[Any] = None) -> None:
|
||||
self.db_conn = db_conn
|
||||
|
||||
# --- Driver compatibility helpers ---
|
||||
@staticmethod
|
||||
def _is_pgsql_conn(conn: Any) -> bool: # 判断是否为 PostgreSQL 连接
|
||||
mod = type(conn).__module__
|
||||
return ("psycopg2" in mod) or ("psycopg" in mod)
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
"""初始化服务
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
for item in data_list:
|
||||
for field in ['created_at', 'updated_at']:
|
||||
if field in item and item[field] is not None:
|
||||
value = item[field]
|
||||
dt = None
|
||||
|
||||
|
||||
# 如果是 datetime 对象,直接使用
|
||||
if isinstance(value, datetime):
|
||||
dt = value
|
||||
@@ -98,24 +98,21 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
pass # 保持原值
|
||||
|
||||
|
||||
# 转换为 YYYYMMDDHHmmss 格式
|
||||
if dt:
|
||||
item[field] = dt.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
|
||||
return data_list
|
||||
|
||||
# --- Create ---
|
||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||||
configs = self._get_workspace_configs(params.workspace_id)
|
||||
if configs is None:
|
||||
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
|
||||
|
||||
|
||||
# 只在未指定时填充(允许手动覆盖)
|
||||
if not params.llm_id:
|
||||
params.llm_id = configs.get('llm')
|
||||
@@ -123,19 +120,16 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
params.embedding_id = configs.get('embedding')
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
query, qparams = DataConfigRepository.build_insert(params)
|
||||
cur = self.db_conn.cursor()
|
||||
# PostgreSQL 使用 psycopg2 的命名参数格式
|
||||
cur.execute(query, qparams)
|
||||
self.db_conn.commit()
|
||||
return {"affected": getattr(cur, "rowcount", None)}
|
||||
|
||||
|
||||
config = DataConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
|
||||
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
|
||||
"""获取工作空间模型配置(内部方法,便于测试)"""
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.workspace_repository import get_workspace_models_configs
|
||||
|
||||
|
||||
db_session = SessionLocal()
|
||||
try:
|
||||
return get_workspace_models_configs(db_session, workspace_id)
|
||||
@@ -143,121 +137,91 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
db_session.close()
|
||||
|
||||
# --- Delete ---
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置名称)
|
||||
query, qparams = DataConfigRepository.build_delete(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
# 如果没有任何行被删除,抛出异常
|
||||
if not affected:
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||
success = DataConfigRepository.delete(self.db, key.config_id)
|
||||
if not success:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
return {"affected": 1}
|
||||
|
||||
# --- Update ---
|
||||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||||
query, qparams = DataConfigRepository.build_update(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
config = DataConfigRepository.update(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
|
||||
return {"affected": 1}
|
||||
|
||||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||||
query, qparams = DataConfigRepository.build_update_extracted(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
config = DataConfigRepository.update_extracted(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
return {"affected": 1}
|
||||
|
||||
|
||||
# --- Forget config params ---
|
||||
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
|
||||
query, qparams = DataConfigRepository.build_update_forget(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
config = DataConfigRepository.update_forget(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
return {"affected": 1}
|
||||
|
||||
# --- Read ---
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
|
||||
query, qparams = DataConfigRepository.build_select_extracted(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||||
result = DataConfigRepository.get_extracted_config(self.db, key.config_id)
|
||||
if not result:
|
||||
raise ValueError("未找到配置")
|
||||
# Map row to dict (DB-API cursor description available for many drivers)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
raw = {columns[i]: row[i] for i in range(len(columns))}
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
data_list = self._convert_timestamps_to_format([raw])
|
||||
return data_list[0] if data_list else raw
|
||||
return result
|
||||
|
||||
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
|
||||
query, qparams = DataConfigRepository.build_select_forget(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数
|
||||
result = DataConfigRepository.get_forget_config(self.db, key.config_id)
|
||||
if not result:
|
||||
raise ValueError("未找到配置")
|
||||
# Map row to dict (DB-API cursor description available for many drivers)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
raw = {columns[i]: row[i] for i in range(len(columns))}
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
data_list = self._convert_timestamps_to_format([raw])
|
||||
return data_list[0] if data_list else raw
|
||||
return result
|
||||
|
||||
# --- Read All ---
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
query, qparams = DataConfigRepository.build_select_all(workspace_id)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
configs = DataConfigRepository.get_all(self.db, workspace_id)
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
rows = cur.fetchall()
|
||||
# 如果没有查询到任何配置,返回空列表(这是正常情况,不应抛出异常)
|
||||
if not rows:
|
||||
return []
|
||||
# Map rows to list of dicts
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
data_list = [dict(zip(columns, row)) for row in rows]
|
||||
# 将 UUID 转换为字符串,将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
for item in data_list:
|
||||
if 'workspace_id' in item and item['workspace_id'] is not None:
|
||||
item['workspace_id'] = str(item['workspace_id'])
|
||||
# 将 ORM 对象转换为字典列表
|
||||
data_list = []
|
||||
for config in configs:
|
||||
config_dict = {
|
||||
"config_id": config.config_id,
|
||||
"config_name": config.config_name,
|
||||
"config_desc": config.config_desc,
|
||||
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
||||
"group_id": config.group_id,
|
||||
"user_id": config.user_id,
|
||||
"apply_id": config.apply_id,
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
"llm": config.llm,
|
||||
"enable_llm_dedup_blockwise": config.enable_llm_dedup_blockwise,
|
||||
"enable_llm_disambiguation": config.enable_llm_disambiguation,
|
||||
"deep_retrieval": config.deep_retrieval,
|
||||
"t_type_strict": config.t_type_strict,
|
||||
"t_name_strict": config.t_name_strict,
|
||||
"t_overall": config.t_overall,
|
||||
"state": config.state,
|
||||
"chunker_strategy": config.chunker_strategy,
|
||||
"pruning_enabled": config.pruning_enabled,
|
||||
"pruning_scene": config.pruning_scene,
|
||||
"pruning_threshold": config.pruning_threshold,
|
||||
"enable_self_reflexion": config.enable_self_reflexion,
|
||||
"iteration_period": config.iteration_period,
|
||||
"reflexion_range": config.reflexion_range,
|
||||
"baseline": config.baseline,
|
||||
"statement_granularity": config.statement_granularity,
|
||||
"include_dialogue_context": config.include_dialogue_context,
|
||||
"max_context": config.max_context,
|
||||
"lambda_time": config.lambda_time,
|
||||
"lambda_mem": config.lambda_mem,
|
||||
"offset": config.offset,
|
||||
"created_at": config.created_at,
|
||||
"updated_at": config.updated_at,
|
||||
}
|
||||
data_list.append(config_dict)
|
||||
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
return self._convert_timestamps_to_format(data_list)
|
||||
|
||||
|
||||
@@ -296,7 +260,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
# 应用内存覆写并刷新常量(在导入主管线前)
|
||||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
@@ -308,7 +272,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
||||
logger.info("[PILOT_RUN] pipeline_main completed")
|
||||
|
||||
|
||||
# 调用自我反思
|
||||
# data = [
|
||||
# {
|
||||
@@ -346,10 +310,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
|
||||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
||||
return {
|
||||
"config_id": cid,
|
||||
@@ -405,7 +369,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
DataConfigRepository.SEARCH_FOR_ALL,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
|
||||
|
||||
# 检查结果是否为空或长度不足
|
||||
if not result or len(result) < 4:
|
||||
data = {
|
||||
@@ -418,7 +382,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
data = {
|
||||
"total": result[-1]["Count"],
|
||||
"counts": {
|
||||
@@ -504,14 +468,27 @@ async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, An
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(end_user_id: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取热门记忆标签,按数量排序并返回前N个
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取更多标签供LLM筛选(获取limit*4个标签)
|
||||
raw_limit = limit * 4
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=raw_limit)
|
||||
from app.services.memory_dashboard_service import get_workspace_end_users
|
||||
end_users = get_workspace_end_users(db, workspace_id, current_user)
|
||||
|
||||
tags = []
|
||||
for end_user in end_users:
|
||||
tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit)
|
||||
if tag:
|
||||
# 将每个用户的标签列表展平到总列表中
|
||||
tags.extend(tag)
|
||||
|
||||
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
|
||||
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user