Merge branch 'feature/rag2' into develop

* feature/rag2:
  [add] batch add chunk for v1
  [fix] index_not_found_exception
  [fix] delete chunk refresh index
  [fix] es vector
  [fix] file upload
  no message
  [add] import qa chunks
  [add] task log
  [fix] qa cache
  [add] batch chunk.  qa_prompt set
  [modify] rag qa chunk
This commit is contained in:
Mark
2026-05-07 18:47:42 +08:00
11 changed files with 700 additions and 134 deletions

View File

@@ -1,8 +1,10 @@
import os
import csv
import io
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
@@ -23,6 +25,7 @@ from app.models.user_model import User
from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
from app.services.model_service import ModelApiKeyService
# Obtain a dedicated API logger
@@ -271,6 +274,9 @@ async def create_chunk(
"sort_id": sort_id,
"status": 1,
}
# QA chunk: 注入 chunk_type/question/answer 到 metadata
if create_data.is_qa:
metadata.update(create_data.qa_metadata)
chunk = DocumentChunk(page_content=content, metadata=metadata)
# 3. Segmented vector storage
vector_service.add_chunks([chunk])
@@ -282,6 +288,187 @@ async def create_chunk(
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
async def create_chunks_batch(
kb_id: uuid.UUID,
document_id: uuid.UUID,
batch_data: chunk_schema.ChunkBatchCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Batch create chunks (max 8)
"""
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Get current max sort_id
sort_id = 0
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
chunks = []
for create_data in batch_data.items:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(document_id),
"knowledge_id": str(kb_id),
"sort_id": sort_id,
"status": 1,
}
if create_data.is_qa:
metadata.update(create_data.qa_metadata)
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
vector_service.add_chunks(chunks)
db_document.chunk_num += len(chunks)
db.commit()
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
async def import_qa_new_doc(
kb_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
导入 QA 问答对并新建文档CSV/Excel异步处理
"""
from app.schemas import file_schema, document_schema
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
filename = file.filename or ""
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
# 2. 校验知识库
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
# 3. 读取文件
contents = await file.read()
file_size = len(contents)
if file_size == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
_, file_extension = os.path.splitext(filename)
file_ext = file_extension.lower()
# 4. 创建 File 记录
file_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id,
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
file_name=filename, file_ext=file_ext, file_size=file_size,
)
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
# 5. 上传文件到存储后端
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
try:
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
db_file.file_key = file_key
db.commit()
db.refresh(db_file)
# 6. 创建 Document 记录(标记为 QA 类型)
doc_data = document_schema.DocumentCreate(
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
file_name=filename, file_ext=file_ext, file_size=file_size,
file_meta={}, parser_id="qa",
parser_config={"doc_type": "qa", "auto_questions": 0}
)
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
# 7. 派发异步任务
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",
args=[str(kb_id), str(db_document.id), filename, contents],
queue="qa_import"
)
return success(data={
"task_id": task.id,
"document_id": str(db_document.id),
"file_id": str(db_file.id),
}, msg="QA 导入任务已提交,后台处理中")
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
async def import_qa_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
导入 QA 问答对CSV/Excel异步处理
"""
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
filename = file.filename or ""
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
# 2. 校验知识库和文档
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
# 3. 读取文件内容,派发异步任务
contents = await file.read()
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",
args=[str(kb_id), str(document_id), filename, contents],
queue="qa_import"
)
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def get_chunk(
kb_id: uuid.UUID,
@@ -342,6 +529,9 @@ async def update_chunk(
if total:
chunk = items[0]
chunk.page_content = content
# QA chunk: 更新 metadata 中的 question/answer
if update_data.is_qa:
chunk.metadata.update(update_data.qa_metadata)
vector_service.update_by_segment(chunk)
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
else:
@@ -356,6 +546,7 @@ async def delete_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -373,7 +564,7 @@ async def delete_chunk(
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
if vector_service.text_exists(doc_id):
vector_service.delete_by_ids([doc_id])
vector_service.delete_by_ids([doc_id], refresh=force_refresh)
# 更新 chunk_num
db_document = db.query(Document).filter(Document.id == document_id).first()
db_document.chunk_num -= 1

View File

@@ -113,6 +113,33 @@ async def create_chunk(
current_user=current_user)
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def create_chunks_batch(
kb_id: uuid.UUID,
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
items: list = Body(..., description="chunk items list"),
):
"""
Batch create chunks (max 8)
"""
body = await request.json()
batch_data = chunk_schema.ChunkBatchCreate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.create_chunks_batch(kb_id=kb_id,
document_id=document_id,
batch_data=batch_data,
db=db,
current_user=current_user)
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_chunk(
@@ -176,6 +203,7 @@ async def delete_chunk(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
):
"""
delete document chunk
@@ -188,6 +216,7 @@ async def delete_chunk(
return await chunk_controller.delete_chunk(kb_id=kb_id,
document_id=document_id,
doc_id=doc_id,
force_refresh=force_refresh,
db=db,
current_user=current_user)