去掉MCP框架,重构

This commit is contained in:
lixinyue
2026-01-19 11:56:10 +08:00
parent 546d52149d
commit 622e67e952

View File

@@ -9,30 +9,28 @@ import os
import re import re
import time import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
import redis import redis
from langchain_core.messages import HumanMessage
from app.core.config import settings from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
write_rag, write_rag,
) )
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -51,20 +49,16 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService: class MemoryAgentService:
"""Service for memory agent operations""" """Service for memory agent operations"""
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
countext = re.findall(r'"status": "(.*?)",', messages)[0]
duration = time.time() - start_time duration = time.time() - start_time
if countext == 'success': if str(messages) == 'success':
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
# 记录成功的操作 # 记录成功的操作
if audit_logger: if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
duration=duration, details={"message_length": len(message)}) duration=duration, details={"message_length": len(message)})
return countext return context
else: else:
logger.warning(f"Write operation failed for group {group_id}") logger.warning(f"Write operation failed for group {group_id}")
@@ -143,8 +137,26 @@ class MemoryAgentService:
else: else:
status = "unknown" status = "unknown"
# Add database connection pool status
try:
from app.db import get_pool_status
pool_status = get_pool_status()
logger.info(f"Database pool status: {pool_status}")
# Check if pool usage is too high
if pool_status.get("usage_percent", 0) > 80:
logger.warning(f"High database pool usage: {pool_status['usage_percent']}%")
status = "warning"
except Exception as e:
logger.error(f"Failed to get pool status: {e}")
pool_status = {"error": str(e)}
logger.info(f"Health status: {status}") logger.info(f"Health status: {status}")
return {"status": status} return {
"status": status,
"database_pool": pool_status
}
def get_log_content(self) -> str: def get_log_content(self) -> str:
""" """
@@ -156,8 +168,7 @@ class MemoryAgentService:
""" """
logger.info("Reading log file") logger.info("Reading log file")
# Use project root directory for logs
# Get the project root (redbear-mem directory)
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
project_root = os.path.dirname(app_dir) # redbear-mem directory project_root = os.path.dirname(app_dir) # redbear-mem directory
@@ -301,53 +312,41 @@ class MemoryAgentService:
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id)
return result
else:
async with client.session("data_flow") as session:
logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session)
workflow_errors = [] # Track errors from workflow
# Pass memory_config to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
logger.debug("Write graph created successfully")
try:
if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": group_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
"memory_config": memory_config}
async for event in graph.astream( # 获取节点更新信息
{"messages": message, "memory_config": memory_config, "errors": []}, async for update_event in graph.astream(
stream_mode="values", initial_state,
stream_mode="updates",
config=config config=config
): ):
messages = event.get('messages') for node_name, node_data in update_event.items():
# Capture any errors from the state if 'save_neo4j' == node_name:
if event.get('errors'): massages = node_data
workflow_errors.extend(event.get('errors', [])) massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.error(f"Write workflow failed with errors: {error_details}")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_details
)
raise ValueError(f"Write workflow failed: {error_details}")
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
async def read_memory( async def read_memory(
self, self,
@@ -387,8 +386,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
ori_message=message
end_user_id=group_id
# Resolve config_id if None using end_user's connected config # Resolve config_id if None using end_user's connected config
if config_id is None: if config_id is None:
try: try:
@@ -410,6 +408,7 @@ class MemoryAgentService:
except ImportError: except ImportError:
audit_logger = None audit_logger = None
try: try:
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
@@ -440,298 +439,98 @@ class MemoryAgentService:
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow # Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config() try:
client = MultiServerMCPClient(mcp_config) async with make_read_graph() as graph:
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time()
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
event_count = 0 "group_id": group_id
async for event in graph.astream( , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
{"messages": history, "memory_config": memory_config, "errors": []}, "memory_config": memory_config}
stream_mode="values", # 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config config=config
): ):
event_count += 1 for node_name, node_data in update_event.items():
event_start = time.time() print(f"处理节点: {node_name}")
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages: # 处理不同Summary节点的返回结构
msg_content = msg.content if 'Summary' in node_name:
msg_role = msg.__class__.__name__.lower().replace("message", "") if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
outputs.append({ summary = node_data['InputSummary']['summary_result']
"role": msg_role, elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
"content": msg_content summary = node_data['RetrieveSummary']['summary_result']
}) elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
# Extract intermediate outputs spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if hasattr(msg, 'content'): if spit_data and spit_data != [] and spit_data != {}:
try: _intermediate_outputs.append(spit_data)
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON # Problem_Extension 节点
if isinstance(content_to_parse, str): problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
try: if problem_extension and problem_extension != [] and problem_extension != {}:
parsed = json.loads(content_to_parse) _intermediate_outputs.append(problem_extension)
if isinstance(parsed, dict):
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: # Retrieve 节点
seen_intermediates.add(output_key) retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Check for multiple intermediate outputs (from Retrieve) # Verify 节点
if '_intermediates' in parsed: verify_n = node_data.get('verify', {}).get('_intermediate', None)
for intermediate_data in parsed['_intermediates']: if verify_n and verify_n != [] and verify_n != {}:
output_key = self._create_intermediate_key(intermediate_data) _intermediate_outputs.append(verify_n)
if output_key not in seen_intermediates: # Summary 节点
seen_intermediates.add(output_key) summary_n = node_data.get('summary', {}).get('_intermediate', None)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) if summary_n and summary_n != [] and summary_n != {}:
except (json.JSONDecodeError, ValueError): _intermediate_outputs.append(summary_n)
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
event_time = time.time() - event_start _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
workflow_duration = time.time() - start optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
session_duration = time.time() - session_start result = reorder_output_results(optimized_outputs)
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}] # Log successful operation
if isinstance(message, list): if audit_logger:
# Extract text from MCP content blocks duration = time.time() - start_time
for block in message: audit_logger.log_operation(
if isinstance(block, dict) and block.get('type') == 'text': operation="READ",
message = block.get('text', '') config_id=config_id,
break group_id=group_id,
else: success=True,
continue # No text block found duration=duration
)
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
return {
"answer": summary,
"intermediate_outputs": result
}
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger: if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
group_id=group_id, group_id=group_id,
success=False, success=False,
duration=total_duration, duration=duration,
error=error_details, error=error_msg
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
) )
raise ValueError(error_msg)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=True,
duration=total_duration,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer)
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
print(intermediate)
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs
}
def _create_intermediate_key(self, output: Dict) -> str:
"""
Create a unique key for an intermediate output to detect duplicates.
Args:
output: Intermediate output dictionary
Returns:
Unique string key for this output
"""
output_type = output.get('type', 'unknown')
if output_type == 'problem_split':
# Use type + original query as key
return f"split:{output.get('original_query', '')}"
elif output_type == 'problem_extension':
# Use type + original query as key
return f"extension:{output.get('original_query', '')}"
elif output_type == 'search_result':
# Use type + query + index as key
return f"search:{output.get('query', '')}:{output.get('index', 0)}"
elif output_type == 'retrieval_summary':
# Use type + query as key
return f"summary:{output.get('query', '')}"
elif output_type == 'verification':
# Use type + query as key
return f"verification:{output.get('query', '')}"
elif output_type == 'input_summary':
# Use type + query as key
return f"input_summary:{output.get('query', '')}"
else:
# Fallback: use JSON representation
import json
return json.dumps(output, sort_keys=True)
def _format_intermediate_output(self, output: Dict) -> Dict:
"""Format intermediate output for frontend display."""
output_type = output.get('type', 'unknown')
if output_type == 'problem_split':
return {
'type': 'problem_split',
'title': '问题拆分',
'data': output.get('data', []),
'original_query': output.get('original_query', '')
}
elif output_type == 'problem_extension':
return {
'type': 'problem_extension',
'title': '问题扩展',
'data': output.get('data', {}),
'original_query': output.get('original_query', '')
}
elif output_type == 'search_result':
return {
'type': 'search_result',
'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})',
'query': output.get('query', ''),
'raw_results': output.get('raw_results', ''),
'index': output.get('index', 0),
'total': output.get('total', 0)
}
elif output_type == 'retrieval_summary':
return {
'type': 'retrieval_summary',
'title': '检索总结',
'summary': output.get('summary', ''),
'query': output.get('query', ''),
'raw_results': output.get('raw_results'),
}
elif output_type == 'verification':
return {
'type': 'verification',
'title': '数据验证',
'result': output.get('result', 'unknown'),
'reason': output.get('reason', ''),
'query': output.get('query', ''),
'verified_count': output.get('verified_count', 0)
}
elif output_type == 'input_summary':
return {
'type': 'input_summary',
'title': '快速答案',
'summary': output.get('summary', ''),
'query': output.get('query', ''),
'raw_results': output.get('raw_results'),
}
else:
return output
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
@@ -838,6 +637,7 @@ class MemoryAgentService:
# 获取当前空间下的所有宿主 # 获取当前空间下的所有宿主
from app.repositories import app_repository, end_user_repository from app.repositories import app_repository, end_user_repository
from app.schemas.app_schema import App as AppSchema from app.schemas.app_schema import App as AppSchema
from app.schemas.end_user_schema import EndUser as EndUserSchema
# 查询应用并转换为 Pydantic 模型 # 查询应用并转换为 Pydantic 模型
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
@@ -1135,43 +935,6 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]:
# """
# Parse and return API documentation
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Dict containing parsed API documentation or error information
# """
# try:
# target = file_path or get_default_docs_path()
# if not os.path.isfile(target):
# return {
# "success": False,
# "msg": "API文档文件不存在",
# "error_code": "DOC_NOT_FOUND",
# "data": {"path": target}
# }
# data = parse_api_docs(target)
# return {
# "success": True,
# "msg": "解析成功",
# "data": data
# }
# except Exception as e:
# logger.error(f"Failed to parse API docs: {e}")
# return {
# "success": False,
# "msg": "解析失败",
# "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)}
# }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]: def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
""" """
获取终端用户关联的记忆配置 获取终端用户关联的记忆配置
@@ -1180,20 +943,18 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
1. 根据 end_user_id 获取用户的 app_id 1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本 2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id 3. 从发布版本的 config 字段中提取 memory_config_id
4. 根据 memory_config_id 查询配置名称
Args: Args:
end_user_id: 终端用户ID end_user_id: 终端用户ID
db: 数据库会话 db: 数据库会话
Returns: Returns:
包含 memory_config_id、config_name 和相关信息的字典 包含 memory_config_id 和相关信息的字典
Raises: Raises:
ValueError: 当终端用户不存在或应用未发布时 ValueError: 当终端用户不存在或应用未发布时
""" """
from app.models.app_release_model import AppRelease from app.models.app_release_model import AppRelease
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from sqlalchemy import select from sqlalchemy import select
@@ -1227,158 +988,13 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
memory_obj = config.get('memory', {}) memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 4. 根据 memory_config_id 查询配置名称
config_name = None
if memory_config_id:
try:
# memory_config_id 可能是整数或字符串,需要转换
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if data_config:
config_name = data_config.config_name
logger.debug(f"Found config_name: {config_name} for config_id: {config_id}")
else:
logger.warning(f"DataConfig not found for config_id: {config_id}")
except (ValueError, TypeError) as e:
logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}")
result = { result = {
"end_user_id": str(end_user_id), "end_user_id": str(end_user_id),
"app_id": str(app_id), "app_id": str(app_id),
"release_id": str(latest_release.id), "release_id": str(latest_release.id),
"release_version": latest_release.version, "release_version": latest_release.version,
"memory_config_id": memory_config_id, "memory_config_id": memory_config_id
"memory_config_name": config_name
} }
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}") logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
"""
批量获取多个终端用户关联的记忆配置
通过优化的查询减少数据库往返次数:
1. 一次性查询所有 end_user 及其 app_id
2. 批量查询所有相关的 app_release
3. 批量查询所有相关的 data_config
Args:
end_user_ids: 终端用户ID列表
db: 数据库会话
Returns:
字典key 为 end_user_idvalue 为配置信息字典
对于查询失败的用户value 包含 error 字段
"""
from app.models.app_release_model import AppRelease
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
result = {}
# 1. 批量查询所有 end_user 及其 app_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 构建 end_user_id -> end_user 的映射
end_user_map = {str(user.id): user for user in end_users}
# 记录不存在的用户
for user_id in end_user_ids:
if user_id not in end_user_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"终端用户不存在: {user_id}"
}
if not end_users:
logger.warning("No valid end users found")
return result
# 2. 批量查询所有相关应用的最新发布版本
app_ids = [user.app_id for user in end_users]
# 使用子查询找到每个 app 的最新版本
from sqlalchemy import and_
# 查询所有相关的活跃发布版本
releases = db.query(AppRelease).filter(
and_(
AppRelease.app_id.in_(app_ids),
AppRelease.is_active.is_(True)
)
).order_by(AppRelease.app_id, AppRelease.version.desc()).all()
# 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
app_release_map = {}
for release in releases:
app_id_str = str(release.app_id)
if app_id_str not in app_release_map:
app_release_map[app_id_str] = release
# 3. 收集所有 memory_config_id
memory_config_ids = []
for release in app_release_map.values():
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
memory_config_ids.append(config_id)
except (ValueError, TypeError):
pass
# 4. 批量查询所有 data_config
config_name_map = {}
if memory_config_ids:
data_configs = db.query(DataConfig).filter(
DataConfig.config_id.in_(memory_config_ids)
).all()
config_name_map = {config.config_id: config.config_name for config in data_configs}
# 5. 组装结果
for user in end_users:
user_id = str(user.id)
app_id = str(user.app_id)
# 检查是否有发布版本
if app_id not in app_release_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"应用未发布: {app_id}"
}
continue
release = app_release_map[app_id]
# 提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取 config_name
config_name = None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
config_name = config_name_map.get(config_id)
except (ValueError, TypeError):
pass
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": memory_config_id,
"memory_config_name": config_name
}
logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}")
return result return result