fix(db): fix database connection leak

This commit is contained in:
Eternity
2026-03-06 10:12:21 +08:00
parent f90e102854
commit aaa0410781
12 changed files with 505 additions and 566 deletions

View File

@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
from app.db import get_db
db = next(get_db())
try:
with get_db_context() as db:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
end_user_id=end_user_id,
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})

View File

@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
"""
import json
import os
import re
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
@@ -69,7 +66,8 @@ class MemoryAgentService:
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
@@ -88,8 +86,6 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event"""
last_message = event["messages"][-1]
@@ -271,7 +267,8 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
"""
Process write operation with config_id
@@ -300,7 +297,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -331,7 +329,8 @@ class MemoryAgentService:
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -351,9 +350,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-')
print(100 * '-')
print(langchain_messages)
print(100*'-')
print(100 * '-')
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
@@ -375,29 +374,28 @@ class MemoryAgentService:
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
self,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[uuid.UUID]|int,
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
self,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
"""
Process read operation with config_id
@@ -425,7 +423,7 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message= message
ori_message = message
# Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided
@@ -437,7 +435,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -454,7 +453,6 @@ class MemoryAgentService:
except ImportError:
audit_logger = None
config_load_start = time.time()
try:
# Use a separate database session to avoid transaction failures
@@ -562,34 +560,35 @@ class MemoryAgentService:
from app.repositories.memory_short_repository import (
ShortTermMemoryRepository,
)
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
try:
reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
@@ -602,15 +601,17 @@ class MemoryAgentService:
)
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -641,7 +642,6 @@ class MemoryAgentService:
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
@@ -657,41 +657,43 @@ class MemoryAgentService:
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
) -> Dict:
"""
Determine the type of user message (read or write)
@@ -719,14 +721,15 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
async def generate_summary_from_retrieve(
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
) -> str:
"""
基于检索信息、历史对话和查询生成最终答案
@@ -761,9 +764,9 @@ class MemoryAgentService:
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
# 加载配置
config_service = MemoryConfigService(db)
@@ -772,7 +775,7 @@ class MemoryAgentService:
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
summary_llm,
@@ -780,13 +783,13 @@ class MemoryAgentService:
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
)
# 构建状态对象
state = {
"data": query,
"memory_config": memory_config
}
# 直接调用 summary_llm 函数
answer = await summary_llm(
state=state,
@@ -797,21 +800,20 @@ class MemoryAgentService:
response_model=RetrieveSummaryResponse,
search_mode="1"
)
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。"
except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats(
self,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None,
db: Session = None
self,
db: Session,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
统计知识库类型分布,包含:
@@ -837,11 +839,6 @@ class MemoryAgentService:
# 1. 统计 PostgreSQL 中的知识库类型
try:
if db is None:
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
# 初始化所有标准类型为 0
for kb_type in KnowledgeType:
result[kb_type.value] = 0
@@ -881,21 +878,19 @@ class MemoryAgentService:
# 3. 计算知识库类型总和(不包括 memory
result["total"] = (
result.get("General", 0) +
result.get("Web", 0) +
result.get("Third-party", 0) +
result.get("Folder", 0)
result.get("General", 0) +
result.get("Web", 0) +
result.get("Third-party", 0) +
result.get("Folder", 0)
)
return result
async def get_interest_distribution_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 5,
language: str = "zh"
self,
end_user_id: Optional[str] = None,
limit: int = 5,
language: str = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的兴趣分布标签。
@@ -921,13 +916,12 @@ class MemoryAgentService:
logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
# 定义标签提取的结构
class UserTags(BaseModel):
tags: list[str] = Field(..., description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [
{
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时
"""
import json as json_module
import uuid
from sqlalchemy import select
@@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
memory_config_id_to_use = end_user.memory_config_id
# 如果已有 memory_config_id直接使用
# 如果新创建enduserenduser.memory_config_id 必定为none
# 那么使用从release中获取memory_config_id为预期行为并且回填到
# end_user.memory_config_id
if not memory_config_id_to_use:
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
# 获取最新发布版本
stmt = (
select(AppRelease)
@@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
)
# TODO: change to current_release_id
latest_release = db.scalars(stmt).first()
if latest_release:
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
@@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {}
# 使用 MemoryConfigService 的提取方法
memory_config_service = MemoryConfigService(db)
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 验证提取的 config_id 是否存在于数据库中
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
existing_config = db.get(MemoryConfigModel, legacy_config_id)
if existing_config:
memory_config_id_to_use = legacy_config_id
# 回填到 end_user 表lazy update
end_user.memory_config_id = memory_config_id_to_use
db.commit()
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"workspace_id": str(app.workspace_id)
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result
@@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建映射 - 保留 EndUser 对象引用以便回填
end_user_map = {str(eu.id): eu for eu in end_users}
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
@@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
users_needing_migration = [
(end_user_id, data["app_id"])
for end_user_id, data in user_data.items()
(end_user_id, data["app_id"])
for end_user_id, data in user_data.items()
if not data["memory_config_id"]
]
if users_needing_migration:
# 批量获取相关应用的最新发布版本
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
# 查询每个应用的最新活跃发布版本
app_latest_releases = {}
for app_id in migration_app_ids:
@@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
latest_release = db.scalars(stmt).first()
if latest_release:
app_latest_releases[app_id] = latest_release
# 为每个需要迁移的用户提取 memory_config_id
config_service = MemoryConfigService(db)
users_to_backfill = [] # [(end_user, memory_config_id), ...]
for end_user_id, app_id in users_needing_migration:
latest_release = app_latest_releases.get(app_id)
if not latest_release:
continue
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
@@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
continue
# 使用 MemoryConfigService 的提取方法
app = app_map.get(app_id)
if not app:
continue
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 更新 user_data 中的 memory_config_id
user_data[end_user_id]["memory_config_id"] = legacy_config_id
# 记录需要回填的用户(稍后验证配置存在后再回填)
end_user = end_user_map.get(end_user_id)
if end_user:
@@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
logger.info(
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
)
# 验证提取的 config_id 是否存在于数据库中
if users_to_backfill:
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
@@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
MemoryConfig.config_id.in_(config_ids_to_validate)
).all()
valid_config_ids = {mc.config_id for mc in existing_configs}
# 只回填存在的配置
valid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
(eu, cid) for eu, cid in users_to_backfill
if cid in valid_config_ids
]
invalid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
(eu, cid) for eu, cid in users_to_backfill
if cid not in valid_config_ids
]
if invalid_backfills:
invalid_ids = [str(cid) for _, cid in invalid_backfills]
logger.warning(
@@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 清除 user_data 中无效的 config_id
for eu, cid in invalid_backfills:
user_data[str(eu.id)]["memory_config_id"] = None
# 批量回填 end_user.memory_config_id
if valid_backfills:
for end_user, memory_config_id in valid_backfills:
@@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
direct_config_ids = []
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
for end_user_id, data in user_data.items():
if data["memory_config_id"]:
direct_config_ids.append(data["memory_config_id"])
@@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
workspace_default_configs = {}
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
if unique_workspace_ids:
config_service = MemoryConfigService(db)
for workspace_id in unique_workspace_ids:
@@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 7. 构建最终结果
for end_user_id, data in user_data.items():
memory_config = None
# 优先使用 end_user 直接分配的配置
if data["memory_config_id"]:
memory_config = config_id_to_config.get(data["memory_config_id"])
# 回退到工作空间默认配置
if not memory_config:
workspace_id = app_to_workspace.get(data["app_id"])
@@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
logger.info(f"Successfully retrieved {len(result)} connected configs")
return result
return result

View File

@@ -1,45 +1,42 @@
# 修改 memory_konwledges_server.py 文件
import asyncio
import os
import re
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
from fastapi import HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.schemas import file_schema, document_schema
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from app.db import get_db_context
from app.models.document_model import Document
import uuid
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.core.config import settings
from app.models.user_model import User
from app.schemas import file_schema, document_schema
from app.schemas.file_schema import CustomTextFileCreate
from app.services import document_service, file_service, knowledge_service
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.schemas.file_schema import CustomTextFileCreate
from app.db import get_db
# 创建一个简单的用户类用于测试
api_logger = get_api_logger()
class ChunkCreate(BaseModel):
content: str
class SimpleUser:
def __init__(self, user_id: str):
# 确保ID是UUID类型
self.id = user_id
self.username = user_id
'''解析'''
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
"""
解析指定文档
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
raise
'''获取块ID'''
async def get_document_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
@@ -198,7 +195,7 @@ async def get_document_chunks(
return success(data=result, msg="文档块列表查询成功")
'''查找文档ID'''
def find_document_id_by_kb_and_filename(
db: Session,
kb_id: str,
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
except Exception as e:
return None
'''获取知识库ID'''
def find_documents_by_kb_id(
db: Session,
kb_id: str,
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
except Exception as e:
return []
''''上传文件'''
async def memory_konwledges_up(
kb_id: str,
parent_id: str,
create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db),
current_user: SimpleUser = None, # 修改为SimpleUser
db: Session,
current_user: SimpleUser,
):
# 如果没有提供current_user则创建一个默认的
if current_user is None:
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes)
print(f"file size: {file_size} byte")
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
'''添加新块'''
async def create_document_chunk(
kb_id: uuid.UUID,
@@ -417,7 +408,7 @@ async def create_document_chunk(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询文档块失败: {error_msg}"
)
sort_id = sort_id + 1
# 5. 创建文档块
@@ -450,6 +441,7 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功")
async def write_rag(end_user_id, message, user_rag_memory_id):
"""
将消息写入 RAG 知识库
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
detail=f"知识库ID格式无效: {user_rag_memory_id}"
)
db_gen = get_db()
db = next(db_gen)
try:
with get_db_context() as db:
create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======',document)
print('======', document)
api_logger.info(f"查找文档结果: document_id={document}")
if document is not None:
# 文档已存在,直接添加新块
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else:
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result
finally:
# 确保数据库会话被关闭
db.close()

View File

@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
if end_user_id:
try:
# 获取数据库会话并查询用户信息
db = next(get_db())
try:
with get_db_context() as db:
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name:
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")