diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 07379ee4..cb927504 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -1,8 +1,10 @@ import os +import csv +import io from typing import Any, Optional import uuid -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File from fastapi.encoders import jsonable_encoder from sqlalchemy.orm import Session @@ -346,6 +348,46 @@ async def create_chunks_batch( return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully") +@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse) +async def import_qa_chunks( + kb_id: uuid.UUID, + document_id: uuid.UUID, + file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 导入 QA 问答对(CSV/Excel),异步处理 + """ + api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}") + + # 1. 校验文件格式 + filename = file.filename or "" + if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式") + + # 2. 校验知识库和文档 + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) + if not db_knowledge: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问") + + db_document = db.query(Document).filter(Document.id == document_id).first() + if not db_document: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问") + + # 3. 读取文件内容,派发异步任务 + contents = await file.read() + + from app.celery_app import celery_app + task = celery_app.send_task( + "app.core.rag.tasks.import_qa_chunks", + args=[str(kb_id), str(document_id), filename, contents], + queue="qa_import" + ) + + return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中") + + @router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) async def get_chunk( kb_id: uuid.UUID, diff --git a/api/app/tasks.py b/api/app/tasks.py index 4d39cf7a..2fcab818 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -697,6 +697,123 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str): return f"build_graphrag_for_document '{document_id}' failed: {e}" +@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import") +def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes): + """ + 异步导入 QA 问答对(CSV/Excel) + + 文件格式:第一行标题(跳过),第一列问题,第二列答案 + """ + import csv as csv_module + import io + + db = SessionLocal() + try: + db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first() + if not db_document or not db_knowledge: + logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found") + return {"error": "document or knowledge not found", "imported": 0} + + # 1. 解析文件 + qa_pairs = [] + failed_rows = [] + + if filename.endswith(".csv"): + try: + text = contents.decode("utf-8-sig") + except UnicodeDecodeError: + text = contents.decode("gbk", errors="ignore") + + sniffer = csv_module.Sniffer() + try: + dialect = sniffer.sniff(text[:2048]) + delimiter = dialect.delimiter + except csv_module.Error: + delimiter = "," if "," in text[:500] else "\t" + + reader = csv_module.reader(io.StringIO(text), delimiter=delimiter) + for i, row in enumerate(reader): + if i == 0: + continue + if len(row) >= 2 and row[0].strip() and row[1].strip(): + qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()}) + elif len(row) >= 1 and row[0].strip(): + failed_rows.append(i + 1) + + elif filename.endswith(".xlsx") or filename.endswith(".xls"): + try: + import openpyxl + wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) + for sheet in wb.worksheets: + for i, row in enumerate(sheet.iter_rows(values_only=True)): + if i == 0: + continue + if len(row) >= 2 and row[0] and row[1]: + q = str(row[0]).strip() + a = str(row[1]).strip() + if q and a: + qa_pairs.append({"question": q, "answer": a}) + elif len(row) >= 1 and row[0]: + failed_rows.append(i + 1) + wb.close() + except Exception as e: + logger.error(f"[ImportQA] Excel parse failed: {e}") + return {"error": f"Excel parse failed: {e}", "imported": 0} + + if not qa_pairs: + logger.warning(f"[ImportQA] No valid QA pairs found in {filename}") + return {"error": "No valid QA pairs found", "imported": 0} + + logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}") + + # 2. 写入 ES + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + sort_id = 0 + total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False) + if items: + sort_id = items[0].metadata["sort_id"] + + chunks = [] + for pair in qa_pairs: + sort_id += 1 + doc_id = uuid.uuid4().hex + metadata = { + "doc_id": doc_id, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": document_id, + "knowledge_id": kb_id, + "sort_id": sort_id, + "status": 1, + "chunk_type": "qa", + "question": pair["question"], + "answer": pair["answer"], + } + chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata)) + + batch_size = 50 + for i in range(0, len(chunks), batch_size): + batch = chunks[i:i + batch_size] + vector_service.add_chunks(batch) + + # 3. 更新 chunk_num + db_document.chunk_num += len(chunks) + db.commit() + + result = {"imported": len(chunks), "failed_rows": failed_rows} + logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}") + return result + + except Exception as e: + logger.error(f"[ImportQA] Failed: {e}", exc_info=True) + return {"error": str(e), "imported": 0} + finally: + db.close() + + @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") def sync_knowledge_for_kb(kb_id: uuid.UUID): """