Feature/return memoryconfig (#89)

* [add]Newly added: Memory configuration for returning results

* [add]Newly added: Memory configuration for returning results

* [changes]Based on the improvement of AI review
This commit is contained in:
乐力齐
2026-01-13 14:55:12 +08:00
committed by GitHub
parent 042a34d22f
commit 0a73b18823
2 changed files with 184 additions and 16 deletions

View File

@@ -1,18 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional from typing import Optional
import uuid
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import App as AppSchema
from app.services import memory_dashboard_service, memory_storage_service, workspace_service from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
# 获取API专用日志器 # 获取API专用日志器
@@ -102,7 +99,8 @@ async def get_workspace_end_users(
""" """
获取工作空间的宿主列表 获取工作空间的宿主列表
返回格式与原 memory_list 接口中的 end_users 字段相同 返回格式与原 memory_list 接口中的 end_users 字段相同
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 获取当前空间类型 # 获取当前空间类型
@@ -113,6 +111,17 @@ async def get_workspace_end_users(
workspace_id=workspace_id, workspace_id=workspace_id,
current_user=current_user current_user=current_user
) )
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
end_user_ids = [str(user.id) for user in end_users]
memory_configs_map = {}
if end_user_ids:
try:
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
# 失败时使用空字典,不影响其他数据返回
result = [] result = []
for end_user in end_users: for end_user in end_users:
memory_num = {} memory_num = {}
@@ -123,10 +132,25 @@ async def get_workspace_end_users(
memory_num = { memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user) "total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
} }
# 从批量查询结果中获取配置信息
user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, {
"memory_config_id": None,
"memory_config_name": None
})
# 只保留需要的字段,移除 error 字段(如果有)
memory_config = {
"memory_config_id": memory_config_info.get("memory_config_id"),
"memory_config_name": memory_config_info.get("memory_config_name")
}
result.append( result.append(
{ {
'end_user':end_user, 'end_user': end_user,
'memory_num':memory_num 'memory_num': memory_num,
'memory_config': memory_config
} }
) )
@@ -465,7 +489,6 @@ async def dashboard_data(
if storage_type is None: if storage_type is None:
storage_type = 'neo4j' storage_type = 'neo4j'
user_rag_memory_id = None
# 根据 storage_type 决定返回哪个数据对象 # 根据 storage_type 决定返回哪个数据对象
# 如果是 'rag'neo4j_data 为 null否则 rag_data 为 null # 如果是 'rag'neo4j_data 为 null否则 rag_data 为 null

View File

@@ -4,7 +4,6 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services, Handles business logic for memory agent operations including read/write services,
health checks, and message type classification. health checks, and message type classification.
""" """
import datetime
import json import json
import os import os
import re import re
@@ -27,7 +26,7 @@ from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
write_rag, write_rag,
@@ -610,7 +609,7 @@ class MemoryAgentService:
reranked_results=raw_results.get('reranked_results',[]) reranked_results=raw_results.get('reranked_results',[])
try: try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])] statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception as e: except Exception:
statements=[] statements=[]
statements=list(set(statements)) statements=list(set(statements))
retrieved_content.append({query:statements}) retrieved_content.append({query:statements})
@@ -832,7 +831,6 @@ class MemoryAgentService:
# 获取当前空间下的所有宿主 # 获取当前空间下的所有宿主
from app.repositories import app_repository, end_user_repository from app.repositories import app_repository, end_user_repository
from app.schemas.app_schema import App as AppSchema from app.schemas.app_schema import App as AppSchema
from app.schemas.end_user_schema import EndUser as EndUserSchema
# 查询应用并转换为 Pydantic 模型 # 查询应用并转换为 Pydantic 模型
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
@@ -1175,19 +1173,21 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
1. 根据 end_user_id 获取用户的 app_id 1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本 2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id 3. 从发布版本的 config 字段中提取 memory_config_id
4. 根据 memory_config_id 查询配置名称
Args: Args:
end_user_id: 终端用户ID end_user_id: 终端用户ID
db: 数据库会话 db: 数据库会话
Returns: Returns:
包含 memory_config_id 和相关信息的字典 包含 memory_config_id、config_name 和相关信息的字典
Raises: Raises:
ValueError: 当终端用户不存在或应用未发布时 ValueError: 当终端用户不存在或应用未发布时
""" """
from app.models.app_release_model import AppRelease from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from sqlalchemy import select from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}") logger.info(f"Getting connected config for end_user: {end_user_id}")
@@ -1220,13 +1220,158 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
memory_obj = config.get('memory', {}) memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 4. 根据 memory_config_id 查询配置名称
config_name = None
if memory_config_id:
try:
# memory_config_id 可能是整数或字符串,需要转换
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if data_config:
config_name = data_config.config_name
logger.debug(f"Found config_name: {config_name} for config_id: {config_id}")
else:
logger.warning(f"DataConfig not found for config_id: {config_id}")
except (ValueError, TypeError) as e:
logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}")
result = { result = {
"end_user_id": str(end_user_id), "end_user_id": str(end_user_id),
"app_id": str(app_id), "app_id": str(app_id),
"release_id": str(latest_release.id), "release_id": str(latest_release.id),
"release_version": latest_release.version, "release_version": latest_release.version,
"memory_config_id": memory_config_id "memory_config_id": memory_config_id,
"memory_config_name": config_name
} }
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}")
return result
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
"""
批量获取多个终端用户关联的记忆配置
通过优化的查询减少数据库往返次数:
1. 一次性查询所有 end_user 及其 app_id
2. 批量查询所有相关的 app_release
3. 批量查询所有相关的 data_config
Args:
end_user_ids: 终端用户ID列表
db: 数据库会话
Returns:
字典key 为 end_user_idvalue 为配置信息字典
对于查询失败的用户value 包含 error 字段
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
result = {}
# 1. 批量查询所有 end_user 及其 app_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 构建 end_user_id -> end_user 的映射
end_user_map = {str(user.id): user for user in end_users}
# 记录不存在的用户
for user_id in end_user_ids:
if user_id not in end_user_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"终端用户不存在: {user_id}"
}
if not end_users:
logger.warning("No valid end users found")
return result
# 2. 批量查询所有相关应用的最新发布版本
app_ids = [user.app_id for user in end_users]
# 使用子查询找到每个 app 的最新版本
from sqlalchemy import and_
# 查询所有相关的活跃发布版本
releases = db.query(AppRelease).filter(
and_(
AppRelease.app_id.in_(app_ids),
AppRelease.is_active.is_(True)
)
).order_by(AppRelease.app_id, AppRelease.version.desc()).all()
# 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
app_release_map = {}
for release in releases:
app_id_str = str(release.app_id)
if app_id_str not in app_release_map:
app_release_map[app_id_str] = release
# 3. 收集所有 memory_config_id
memory_config_ids = []
for release in app_release_map.values():
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
memory_config_ids.append(config_id)
except (ValueError, TypeError):
pass
# 4. 批量查询所有 data_config
config_name_map = {}
if memory_config_ids:
data_configs = db.query(DataConfig).filter(
DataConfig.config_id.in_(memory_config_ids)
).all()
config_name_map = {config.config_id: config.config_name for config in data_configs}
# 5. 组装结果
for user in end_users:
user_id = str(user.id)
app_id = str(user.app_id)
# 检查是否有发布版本
if app_id not in app_release_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"应用未发布: {app_id}"
}
continue
release = app_release_map[app_id]
# 提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取 config_name
config_name = None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
config_name = config_name_map.get(config_id)
except (ValueError, TypeError):
pass
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": memory_config_id,
"memory_config_name": config_name
}
logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}")
return result return result