diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index cb927504..fe383cb1 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -348,6 +348,69 @@ async def create_chunks_batch( return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully") +@router.post("/{kb_id}/import_qa", response_model=ApiResponse) +async def import_qa_new_doc( + kb_id: uuid.UUID, + file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 导入 QA 问答对并新建文档(CSV/Excel),异步处理 + """ + api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}") + + # 1. 校验文件格式 + filename = file.filename or "" + if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式") + + # 2. 校验知识库 + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) + if not db_knowledge: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问") + + # 3. 创建 File 记录 + from app.schemas import file_schema, document_schema + _, file_extension = os.path.splitext(filename) + file_ext = file_extension.lower() + contents = await file.read() + file_size = len(contents) + + file_data = file_schema.FileCreate( + kb_id=kb_id, created_by=current_user.id, + parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"), + file_name=filename, file_ext=file_ext, file_size=file_size, + ) + db_file = file_service.create_file(db=db, file=file_data, current_user=current_user) + + # 4. 创建 Document 记录 + doc_data = document_schema.DocumentCreate( + kb_id=kb_id, created_by=current_user.id, file_id=db_file.id, + file_name=filename, file_ext=file_ext, file_size=file_size, + file_meta={}, parser_id="naive", + parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, + "delimiter": "\n", "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"} + ) + db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user) + + api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}") + + # 5. 派发异步任务 + from app.celery_app import celery_app + task = celery_app.send_task( + "app.core.rag.tasks.import_qa_chunks", + args=[str(kb_id), str(db_document.id), filename, contents], + queue="qa_import" + ) + + return success(data={ + "task_id": task.id, + "document_id": str(db_document.id), + "file_id": str(db_file.id), + }, msg="QA 导入任务已提交,后台处理中") + + @router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse) async def import_qa_chunks( kb_id: uuid.UUID, diff --git a/api/app/tasks.py b/api/app/tasks.py index 2fcab818..48368b76 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -707,111 +707,111 @@ def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: byte import csv as csv_module import io - db = SessionLocal() + db = None try: - db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first() - if not db_document or not db_knowledge: - logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found") - return {"error": "document or knowledge not found", "imported": 0} + from app.db import get_db_context + with get_db_context() as db: + db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first() + if not db_document or not db_knowledge: + logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found") + return {"error": "document or knowledge not found", "imported": 0} - # 1. 解析文件 - qa_pairs = [] - failed_rows = [] + # 1. 解析文件 + qa_pairs = [] + failed_rows = [] - if filename.endswith(".csv"): - try: - text = contents.decode("utf-8-sig") - except UnicodeDecodeError: - text = contents.decode("gbk", errors="ignore") + if filename.endswith(".csv"): + try: + text = contents.decode("utf-8-sig") + except UnicodeDecodeError: + text = contents.decode("gbk", errors="ignore") - sniffer = csv_module.Sniffer() - try: - dialect = sniffer.sniff(text[:2048]) - delimiter = dialect.delimiter - except csv_module.Error: - delimiter = "," if "," in text[:500] else "\t" + sniffer = csv_module.Sniffer() + try: + dialect = sniffer.sniff(text[:2048]) + delimiter = dialect.delimiter + except csv_module.Error: + delimiter = "," if "," in text[:500] else "\t" - reader = csv_module.reader(io.StringIO(text), delimiter=delimiter) - for i, row in enumerate(reader): - if i == 0: - continue - if len(row) >= 2 and row[0].strip() and row[1].strip(): - qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()}) - elif len(row) >= 1 and row[0].strip(): - failed_rows.append(i + 1) + reader = csv_module.reader(io.StringIO(text), delimiter=delimiter) + for i, row in enumerate(reader): + if i == 0: + continue + if len(row) >= 2 and row[0].strip() and row[1].strip(): + qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()}) + elif len(row) >= 1 and row[0].strip(): + failed_rows.append(i + 1) - elif filename.endswith(".xlsx") or filename.endswith(".xls"): - try: - import openpyxl - wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) - for sheet in wb.worksheets: - for i, row in enumerate(sheet.iter_rows(values_only=True)): - if i == 0: - continue - if len(row) >= 2 and row[0] and row[1]: - q = str(row[0]).strip() - a = str(row[1]).strip() - if q and a: - qa_pairs.append({"question": q, "answer": a}) - elif len(row) >= 1 and row[0]: - failed_rows.append(i + 1) - wb.close() - except Exception as e: - logger.error(f"[ImportQA] Excel parse failed: {e}") - return {"error": f"Excel parse failed: {e}", "imported": 0} + elif filename.endswith(".xlsx") or filename.endswith(".xls"): + try: + import openpyxl + wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) + for sheet in wb.worksheets: + for i, row in enumerate(sheet.iter_rows(values_only=True)): + if i == 0: + continue + if len(row) >= 2 and row[0] and row[1]: + q = str(row[0]).strip() + a = str(row[1]).strip() + if q and a: + qa_pairs.append({"question": q, "answer": a}) + elif len(row) >= 1 and row[0]: + failed_rows.append(i + 1) + wb.close() + except Exception as e: + logger.error(f"[ImportQA] Excel parse failed: {e}") + return {"error": f"Excel parse failed: {e}", "imported": 0} - if not qa_pairs: - logger.warning(f"[ImportQA] No valid QA pairs found in {filename}") - return {"error": "No valid QA pairs found", "imported": 0} + if not qa_pairs: + logger.warning(f"[ImportQA] No valid QA pairs found in {filename}") + return {"error": "No valid QA pairs found", "imported": 0} - logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}") + logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}") - # 2. 写入 ES - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + # 2. 写入 ES + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - sort_id = 0 - total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False) - if items: - sort_id = items[0].metadata["sort_id"] + sort_id = 0 + total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False) + if items: + sort_id = items[0].metadata["sort_id"] - chunks = [] - for pair in qa_pairs: - sort_id += 1 - doc_id = uuid.uuid4().hex - metadata = { - "doc_id": doc_id, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": document_id, - "knowledge_id": kb_id, - "sort_id": sort_id, - "status": 1, - "chunk_type": "qa", - "question": pair["question"], - "answer": pair["answer"], - } - chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata)) + chunks = [] + for pair in qa_pairs: + sort_id += 1 + doc_id = uuid.uuid4().hex + metadata = { + "doc_id": doc_id, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": document_id, + "knowledge_id": kb_id, + "sort_id": sort_id, + "status": 1, + "chunk_type": "qa", + "question": pair["question"], + "answer": pair["answer"], + } + chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata)) - batch_size = 50 - for i in range(0, len(chunks), batch_size): - batch = chunks[i:i + batch_size] - vector_service.add_chunks(batch) + batch_size = 50 + for i in range(0, len(chunks), batch_size): + batch = chunks[i:i + batch_size] + vector_service.add_chunks(batch) - # 3. 更新 chunk_num - db_document.chunk_num += len(chunks) - db.commit() + # 3. 更新 chunk_num + db_document.chunk_num += len(chunks) + db.commit() - result = {"imported": len(chunks), "failed_rows": failed_rows} - logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}") - return result + result = {"imported": len(chunks), "failed_rows": failed_rows} + logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}") + return result except Exception as e: logger.error(f"[ImportQA] Failed: {e}", exc_info=True) return {"error": str(e), "imported": 0} - finally: - db.close() @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")