fix(db): fix database connection leak
This commit is contained in:
@@ -1,45 +1,42 @@
|
||||
# 修改 memory_konwledges_server.py 文件
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas import file_schema, document_schema
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from app.db import get_db_context
|
||||
from app.models.document_model import Document
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.db import get_db
|
||||
|
||||
# 创建一个简单的用户类用于测试
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
class ChunkCreate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class SimpleUser:
|
||||
def __init__(self, user_id: str):
|
||||
# 确保ID是UUID类型
|
||||
self.id = user_id
|
||||
self.username = user_id
|
||||
|
||||
'''解析'''
|
||||
|
||||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||||
"""
|
||||
解析指定文档
|
||||
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
|
||||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
'''获取块ID'''
|
||||
|
||||
async def get_document_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
@@ -198,7 +195,7 @@ async def get_document_chunks(
|
||||
|
||||
return success(data=result, msg="文档块列表查询成功")
|
||||
|
||||
'''查找文档ID'''
|
||||
|
||||
def find_document_id_by_kb_and_filename(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
'''获取知识库ID'''
|
||||
|
||||
def find_documents_by_kb_id(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
''''上传文件'''
|
||||
|
||||
async def memory_konwledges_up(
|
||||
kb_id: str,
|
||||
parent_id: str,
|
||||
create_data: file_schema.CustomTextFileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: SimpleUser = None, # 修改为SimpleUser
|
||||
db: Session,
|
||||
current_user: SimpleUser,
|
||||
):
|
||||
# 如果没有提供current_user,则创建一个默认的
|
||||
if current_user is None:
|
||||
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
|
||||
|
||||
content_bytes = create_data.content.encode('utf-8')
|
||||
file_size = len(content_bytes)
|
||||
print(f"file size: {file_size} byte")
|
||||
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
|
||||
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
|
||||
'''添加新块'''
|
||||
|
||||
|
||||
async def create_document_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
@@ -417,7 +408,7 @@ async def create_document_chunk(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"查询文档块失败: {error_msg}"
|
||||
)
|
||||
|
||||
|
||||
sort_id = sort_id + 1
|
||||
|
||||
# 5. 创建文档块
|
||||
@@ -450,6 +441,7 @@ async def create_document_chunk(
|
||||
|
||||
return success(data=chunk, msg="文档块创建成功")
|
||||
|
||||
|
||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
"""
|
||||
将消息写入 RAG 知识库
|
||||
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||||
)
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||
current_user = SimpleUser(user_rag_memory_id)
|
||||
# 检查文档是否已存在
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||
print('======',document)
|
||||
print('======', document)
|
||||
api_logger.info(f"查找文档结果: document_id={document}")
|
||||
if document is not None:
|
||||
# 文档已存在,直接添加新块
|
||||
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
else:
|
||||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||
return result
|
||||
finally:
|
||||
# 确保数据库会话被关闭
|
||||
db.close()
|
||||
Reference in New Issue
Block a user