[modify] mineru
This commit is contained in:
@@ -292,9 +292,10 @@ class MinerUParser(RAGPdfParser):
|
|||||||
self.page_from = page_from
|
self.page_from = page_from
|
||||||
self.page_to = page_to
|
self.page_to = page_to
|
||||||
try:
|
try:
|
||||||
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
with sys.modules[LOCK_KEY_pdfplumber]: # ← 加这一行,获取全局锁
|
||||||
self.pdf = pdf
|
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
|
||||||
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
|
self.pdf = pdf
|
||||||
|
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.page_images = None
|
self.page_images = None
|
||||||
self.total_page = 0
|
self.total_page = 0
|
||||||
|
|||||||
161
api/app/tasks.py
161
api/app/tasks.py
@@ -56,6 +56,10 @@ VIDEO_IMAGE_PATTERN = re.compile(
|
|||||||
DEFAULT_PARSE_LANGUAGE = "Chinese"
|
DEFAULT_PARSE_LANGUAGE = "Chinese"
|
||||||
DEFAULT_PARSE_TO_PAGE = 100_000
|
DEFAULT_PARSE_TO_PAGE = 100_000
|
||||||
EMBEDDING_BATCH_SIZE = 100
|
EMBEDDING_BATCH_SIZE = 100
|
||||||
|
# Embedding 并发写入的最大线程数,需根据模型 API rate limit 调整
|
||||||
|
EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", "3"))
|
||||||
|
# auto_questions LLM 并发调用的最大线程数
|
||||||
|
AUTO_QUESTIONS_MAX_WORKERS = int(os.getenv("AUTO_QUESTIONS_MAX_WORKERS", "5"))
|
||||||
|
|
||||||
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||||
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||||
@@ -218,6 +222,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
|
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
|
||||||
|
if not isinstance(document_id, uuid.UUID):
|
||||||
|
document_id = uuid.UUID(str(document_id))
|
||||||
|
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
if db_document is None:
|
if db_document is None:
|
||||||
raise ValueError(f"Document {document_id} not found")
|
raise ValueError(f"Document {document_id} not found")
|
||||||
@@ -265,7 +273,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.")
|
||||||
else:
|
else:
|
||||||
total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE)
|
total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE)
|
||||||
progress_per_batch = 0.2 / total_batches # Progress of each batch
|
progress_per_batch = 0.2 / total_batches
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
# 2.1 Delete document vector index
|
# 2.1 Delete document vector index
|
||||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||||
@@ -278,47 +286,114 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||||
)
|
)
|
||||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
|
||||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
|
||||||
batch = res[batch_start: batch_end]
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
for idx_in_batch, item in enumerate(batch):
|
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
||||||
global_idx = batch_start + idx_in_batch
|
all_batch_chunks: list[list[DocumentChunk]] = []
|
||||||
metadata = {
|
|
||||||
"doc_id": uuid.uuid4().hex,
|
if auto_questions_topn:
|
||||||
"file_id": str(db_document.file_id),
|
# auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
|
||||||
"file_name": db_document.file_name,
|
# 构建 (global_idx, item) 列表
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
indexed_items = list(enumerate(res))
|
||||||
"document_id": str(db_document.id),
|
|
||||||
"knowledge_id": str(db_document.kb_id),
|
def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
|
||||||
"sort_id": global_idx,
|
"""为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
|
||||||
"status": 1,
|
global_idx, item = idx_item
|
||||||
}
|
content = item["content_with_weight"]
|
||||||
if auto_questions_topn:
|
cached = get_llm_cache(chat_model.model_name, content, "question",
|
||||||
cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question",
|
{"topn": auto_questions_topn})
|
||||||
{"topn": auto_questions_topn})
|
if not cached:
|
||||||
if not cached:
|
cached = question_proposal(chat_model, content, auto_questions_topn)
|
||||||
cached = question_proposal(chat_model, item["content_with_weight"], auto_questions_topn)
|
set_llm_cache(chat_model.model_name, content, cached, "question",
|
||||||
set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question",
|
{"topn": auto_questions_topn})
|
||||||
{"topn": auto_questions_topn})
|
return global_idx, cached
|
||||||
|
|
||||||
|
# 并发调用 LLM 生成问题
|
||||||
|
question_map: dict[int, str] = {}
|
||||||
|
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
||||||
|
futures = {q_executor.submit(_generate_question, item): item[0]
|
||||||
|
for item in indexed_items}
|
||||||
|
for future in futures:
|
||||||
|
global_idx, cached = future.result()
|
||||||
|
question_map[global_idx] = cached
|
||||||
|
|
||||||
|
progress_lines.append(
|
||||||
|
f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
|
||||||
|
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
||||||
|
|
||||||
|
# 按 batch 分组组装 DocumentChunk
|
||||||
|
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||||
|
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
||||||
|
chunks = []
|
||||||
|
for global_idx in range(batch_start, batch_end):
|
||||||
|
item = res[global_idx]
|
||||||
|
metadata = {
|
||||||
|
"doc_id": uuid.uuid4().hex,
|
||||||
|
"file_id": str(db_document.file_id),
|
||||||
|
"file_name": db_document.file_name,
|
||||||
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
|
"document_id": str(db_document.id),
|
||||||
|
"knowledge_id": str(db_document.kb_id),
|
||||||
|
"sort_id": global_idx,
|
||||||
|
"status": 1,
|
||||||
|
}
|
||||||
|
cached = question_map[global_idx]
|
||||||
chunks.append(
|
chunks.append(
|
||||||
DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
DocumentChunk(
|
||||||
metadata=metadata))
|
page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
||||||
else:
|
metadata=metadata))
|
||||||
|
all_batch_chunks.append(chunks)
|
||||||
|
else:
|
||||||
|
# 无 auto_questions:直接构建 chunks
|
||||||
|
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||||
|
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
||||||
|
chunks = []
|
||||||
|
for global_idx in range(batch_start, batch_end):
|
||||||
|
item = res[global_idx]
|
||||||
|
metadata = {
|
||||||
|
"doc_id": uuid.uuid4().hex,
|
||||||
|
"file_id": str(db_document.file_id),
|
||||||
|
"file_name": db_document.file_name,
|
||||||
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
|
"document_id": str(db_document.id),
|
||||||
|
"knowledge_id": str(db_document.kb_id),
|
||||||
|
"sort_id": global_idx,
|
||||||
|
"status": 1,
|
||||||
|
}
|
||||||
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
||||||
|
all_batch_chunks.append(chunks)
|
||||||
|
|
||||||
# Bulk segmented vector import
|
# 并发提交 embedding + ES 写入,max_workers 控制模型 API 并发压力
|
||||||
vector_service.add_chunks(chunks)
|
batch_errors: dict[int, Exception] = {}
|
||||||
|
|
||||||
# Update progress
|
def _embed_and_store(batch_idx: int, batch_chunks: list[DocumentChunk]):
|
||||||
db_document.progress += progress_per_batch
|
try:
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).")
|
vector_service.add_chunks(batch_chunks)
|
||||||
db_document.progress_msg = _progress_msg()
|
except Exception as exc:
|
||||||
db_document.process_duration = time.time() - start_time
|
batch_errors[batch_idx] = exc
|
||||||
db_document.run = 0
|
|
||||||
db.commit()
|
with ThreadPoolExecutor(max_workers=EMBEDDING_MAX_WORKERS) as executor:
|
||||||
db.refresh(db_document)
|
futures = {
|
||||||
|
executor.submit(_embed_and_store, i, batch_chunks): i
|
||||||
|
for i, batch_chunks in enumerate(all_batch_chunks)
|
||||||
|
}
|
||||||
|
# 按提交顺序收集结果,逐批更新进度
|
||||||
|
for future in futures:
|
||||||
|
future.result() # 等待完成(异常已在 _embed_and_store 内捕获)
|
||||||
|
|
||||||
|
# 如果有 batch 失败,汇总抛出
|
||||||
|
if batch_errors:
|
||||||
|
failed = ", ".join(str(i) for i in sorted(batch_errors))
|
||||||
|
first_err = next(iter(batch_errors.values()))
|
||||||
|
raise RuntimeError(f"Embedding failed for batch(es) [{failed}]: {first_err}") from first_err
|
||||||
|
|
||||||
|
# 所有 batch 完成后一次性更新进度
|
||||||
|
db_document.progress = 0.8 + 0.2 # 直接到 1.0 前的状态
|
||||||
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} All {total_batches} batches embedded (workers={EMBEDDING_MAX_WORKERS}).")
|
||||||
|
db_document.progress_msg = _progress_msg()
|
||||||
|
db_document.process_duration = time.time() - start_time
|
||||||
|
db_document.run = 0
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_document)
|
||||||
|
|
||||||
# Vectorization and data entry completed
|
# Vectorization and data entry completed
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.")
|
||||||
@@ -344,13 +419,15 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True)
|
logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True)
|
||||||
if db_document is not None:
|
if db_document is not None:
|
||||||
try:
|
try:
|
||||||
|
db.rollback()
|
||||||
db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n"
|
db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n"
|
||||||
db_document.run = 0
|
db_document.run = 0
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True)
|
logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True)
|
||||||
file_name = db_document.file_name if db_document else document_id
|
# db_document 可能处于 detached/expired 状态,用之前缓存的值或 document_id 兜底
|
||||||
return f"parse document '{file_name}' failed."
|
file_name = getattr(db_document, 'file_name', None) if db_document else None
|
||||||
|
return f"parse document '{file_name or document_id}' failed."
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb")
|
@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb")
|
||||||
@@ -365,6 +442,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
|||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(kb_id, uuid.UUID):
|
||||||
|
kb_id = uuid.UUID(str(kb_id))
|
||||||
|
|
||||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||||
if db_knowledge is None:
|
if db_knowledge is None:
|
||||||
logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found")
|
logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found")
|
||||||
@@ -524,6 +604,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
"""
|
"""
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
|
if not isinstance(kb_id, uuid.UUID):
|
||||||
|
kb_id = uuid.UUID(str(kb_id))
|
||||||
|
|
||||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||||
if db_knowledge is None:
|
if db_knowledge is None:
|
||||||
logger.error(f"[SyncKB] knowledge={kb_id} not found")
|
logger.error(f"[SyncKB] knowledge={kb_id} not found")
|
||||||
|
|||||||
Reference in New Issue
Block a user