no message

This commit is contained in:
Mark
2026-04-29 11:44:03 +08:00
parent 90aa4cef21
commit 6e89302cb2
2 changed files with 150 additions and 87 deletions

View File

@@ -348,6 +348,69 @@ async def create_chunks_batch(
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
async def import_qa_new_doc(
kb_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
导入 QA 问答对并新建文档CSV/Excel异步处理
"""
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
filename = file.filename or ""
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
# 2. 校验知识库
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
# 3. 创建 File 记录
from app.schemas import file_schema, document_schema
_, file_extension = os.path.splitext(filename)
file_ext = file_extension.lower()
contents = await file.read()
file_size = len(contents)
file_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id,
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
file_name=filename, file_ext=file_ext, file_size=file_size,
)
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
# 4. 创建 Document 记录
doc_data = document_schema.DocumentCreate(
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
file_name=filename, file_ext=file_ext, file_size=file_size,
file_meta={}, parser_id="naive",
parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128,
"delimiter": "\n", "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
)
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}")
# 5. 派发异步任务
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",
args=[str(kb_id), str(db_document.id), filename, contents],
queue="qa_import"
)
return success(data={
"task_id": task.id,
"document_id": str(db_document.id),
"file_id": str(db_file.id),
}, msg="QA 导入任务已提交,后台处理中")
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
async def import_qa_chunks(
kb_id: uuid.UUID,

View File

@@ -707,111 +707,111 @@ def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: byte
import csv as csv_module
import io
db = SessionLocal()
db = None
try:
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
if not db_document or not db_knowledge:
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
return {"error": "document or knowledge not found", "imported": 0}
from app.db import get_db_context
with get_db_context() as db:
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
if not db_document or not db_knowledge:
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
return {"error": "document or knowledge not found", "imported": 0}
# 1. 解析文件
qa_pairs = []
failed_rows = []
# 1. 解析文件
qa_pairs = []
failed_rows = []
if filename.endswith(".csv"):
try:
text = contents.decode("utf-8-sig")
except UnicodeDecodeError:
text = contents.decode("gbk", errors="ignore")
if filename.endswith(".csv"):
try:
text = contents.decode("utf-8-sig")
except UnicodeDecodeError:
text = contents.decode("gbk", errors="ignore")
sniffer = csv_module.Sniffer()
try:
dialect = sniffer.sniff(text[:2048])
delimiter = dialect.delimiter
except csv_module.Error:
delimiter = "," if "," in text[:500] else "\t"
sniffer = csv_module.Sniffer()
try:
dialect = sniffer.sniff(text[:2048])
delimiter = dialect.delimiter
except csv_module.Error:
delimiter = "," if "," in text[:500] else "\t"
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
for i, row in enumerate(reader):
if i == 0:
continue
if len(row) >= 2 and row[0].strip() and row[1].strip():
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
elif len(row) >= 1 and row[0].strip():
failed_rows.append(i + 1)
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
for i, row in enumerate(reader):
if i == 0:
continue
if len(row) >= 2 and row[0].strip() and row[1].strip():
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
elif len(row) >= 1 and row[0].strip():
failed_rows.append(i + 1)
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
try:
import openpyxl
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
for sheet in wb.worksheets:
for i, row in enumerate(sheet.iter_rows(values_only=True)):
if i == 0:
continue
if len(row) >= 2 and row[0] and row[1]:
q = str(row[0]).strip()
a = str(row[1]).strip()
if q and a:
qa_pairs.append({"question": q, "answer": a})
elif len(row) >= 1 and row[0]:
failed_rows.append(i + 1)
wb.close()
except Exception as e:
logger.error(f"[ImportQA] Excel parse failed: {e}")
return {"error": f"Excel parse failed: {e}", "imported": 0}
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
try:
import openpyxl
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
for sheet in wb.worksheets:
for i, row in enumerate(sheet.iter_rows(values_only=True)):
if i == 0:
continue
if len(row) >= 2 and row[0] and row[1]:
q = str(row[0]).strip()
a = str(row[1]).strip()
if q and a:
qa_pairs.append({"question": q, "answer": a})
elif len(row) >= 1 and row[0]:
failed_rows.append(i + 1)
wb.close()
except Exception as e:
logger.error(f"[ImportQA] Excel parse failed: {e}")
return {"error": f"Excel parse failed: {e}", "imported": 0}
if not qa_pairs:
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
return {"error": "No valid QA pairs found", "imported": 0}
if not qa_pairs:
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
return {"error": "No valid QA pairs found", "imported": 0}
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
# 2. 写入 ES
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 2. 写入 ES
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
sort_id = 0
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
sort_id = 0
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
chunks = []
for pair in qa_pairs:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": document_id,
"knowledge_id": kb_id,
"sort_id": sort_id,
"status": 1,
"chunk_type": "qa",
"question": pair["question"],
"answer": pair["answer"],
}
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
chunks = []
for pair in qa_pairs:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": document_id,
"knowledge_id": kb_id,
"sort_id": sort_id,
"status": 1,
"chunk_type": "qa",
"question": pair["question"],
"answer": pair["answer"],
}
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
batch_size = 50
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
vector_service.add_chunks(batch)
batch_size = 50
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
vector_service.add_chunks(batch)
# 3. 更新 chunk_num
db_document.chunk_num += len(chunks)
db.commit()
# 3. 更新 chunk_num
db_document.chunk_num += len(chunks)
db.commit()
result = {"imported": len(chunks), "failed_rows": failed_rows}
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
return result
result = {"imported": len(chunks), "failed_rows": failed_rows}
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
return result
except Exception as e:
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
return {"error": str(e), "imported": 0}
finally:
db.close()
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")