[fix] file upload

This commit is contained in:
Mark
2026-04-29 13:41:14 +08:00
parent 6e89302cb2
commit 5fceba54b4
2 changed files with 43 additions and 12 deletions

View File

@@ -25,6 +25,7 @@ from app.models.user_model import User
from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
from app.services.model_service import ModelApiKeyService
# Obtain a dedicated API logger
@@ -353,11 +354,14 @@ async def import_qa_new_doc(
kb_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
导入 QA 问答对并新建文档CSV/Excel异步处理
"""
from app.schemas import file_schema, document_schema
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
@@ -370,13 +374,16 @@ async def import_qa_new_doc(
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
# 3. 创建 File 记录
from app.schemas import file_schema, document_schema
_, file_extension = os.path.splitext(filename)
file_ext = file_extension.lower()
# 3. 读取文件
contents = await file.read()
file_size = len(contents)
if file_size == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
_, file_extension = os.path.splitext(filename)
file_ext = file_extension.lower()
# 4. 创建 File 记录
file_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id,
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
@@ -384,19 +391,30 @@ async def import_qa_new_doc(
)
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
# 4. 创建 Document 记录
# 5. 上传文件到存储后端
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
try:
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
db_file.file_key = file_key
db.commit()
db.refresh(db_file)
# 6. 创建 Document 记录(标记为 QA 类型)
doc_data = document_schema.DocumentCreate(
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
file_name=filename, file_ext=file_ext, file_size=file_size,
file_meta={}, parser_id="naive",
parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128,
"delimiter": "\n", "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
file_meta={}, parser_id="qa",
parser_config={"doc_type": "qa", "auto_questions": 0}
)
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}")
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
# 5. 派发异步任务
# 7. 派发异步任务
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",

View File

@@ -801,8 +801,10 @@ def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: byte
batch = chunks[i:i + batch_size]
vector_service.add_chunks(batch)
# 3. 更新 chunk_num
# 3. 更新 chunk_num 和 progress
db_document.chunk_num += len(chunks)
db_document.progress = 1.0
db_document.progress_msg = f"QA 导入完成: {len(chunks)}"
db.commit()
result = {"imported": len(chunks), "failed_rows": failed_rows}
@@ -811,6 +813,17 @@ def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: byte
except Exception as e:
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
# 尝试更新文档状态为失败
try:
from app.db import get_db_context
with get_db_context() as err_db:
doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
if doc:
doc.progress = -1.0
doc.progress_msg = f"QA 导入失败: {str(e)[:200]}"
err_db.commit()
except Exception:
pass
return {"error": str(e), "imported": 0}