From 3ff44f01087e4e650f52091b3740862350460f12 Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 9 Apr 2026 11:59:02 +0800 Subject: [PATCH 1/3] =?UTF-8?q?[modify]=20=E4=BC=98=E5=8C=96tasks=20?= =?UTF-8?q?=EF=BC=8C=E6=8B=86=E5=88=86graphirag=20=E9=98=9F=E5=88=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/celery_app.py | 5 +- api/app/tasks.py | 565 +++++++++++++++++++++--------------------- 2 files changed, 281 insertions(+), 289 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 23fd82ed..ba74fb19 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -113,9 +113,12 @@ celery_app.conf.update( # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, - 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, + # GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析) + 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'}, + 'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'}, + # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, diff --git a/api/app/tasks.py b/api/app/tasks.py index f918743c..60e9f4e4 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -44,6 +44,19 @@ from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) +# ── 预编译文件类型正则 & 常量 ────────────────────────────────── +AUDIO_PATTERN = re.compile( + r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", + re.IGNORECASE, +) +VIDEO_IMAGE_PATTERN = re.compile( + r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", + re.IGNORECASE, +) +DEFAULT_PARSE_LANGUAGE = "Chinese" +DEFAULT_PARSE_TO_PAGE = 100_000 +EMBEDDING_BATCH_SIZE = 100 + # 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 # 使用连接池而非单例客户端,提供更好的并发性能和自动重连 @@ -160,28 +173,63 @@ def process_item(item: dict): return result +def _build_vision_model(file_path: str, db_knowledge): + """根据文件类型选择合适的视觉/音频模型,避免冗余初始化。""" + if AUDIO_PATTERN.search(file_path): + omni_key = os.getenv("QWEN3_OMNI_API_KEY", "") + omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash") + omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + return QWenSeq2txt( + key=omni_key, + model_name=omni_model, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=omni_base, + ) + if VIDEO_IMAGE_PATTERN.search(file_path): + omni_key = os.getenv("QWEN3_OMNI_API_KEY", "") + omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash") + omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + return QWenCV( + key=omni_key, + model_name=omni_model, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=omni_base, + ) + # 默认:使用知识库配置的 image2text 模型 + return QWenCV( + key=db_knowledge.image2text.api_keys[0].api_key, + model_name=db_knowledge.image2text.api_keys[0].model_name, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=db_knowledge.image2text.api_keys[0].api_base, + ) + + @celery_app.task(name="app.core.rag.tasks.parse_document") def parse_document(file_path: str, document_id: uuid.UUID): """ Document parsing, vectorization, and storage """ - # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import importlib - import trio - importlib.reload(trio) - db = next(get_db()) # Manually call the generator db_document = None - db_knowledge = None - progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" - try: + progress_lines: list[str] = [f"{datetime.now().strftime('%H:%M:%S')} Task has been received."] + + def _progress_msg() -> str: + return "\n".join(progress_lines) + "\n" + + with get_db_context() as db: + try: db_document = db.query(Document).filter(Document.id == document_id).first() + if db_document is None: + raise ValueError(f"Document {document_id} not found") db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() + if db_knowledge is None: + raise ValueError(f"Knowledge {db_document.kb_id} not found") + # 1. Document parsing & segmentation - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.") start_time = time.time() db_document.progress = 0.0 - db_document.progress_msg = progress_msg + db_document.progress_msg = _progress_msg() db_document.process_begin_at = datetime.now(tz=timezone.utc) db_document.process_duration = 0.0 db_document.run = 1 @@ -189,220 +237,120 @@ def parse_document(file_path: str, document_id: uuid.UUID): db.refresh(db_document) def progress_callback(prog=None, msg=None): - nonlocal progress_msg # Declare the use of an external progress_msg variable - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.") - # Prepare to configure chat_mdl、embedding_model、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - if re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", file_path, - re.IGNORECASE): - vision_model = QWenSeq2txt( - key=os.getenv("QWEN3_OMNI_API_KEY", ""), - model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"), - lang="Chinese", - base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"), - ) - elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", file_path, - re.IGNORECASE): - vision_model = QWenCV( - key=os.getenv("QWEN3_OMNI_API_KEY", ""), - model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"), - lang="Chinese", - base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"), - ) - else: - print(file_path) + # Prepare vision_model for parsing + vision_model = _build_vision_model(file_path, db_knowledge) from app.core.rag.app.naive import chunk res = chunk(filename=file_path, from_page=0, - to_page=100000, + to_page=DEFAULT_PARSE_TO_PAGE, callback=progress_callback, vision_model=vision_model, parser_config=db_document.parser_config, is_root=False) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.") db_document.progress = 0.8 - db_document.progress_msg = progress_msg + db_document.progress_msg = _progress_msg() db.commit() db.refresh(db_document) # 2. Document vectorization and storage total_chunks = len(res) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" - batch_size = 100 - total_batches = ceil(total_chunks / batch_size) - progress_per_batch = 0.2 / total_batches # Progress of each batch - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2.1 Delete document vector index - vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) - # 2.2 Vectorize and import batch documents - for batch_start in range(0, total_chunks, batch_size): - batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds - batch = res[batch_start: batch_end] # Retrieve the current batch - chunks = [] + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.") - # Process the current batch - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch # Calculate global index - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - if db_document.parser_config.get("auto_questions", 0): - topn = db_document.parser_config["auto_questions"] - cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", - {"topn": topn}) - if not cached: - cached = question_proposal(chat_model, item["content_with_weight"], topn) - set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", - {"topn": topn}) - chunks.append( - DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", - metadata=metadata)) - else: - chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + if total_chunks == 0: + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.") + else: + total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE) + progress_per_batch = 0.2 / total_batches # Progress of each batch + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + # 2.1 Delete document vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) + # 2.2 Vectorize and import batch documents + auto_questions_topn = db_document.parser_config.get("auto_questions", 0) + chat_model = None + if auto_questions_topn: + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + batch = res[batch_start: batch_end] + chunks = [] - # Bulk segmented vector import - vector_service.add_chunks(chunks) + for idx_in_batch, item in enumerate(batch): + global_idx = batch_start + idx_in_batch + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + if auto_questions_topn: + cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", + {"topn": auto_questions_topn}) + if not cached: + cached = question_proposal(chat_model, item["content_with_weight"], auto_questions_topn) + set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", + {"topn": auto_questions_topn}) + chunks.append( + DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", + metadata=metadata)) + else: + chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) - # Update progress - db_document.progress += progress_per_batch - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" - db_document.progress_msg = progress_msg - db_document.process_duration = time.time() - start_time - db_document.run = 0 - db.commit() - db.refresh(db_document) + # Bulk segmented vector import + vector_service.add_chunks(chunks) + + # Update progress + db_document.progress += progress_per_batch + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).") + db_document.progress_msg = _progress_msg() + db_document.process_duration = time.time() - start_time + db_document.run = 0 + db.commit() + db.refresh(db_document) # Vectorization and data entry completed - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.") db_document.chunk_num = total_chunks db_document.progress = 1.0 db_document.process_duration = time.time() - start_time - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" - db_document.progress_msg = progress_msg + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).") + db_document.progress_msg = _progress_msg() db_document.run = 0 db.commit() - # using graphrag + # GraphRAG: 异步派发到独立队列,不阻塞文档解析流程 if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False): - graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) - with_resolution = graphrag_conf.get("resolution", False) - with_community = graphrag_conf.get("community", False) - - def callback(*args, msg=None, **kwargs): - nonlocal progress_msg - message = msg or (args[0] if args else "No message") - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n" - - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to run graphrag.\n" - start_time = time.time() - db_document.progress_msg = progress_msg + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG enabled, dispatching async task.") + db_document.progress_msg = _progress_msg() db.commit() - db.refresh(db_document) - - task = { - "id": str(db_document.id), - "workspace_id": str(db_knowledge.workspace_id), - "kb_id": str(db_knowledge.id), - "parser_config": db_knowledge.parser_config, - } - - # init_graphrag - vts, _ = embedding_model.encode(["ok"]) - vector_size = len(vts[0]) - init_graphrag(task, vector_size) - - async def _run( - row: dict, - document_ids: list[str], - language: str, - parser_config: dict, - vector_service, - chat_model, - embedding_model, - callback, - with_resolution: bool = True, - with_community: bool = True - ) -> dict: - await trio.sleep(5) # Delay for 10 seconds - nonlocal progress_msg # Declare the use of an external progress_msg variable - result = await run_graphrag_for_kb( - row=row, - document_ids=document_ids, - language=language, - parser_config=parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n" - return result - - def sync_task(): - trio.run( - lambda: _run( - row=task, - document_ids=[str(db_document.id)], - language="Chinese", - parser_config=db_knowledge.parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) - ) - - try: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(sync_task) - future.result() # Blocks until the task completes - except Exception as e: - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n" - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)" - db_document.progress_msg = progress_msg - db.commit() - db.refresh(db_document) + build_graphrag_for_document.delay(str(document_id), str(db_knowledge.id)) result = f"parse document '{db_document.file_name}' processed successfully." + logger.info(f"[ParseDoc] document={document_id} file='{db_document.file_name}' done in {db_document.process_duration:.1f}s, chunks={total_chunks}") return result - except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" - db_document.run = 0 - db.commit() - result = f"parse document '{db_document.file_name}' failed." - return result - finally: - db.close() + except Exception as e: + logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True) + if db_document is not None: + try: + db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n" + db_document.run = 0 + db.commit() + except Exception: + logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True) + file_name = db_document.file_name if db_document else document_id + return f"parse document '{file_name}' failed." @celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb") @@ -410,51 +358,41 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): """ build knowledge graph """ - # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) import importlib import trio importlib.reload(trio) - db = next(get_db()) # Manually call the generator - db_documents = None - db_knowledge = None - try: - db_documents = db.query(Document).filter(Document.kb_id == kb_id).all() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() - # 1. Prepare to configure chat_mdl、embedding_model、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - # 2. get all document_ids from knowledge base - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - total, items = vector_service.search_by_segment(document_id=None, query=None, pagesize=9999, page=1, asc=True) - document_ids = [str(item.id) for item in db_documents] + with get_db_context() as db: + try: + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + if db_knowledge is None: + logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found") + return f"build knowledge graph failed: knowledge not found" + + if not (db_knowledge.parser_config and + db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False)): + return f"build knowledge graph '{db_knowledge.name}' skipped: graphrag not enabled" + + db_documents = db.query(Document).filter(Document.kb_id == kb_id).all() + document_ids = [str(doc.id) for doc in db_documents] + + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base, + ) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2. using graphrag - if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False): graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) - def callback(*args, msg=None, **kwargs): - message = msg or (args[0] if args else "No message") - print(f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n") - - start_time = time.time() task = { "id": str(db_knowledge.id), "workspace_id": str(db_knowledge.workspace_id), @@ -467,14 +405,18 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): vector_size = len(vts[0]) init_graphrag(task, vector_size) - async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service, - chat_model, embedding_model, callback, with_resolution: bool = True, - with_community: bool = True, ) -> dict: - result = await run_graphrag_for_kb( - row=row, + def callback(*args, msg=None, **kwargs): + message = msg or (args[0] if args else "No message") + logger.info(f"[GraphRAG-KB] kb={kb_id} msg: {message}") + + start_time = time.time() + + async def _run() -> dict: + return await run_graphrag_for_kb( + row=task, document_ids=document_ids, - language=language, - parser_config=parser_config, + language=DEFAULT_PARSE_LANGUAGE, + parser_config=db_knowledge.parser_config, vector_service=vector_service, chat_model=chat_model, embedding_model=embedding_model, @@ -482,46 +424,97 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with_resolution=with_resolution, with_community=with_community, ) - print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n") - return result - def sync_task(): - trio.run( - lambda: _run( - row=task, - document_ids=document_ids, - language="Chinese", - parser_config=db_knowledge.parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) + result = trio.run(_run) + duration = time.time() - start_time + logger.info(f"[GraphRAG-KB] kb={kb_id} done in {duration:.1f}s, result: {result}") + + return f"build knowledge graph '{db_knowledge.name}' processed successfully." + except Exception as e: + logger.error(f"[GraphRAG-KB] kb={kb_id} failed: {e}", exc_info=True) + return f"build knowledge graph failed: {e}" + + +@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_document") +def build_graphrag_for_document(document_id: str, knowledge_id: str): + """ + 为单个文档构建 GraphRAG,由 parse_document 异步派发。 + """ + import importlib + + import trio + importlib.reload(trio) + + with get_db_context() as db: + try: + db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(knowledge_id)).first() + if db_document is None or db_knowledge is None: + logger.error(f"[GraphRAG] document={document_id} or knowledge={knowledge_id} not found") + return f"build_graphrag_for_document failed: record not found" + + graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) + with_resolution = graphrag_conf.get("resolution", False) + with_community = graphrag_conf.get("community", False) + + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base, + ) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + task = { + "id": document_id, + "workspace_id": str(db_knowledge.workspace_id), + "kb_id": str(db_knowledge.id), + "parser_config": db_knowledge.parser_config, + } + + # init_graphrag + vts, _ = embedding_model.encode(["ok"]) + vector_size = len(vts[0]) + init_graphrag(task, vector_size) + + def callback(*args, msg=None, **kwargs): + message = msg or (args[0] if args else "No message") + logger.info(f"[GraphRAG] doc={document_id} msg: {message}") + + start_time = time.time() + + async def _run() -> dict: + await trio.sleep(5) + return await run_graphrag_for_kb( + row=task, + document_ids=[document_id], + language=DEFAULT_PARSE_LANGUAGE, + parser_config=db_knowledge.parser_config, + vector_service=vector_service, + chat_model=chat_model, + embedding_model=embedding_model, + callback=callback, + with_resolution=with_resolution, + with_community=with_community, ) - try: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(sync_task) - future.result() # Blocks until the task completes - except Exception as e: - print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n") - finally: - if db: - db.close() - print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)") + result = trio.run(_run) + duration = time.time() - start_time + logger.info(f"[GraphRAG] doc={document_id} done in {duration:.1f}s") - result = f"build knowledge graph '{db_knowledge.name}' processed successfully." - return result - except Exception as e: - if 'db_knowledge' in locals(): - print(f"Failed to build knowledge grap:{str(e)}\n") - result = f"build knowledge grap '{db_knowledge.name}' failed." - return result - finally: - if db: - db.close() + # 更新文档进度信息 + db_document.progress_msg = (db_document.progress_msg or "") + \ + f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({duration:.1f}s)\n" + db.commit() + + return f"build_graphrag_for_document '{document_id}' processed successfully." + except Exception as e: + logger.error(f"[GraphRAG] doc={document_id} failed: {e}", exc_info=True) + return f"build_graphrag_for_document '{document_id}' failed: {e}" @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") @@ -529,10 +522,13 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): """ sync knowledge document and Document parsing, vectorization, and storage """ - db = next(get_db()) # Manually call the generator - db_knowledge = None - try: + with get_db_context() as db: + try: db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + if db_knowledge is None: + logger.error(f"[SyncKB] knowledge={kb_id} not found") + return f"sync knowledge failed: knowledge not found" + # 1. get vector_service vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) @@ -667,7 +663,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during crawl: {e}") + logger.error(f"[SyncKB] Error during crawl: {e}", exc_info=True) case "Third-party": # Integration of knowledge bases from three parties yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") @@ -685,13 +681,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): # Get all files from all repos async def async_get_files(api_client: YuqueAPIClient): async with api_client as client: - print("\n=== Fetching repositories ===") repos = await client.get_user_repos() - print(f"Found {len(repos)} repositories:") all_files = [] for repo in repos: - # Get documents from repository - print(f"\n=== Fetching documents from '{repo.name}' ===") docs = await client.get_repo_docs(repo.id) all_files.extend(docs) return all_files @@ -837,7 +829,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during fetch feishu: {e}") + logger.error(f"[SyncKB] Error during fetch yuque: {e}", exc_info=True) if feishu_app_id: # Feishu Knowledge Base feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") @@ -999,19 +991,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during fetch feishu: {e}") + logger.error(f"[SyncKB] Error during fetch feishu: {e}", exc_info=True) case _: # General - print(f"General: No synchronization needed\n") + logger.info(f"[SyncKB] kb={kb_id} type={db_knowledge.type}: no synchronization needed") result = f"sync knowledge '{db_knowledge.name}' processed successfully." return result - except Exception as e: - if 'db_knowledge' in locals(): - print(f"Failed to sync knowledge:{str(e)}\n") - result = f"sync knowledge '{db_knowledge.name}' failed." - return result - finally: - db.close() + except Exception as e: + logger.error(f"[SyncKB] kb={kb_id} failed: {e}", exc_info=True) + kb_name = db_knowledge.name if db_knowledge else kb_id + return f"sync knowledge '{kb_name}' failed: {e}" @celery_app.task(name="app.core.memory.agent.read_message", bind=True) From 0f50537d7d25d2569896f36afeadd5d7170709bf Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 9 Apr 2026 14:11:01 +0800 Subject: [PATCH 2/3] [modify] mineru --- .../core/rag/deepdoc/parser/mineru_parser.py | 7 +- api/app/tasks.py | 161 +++++++++++++----- 2 files changed, 126 insertions(+), 42 deletions(-) diff --git a/api/app/core/rag/deepdoc/parser/mineru_parser.py b/api/app/core/rag/deepdoc/parser/mineru_parser.py index fe6178ec..c2f7af16 100644 --- a/api/app/core/rag/deepdoc/parser/mineru_parser.py +++ b/api/app/core/rag/deepdoc/parser/mineru_parser.py @@ -292,9 +292,10 @@ class MinerUParser(RAGPdfParser): self.page_from = page_from self.page_to = page_to try: - with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf: - self.pdf = pdf - self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])] + with sys.modules[LOCK_KEY_pdfplumber]: # ← 加这一行,获取全局锁 + with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf: + self.pdf = pdf + self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])] except Exception as e: self.page_images = None self.total_page = 0 diff --git a/api/app/tasks.py b/api/app/tasks.py index 60e9f4e4..c32b8ad9 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -56,6 +56,10 @@ VIDEO_IMAGE_PATTERN = re.compile( DEFAULT_PARSE_LANGUAGE = "Chinese" DEFAULT_PARSE_TO_PAGE = 100_000 EMBEDDING_BATCH_SIZE = 100 +# Embedding 并发写入的最大线程数,需根据模型 API rate limit 调整 +EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", "3")) +# auto_questions LLM 并发调用的最大线程数 +AUTO_QUESTIONS_MAX_WORKERS = int(os.getenv("AUTO_QUESTIONS_MAX_WORKERS", "5")) # 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 @@ -218,6 +222,10 @@ def parse_document(file_path: str, document_id: uuid.UUID): with get_db_context() as db: try: + # Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确 + if not isinstance(document_id, uuid.UUID): + document_id = uuid.UUID(str(document_id)) + db_document = db.query(Document).filter(Document.id == document_id).first() if db_document is None: raise ValueError(f"Document {document_id} not found") @@ -265,7 +273,7 @@ def parse_document(file_path: str, document_id: uuid.UUID): progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.") else: total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE) - progress_per_batch = 0.2 / total_batches # Progress of each batch + progress_per_batch = 0.2 / total_batches vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) # 2.1 Delete document vector index vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) @@ -278,47 +286,114 @@ def parse_document(file_path: str, document_id: uuid.UUID): model_name=db_knowledge.llm.api_keys[0].model_name, base_url=db_knowledge.llm.api_keys[0].api_base, ) - for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): - batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) - batch = res[batch_start: batch_end] - chunks = [] - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - if auto_questions_topn: - cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", - {"topn": auto_questions_topn}) - if not cached: - cached = question_proposal(chat_model, item["content_with_weight"], auto_questions_topn) - set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", - {"topn": auto_questions_topn}) + # 预先构建所有 batch 的 chunks,保证 sort_id 全局有序 + all_batch_chunks: list[list[DocumentChunk]] = [] + + if auto_questions_topn: + # auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组 + # 构建 (global_idx, item) 列表 + indexed_items = list(enumerate(res)) + + def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]: + """为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)""" + global_idx, item = idx_item + content = item["content_with_weight"] + cached = get_llm_cache(chat_model.model_name, content, "question", + {"topn": auto_questions_topn}) + if not cached: + cached = question_proposal(chat_model, content, auto_questions_topn) + set_llm_cache(chat_model.model_name, content, cached, "question", + {"topn": auto_questions_topn}) + return global_idx, cached + + # 并发调用 LLM 生成问题 + question_map: dict[int, str] = {} + with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor: + futures = {q_executor.submit(_generate_question, item): item[0] + for item in indexed_items} + for future in futures: + global_idx, cached = future.result() + question_map[global_idx] = cached + + progress_lines.append( + f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks " + f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).") + + # 按 batch 分组组装 DocumentChunk + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + chunks = [] + for global_idx in range(batch_start, batch_end): + item = res[global_idx] + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + cached = question_map[global_idx] chunks.append( - DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", - metadata=metadata)) - else: + DocumentChunk( + page_content=f"question: {cached} answer: {item['content_with_weight']}", + metadata=metadata)) + all_batch_chunks.append(chunks) + else: + # 无 auto_questions:直接构建 chunks + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + chunks = [] + for global_idx in range(batch_start, batch_end): + item = res[global_idx] + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + all_batch_chunks.append(chunks) - # Bulk segmented vector import - vector_service.add_chunks(chunks) + # 并发提交 embedding + ES 写入,max_workers 控制模型 API 并发压力 + batch_errors: dict[int, Exception] = {} - # Update progress - db_document.progress += progress_per_batch - progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).") - db_document.progress_msg = _progress_msg() - db_document.process_duration = time.time() - start_time - db_document.run = 0 - db.commit() - db.refresh(db_document) + def _embed_and_store(batch_idx: int, batch_chunks: list[DocumentChunk]): + try: + vector_service.add_chunks(batch_chunks) + except Exception as exc: + batch_errors[batch_idx] = exc + + with ThreadPoolExecutor(max_workers=EMBEDDING_MAX_WORKERS) as executor: + futures = { + executor.submit(_embed_and_store, i, batch_chunks): i + for i, batch_chunks in enumerate(all_batch_chunks) + } + # 按提交顺序收集结果,逐批更新进度 + for future in futures: + future.result() # 等待完成(异常已在 _embed_and_store 内捕获) + + # 如果有 batch 失败,汇总抛出 + if batch_errors: + failed = ", ".join(str(i) for i in sorted(batch_errors)) + first_err = next(iter(batch_errors.values())) + raise RuntimeError(f"Embedding failed for batch(es) [{failed}]: {first_err}") from first_err + + # 所有 batch 完成后一次性更新进度 + db_document.progress = 0.8 + 0.2 # 直接到 1.0 前的状态 + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} All {total_batches} batches embedded (workers={EMBEDDING_MAX_WORKERS}).") + db_document.progress_msg = _progress_msg() + db_document.process_duration = time.time() - start_time + db_document.run = 0 + db.commit() + db.refresh(db_document) # Vectorization and data entry completed progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.") @@ -344,13 +419,15 @@ def parse_document(file_path: str, document_id: uuid.UUID): logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True) if db_document is not None: try: + db.rollback() db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n" db_document.run = 0 db.commit() except Exception: logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True) - file_name = db_document.file_name if db_document else document_id - return f"parse document '{file_name}' failed." + # db_document 可能处于 detached/expired 状态,用之前缓存的值或 document_id 兜底 + file_name = getattr(db_document, 'file_name', None) if db_document else None + return f"parse document '{file_name or document_id}' failed." @celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb") @@ -365,6 +442,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with get_db_context() as db: try: + if not isinstance(kb_id, uuid.UUID): + kb_id = uuid.UUID(str(kb_id)) + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() if db_knowledge is None: logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found") @@ -524,6 +604,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): """ with get_db_context() as db: try: + if not isinstance(kb_id, uuid.UUID): + kb_id = uuid.UUID(str(kb_id)) + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() if db_knowledge is None: logger.error(f"[SyncKB] knowledge={kb_id} not found") From a96f20ee0542baadd5bd3a859b99eac89ac5ac08 Mon Sep 17 00:00:00 2001 From: Mark <348207283@qq.com> Date: Mon, 13 Apr 2026 10:40:58 +0800 Subject: [PATCH 3/3] [modify] parse document workflow, add graph queue hand build graph --- api/app/tasks.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/api/app/tasks.py b/api/app/tasks.py index c32b8ad9..98b413c3 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -55,7 +55,7 @@ VIDEO_IMAGE_PATTERN = re.compile( ) DEFAULT_PARSE_LANGUAGE = "Chinese" DEFAULT_PARSE_TO_PAGE = 100_000 -EMBEDDING_BATCH_SIZE = 100 +EMBEDDING_BATCH_SIZE = int(os.getenv("EMBEDDING_BATCH_SIZE", "20")) # Embedding 并发写入的最大线程数,需根据模型 API rate limit 调整 EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", "3")) # auto_questions LLM 并发调用的最大线程数 @@ -369,22 +369,28 @@ def parse_document(file_path: str, document_id: uuid.UUID): try: vector_service.add_chunks(batch_chunks) except Exception as exc: - batch_errors[batch_idx] = exc + logger.warning(f"[ParseDoc] batch {batch_idx} failed, retrying: {exc}") + try: + vector_service.add_chunks(batch_chunks) + except Exception as retry_exc: + logger.error(f"[ParseDoc] batch {batch_idx} retry failed: {retry_exc}", exc_info=True) + batch_errors[batch_idx] = retry_exc with ThreadPoolExecutor(max_workers=EMBEDDING_MAX_WORKERS) as executor: futures = { executor.submit(_embed_and_store, i, batch_chunks): i for i, batch_chunks in enumerate(all_batch_chunks) } - # 按提交顺序收集结果,逐批更新进度 for future in futures: - future.result() # 等待完成(异常已在 _embed_and_store 内捕获) + future.result() # 如果有 batch 失败,汇总抛出 if batch_errors: - failed = ", ".join(str(i) for i in sorted(batch_errors)) - first_err = next(iter(batch_errors.values())) - raise RuntimeError(f"Embedding failed for batch(es) [{failed}]: {first_err}") from first_err + failed_detail = "; ".join( + f"batch {i}: {type(err).__name__}: {err}" + for i, err in sorted(batch_errors.items()) + ) + raise RuntimeError(f"Embedding failed for {len(batch_errors)}/{total_batches} batch(es). {failed_detail}") # 所有 batch 完成后一次性更新进度 db_document.progress = 0.8 + 0.2 # 直接到 1.0 前的状态