From 0f50537d7d25d2569896f36afeadd5d7170709bf Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 9 Apr 2026 14:11:01 +0800 Subject: [PATCH] [modify] mineru --- .../core/rag/deepdoc/parser/mineru_parser.py | 7 +- api/app/tasks.py | 161 +++++++++++++----- 2 files changed, 126 insertions(+), 42 deletions(-) diff --git a/api/app/core/rag/deepdoc/parser/mineru_parser.py b/api/app/core/rag/deepdoc/parser/mineru_parser.py index fe6178ec..c2f7af16 100644 --- a/api/app/core/rag/deepdoc/parser/mineru_parser.py +++ b/api/app/core/rag/deepdoc/parser/mineru_parser.py @@ -292,9 +292,10 @@ class MinerUParser(RAGPdfParser): self.page_from = page_from self.page_to = page_to try: - with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf: - self.pdf = pdf - self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])] + with sys.modules[LOCK_KEY_pdfplumber]: # ← 加这一行,获取全局锁 + with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf: + self.pdf = pdf + self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])] except Exception as e: self.page_images = None self.total_page = 0 diff --git a/api/app/tasks.py b/api/app/tasks.py index 60e9f4e4..c32b8ad9 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -56,6 +56,10 @@ VIDEO_IMAGE_PATTERN = re.compile( DEFAULT_PARSE_LANGUAGE = "Chinese" DEFAULT_PARSE_TO_PAGE = 100_000 EMBEDDING_BATCH_SIZE = 100 +# Embedding 并发写入的最大线程数,需根据模型 API rate limit 调整 +EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", "3")) +# auto_questions LLM 并发调用的最大线程数 +AUTO_QUESTIONS_MAX_WORKERS = int(os.getenv("AUTO_QUESTIONS_MAX_WORKERS", "5")) # 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 @@ -218,6 +222,10 @@ def parse_document(file_path: str, document_id: uuid.UUID): with get_db_context() as db: try: + # Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确 + if not isinstance(document_id, uuid.UUID): + document_id = uuid.UUID(str(document_id)) + db_document = db.query(Document).filter(Document.id == document_id).first() if db_document is None: raise ValueError(f"Document {document_id} not found") @@ -265,7 +273,7 @@ def parse_document(file_path: str, document_id: uuid.UUID): progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.") else: total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE) - progress_per_batch = 0.2 / total_batches # Progress of each batch + progress_per_batch = 0.2 / total_batches vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) # 2.1 Delete document vector index vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) @@ -278,47 +286,114 @@ def parse_document(file_path: str, document_id: uuid.UUID): model_name=db_knowledge.llm.api_keys[0].model_name, base_url=db_knowledge.llm.api_keys[0].api_base, ) - for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): - batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) - batch = res[batch_start: batch_end] - chunks = [] - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - if auto_questions_topn: - cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", - {"topn": auto_questions_topn}) - if not cached: - cached = question_proposal(chat_model, item["content_with_weight"], auto_questions_topn) - set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", - {"topn": auto_questions_topn}) + # 预先构建所有 batch 的 chunks,保证 sort_id 全局有序 + all_batch_chunks: list[list[DocumentChunk]] = [] + + if auto_questions_topn: + # auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组 + # 构建 (global_idx, item) 列表 + indexed_items = list(enumerate(res)) + + def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]: + """为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)""" + global_idx, item = idx_item + content = item["content_with_weight"] + cached = get_llm_cache(chat_model.model_name, content, "question", + {"topn": auto_questions_topn}) + if not cached: + cached = question_proposal(chat_model, content, auto_questions_topn) + set_llm_cache(chat_model.model_name, content, cached, "question", + {"topn": auto_questions_topn}) + return global_idx, cached + + # 并发调用 LLM 生成问题 + question_map: dict[int, str] = {} + with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor: + futures = {q_executor.submit(_generate_question, item): item[0] + for item in indexed_items} + for future in futures: + global_idx, cached = future.result() + question_map[global_idx] = cached + + progress_lines.append( + f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks " + f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).") + + # 按 batch 分组组装 DocumentChunk + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + chunks = [] + for global_idx in range(batch_start, batch_end): + item = res[global_idx] + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + cached = question_map[global_idx] chunks.append( - DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", - metadata=metadata)) - else: + DocumentChunk( + page_content=f"question: {cached} answer: {item['content_with_weight']}", + metadata=metadata)) + all_batch_chunks.append(chunks) + else: + # 无 auto_questions:直接构建 chunks + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + chunks = [] + for global_idx in range(batch_start, batch_end): + item = res[global_idx] + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + all_batch_chunks.append(chunks) - # Bulk segmented vector import - vector_service.add_chunks(chunks) + # 并发提交 embedding + ES 写入,max_workers 控制模型 API 并发压力 + batch_errors: dict[int, Exception] = {} - # Update progress - db_document.progress += progress_per_batch - progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).") - db_document.progress_msg = _progress_msg() - db_document.process_duration = time.time() - start_time - db_document.run = 0 - db.commit() - db.refresh(db_document) + def _embed_and_store(batch_idx: int, batch_chunks: list[DocumentChunk]): + try: + vector_service.add_chunks(batch_chunks) + except Exception as exc: + batch_errors[batch_idx] = exc + + with ThreadPoolExecutor(max_workers=EMBEDDING_MAX_WORKERS) as executor: + futures = { + executor.submit(_embed_and_store, i, batch_chunks): i + for i, batch_chunks in enumerate(all_batch_chunks) + } + # 按提交顺序收集结果,逐批更新进度 + for future in futures: + future.result() # 等待完成(异常已在 _embed_and_store 内捕获) + + # 如果有 batch 失败,汇总抛出 + if batch_errors: + failed = ", ".join(str(i) for i in sorted(batch_errors)) + first_err = next(iter(batch_errors.values())) + raise RuntimeError(f"Embedding failed for batch(es) [{failed}]: {first_err}") from first_err + + # 所有 batch 完成后一次性更新进度 + db_document.progress = 0.8 + 0.2 # 直接到 1.0 前的状态 + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} All {total_batches} batches embedded (workers={EMBEDDING_MAX_WORKERS}).") + db_document.progress_msg = _progress_msg() + db_document.process_duration = time.time() - start_time + db_document.run = 0 + db.commit() + db.refresh(db_document) # Vectorization and data entry completed progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.") @@ -344,13 +419,15 @@ def parse_document(file_path: str, document_id: uuid.UUID): logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True) if db_document is not None: try: + db.rollback() db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n" db_document.run = 0 db.commit() except Exception: logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True) - file_name = db_document.file_name if db_document else document_id - return f"parse document '{file_name}' failed." + # db_document 可能处于 detached/expired 状态,用之前缓存的值或 document_id 兜底 + file_name = getattr(db_document, 'file_name', None) if db_document else None + return f"parse document '{file_name or document_id}' failed." @celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb") @@ -365,6 +442,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with get_db_context() as db: try: + if not isinstance(kb_id, uuid.UUID): + kb_id = uuid.UUID(str(kb_id)) + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() if db_knowledge is None: logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found") @@ -524,6 +604,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): """ with get_db_context() as db: try: + if not isinstance(kb_id, uuid.UUID): + kb_id = uuid.UUID(str(kb_id)) + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() if db_knowledge is None: logger.error(f"[SyncKB] knowledge={kb_id} not found")