From 64e640d882a82e83a90568f8cfc7ce31dfbae1c5 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Tue, 28 Apr 2026 15:33:44 +0800 Subject: [PATCH] [add] batch chunk. qa_prompt set --- api/app/controllers/chunk_controller.py | 61 +++++++++++++++++++++++++ api/app/core/config.py | 1 + api/app/core/rag/prompts/generator.py | 16 +++++-- api/app/schemas/chunk_schema.py | 5 ++ api/app/tasks.py | 30 +++++++++--- 5 files changed, 103 insertions(+), 10 deletions(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index fee26669..07379ee4 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -285,6 +285,67 @@ async def create_chunk( return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful") +@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse) +async def create_chunks_batch( + kb_id: uuid.UUID, + document_id: uuid.UUID, + batch_data: chunk_schema.ChunkBatchCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + Batch create chunks (max 8) + """ + api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}") + + if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}" + ) + + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) + if not db_knowledge: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied") + + db_document = db.query(Document).filter(Document.id == document_id).first() + if not db_document: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it") + + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + # Get current max sort_id + sort_id = 0 + total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False) + if items: + sort_id = items[0].metadata["sort_id"] + + chunks = [] + for create_data in batch_data.items: + sort_id += 1 + doc_id = uuid.uuid4().hex + metadata = { + "doc_id": doc_id, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(document_id), + "knowledge_id": str(kb_id), + "sort_id": sort_id, + "status": 1, + } + if create_data.is_qa: + metadata.update(create_data.qa_metadata) + chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata)) + + vector_service.add_chunks(chunks) + + db_document.chunk_num += len(chunks) + db.commit() + + return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully") + + @router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) async def get_chunk( kb_id: uuid.UUID, diff --git a/api/app/core/config.py b/api/app/core/config.py index 64c5520e..725c35ce 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -98,6 +98,7 @@ class Settings: # File Upload MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800")) MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20")) + MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8")) FILE_PATH: str = os.getenv("FILE_PATH", "/files") FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600")) diff --git a/api/app/core/rag/prompts/generator.py b/api/app/core/rag/prompts/generator.py index 88e9171f..4af8fa6f 100644 --- a/api/app/core/rag/prompts/generator.py +++ b/api/app/core/rag/prompts/generator.py @@ -138,9 +138,19 @@ def question_proposal(chat_mdl, content, topn=3): return "\n".join([p["question"] for p in pairs]) -def qa_proposal(chat_mdl, content, topn=3): - """生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]""" - template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) +def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None): + """生成 QA 对,返回 [{"question": ..., "answer": ...}, ...] + + Args: + chat_mdl: LLM 模型 + content: 文本内容 + topn: 生成 QA 对数量 + custom_prompt: 自定义 prompt 模板(支持 Jinja2,可用变量: content, topn) + """ + if custom_prompt: + template = PROMPT_JINJA_ENV.from_string(custom_prompt) + else: + template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] diff --git a/api/app/schemas/chunk_schema.py b/api/app/schemas/chunk_schema.py index b888c361..9c785c5e 100644 --- a/api/app/schemas/chunk_schema.py +++ b/api/app/schemas/chunk_schema.py @@ -77,3 +77,8 @@ class ChunkRetrieve(BaseModel): vector_similarity_weight: float | None = Field(None) top_k: int | None = Field(None) retrieve_type: RetrieveType | None = Field(None) + + +class ChunkBatchCreate(BaseModel): + """批量创建 chunk""" + items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表") diff --git a/api/app/tasks.py b/api/app/tasks.py index ed961115..114959b6 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -311,6 +311,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) # 2.2 Vectorize and import batch documents auto_questions_topn = db_document.parser_config.get("auto_questions", 0) + qa_prompt = db_document.parser_config.get("qa_prompt", None) chat_model = None if auto_questions_topn: chat_model = Base( @@ -318,6 +319,9 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): model_name=db_knowledge.llm.api_keys[0].model_name, base_url=db_knowledge.llm.api_keys[0].api_base, ) + logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}") + if qa_prompt: + logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)") # 预先构建所有 batch 的 chunks,保证 sort_id 全局有序 all_batch_chunks: list[list[DocumentChunk]] = [] @@ -335,15 +339,27 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""): cached = get_llm_cache(chat_model.model_name, content, "qa", {"topn": auto_questions_topn}) if not cached: - pairs = qa_proposal(chat_model, content, auto_questions_topn) - cached = pairs - set_llm_cache(chat_model.model_name, content, cached, "qa", + try: + pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt) + except Exception as e: + logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}") + return global_idx, [] + # 缓存存 JSON 字符串 + set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa", {"topn": auto_questions_topn}) - elif isinstance(cached, str): - # 兼容旧缓存格式(纯文本问题) + return global_idx, pairs + # 从缓存读取:可能是 JSON 字符串或旧格式纯文本 + if isinstance(cached, str): + try: + parsed = json.loads(cached) + if isinstance(parsed, list): + return global_idx, parsed + except (json.JSONDecodeError, TypeError): + pass + # 旧缓存格式(纯文本问题),尝试解析 from app.core.rag.prompts.generator import parse_qa_pairs - cached = parse_qa_pairs(cached) if cached else [] - return global_idx, cached + return global_idx, parse_qa_pairs(cached) if cached else [] + return global_idx, cached if isinstance(cached, list) else [] # 并发调用 LLM 生成 QA 对 qa_map: dict[int, list] = {}