fix(db): fix database connection leak

This commit is contained in:
Eternity
2026-03-06 10:12:21 +08:00
parent f90e102854
commit aaa0410781
12 changed files with 505 additions and 566 deletions

View File

@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
"""
import json
import os
import re
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
@@ -69,7 +66,8 @@ class MemoryAgentService:
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
@@ -88,8 +86,6 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event"""
last_message = event["messages"][-1]
@@ -271,7 +267,8 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
"""
Process write operation with config_id
@@ -300,7 +297,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -331,7 +329,8 @@ class MemoryAgentService:
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -351,9 +350,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-')
print(100 * '-')
print(langchain_messages)
print(100*'-')
print(100 * '-')
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
@@ -375,29 +374,28 @@ class MemoryAgentService:
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
self,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[uuid.UUID]|int,
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
self,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
"""
Process read operation with config_id
@@ -425,7 +423,7 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message= message
ori_message = message
# Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided
@@ -437,7 +435,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -454,7 +453,6 @@ class MemoryAgentService:
except ImportError:
audit_logger = None
config_load_start = time.time()
try:
# Use a separate database session to avoid transaction failures
@@ -562,34 +560,35 @@ class MemoryAgentService:
from app.repositories.memory_short_repository import (
ShortTermMemoryRepository,
)
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
try:
reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
@@ -602,15 +601,17 @@ class MemoryAgentService:
)
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -641,7 +642,6 @@ class MemoryAgentService:
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
@@ -657,41 +657,43 @@ class MemoryAgentService:
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
) -> Dict:
"""
Determine the type of user message (read or write)
@@ -719,14 +721,15 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
async def generate_summary_from_retrieve(
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
) -> str:
"""
基于检索信息、历史对话和查询生成最终答案
@@ -761,9 +764,9 @@ class MemoryAgentService:
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
# 加载配置
config_service = MemoryConfigService(db)
@@ -772,7 +775,7 @@ class MemoryAgentService:
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
summary_llm,
@@ -780,13 +783,13 @@ class MemoryAgentService:
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
)
# 构建状态对象
state = {
"data": query,
"memory_config": memory_config
}
# 直接调用 summary_llm 函数
answer = await summary_llm(
state=state,
@@ -797,21 +800,20 @@ class MemoryAgentService:
response_model=RetrieveSummaryResponse,
search_mode="1"
)
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。"
except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats(
self,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None,
db: Session = None
self,
db: Session,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
统计知识库类型分布,包含:
@@ -837,11 +839,6 @@ class MemoryAgentService:
# 1. 统计 PostgreSQL 中的知识库类型
try:
if db is None:
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
# 初始化所有标准类型为 0
for kb_type in KnowledgeType:
result[kb_type.value] = 0
@@ -881,21 +878,19 @@ class MemoryAgentService:
# 3. 计算知识库类型总和(不包括 memory
result["total"] = (
result.get("General", 0) +
result.get("Web", 0) +
result.get("Third-party", 0) +
result.get("Folder", 0)
result.get("General", 0) +
result.get("Web", 0) +
result.get("Third-party", 0) +
result.get("Folder", 0)
)
return result
async def get_interest_distribution_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 5,
language: str = "zh"
self,
end_user_id: Optional[str] = None,
limit: int = 5,
language: str = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的兴趣分布标签。
@@ -921,13 +916,12 @@ class MemoryAgentService:
logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
# 定义标签提取的结构
class UserTags(BaseModel):
tags: list[str] = Field(..., description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [
{
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时
"""
import json as json_module
import uuid
from sqlalchemy import select
@@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
memory_config_id_to_use = end_user.memory_config_id
# 如果已有 memory_config_id直接使用
# 如果新创建enduserenduser.memory_config_id 必定为none
# 那么使用从release中获取memory_config_id为预期行为并且回填到
# end_user.memory_config_id
if not memory_config_id_to_use:
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
# 获取最新发布版本
stmt = (
select(AppRelease)
@@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
)
# TODO: change to current_release_id
latest_release = db.scalars(stmt).first()
if latest_release:
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
@@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {}
# 使用 MemoryConfigService 的提取方法
memory_config_service = MemoryConfigService(db)
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 验证提取的 config_id 是否存在于数据库中
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
existing_config = db.get(MemoryConfigModel, legacy_config_id)
if existing_config:
memory_config_id_to_use = legacy_config_id
# 回填到 end_user 表lazy update
end_user.memory_config_id = memory_config_id_to_use
db.commit()
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"workspace_id": str(app.workspace_id)
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result
@@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建映射 - 保留 EndUser 对象引用以便回填
end_user_map = {str(eu.id): eu for eu in end_users}
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
@@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
users_needing_migration = [
(end_user_id, data["app_id"])
for end_user_id, data in user_data.items()
(end_user_id, data["app_id"])
for end_user_id, data in user_data.items()
if not data["memory_config_id"]
]
if users_needing_migration:
# 批量获取相关应用的最新发布版本
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
# 查询每个应用的最新活跃发布版本
app_latest_releases = {}
for app_id in migration_app_ids:
@@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
latest_release = db.scalars(stmt).first()
if latest_release:
app_latest_releases[app_id] = latest_release
# 为每个需要迁移的用户提取 memory_config_id
config_service = MemoryConfigService(db)
users_to_backfill = [] # [(end_user, memory_config_id), ...]
for end_user_id, app_id in users_needing_migration:
latest_release = app_latest_releases.get(app_id)
if not latest_release:
continue
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
@@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
continue
# 使用 MemoryConfigService 的提取方法
app = app_map.get(app_id)
if not app:
continue
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 更新 user_data 中的 memory_config_id
user_data[end_user_id]["memory_config_id"] = legacy_config_id
# 记录需要回填的用户(稍后验证配置存在后再回填)
end_user = end_user_map.get(end_user_id)
if end_user:
@@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
logger.info(
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
)
# 验证提取的 config_id 是否存在于数据库中
if users_to_backfill:
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
@@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
MemoryConfig.config_id.in_(config_ids_to_validate)
).all()
valid_config_ids = {mc.config_id for mc in existing_configs}
# 只回填存在的配置
valid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
(eu, cid) for eu, cid in users_to_backfill
if cid in valid_config_ids
]
invalid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
(eu, cid) for eu, cid in users_to_backfill
if cid not in valid_config_ids
]
if invalid_backfills:
invalid_ids = [str(cid) for _, cid in invalid_backfills]
logger.warning(
@@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 清除 user_data 中无效的 config_id
for eu, cid in invalid_backfills:
user_data[str(eu.id)]["memory_config_id"] = None
# 批量回填 end_user.memory_config_id
if valid_backfills:
for end_user, memory_config_id in valid_backfills:
@@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
direct_config_ids = []
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
for end_user_id, data in user_data.items():
if data["memory_config_id"]:
direct_config_ids.append(data["memory_config_id"])
@@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
workspace_default_configs = {}
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
if unique_workspace_ids:
config_service = MemoryConfigService(db)
for workspace_id in unique_workspace_ids:
@@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 7. 构建最终结果
for end_user_id, data in user_data.items():
memory_config = None
# 优先使用 end_user 直接分配的配置
if data["memory_config_id"]:
memory_config = config_id_to_config.get(data["memory_config_id"])
# 回退到工作空间默认配置
if not memory_config:
workspace_id = app_to_workspace.get(data["app_id"])
@@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
logger.info(f"Successfully retrieved {len(result)} connected configs")
return result
return result