Feature/return memoryconfig (#89)
* [add]Newly added: Memory configuration for returning results * [add]Newly added: Memory configuration for returning results * [changes]Based on the improvement of AI review
This commit is contained in:
@@ -1,18 +1,15 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
import uuid
|
|
||||||
from app.repositories.end_user_repository import update_end_user_other_name
|
|
||||||
import uuid
|
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.schemas.memory_agent_schema import End_User_Information
|
from app.schemas.memory_agent_schema import End_User_Information
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.schemas.app_schema import App as AppSchema
|
|
||||||
|
|
||||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||||
|
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
@@ -102,7 +99,8 @@ async def get_workspace_end_users(
|
|||||||
"""
|
"""
|
||||||
获取工作空间的宿主列表
|
获取工作空间的宿主列表
|
||||||
|
|
||||||
返回格式与原 memory_list 接口中的 end_users 字段相同
|
返回格式与原 memory_list 接口中的 end_users 字段相同,
|
||||||
|
并包含每个用户的记忆配置信息(memory_config_id 和 memory_config_name)
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# 获取当前空间类型
|
# 获取当前空间类型
|
||||||
@@ -113,6 +111,17 @@ async def get_workspace_end_users(
|
|||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
|
||||||
|
end_user_ids = [str(user.id) for user in end_users]
|
||||||
|
memory_configs_map = {}
|
||||||
|
if end_user_ids:
|
||||||
|
try:
|
||||||
|
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||||
|
# 失败时使用空字典,不影响其他数据返回
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
memory_num = {}
|
memory_num = {}
|
||||||
@@ -123,10 +132,25 @@ async def get_workspace_end_users(
|
|||||||
memory_num = {
|
memory_num = {
|
||||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 从批量查询结果中获取配置信息
|
||||||
|
user_id = str(end_user.id)
|
||||||
|
memory_config_info = memory_configs_map.get(user_id, {
|
||||||
|
"memory_config_id": None,
|
||||||
|
"memory_config_name": None
|
||||||
|
})
|
||||||
|
|
||||||
|
# 只保留需要的字段,移除 error 字段(如果有)
|
||||||
|
memory_config = {
|
||||||
|
"memory_config_id": memory_config_info.get("memory_config_id"),
|
||||||
|
"memory_config_name": memory_config_info.get("memory_config_name")
|
||||||
|
}
|
||||||
|
|
||||||
result.append(
|
result.append(
|
||||||
{
|
{
|
||||||
'end_user': end_user,
|
'end_user': end_user,
|
||||||
'memory_num':memory_num
|
'memory_num': memory_num,
|
||||||
|
'memory_config': memory_config
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -465,7 +489,6 @@ async def dashboard_data(
|
|||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
user_rag_memory_id = None
|
|
||||||
|
|
||||||
# 根据 storage_type 决定返回哪个数据对象
|
# 根据 storage_type 决定返回哪个数据对象
|
||||||
# 如果是 'rag',neo4j_data 为 null;否则 rag_data 为 null
|
# 如果是 'rag',neo4j_data 为 null;否则 rag_data 为 null
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Memory Agent Service
|
|||||||
Handles business logic for memory agent operations including read/write services,
|
Handles business logic for memory agent operations including read/write services,
|
||||||
health checks, and message type classification.
|
health checks, and message type classification.
|
||||||
"""
|
"""
|
||||||
import datetime
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -27,7 +26,7 @@ from app.db import get_db_context
|
|||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
@@ -610,7 +609,7 @@ class MemoryAgentService:
|
|||||||
reranked_results=raw_results.get('reranked_results',[])
|
reranked_results=raw_results.get('reranked_results',[])
|
||||||
try:
|
try:
|
||||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||||
except Exception as e:
|
except Exception:
|
||||||
statements=[]
|
statements=[]
|
||||||
statements=list(set(statements))
|
statements=list(set(statements))
|
||||||
retrieved_content.append({query:statements})
|
retrieved_content.append({query:statements})
|
||||||
@@ -832,7 +831,6 @@ class MemoryAgentService:
|
|||||||
# 获取当前空间下的所有宿主
|
# 获取当前空间下的所有宿主
|
||||||
from app.repositories import app_repository, end_user_repository
|
from app.repositories import app_repository, end_user_repository
|
||||||
from app.schemas.app_schema import App as AppSchema
|
from app.schemas.app_schema import App as AppSchema
|
||||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
|
||||||
|
|
||||||
# 查询应用并转换为 Pydantic 模型
|
# 查询应用并转换为 Pydantic 模型
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
||||||
@@ -1175,19 +1173,21 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
1. 根据 end_user_id 获取用户的 app_id
|
1. 根据 end_user_id 获取用户的 app_id
|
||||||
2. 获取该应用的最新发布版本
|
2. 获取该应用的最新发布版本
|
||||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||||
|
4. 根据 memory_config_id 查询配置名称
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 终端用户ID
|
end_user_id: 终端用户ID
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含 memory_config_id 和相关信息的字典
|
包含 memory_config_id、config_name 和相关信息的字典
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 当终端用户不存在或应用未发布时
|
ValueError: 当终端用户不存在或应用未发布时
|
||||||
"""
|
"""
|
||||||
from app.models.app_release_model import AppRelease
|
from app.models.app_release_model import AppRelease
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.models.data_config_model import DataConfig
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||||
@@ -1220,13 +1220,158 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
memory_obj = config.get('memory', {})
|
memory_obj = config.get('memory', {})
|
||||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||||
|
|
||||||
|
# 4. 根据 memory_config_id 查询配置名称
|
||||||
|
config_name = None
|
||||||
|
if memory_config_id:
|
||||||
|
try:
|
||||||
|
# memory_config_id 可能是整数或字符串,需要转换
|
||||||
|
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
||||||
|
data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||||
|
if data_config:
|
||||||
|
config_name = data_config.config_name
|
||||||
|
logger.debug(f"Found config_name: {config_name} for config_id: {config_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"DataConfig not found for config_id: {config_id}")
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}")
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"end_user_id": str(end_user_id),
|
"end_user_id": str(end_user_id),
|
||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"release_id": str(latest_release.id),
|
"release_id": str(latest_release.id),
|
||||||
"release_version": latest_release.version,
|
"release_version": latest_release.version,
|
||||||
"memory_config_id": memory_config_id
|
"memory_config_id": memory_config_id,
|
||||||
|
"memory_config_name": config_name
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
批量获取多个终端用户关联的记忆配置
|
||||||
|
|
||||||
|
通过优化的查询减少数据库往返次数:
|
||||||
|
1. 一次性查询所有 end_user 及其 app_id
|
||||||
|
2. 批量查询所有相关的 app_release
|
||||||
|
3. 批量查询所有相关的 data_config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_ids: 终端用户ID列表
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,key 为 end_user_id,value 为配置信息字典
|
||||||
|
对于查询失败的用户,value 包含 error 字段
|
||||||
|
"""
|
||||||
|
from app.models.app_release_model import AppRelease
|
||||||
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.models.data_config_model import DataConfig
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 1. 批量查询所有 end_user 及其 app_id
|
||||||
|
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||||||
|
|
||||||
|
# 构建 end_user_id -> end_user 的映射
|
||||||
|
end_user_map = {str(user.id): user for user in end_users}
|
||||||
|
|
||||||
|
# 记录不存在的用户
|
||||||
|
for user_id in end_user_ids:
|
||||||
|
if user_id not in end_user_map:
|
||||||
|
result[user_id] = {
|
||||||
|
"end_user_id": user_id,
|
||||||
|
"memory_config_id": None,
|
||||||
|
"memory_config_name": None,
|
||||||
|
"error": f"终端用户不存在: {user_id}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not end_users:
|
||||||
|
logger.warning("No valid end users found")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 2. 批量查询所有相关应用的最新发布版本
|
||||||
|
app_ids = [user.app_id for user in end_users]
|
||||||
|
|
||||||
|
# 使用子查询找到每个 app 的最新版本
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
# 查询所有相关的活跃发布版本
|
||||||
|
releases = db.query(AppRelease).filter(
|
||||||
|
and_(
|
||||||
|
AppRelease.app_id.in_(app_ids),
|
||||||
|
AppRelease.is_active.is_(True)
|
||||||
|
)
|
||||||
|
).order_by(AppRelease.app_id, AppRelease.version.desc()).all()
|
||||||
|
|
||||||
|
# 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
|
||||||
|
app_release_map = {}
|
||||||
|
for release in releases:
|
||||||
|
app_id_str = str(release.app_id)
|
||||||
|
if app_id_str not in app_release_map:
|
||||||
|
app_release_map[app_id_str] = release
|
||||||
|
|
||||||
|
# 3. 收集所有 memory_config_id
|
||||||
|
memory_config_ids = []
|
||||||
|
for release in app_release_map.values():
|
||||||
|
config = release.config or {}
|
||||||
|
memory_obj = config.get('memory', {})
|
||||||
|
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||||
|
if memory_config_id:
|
||||||
|
try:
|
||||||
|
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
||||||
|
memory_config_ids.append(config_id)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 4. 批量查询所有 data_config
|
||||||
|
config_name_map = {}
|
||||||
|
if memory_config_ids:
|
||||||
|
data_configs = db.query(DataConfig).filter(
|
||||||
|
DataConfig.config_id.in_(memory_config_ids)
|
||||||
|
).all()
|
||||||
|
config_name_map = {config.config_id: config.config_name for config in data_configs}
|
||||||
|
|
||||||
|
# 5. 组装结果
|
||||||
|
for user in end_users:
|
||||||
|
user_id = str(user.id)
|
||||||
|
app_id = str(user.app_id)
|
||||||
|
|
||||||
|
# 检查是否有发布版本
|
||||||
|
if app_id not in app_release_map:
|
||||||
|
result[user_id] = {
|
||||||
|
"end_user_id": user_id,
|
||||||
|
"memory_config_id": None,
|
||||||
|
"memory_config_name": None,
|
||||||
|
"error": f"应用未发布: {app_id}"
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
release = app_release_map[app_id]
|
||||||
|
|
||||||
|
# 提取 memory_config_id
|
||||||
|
config = release.config or {}
|
||||||
|
memory_obj = config.get('memory', {})
|
||||||
|
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||||
|
|
||||||
|
# 获取 config_name
|
||||||
|
config_name = None
|
||||||
|
if memory_config_id:
|
||||||
|
try:
|
||||||
|
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
||||||
|
config_name = config_name_map.get(config_id)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
result[user_id] = {
|
||||||
|
"end_user_id": user_id,
|
||||||
|
"memory_config_id": memory_config_id,
|
||||||
|
"memory_config_name": config_name
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}")
|
||||||
return result
|
return result
|
||||||
Reference in New Issue
Block a user