diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 23fd82ed..ba74fb19 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -113,9 +113,12 @@ celery_app.conf.update( # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, - 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, + # GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析) + 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'}, + 'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'}, + # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, diff --git a/api/app/tasks.py b/api/app/tasks.py index f918743c..60e9f4e4 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -44,6 +44,19 @@ from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) +# ── 预编译文件类型正则 & 常量 ────────────────────────────────── +AUDIO_PATTERN = re.compile( + r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", + re.IGNORECASE, +) +VIDEO_IMAGE_PATTERN = re.compile( + r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", + re.IGNORECASE, +) +DEFAULT_PARSE_LANGUAGE = "Chinese" +DEFAULT_PARSE_TO_PAGE = 100_000 +EMBEDDING_BATCH_SIZE = 100 + # 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 # 使用连接池而非单例客户端,提供更好的并发性能和自动重连 @@ -160,28 +173,63 @@ def process_item(item: dict): return result +def _build_vision_model(file_path: str, db_knowledge): + """根据文件类型选择合适的视觉/音频模型,避免冗余初始化。""" + if AUDIO_PATTERN.search(file_path): + omni_key = os.getenv("QWEN3_OMNI_API_KEY", "") + omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash") + omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + return QWenSeq2txt( + key=omni_key, + model_name=omni_model, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=omni_base, + ) + if VIDEO_IMAGE_PATTERN.search(file_path): + omni_key = os.getenv("QWEN3_OMNI_API_KEY", "") + omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash") + omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + return QWenCV( + key=omni_key, + model_name=omni_model, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=omni_base, + ) + # 默认:使用知识库配置的 image2text 模型 + return QWenCV( + key=db_knowledge.image2text.api_keys[0].api_key, + model_name=db_knowledge.image2text.api_keys[0].model_name, + lang=DEFAULT_PARSE_LANGUAGE, + base_url=db_knowledge.image2text.api_keys[0].api_base, + ) + + @celery_app.task(name="app.core.rag.tasks.parse_document") def parse_document(file_path: str, document_id: uuid.UUID): """ Document parsing, vectorization, and storage """ - # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import importlib - import trio - importlib.reload(trio) - db = next(get_db()) # Manually call the generator db_document = None - db_knowledge = None - progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" - try: + progress_lines: list[str] = [f"{datetime.now().strftime('%H:%M:%S')} Task has been received."] + + def _progress_msg() -> str: + return "\n".join(progress_lines) + "\n" + + with get_db_context() as db: + try: db_document = db.query(Document).filter(Document.id == document_id).first() + if db_document is None: + raise ValueError(f"Document {document_id} not found") db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() + if db_knowledge is None: + raise ValueError(f"Knowledge {db_document.kb_id} not found") + # 1. Document parsing & segmentation - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.") start_time = time.time() db_document.progress = 0.0 - db_document.progress_msg = progress_msg + db_document.progress_msg = _progress_msg() db_document.process_begin_at = datetime.now(tz=timezone.utc) db_document.process_duration = 0.0 db_document.run = 1 @@ -189,220 +237,120 @@ def parse_document(file_path: str, document_id: uuid.UUID): db.refresh(db_document) def progress_callback(prog=None, msg=None): - nonlocal progress_msg # Declare the use of an external progress_msg variable - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.") - # Prepare to configure chat_mdl、embedding_model、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - if re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", file_path, - re.IGNORECASE): - vision_model = QWenSeq2txt( - key=os.getenv("QWEN3_OMNI_API_KEY", ""), - model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"), - lang="Chinese", - base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"), - ) - elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", file_path, - re.IGNORECASE): - vision_model = QWenCV( - key=os.getenv("QWEN3_OMNI_API_KEY", ""), - model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"), - lang="Chinese", - base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"), - ) - else: - print(file_path) + # Prepare vision_model for parsing + vision_model = _build_vision_model(file_path, db_knowledge) from app.core.rag.app.naive import chunk res = chunk(filename=file_path, from_page=0, - to_page=100000, + to_page=DEFAULT_PARSE_TO_PAGE, callback=progress_callback, vision_model=vision_model, parser_config=db_document.parser_config, is_root=False) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.") db_document.progress = 0.8 - db_document.progress_msg = progress_msg + db_document.progress_msg = _progress_msg() db.commit() db.refresh(db_document) # 2. Document vectorization and storage total_chunks = len(res) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" - batch_size = 100 - total_batches = ceil(total_chunks / batch_size) - progress_per_batch = 0.2 / total_batches # Progress of each batch - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2.1 Delete document vector index - vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) - # 2.2 Vectorize and import batch documents - for batch_start in range(0, total_chunks, batch_size): - batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds - batch = res[batch_start: batch_end] # Retrieve the current batch - chunks = [] + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.") - # Process the current batch - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch # Calculate global index - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - if db_document.parser_config.get("auto_questions", 0): - topn = db_document.parser_config["auto_questions"] - cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", - {"topn": topn}) - if not cached: - cached = question_proposal(chat_model, item["content_with_weight"], topn) - set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", - {"topn": topn}) - chunks.append( - DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", - metadata=metadata)) - else: - chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + if total_chunks == 0: + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.") + else: + total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE) + progress_per_batch = 0.2 / total_batches # Progress of each batch + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + # 2.1 Delete document vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) + # 2.2 Vectorize and import batch documents + auto_questions_topn = db_document.parser_config.get("auto_questions", 0) + chat_model = None + if auto_questions_topn: + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE): + batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks) + batch = res[batch_start: batch_end] + chunks = [] - # Bulk segmented vector import - vector_service.add_chunks(chunks) + for idx_in_batch, item in enumerate(batch): + global_idx = batch_start + idx_in_batch + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + if auto_questions_topn: + cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", + {"topn": auto_questions_topn}) + if not cached: + cached = question_proposal(chat_model, item["content_with_weight"], auto_questions_topn) + set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", + {"topn": auto_questions_topn}) + chunks.append( + DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", + metadata=metadata)) + else: + chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) - # Update progress - db_document.progress += progress_per_batch - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" - db_document.progress_msg = progress_msg - db_document.process_duration = time.time() - start_time - db_document.run = 0 - db.commit() - db.refresh(db_document) + # Bulk segmented vector import + vector_service.add_chunks(chunks) + + # Update progress + db_document.progress += progress_per_batch + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).") + db_document.progress_msg = _progress_msg() + db_document.process_duration = time.time() - start_time + db_document.run = 0 + db.commit() + db.refresh(db_document) # Vectorization and data entry completed - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.") db_document.chunk_num = total_chunks db_document.progress = 1.0 db_document.process_duration = time.time() - start_time - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" - db_document.progress_msg = progress_msg + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).") + db_document.progress_msg = _progress_msg() db_document.run = 0 db.commit() - # using graphrag + # GraphRAG: 异步派发到独立队列,不阻塞文档解析流程 if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False): - graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) - with_resolution = graphrag_conf.get("resolution", False) - with_community = graphrag_conf.get("community", False) - - def callback(*args, msg=None, **kwargs): - nonlocal progress_msg - message = msg or (args[0] if args else "No message") - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n" - - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to run graphrag.\n" - start_time = time.time() - db_document.progress_msg = progress_msg + progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG enabled, dispatching async task.") + db_document.progress_msg = _progress_msg() db.commit() - db.refresh(db_document) - - task = { - "id": str(db_document.id), - "workspace_id": str(db_knowledge.workspace_id), - "kb_id": str(db_knowledge.id), - "parser_config": db_knowledge.parser_config, - } - - # init_graphrag - vts, _ = embedding_model.encode(["ok"]) - vector_size = len(vts[0]) - init_graphrag(task, vector_size) - - async def _run( - row: dict, - document_ids: list[str], - language: str, - parser_config: dict, - vector_service, - chat_model, - embedding_model, - callback, - with_resolution: bool = True, - with_community: bool = True - ) -> dict: - await trio.sleep(5) # Delay for 10 seconds - nonlocal progress_msg # Declare the use of an external progress_msg variable - result = await run_graphrag_for_kb( - row=row, - document_ids=document_ids, - language=language, - parser_config=parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n" - return result - - def sync_task(): - trio.run( - lambda: _run( - row=task, - document_ids=[str(db_document.id)], - language="Chinese", - parser_config=db_knowledge.parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) - ) - - try: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(sync_task) - future.result() # Blocks until the task completes - except Exception as e: - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n" - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)" - db_document.progress_msg = progress_msg - db.commit() - db.refresh(db_document) + build_graphrag_for_document.delay(str(document_id), str(db_knowledge.id)) result = f"parse document '{db_document.file_name}' processed successfully." + logger.info(f"[ParseDoc] document={document_id} file='{db_document.file_name}' done in {db_document.process_duration:.1f}s, chunks={total_chunks}") return result - except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" - db_document.run = 0 - db.commit() - result = f"parse document '{db_document.file_name}' failed." - return result - finally: - db.close() + except Exception as e: + logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True) + if db_document is not None: + try: + db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n" + db_document.run = 0 + db.commit() + except Exception: + logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True) + file_name = db_document.file_name if db_document else document_id + return f"parse document '{file_name}' failed." @celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb") @@ -410,51 +358,41 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): """ build knowledge graph """ - # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) import importlib import trio importlib.reload(trio) - db = next(get_db()) # Manually call the generator - db_documents = None - db_knowledge = None - try: - db_documents = db.query(Document).filter(Document.kb_id == kb_id).all() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() - # 1. Prepare to configure chat_mdl、embedding_model、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - # 2. get all document_ids from knowledge base - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - total, items = vector_service.search_by_segment(document_id=None, query=None, pagesize=9999, page=1, asc=True) - document_ids = [str(item.id) for item in db_documents] + with get_db_context() as db: + try: + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + if db_knowledge is None: + logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found") + return f"build knowledge graph failed: knowledge not found" + + if not (db_knowledge.parser_config and + db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False)): + return f"build knowledge graph '{db_knowledge.name}' skipped: graphrag not enabled" + + db_documents = db.query(Document).filter(Document.kb_id == kb_id).all() + document_ids = [str(doc.id) for doc in db_documents] + + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base, + ) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2. using graphrag - if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False): graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) with_resolution = graphrag_conf.get("resolution", False) with_community = graphrag_conf.get("community", False) - def callback(*args, msg=None, **kwargs): - message = msg or (args[0] if args else "No message") - print(f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n") - - start_time = time.time() task = { "id": str(db_knowledge.id), "workspace_id": str(db_knowledge.workspace_id), @@ -467,14 +405,18 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): vector_size = len(vts[0]) init_graphrag(task, vector_size) - async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service, - chat_model, embedding_model, callback, with_resolution: bool = True, - with_community: bool = True, ) -> dict: - result = await run_graphrag_for_kb( - row=row, + def callback(*args, msg=None, **kwargs): + message = msg or (args[0] if args else "No message") + logger.info(f"[GraphRAG-KB] kb={kb_id} msg: {message}") + + start_time = time.time() + + async def _run() -> dict: + return await run_graphrag_for_kb( + row=task, document_ids=document_ids, - language=language, - parser_config=parser_config, + language=DEFAULT_PARSE_LANGUAGE, + parser_config=db_knowledge.parser_config, vector_service=vector_service, chat_model=chat_model, embedding_model=embedding_model, @@ -482,46 +424,97 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with_resolution=with_resolution, with_community=with_community, ) - print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n") - return result - def sync_task(): - trio.run( - lambda: _run( - row=task, - document_ids=document_ids, - language="Chinese", - parser_config=db_knowledge.parser_config, - vector_service=vector_service, - chat_model=chat_model, - embedding_model=embedding_model, - callback=callback, - with_resolution=with_resolution, - with_community=with_community, - ) + result = trio.run(_run) + duration = time.time() - start_time + logger.info(f"[GraphRAG-KB] kb={kb_id} done in {duration:.1f}s, result: {result}") + + return f"build knowledge graph '{db_knowledge.name}' processed successfully." + except Exception as e: + logger.error(f"[GraphRAG-KB] kb={kb_id} failed: {e}", exc_info=True) + return f"build knowledge graph failed: {e}" + + +@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_document") +def build_graphrag_for_document(document_id: str, knowledge_id: str): + """ + 为单个文档构建 GraphRAG,由 parse_document 异步派发。 + """ + import importlib + + import trio + importlib.reload(trio) + + with get_db_context() as db: + try: + db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(knowledge_id)).first() + if db_document is None or db_knowledge is None: + logger.error(f"[GraphRAG] document={document_id} or knowledge={knowledge_id} not found") + return f"build_graphrag_for_document failed: record not found" + + graphrag_conf = db_knowledge.parser_config.get("graphrag", {}) + with_resolution = graphrag_conf.get("resolution", False) + with_community = graphrag_conf.get("community", False) + + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base, + ) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + task = { + "id": document_id, + "workspace_id": str(db_knowledge.workspace_id), + "kb_id": str(db_knowledge.id), + "parser_config": db_knowledge.parser_config, + } + + # init_graphrag + vts, _ = embedding_model.encode(["ok"]) + vector_size = len(vts[0]) + init_graphrag(task, vector_size) + + def callback(*args, msg=None, **kwargs): + message = msg or (args[0] if args else "No message") + logger.info(f"[GraphRAG] doc={document_id} msg: {message}") + + start_time = time.time() + + async def _run() -> dict: + await trio.sleep(5) + return await run_graphrag_for_kb( + row=task, + document_ids=[document_id], + language=DEFAULT_PARSE_LANGUAGE, + parser_config=db_knowledge.parser_config, + vector_service=vector_service, + chat_model=chat_model, + embedding_model=embedding_model, + callback=callback, + with_resolution=with_resolution, + with_community=with_community, ) - try: - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(sync_task) - future.result() # Blocks until the task completes - except Exception as e: - print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n") - finally: - if db: - db.close() - print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)") + result = trio.run(_run) + duration = time.time() - start_time + logger.info(f"[GraphRAG] doc={document_id} done in {duration:.1f}s") - result = f"build knowledge graph '{db_knowledge.name}' processed successfully." - return result - except Exception as e: - if 'db_knowledge' in locals(): - print(f"Failed to build knowledge grap:{str(e)}\n") - result = f"build knowledge grap '{db_knowledge.name}' failed." - return result - finally: - if db: - db.close() + # 更新文档进度信息 + db_document.progress_msg = (db_document.progress_msg or "") + \ + f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({duration:.1f}s)\n" + db.commit() + + return f"build_graphrag_for_document '{document_id}' processed successfully." + except Exception as e: + logger.error(f"[GraphRAG] doc={document_id} failed: {e}", exc_info=True) + return f"build_graphrag_for_document '{document_id}' failed: {e}" @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") @@ -529,10 +522,13 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): """ sync knowledge document and Document parsing, vectorization, and storage """ - db = next(get_db()) # Manually call the generator - db_knowledge = None - try: + with get_db_context() as db: + try: db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + if db_knowledge is None: + logger.error(f"[SyncKB] knowledge={kb_id} not found") + return f"sync knowledge failed: knowledge not found" + # 1. get vector_service vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) @@ -667,7 +663,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during crawl: {e}") + logger.error(f"[SyncKB] Error during crawl: {e}", exc_info=True) case "Third-party": # Integration of knowledge bases from three parties yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") @@ -685,13 +681,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): # Get all files from all repos async def async_get_files(api_client: YuqueAPIClient): async with api_client as client: - print("\n=== Fetching repositories ===") repos = await client.get_user_repos() - print(f"Found {len(repos)} repositories:") all_files = [] for repo in repos: - # Get documents from repository - print(f"\n=== Fetching documents from '{repo.name}' ===") docs = await client.get_repo_docs(repo.id) all_files.extend(docs) return all_files @@ -837,7 +829,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during fetch feishu: {e}") + logger.error(f"[SyncKB] Error during fetch yuque: {e}", exc_info=True) if feishu_app_id: # Feishu Knowledge Base feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") @@ -999,19 +991,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() except Exception as e: - print(f"\n\nError during fetch feishu: {e}") + logger.error(f"[SyncKB] Error during fetch feishu: {e}", exc_info=True) case _: # General - print(f"General: No synchronization needed\n") + logger.info(f"[SyncKB] kb={kb_id} type={db_knowledge.type}: no synchronization needed") result = f"sync knowledge '{db_knowledge.name}' processed successfully." return result - except Exception as e: - if 'db_knowledge' in locals(): - print(f"Failed to sync knowledge:{str(e)}\n") - result = f"sync knowledge '{db_knowledge.name}' failed." - return result - finally: - db.close() + except Exception as e: + logger.error(f"[SyncKB] kb={kb_id} failed: {e}", exc_info=True) + kb_name = db_knowledge.name if db_knowledge else kb_id + return f"sync knowledge '{kb_name}' failed: {e}" @celery_app.task(name="app.core.memory.agent.read_message", bind=True)