Feature/return memoryconfig (#89)

* [add]Newly added: Memory configuration for returning results

* [add]Newly added: Memory configuration for returning results

* [changes]Based on the improvement of AI review
This commit is contained in:
乐力齐
2026-01-13 14:55:12 +08:00
committed by GitHub
parent 042a34d22f
commit 0a73b18823
2 changed files with 184 additions and 16 deletions

View File

@@ -1,18 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
from typing import Optional
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import App as AppSchema
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -102,7 +99,8 @@ async def get_workspace_end_users(
"""
获取工作空间的宿主列表
返回格式与原 memory_list 接口中的 end_users 字段相同
返回格式与原 memory_list 接口中的 end_users 字段相同
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name
"""
workspace_id = current_user.current_workspace_id
# 获取当前空间类型
@@ -113,6 +111,17 @@ async def get_workspace_end_users(
workspace_id=workspace_id,
current_user=current_user
)
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
end_user_ids = [str(user.id) for user in end_users]
memory_configs_map = {}
if end_user_ids:
try:
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
# 失败时使用空字典,不影响其他数据返回
result = []
for end_user in end_users:
memory_num = {}
@@ -123,10 +132,25 @@ async def get_workspace_end_users(
memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
}
# 从批量查询结果中获取配置信息
user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, {
"memory_config_id": None,
"memory_config_name": None
})
# 只保留需要的字段,移除 error 字段(如果有)
memory_config = {
"memory_config_id": memory_config_info.get("memory_config_id"),
"memory_config_name": memory_config_info.get("memory_config_name")
}
result.append(
{
'end_user':end_user,
'memory_num':memory_num
'end_user': end_user,
'memory_num': memory_num,
'memory_config': memory_config
}
)
@@ -465,7 +489,6 @@ async def dashboard_data(
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = None
# 根据 storage_type 决定返回哪个数据对象
# 如果是 'rag'neo4j_data 为 null否则 rag_data 为 null

View File

@@ -4,7 +4,6 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
"""
import datetime
import json
import os
import re
@@ -27,7 +26,7 @@ from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
@@ -610,7 +609,7 @@ class MemoryAgentService:
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception as e:
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
@@ -832,7 +831,6 @@ class MemoryAgentService:
# 获取当前空间下的所有宿主
from app.repositories import app_repository, end_user_repository
from app.schemas.app_schema import App as AppSchema
from app.schemas.end_user_schema import EndUser as EndUserSchema
# 查询应用并转换为 Pydantic 模型
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
@@ -1175,19 +1173,21 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
4. 根据 memory_config_id 查询配置名称
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
包含 memory_config_id、config_name 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
@@ -1220,13 +1220,158 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 4. 根据 memory_config_id 查询配置名称
config_name = None
if memory_config_id:
try:
# memory_config_id 可能是整数或字符串,需要转换
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if data_config:
config_name = data_config.config_name
logger.debug(f"Found config_name: {config_name} for config_id: {config_id}")
else:
logger.warning(f"DataConfig not found for config_id: {config_id}")
except (ValueError, TypeError) as e:
logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}")
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id
"memory_config_id": memory_config_id,
"memory_config_name": config_name
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}")
return result
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
"""
批量获取多个终端用户关联的记忆配置
通过优化的查询减少数据库往返次数:
1. 一次性查询所有 end_user 及其 app_id
2. 批量查询所有相关的 app_release
3. 批量查询所有相关的 data_config
Args:
end_user_ids: 终端用户ID列表
db: 数据库会话
Returns:
字典key 为 end_user_idvalue 为配置信息字典
对于查询失败的用户value 包含 error 字段
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
result = {}
# 1. 批量查询所有 end_user 及其 app_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 构建 end_user_id -> end_user 的映射
end_user_map = {str(user.id): user for user in end_users}
# 记录不存在的用户
for user_id in end_user_ids:
if user_id not in end_user_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"终端用户不存在: {user_id}"
}
if not end_users:
logger.warning("No valid end users found")
return result
# 2. 批量查询所有相关应用的最新发布版本
app_ids = [user.app_id for user in end_users]
# 使用子查询找到每个 app 的最新版本
from sqlalchemy import and_
# 查询所有相关的活跃发布版本
releases = db.query(AppRelease).filter(
and_(
AppRelease.app_id.in_(app_ids),
AppRelease.is_active.is_(True)
)
).order_by(AppRelease.app_id, AppRelease.version.desc()).all()
# 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
app_release_map = {}
for release in releases:
app_id_str = str(release.app_id)
if app_id_str not in app_release_map:
app_release_map[app_id_str] = release
# 3. 收集所有 memory_config_id
memory_config_ids = []
for release in app_release_map.values():
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
memory_config_ids.append(config_id)
except (ValueError, TypeError):
pass
# 4. 批量查询所有 data_config
config_name_map = {}
if memory_config_ids:
data_configs = db.query(DataConfig).filter(
DataConfig.config_id.in_(memory_config_ids)
).all()
config_name_map = {config.config_id: config.config_name for config in data_configs}
# 5. 组装结果
for user in end_users:
user_id = str(user.id)
app_id = str(user.app_id)
# 检查是否有发布版本
if app_id not in app_release_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"应用未发布: {app_id}"
}
continue
release = app_release_map[app_id]
# 提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取 config_name
config_name = None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
config_name = config_name_map.get(config_id)
except (ValueError, TypeError):
pass
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": memory_config_id,
"memory_config_name": config_name
}
logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}")
return result