fix(db): fix database connection leak
This commit is contained in:
@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
@@ -69,7 +66,8 @@ class MemoryAgentService:
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return context
|
||||
else:
|
||||
@@ -88,8 +86,6 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(f"写入失败: {messages}")
|
||||
|
||||
|
||||
|
||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||
"""Extract tool call information from event"""
|
||||
last_message = event["messages"][-1]
|
||||
@@ -271,7 +267,8 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
||||
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
@@ -300,7 +297,8 @@ class MemoryAgentService:
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||
if config_id is None and workspace_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
@@ -331,7 +329,8 @@ class MemoryAgentService:
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -351,9 +350,9 @@ class MemoryAgentService:
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
print(langchain_messages)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
@@ -375,29 +374,28 @@ class MemoryAgentService:
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||
contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: Optional[uuid.UUID]|int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str) -> Dict:
|
||||
self,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: Optional[uuid.UUID] | int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str) -> Dict:
|
||||
"""
|
||||
Process read operation with config_id
|
||||
|
||||
@@ -425,7 +423,7 @@ class MemoryAgentService:
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
ori_message= message
|
||||
ori_message = message
|
||||
|
||||
# Resolve config_id and workspace_id
|
||||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||||
@@ -437,7 +435,8 @@ class MemoryAgentService:
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||
if config_id is None and workspace_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
@@ -454,7 +453,6 @@ class MemoryAgentService:
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
# Use a separate database session to avoid transaction failures
|
||||
@@ -562,34 +560,35 @@ class MemoryAgentService:
|
||||
from app.repositories.memory_short_repository import (
|
||||
ShortTermMemoryRepository,
|
||||
)
|
||||
|
||||
|
||||
retrieved_content = []
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
|
||||
|
||||
if str(search_switch) != "2":
|
||||
for intermediate in _intermediate_outputs:
|
||||
logger.debug(f"处理中间结果: {intermediate}")
|
||||
intermediate_type = intermediate.get('type', '')
|
||||
|
||||
|
||||
if intermediate_type == "search_result":
|
||||
query = intermediate.get('query', '')
|
||||
raw_results = intermediate.get('raw_results', {})
|
||||
try:
|
||||
reranked_results = raw_results.get('reranked_results', [])
|
||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
statements = [statement['statement'] for statement in
|
||||
reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements = []
|
||||
|
||||
|
||||
# 去重
|
||||
statements = list(set(statements))
|
||||
|
||||
|
||||
if query and statements:
|
||||
retrieved_content.append({query: statements})
|
||||
|
||||
|
||||
# 如果 retrieved_content 为空,设置为空字符串
|
||||
if retrieved_content == []:
|
||||
retrieved_content = ''
|
||||
|
||||
|
||||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||||
# 使用 upsert 方法
|
||||
@@ -602,15 +601,17 @@ class MemoryAgentService:
|
||||
)
|
||||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||
else:
|
||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
logger.debug(
|
||||
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
except Exception as save_error:
|
||||
# 保存失败不应该影响主流程,只记录错误
|
||||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||
|
||||
# Log successful operation
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
logger.info(
|
||||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -641,7 +642,6 @@ class MemoryAgentService:
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
"""
|
||||
Get standardized message list from user input.
|
||||
@@ -657,41 +657,43 @@ class MemoryAgentService:
|
||||
"""
|
||||
from app.core.logging_config import get_api_logger
|
||||
logger = get_api_logger()
|
||||
|
||||
|
||||
if len(user_input.messages) == 0:
|
||||
logger.error("Validation failed: Message list cannot be empty")
|
||||
raise ValueError("Message list cannot be empty")
|
||||
|
||||
|
||||
for idx, msg in enumerate(user_input.messages):
|
||||
if not isinstance(msg, dict):
|
||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
|
||||
raise ValueError(
|
||||
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
|
||||
if 'role' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||||
|
||||
|
||||
if 'content' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
|
||||
raise ValueError(
|
||||
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
|
||||
if msg['role'] not in ['user', 'assistant']:
|
||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||||
|
||||
|
||||
if not msg['content'] or not msg['content'].strip():
|
||||
logger.error(f"Validation failed: Message {idx} content is empty")
|
||||
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||||
|
||||
|
||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||
return user_input.messages
|
||||
|
||||
async def classify_message_type(
|
||||
self,
|
||||
message: str,
|
||||
config_id: UUID,
|
||||
db: Session,
|
||||
workspace_id: Optional[UUID] = None
|
||||
self,
|
||||
message: str,
|
||||
config_id: UUID,
|
||||
db: Session,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
@@ -719,14 +721,15 @@ class MemoryAgentService:
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
async def generate_summary_from_retrieve(
|
||||
self,
|
||||
end_user_id: str,
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
config_id: str,
|
||||
db: Session
|
||||
self,
|
||||
end_user_id: str,
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
config_id: str,
|
||||
db: Session
|
||||
) -> str:
|
||||
"""
|
||||
基于检索信息、历史对话和查询生成最终答案
|
||||
@@ -761,9 +764,9 @@ class MemoryAgentService:
|
||||
if config_id is None:
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||
# If config_id was provided, continue without workspace_id fallback
|
||||
|
||||
|
||||
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||||
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
config_service = MemoryConfigService(db)
|
||||
@@ -772,7 +775,7 @@ class MemoryAgentService:
|
||||
workspace_id=workspace_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
|
||||
# 导入必要的模块
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
summary_llm,
|
||||
@@ -780,13 +783,13 @@ class MemoryAgentService:
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
)
|
||||
|
||||
|
||||
# 构建状态对象
|
||||
state = {
|
||||
"data": query,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
|
||||
# 直接调用 summary_llm 函数
|
||||
answer = await summary_llm(
|
||||
state=state,
|
||||
@@ -797,21 +800,20 @@ class MemoryAgentService:
|
||||
response_model=RetrieveSummaryResponse,
|
||||
search_mode="1"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
||||
return answer if answer else "信息不足,无法回答。"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||
return "信息不足,无法回答。"
|
||||
|
||||
|
||||
async def get_knowledge_type_stats(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
only_active: bool = True,
|
||||
current_workspace_id: Optional[uuid.UUID] = None,
|
||||
db: Session = None
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: Optional[str] = None,
|
||||
only_active: bool = True,
|
||||
current_workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统计知识库类型分布,包含:
|
||||
@@ -837,11 +839,6 @@ class MemoryAgentService:
|
||||
|
||||
# 1. 统计 PostgreSQL 中的知识库类型
|
||||
try:
|
||||
if db is None:
|
||||
from app.db import get_db
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 初始化所有标准类型为 0
|
||||
for kb_type in KnowledgeType:
|
||||
result[kb_type.value] = 0
|
||||
@@ -881,21 +878,19 @@ class MemoryAgentService:
|
||||
|
||||
# 3. 计算知识库类型总和(不包括 memory)
|
||||
result["total"] = (
|
||||
result.get("General", 0) +
|
||||
result.get("Web", 0) +
|
||||
result.get("Third-party", 0) +
|
||||
result.get("Folder", 0)
|
||||
result.get("General", 0) +
|
||||
result.get("Web", 0) +
|
||||
result.get("Third-party", 0) +
|
||||
result.get("Folder", 0)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
async def get_interest_distribution_by_user(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
language: str = "zh"
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
language: str = "zh"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的兴趣分布标签。
|
||||
@@ -921,13 +916,12 @@ class MemoryAgentService:
|
||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||
tags: list[str] = Field(...,
|
||||
description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
import json as json_module
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
|
||||
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
|
||||
memory_config_id_to_use = end_user.memory_config_id
|
||||
|
||||
|
||||
# 如果已有 memory_config_id,直接使用
|
||||
# 如果新创建enduser,enduser.memory_config_id 必定为none
|
||||
# 那么使用从release中获取memory_config_id为预期行为,并且回填到
|
||||
# end_user.memory_config_id
|
||||
if not memory_config_id_to_use:
|
||||
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
|
||||
|
||||
|
||||
# 获取最新发布版本
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
@@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
)
|
||||
# TODO: change to current_release_id
|
||||
latest_release = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if latest_release:
|
||||
config = latest_release.config or {}
|
||||
|
||||
|
||||
# 如果 config 是字符串,解析为字典
|
||||
if isinstance(config, str):
|
||||
try:
|
||||
@@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
except json_module.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||
config = {}
|
||||
|
||||
|
||||
# 使用 MemoryConfigService 的提取方法
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
|
||||
app_type=app.type,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
if legacy_config_id:
|
||||
# 验证提取的 config_id 是否存在于数据库中
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
existing_config = db.get(MemoryConfigModel, legacy_config_id)
|
||||
|
||||
|
||||
if existing_config:
|
||||
memory_config_id_to_use = legacy_config_id
|
||||
|
||||
|
||||
# 回填到 end_user 表(lazy update)
|
||||
end_user.memory_config_id = memory_config_id_to_use
|
||||
db.commit()
|
||||
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
"workspace_id": str(app.workspace_id)
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||
logger.info(
|
||||
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
|
||||
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
|
||||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||||
|
||||
|
||||
# 创建映射 - 保留 EndUser 对象引用以便回填
|
||||
end_user_map = {str(eu.id): eu for eu in end_users}
|
||||
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
|
||||
@@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
|
||||
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
|
||||
users_needing_migration = [
|
||||
(end_user_id, data["app_id"])
|
||||
for end_user_id, data in user_data.items()
|
||||
(end_user_id, data["app_id"])
|
||||
for end_user_id, data in user_data.items()
|
||||
if not data["memory_config_id"]
|
||||
]
|
||||
|
||||
|
||||
if users_needing_migration:
|
||||
# 批量获取相关应用的最新发布版本
|
||||
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
|
||||
|
||||
|
||||
# 查询每个应用的最新活跃发布版本
|
||||
app_latest_releases = {}
|
||||
for app_id in migration_app_ids:
|
||||
@@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
latest_release = db.scalars(stmt).first()
|
||||
if latest_release:
|
||||
app_latest_releases[app_id] = latest_release
|
||||
|
||||
|
||||
# 为每个需要迁移的用户提取 memory_config_id
|
||||
config_service = MemoryConfigService(db)
|
||||
users_to_backfill = [] # [(end_user, memory_config_id), ...]
|
||||
|
||||
|
||||
for end_user_id, app_id in users_needing_migration:
|
||||
latest_release = app_latest_releases.get(app_id)
|
||||
if not latest_release:
|
||||
continue
|
||||
|
||||
|
||||
config = latest_release.config or {}
|
||||
|
||||
|
||||
# 如果 config 是字符串,解析为字典
|
||||
if isinstance(config, str):
|
||||
try:
|
||||
@@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
except json_module.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||
continue
|
||||
|
||||
|
||||
# 使用 MemoryConfigService 的提取方法
|
||||
app = app_map.get(app_id)
|
||||
if not app:
|
||||
continue
|
||||
|
||||
|
||||
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
|
||||
app_type=app.type,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
if legacy_config_id:
|
||||
# 更新 user_data 中的 memory_config_id
|
||||
user_data[end_user_id]["memory_config_id"] = legacy_config_id
|
||||
|
||||
|
||||
# 记录需要回填的用户(稍后验证配置存在后再回填)
|
||||
end_user = end_user_map.get(end_user_id)
|
||||
if end_user:
|
||||
@@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
logger.info(
|
||||
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
|
||||
)
|
||||
|
||||
|
||||
# 验证提取的 config_id 是否存在于数据库中
|
||||
if users_to_backfill:
|
||||
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
|
||||
@@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
MemoryConfig.config_id.in_(config_ids_to_validate)
|
||||
).all()
|
||||
valid_config_ids = {mc.config_id for mc in existing_configs}
|
||||
|
||||
|
||||
# 只回填存在的配置
|
||||
valid_backfills = [
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
if cid in valid_config_ids
|
||||
]
|
||||
invalid_backfills = [
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
if cid not in valid_config_ids
|
||||
]
|
||||
|
||||
|
||||
if invalid_backfills:
|
||||
invalid_ids = [str(cid) for _, cid in invalid_backfills]
|
||||
logger.warning(
|
||||
@@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 清除 user_data 中无效的 config_id
|
||||
for eu, cid in invalid_backfills:
|
||||
user_data[str(eu.id)]["memory_config_id"] = None
|
||||
|
||||
|
||||
# 批量回填 end_user.memory_config_id
|
||||
if valid_backfills:
|
||||
for end_user, memory_config_id in valid_backfills:
|
||||
@@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
|
||||
direct_config_ids = []
|
||||
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
|
||||
|
||||
|
||||
for end_user_id, data in user_data.items():
|
||||
if data["memory_config_id"]:
|
||||
direct_config_ids.append(data["memory_config_id"])
|
||||
@@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
|
||||
workspace_default_configs = {}
|
||||
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
|
||||
|
||||
|
||||
if unique_workspace_ids:
|
||||
config_service = MemoryConfigService(db)
|
||||
for workspace_id in unique_workspace_ids:
|
||||
@@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 7. 构建最终结果
|
||||
for end_user_id, data in user_data.items():
|
||||
memory_config = None
|
||||
|
||||
|
||||
# 优先使用 end_user 直接分配的配置
|
||||
if data["memory_config_id"]:
|
||||
memory_config = config_id_to_config.get(data["memory_config_id"])
|
||||
|
||||
|
||||
# 回退到工作空间默认配置
|
||||
if not memory_config:
|
||||
workspace_id = app_to_workspace.get(data["app_id"])
|
||||
@@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||||
|
||||
logger.info(f"Successfully retrieved {len(result)} connected configs")
|
||||
return result
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user