[add] import qa chunks
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import csv
|
||||
import io
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -346,6 +348,46 @@ async def create_chunks_batch(
|
||||
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
|
||||
async def import_qa_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
导入 QA 问答对(CSV/Excel),异步处理
|
||||
"""
|
||||
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
|
||||
|
||||
# 1. 校验文件格式
|
||||
filename = file.filename or ""
|
||||
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
||||
|
||||
# 2. 校验知识库和文档
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
||||
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if not db_document:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
|
||||
|
||||
# 3. 读取文件内容,派发异步任务
|
||||
contents = await file.read()
|
||||
|
||||
from app.celery_app import celery_app
|
||||
task = celery_app.send_task(
|
||||
"app.core.rag.tasks.import_qa_chunks",
|
||||
args=[str(kb_id), str(document_id), filename, contents],
|
||||
queue="qa_import"
|
||||
)
|
||||
|
||||
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
async def get_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
|
||||
117
api/app/tasks.py
117
api/app/tasks.py
@@ -697,6 +697,123 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
|
||||
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import")
|
||||
def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes):
|
||||
"""
|
||||
异步导入 QA 问答对(CSV/Excel)
|
||||
|
||||
文件格式:第一行标题(跳过),第一列问题,第二列答案
|
||||
"""
|
||||
import csv as csv_module
|
||||
import io
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
|
||||
if not db_document or not db_knowledge:
|
||||
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
|
||||
return {"error": "document or knowledge not found", "imported": 0}
|
||||
|
||||
# 1. 解析文件
|
||||
qa_pairs = []
|
||||
failed_rows = []
|
||||
|
||||
if filename.endswith(".csv"):
|
||||
try:
|
||||
text = contents.decode("utf-8-sig")
|
||||
except UnicodeDecodeError:
|
||||
text = contents.decode("gbk", errors="ignore")
|
||||
|
||||
sniffer = csv_module.Sniffer()
|
||||
try:
|
||||
dialect = sniffer.sniff(text[:2048])
|
||||
delimiter = dialect.delimiter
|
||||
except csv_module.Error:
|
||||
delimiter = "," if "," in text[:500] else "\t"
|
||||
|
||||
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
|
||||
for i, row in enumerate(reader):
|
||||
if i == 0:
|
||||
continue
|
||||
if len(row) >= 2 and row[0].strip() and row[1].strip():
|
||||
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
|
||||
elif len(row) >= 1 and row[0].strip():
|
||||
failed_rows.append(i + 1)
|
||||
|
||||
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
|
||||
try:
|
||||
import openpyxl
|
||||
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
|
||||
for sheet in wb.worksheets:
|
||||
for i, row in enumerate(sheet.iter_rows(values_only=True)):
|
||||
if i == 0:
|
||||
continue
|
||||
if len(row) >= 2 and row[0] and row[1]:
|
||||
q = str(row[0]).strip()
|
||||
a = str(row[1]).strip()
|
||||
if q and a:
|
||||
qa_pairs.append({"question": q, "answer": a})
|
||||
elif len(row) >= 1 and row[0]:
|
||||
failed_rows.append(i + 1)
|
||||
wb.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[ImportQA] Excel parse failed: {e}")
|
||||
return {"error": f"Excel parse failed: {e}", "imported": 0}
|
||||
|
||||
if not qa_pairs:
|
||||
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
|
||||
return {"error": "No valid QA pairs found", "imported": 0}
|
||||
|
||||
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
|
||||
|
||||
# 2. 写入 ES
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
sort_id = 0
|
||||
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
|
||||
if items:
|
||||
sort_id = items[0].metadata["sort_id"]
|
||||
|
||||
chunks = []
|
||||
for pair in qa_pairs:
|
||||
sort_id += 1
|
||||
doc_id = uuid.uuid4().hex
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": document_id,
|
||||
"knowledge_id": kb_id,
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
"chunk_type": "qa",
|
||||
"question": pair["question"],
|
||||
"answer": pair["answer"],
|
||||
}
|
||||
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
|
||||
|
||||
batch_size = 50
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch = chunks[i:i + batch_size]
|
||||
vector_service.add_chunks(batch)
|
||||
|
||||
# 3. 更新 chunk_num
|
||||
db_document.chunk_num += len(chunks)
|
||||
db.commit()
|
||||
|
||||
result = {"imported": len(chunks), "failed_rows": failed_rows}
|
||||
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
|
||||
return {"error": str(e), "imported": 0}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user