[MODIFY] Code optimization

This commit is contained in:
Mark
2025-12-15 14:09:43 +08:00
parent d2a630addb
commit a4e276ab27
157 changed files with 15976 additions and 3601 deletions

View File

@@ -7,9 +7,12 @@ Handles business logic for memory storage operations.
from typing import Dict, List, Optional, Any
import os
import json
from sqlalchemy.orm import Session
from dotenv import load_dotenv
from app.models.user_model import User
from app.models.end_user_model import EndUser
from app.core.logging_config import get_logger
from app.schemas.memory_storage_schema import (
ConfigFilter,
@@ -23,11 +26,10 @@ from app.schemas.memory_storage_schema import (
)
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# TODO 后续更新
# from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
# from app.core.memory.analytics.memory_insight import MemoryInsight
# from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
# from app.core.memory.analytics.user_summary import generate_user_summary
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
from app.repositories.data_config_repository import DataConfigRepository
logger = get_logger(__name__)
@@ -52,7 +54,7 @@ class MemoryStorageService:
Returns:
Storage information dictionary
"""
logger.info(f"Getting storage info ")
logger.info("Getting storage info ")
# Empty wrapper - implement your logic here
result = {
@@ -65,30 +67,28 @@ class MemoryStorageService:
class DataConfigService: # 数据配置服务类PostgreSQL
"""Service layer for config params CRUD.
The DB connection is optional; when absent, methods return a failure
response containing an SQL preview to aid integration.
使用 SQLAlchemy ORM 进行数据库操作。
"""
def __init__(self, db_conn: Optional[Any] = None) -> None:
self.db_conn = db_conn
# --- Driver compatibility helpers ---
@staticmethod
def _is_pgsql_conn(conn: Any) -> bool: # 判断是否为 PostgreSQL 连接
mod = type(conn).__module__
return ("psycopg2" in mod) or ("psycopg" in mod)
def __init__(self, db: Session) -> None:
"""初始化服务
Args:
db: SQLAlchemy 数据库会话
"""
self.db = db
@staticmethod
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
from datetime import datetime
for item in data_list:
for field in ['created_at', 'updated_at']:
if field in item and item[field] is not None:
value = item[field]
dt = None
# 如果是 datetime 对象,直接使用
if isinstance(value, datetime):
dt = value
@@ -98,24 +98,21 @@ class DataConfigService: # 数据配置服务类PostgreSQL
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
except Exception:
pass # 保持原值
# 转换为 YYYYMMDDHHmmss 格式
if dt:
item[field] = dt.strftime('%Y%m%d%H%M%S')
return data_list
# --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
# 如果workspace_id存在且模型字段未全部指定则自动获取
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
configs = self._get_workspace_configs(params.workspace_id)
if configs is None:
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
# 只在未指定时填充(允许手动覆盖)
if not params.llm_id:
params.llm_id = configs.get('llm')
@@ -123,19 +120,16 @@ class DataConfigService: # 数据配置服务类PostgreSQL
params.embedding_id = configs.get('embedding')
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
query, qparams = DataConfigRepository.build_insert(params)
cur = self.db_conn.cursor()
# PostgreSQL 使用 psycopg2 的命名参数格式
cur.execute(query, qparams)
self.db_conn.commit()
return {"affected": getattr(cur, "rowcount", None)}
config = DataConfigRepository.create(self.db, params)
self.db.commit()
return {"affected": 1, "config_id": config.config_id}
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
"""获取工作空间模型配置(内部方法,便于测试)"""
from app.db import SessionLocal
from app.repositories.workspace_repository import get_workspace_models_configs
db_session = SessionLocal()
try:
return get_workspace_models_configs(db_session, workspace_id)
@@ -143,121 +137,91 @@ class DataConfigService: # 数据配置服务类PostgreSQL
db_session.close()
# --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置名称
query, qparams = DataConfigRepository.build_delete(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
# 如果没有任何行被删除,抛出异常
if not affected:
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID
success = DataConfigRepository.delete(self.db, key.config_id)
if not success:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
# --- Update ---
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
query, qparams = DataConfigRepository.build_update(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
config = DataConfigRepository.update(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
query, qparams = DataConfigRepository.build_update_extracted(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
config = DataConfigRepository.update_extracted(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
# --- Forget config params ---
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
query, qparams = DataConfigRepository.build_update_forget(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
config = DataConfigRepository.update_forget(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
# --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
query, qparams = DataConfigRepository.build_select_extracted(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
row = cur.fetchone()
if not row:
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
result = DataConfigRepository.get_extracted_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
# Map row to dict (DB-API cursor description available for many drivers)
columns = [desc[0] for desc in cur.description]
raw = {columns[i]: row[i] for i in range(len(columns))}
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
data_list = self._convert_timestamps_to_format([raw])
return data_list[0] if data_list else raw
return result
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
query, qparams = DataConfigRepository.build_select_forget(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
row = cur.fetchone()
if not row:
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数
result = DataConfigRepository.get_forget_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
# Map row to dict (DB-API cursor description available for many drivers)
columns = [desc[0] for desc in cur.description]
raw = {columns[i]: row[i] for i in range(len(columns))}
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
data_list = self._convert_timestamps_to_format([raw])
return data_list[0] if data_list else raw
return result
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
query, qparams = DataConfigRepository.build_select_all(workspace_id)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
configs = DataConfigRepository.get_all(self.db, workspace_id)
cur = self.db_conn.cursor()
cur.execute(query, qparams)
rows = cur.fetchall()
# 如果没有查询到任何配置,返回空列表(这是正常情况,不应抛出异常)
if not rows:
return []
# Map rows to list of dicts
columns = [desc[0] for desc in cur.description]
data_list = [dict(zip(columns, row)) for row in rows]
# 将 UUID 转换为字符串,将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
for item in data_list:
if 'workspace_id' in item and item['workspace_id'] is not None:
item['workspace_id'] = str(item['workspace_id'])
# 将 ORM 对象转换为字典列表
data_list = []
for config in configs:
config_dict = {
"config_id": config.config_id,
"config_name": config.config_name,
"config_desc": config.config_desc,
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
"group_id": config.group_id,
"user_id": config.user_id,
"apply_id": config.apply_id,
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,
"llm": config.llm,
"enable_llm_dedup_blockwise": config.enable_llm_dedup_blockwise,
"enable_llm_disambiguation": config.enable_llm_disambiguation,
"deep_retrieval": config.deep_retrieval,
"t_type_strict": config.t_type_strict,
"t_name_strict": config.t_name_strict,
"t_overall": config.t_overall,
"state": config.state,
"chunker_strategy": config.chunker_strategy,
"pruning_enabled": config.pruning_enabled,
"pruning_scene": config.pruning_scene,
"pruning_threshold": config.pruning_threshold,
"enable_self_reflexion": config.enable_self_reflexion,
"iteration_period": config.iteration_period,
"reflexion_range": config.reflexion_range,
"baseline": config.baseline,
"statement_granularity": config.statement_granularity,
"include_dialogue_context": config.include_dialogue_context,
"max_context": config.max_context,
"lambda_time": config.lambda_time,
"lambda_mem": config.lambda_mem,
"offset": config.offset,
"created_at": config.created_at,
"updated_at": config.updated_at,
}
data_list.append(config_dict)
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
return self._convert_timestamps_to_format(data_list)
@@ -296,7 +260,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 应用内存覆写并刷新常量(在导入主管线前)
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
@@ -308,7 +272,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
logger.info("[PILOT_RUN] pipeline_main completed")
# 调用自我反思
# data = [
# {
@@ -346,10 +310,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
return {
"config_id": cid,
@@ -405,7 +369,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
)
# 检查结果是否为空或长度不足
if not result or len(result) < 4:
data = {
@@ -418,7 +382,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
},
}
return data
data = {
"total": result[-1]["Count"],
"counts": {
@@ -504,14 +468,27 @@ async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, An
return data
async def analytics_hot_memory_tags(end_user_id: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
async def analytics_hot_memory_tags(
db: Session,
current_user: User,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取热门记忆标签按数量排序并返回前N个
"""
workspace_id = current_user.current_workspace_id
# 获取更多标签供LLM筛选获取limit*4个标签
raw_limit = limit * 4
tags = await get_hot_memory_tags(end_user_id, limit=raw_limit)
from app.services.memory_dashboard_service import get_workspace_end_users
end_users = get_workspace_end_users(db, workspace_id, current_user)
tags = []
for end_user in end_users:
tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit)
if tag:
# 将每个用户的标签列表展平到总列表中
tags.extend(tag)
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)