去掉MCP框架,重构

This commit is contained in:
lixinyue
2026-01-14 18:29:33 +08:00
parent 92b144d7f5
commit 0b685b136f
62 changed files with 3263 additions and 4421 deletions

View File

@@ -13,26 +13,26 @@ from threading import Lock
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_core.messages import HumanMessage
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -55,18 +55,16 @@ class MemoryAgentService:
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
countext = re.findall(r'"status": "(.*?)",', messages)[0]
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
duration = time.time() - start_time
if countext == 'success':
if str(messages) == 'success':
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
duration=duration, details={"message_length": len(message)})
return countext
return context
else:
logger.warning(f"Write operation failed for group {group_id}")
@@ -150,8 +148,26 @@ class MemoryAgentService:
else:
status = "unknown"
# Add database connection pool status
try:
from app.db import get_pool_status
pool_status = get_pool_status()
logger.info(f"Database pool status: {pool_status}")
# Check if pool usage is too high
if pool_status.get("usage_percent", 0) > 80:
logger.warning(f"High database pool usage: {pool_status['usage_percent']}%")
status = "warning"
except Exception as e:
logger.error(f"Failed to get pool status: {e}")
pool_status = {"error": str(e)}
logger.info(f"Health status: {status}")
return {"status": status}
return {
"status": status,
"database_pool": pool_status
}
def get_log_content(self) -> str:
"""
@@ -308,54 +324,42 @@ class MemoryAgentService:
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id)
return result
else:
async with client.session("data_flow") as session:
logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session)
workflow_errors = [] # Track errors from workflow
# Pass memory_config to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
logger.debug("Write graph created successfully")
try:
if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
"memory_config": memory_config}
async for event in graph.astream(
{"messages": message, "memory_config": memory_config, "errors": []},
stream_mode="values",
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.error(f"Write workflow failed with errors: {error_details}")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_details
)
raise ValueError(f"Write workflow failed: {error_details}")
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
self,
group_id: str,
@@ -394,8 +398,9 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message=message
end_user_id=group_id
ori_message=message
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
@@ -408,15 +413,15 @@ class MemoryAgentService:
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
# Get group lock to prevent concurrent processing
group_lock = self.get_group_lock(group_id)
@@ -432,7 +437,7 @@ class MemoryAgentService:
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
@@ -444,305 +449,133 @@ class MemoryAgentService:
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"group_id": group_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for node_name, node_data in update_event.items():
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg_role,
"content": msg_content
})
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
summary = node_data['InputSummary']['summary_result']
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
summary = node_data['RetrieveSummary']['summary_result']
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
event_time = time.time() - event_start
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
workflow_duration = time.time() - start
session_duration = time.time() - session_start
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
result = reorder_output_results(optimized_outputs)
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# Log successful operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=True,
duration=duration
)
# 记录成功的操作
total_duration = time.time() - start_time
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch).strip() != "2":
for intermediate in result:
intermediate_type = intermediate['type']
if intermediate_type == "search_result":
query = intermediate['query']
raw_results = intermediate['raw_results']
reranked_results = raw_results.get('reranked_results', [])
try:
statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception:
statements = []
statements = list(set(statements))
retrieved_content.append({query: statements})
if retrieved_content == []:
retrieved_content = ''
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=summary,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return {
"answer": summary,
"intermediate_outputs": result
}
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
duration=duration,
error=error_msg
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=True,
duration=total_duration,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer)
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
print(intermediate)
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs
}
raise ValueError(error_msg)
def _create_intermediate_key(self, output: Dict) -> str:
"""
Create a unique key for an intermediate output to detect duplicates.
Args:
output: Intermediate output dictionary
Returns:
Unique string key for this output
"""
output_type = output.get('type', 'unknown')
if output_type == 'problem_split':
# Use type + original query as key
return f"split:{output.get('original_query', '')}"
elif output_type == 'problem_extension':
# Use type + original query as key
return f"extension:{output.get('original_query', '')}"
elif output_type == 'search_result':
# Use type + query + index as key
return f"search:{output.get('query', '')}:{output.get('index', 0)}"
elif output_type == 'retrieval_summary':
# Use type + query as key
return f"summary:{output.get('query', '')}"
elif output_type == 'verification':
# Use type + query as key
return f"verification:{output.get('query', '')}"
elif output_type == 'input_summary':
# Use type + query as key
return f"input_summary:{output.get('query', '')}"
else:
# Fallback: use JSON representation
import json
return json.dumps(output, sort_keys=True)
def _format_intermediate_output(self, output: Dict) -> Dict:
"""Format intermediate output for frontend display."""
output_type = output.get('type', 'unknown')
if output_type == 'problem_split':
return {
'type': 'problem_split',
'title': '问题拆分',
'data': output.get('data', []),
'original_query': output.get('original_query', '')
}
elif output_type == 'problem_extension':
return {
'type': 'problem_extension',
'title': '问题扩展',
'data': output.get('data', {}),
'original_query': output.get('original_query', '')
}
elif output_type == 'search_result':
return {
'type': 'search_result',
'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})',
'query': output.get('query', ''),
'raw_results': output.get('raw_results', ''),
'index': output.get('index', 0),
'total': output.get('total', 0)
}
elif output_type == 'retrieval_summary':
return {
'type': 'retrieval_summary',
'title': '检索总结',
'summary': output.get('summary', ''),
'query': output.get('query', ''),
'raw_results': output.get('raw_results'),
}
elif output_type == 'verification':
return {
'type': 'verification',
'title': '数据验证',
'result': output.get('result', 'unknown'),
'reason': output.get('reason', ''),
'query': output.get('query', ''),
'verified_count': output.get('verified_count', 0)
}
elif output_type == 'input_summary':
return {
'type': 'input_summary',
'title': '快速答案',
'summary': output.get('summary', ''),
'query': output.get('query', ''),
'raw_results': output.get('raw_results'),
}
else:
return output
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
"""
@@ -850,6 +683,7 @@ class MemoryAgentService:
# 获取当前空间下的所有宿主
from app.repositories import app_repository, end_user_repository
from app.schemas.app_schema import App as AppSchema
from app.schemas.end_user_schema import EndUser as EndUserSchema
# 查询应用并转换为 Pydantic 模型
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
@@ -1147,43 +981,6 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]:
# """
# Parse and return API documentation
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Dict containing parsed API documentation or error information
# """
# try:
# target = file_path or get_default_docs_path()
# if not os.path.isfile(target):
# return {
# "success": False,
# "msg": "API文档文件不存在",
# "error_code": "DOC_NOT_FOUND",
# "data": {"path": target}
# }
# data = parse_api_docs(target)
# return {
# "success": True,
# "msg": "解析成功",
# "data": data
# }
# except Exception as e:
# logger.error(f"Failed to parse API docs: {e}")
# return {
# "success": False,
# "msg": "解析失败",
# "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)}
# }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
@@ -1192,20 +989,18 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
4. 根据 memory_config_id 查询配置名称
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id、config_name 和相关信息的字典
包含 memory_config_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser
from sqlalchemy import select
@@ -1239,31 +1034,15 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 4. 根据 memory_config_id 查询配置名称
config_name = None
if memory_config_id:
try:
# memory_config_id 可能是整数或字符串,需要转换
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if data_config:
config_name = data_config.config_name
logger.debug(f"Found config_name: {config_name} for config_id: {config_id}")
else:
logger.warning(f"DataConfig not found for config_id: {config_id}")
except (ValueError, TypeError) as e:
logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}")
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id,
"memory_config_name": config_name
"memory_config_id": memory_config_id
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}")
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result
@@ -1271,126 +1050,112 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
"""
批量获取多个终端用户关联的记忆配置
通过优化的查询减少数据库往返次数
1. 一次性查询所有 end_user 及其 app_id
2. 批量查询所有相关的 app_release
3. 批量查询所有相关的 data_config
通过以下流程获取配置
1. 批量查询所有 end_user 及其 app_id
2. 批量获取所有应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id 和 memory_config_name
Args:
end_user_ids: 终端用户ID列表
db: 数据库会话
Returns:
字典key 为 end_user_idvalue 为配置信息字典
对于查询失败的用户value 包含 error 字段
字典key 为 end_user_idvalue 为包含 memory_config_id 和 memory_config_name 的字典
格式: {
"user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"},
"user_id_2": {"memory_config_id": None, "memory_config_name": None},
...
}
"""
from app.models.app_release_model import AppRelease
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
result = {}
# 1. 批量查询所有 end_user 及其 app_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 建 end_user_id -> end_user 的映射
end_user_map = {str(user.id): user for user in end_users}
# 建 end_user_id 到 app_id 的映射
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
# 记录不存在的用户
for user_id in end_user_ids:
if user_id not in end_user_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"终端用户不存在: {user_id}"
}
# 获取所有相关的 app_id
app_ids = list(set(user_to_app.values()))
if not end_users:
logger.warning("No valid end users found")
if not app_ids:
logger.warning("No valid app_ids found for the provided end_user_ids")
# 返回空配置
for user_id in end_user_ids:
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
return result
# 2. 批量查询所有相关应用的最新发布版本
app_ids = [user.app_id for user in end_users]
# 2. 批量获取所有应用的最新发布版本
# 使用子查询找到每个 app 的最新版本
from sqlalchemy import and_
from sqlalchemy import func
# 查询所有相关的活跃发布版本
releases = db.query(AppRelease).filter(
and_(
AppRelease.app_id.in_(app_ids),
AppRelease.is_active.is_(True)
subq = (
select(
AppRelease.app_id,
func.max(AppRelease.version).label('max_version')
)
).order_by(AppRelease.app_id, AppRelease.version.desc()).all()
.where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True))
.group_by(AppRelease.app_id)
.subquery()
)
# 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
app_release_map = {}
for release in releases:
app_id_str = str(release.app_id)
if app_id_str not in app_release_map:
app_release_map[app_id_str] = release
stmt = (
select(AppRelease)
.join(
subq,
(AppRelease.app_id == subq.c.app_id) & (AppRelease.version == subq.c.max_version)
)
.where(AppRelease.is_active.is_(True))
)
# 3. 收集所有 memory_config_id
latest_releases = db.scalars(stmt).all()
# 创建 app_id 到 release 的映射
app_to_release = {str(release.app_id): release for release in latest_releases}
# 3. 提取所有 memory_config_id
memory_config_ids = []
for release in app_release_map.values():
for release in latest_releases:
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
memory_config_ids.append(config_id)
except (ValueError, TypeError):
pass
memory_config_ids.append(memory_config_id)
# 4. 批量查询所有 data_config
config_name_map = {}
# 4. 批量查询 memory_config_name
memory_configs = {}
if memory_config_ids:
data_configs = db.query(DataConfig).filter(
DataConfig.config_id.in_(memory_config_ids)
).all()
config_name_map = {config.config_id: config.config_name for config in data_configs}
configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
memory_configs = {str(cfg.id): cfg.config_name for cfg in configs}
# 5. 组装结果
for user in end_users:
user_id = str(user.id)
app_id = str(user.app_id)
# 检查是否有发布版本
if app_id not in app_release_map:
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": None,
"memory_config_name": None,
"error": f"应用未发布: {app_id}"
}
for user_id in end_user_ids:
app_id = user_to_app.get(user_id)
if not app_id:
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
continue
release = app_release_map[app_id]
release = app_to_release.get(str(app_id))
if not release:
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
continue
# 提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取 config_name
config_name = None
if memory_config_id:
try:
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
config_name = config_name_map.get(config_id)
except (ValueError, TypeError):
pass
memory_config_name = memory_configs.get(memory_config_id) if memory_config_id else None
result[user_id] = {
"end_user_id": user_id,
"memory_config_id": memory_config_id,
"memory_config_name": config_name
"memory_config_name": memory_config_name
}
logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}")
logger.info(f"Successfully retrieved {len(result)} connected configs")
return result