Merge branch 'feature/rag2' into develop
* feature/rag2: [modify] parse document workflow, add graph queue hand build graph [modify] mineru [modify] 优化tasks ,拆分graphirag 队列 # Conflicts: # api/app/tasks.py
This commit is contained in:
@@ -116,9 +116,12 @@ celery_app.conf.update(
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# GraphRAG tasks → graphrag_tasks queue (独立队列,避免阻塞文档解析)
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'graphrag_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_document': {'queue': 'graphrag_tasks'},
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
|
||||
644
api/app/tasks.py
644
api/app/tasks.py
@@ -45,6 +45,23 @@ from app.utils.redis_lock import RedisFairLock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ── 预编译文件类型正则 & 常量 ──────────────────────────────────
|
||||
AUDIO_PATTERN = re.compile(
|
||||
r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
VIDEO_IMAGE_PATTERN = re.compile(
|
||||
r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
DEFAULT_PARSE_LANGUAGE = "Chinese"
|
||||
DEFAULT_PARSE_TO_PAGE = 100_000
|
||||
EMBEDDING_BATCH_SIZE = int(os.getenv("EMBEDDING_BATCH_SIZE", "20"))
|
||||
# Embedding 并发写入的最大线程数,需根据模型 API rate limit 调整
|
||||
EMBEDDING_MAX_WORKERS = int(os.getenv("EMBEDDING_MAX_WORKERS", "3"))
|
||||
# auto_questions LLM 并发调用的最大线程数
|
||||
AUTO_QUESTIONS_MAX_WORKERS = int(os.getenv("AUTO_QUESTIONS_MAX_WORKERS", "5"))
|
||||
|
||||
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
|
||||
@@ -161,28 +178,67 @@ def process_item(item: dict):
|
||||
return result
|
||||
|
||||
|
||||
def _build_vision_model(file_path: str, db_knowledge):
|
||||
"""根据文件类型选择合适的视觉/音频模型,避免冗余初始化。"""
|
||||
if AUDIO_PATTERN.search(file_path):
|
||||
omni_key = os.getenv("QWEN3_OMNI_API_KEY", "")
|
||||
omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash")
|
||||
omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
return QWenSeq2txt(
|
||||
key=omni_key,
|
||||
model_name=omni_model,
|
||||
lang=DEFAULT_PARSE_LANGUAGE,
|
||||
base_url=omni_base,
|
||||
)
|
||||
if VIDEO_IMAGE_PATTERN.search(file_path):
|
||||
omni_key = os.getenv("QWEN3_OMNI_API_KEY", "")
|
||||
omni_model = os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash")
|
||||
omni_base = os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
return QWenCV(
|
||||
key=omni_key,
|
||||
model_name=omni_model,
|
||||
lang=DEFAULT_PARSE_LANGUAGE,
|
||||
base_url=omni_base,
|
||||
)
|
||||
# 默认:使用知识库配置的 image2text 模型
|
||||
return QWenCV(
|
||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||
lang=DEFAULT_PARSE_LANGUAGE,
|
||||
base_url=db_knowledge.image2text.api_keys[0].api_base,
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
||||
def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
"""
|
||||
Document parsing, vectorization, and storage
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import importlib
|
||||
|
||||
import trio
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_document = None
|
||||
db_knowledge = None
|
||||
progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n"
|
||||
try:
|
||||
progress_lines: list[str] = [f"{datetime.now().strftime('%H:%M:%S')} Task has been received."]
|
||||
|
||||
def _progress_msg() -> str:
|
||||
return "\n".join(progress_lines) + "\n"
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
|
||||
if not isinstance(document_id, uuid.UUID):
|
||||
document_id = uuid.UUID(str(document_id))
|
||||
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if db_document is None:
|
||||
raise ValueError(f"Document {document_id} not found")
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first()
|
||||
if db_knowledge is None:
|
||||
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
||||
|
||||
# 1. Document parsing & segmentation
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n"
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
||||
start_time = time.time()
|
||||
db_document.progress = 0.0
|
||||
db_document.progress_msg = progress_msg
|
||||
db_document.progress_msg = _progress_msg()
|
||||
db_document.process_begin_at = datetime.now(tz=timezone.utc)
|
||||
db_document.process_duration = 0.0
|
||||
db_document.run = 1
|
||||
@@ -190,220 +246,195 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
db.refresh(db_document)
|
||||
|
||||
def progress_callback(prog=None, msg=None):
|
||||
nonlocal progress_msg # Declare the use of an external progress_msg variable
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n"
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
||||
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
vision_model = QWenCV(
|
||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||
lang="Chinese",
|
||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||
)
|
||||
if re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", file_path,
|
||||
re.IGNORECASE):
|
||||
vision_model = QWenSeq2txt(
|
||||
key=os.getenv("QWEN3_OMNI_API_KEY", ""),
|
||||
model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"),
|
||||
lang="Chinese",
|
||||
base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
)
|
||||
elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", file_path,
|
||||
re.IGNORECASE):
|
||||
vision_model = QWenCV(
|
||||
key=os.getenv("QWEN3_OMNI_API_KEY", ""),
|
||||
model_name=os.getenv("QWEN3_OMNI_MODEL_NAME", "qwen3-omni-flash"),
|
||||
lang="Chinese",
|
||||
base_url=os.getenv("QWEN3_OMNI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
)
|
||||
else:
|
||||
print(file_path)
|
||||
# Prepare vision_model for parsing
|
||||
vision_model = _build_vision_model(file_path, db_knowledge)
|
||||
|
||||
from app.core.rag.app.naive import chunk
|
||||
res = chunk(filename=file_path,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
to_page=DEFAULT_PARSE_TO_PAGE,
|
||||
callback=progress_callback,
|
||||
vision_model=vision_model,
|
||||
parser_config=db_document.parser_config,
|
||||
is_root=False)
|
||||
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n"
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.")
|
||||
db_document.progress = 0.8
|
||||
db_document.progress_msg = progress_msg
|
||||
db_document.progress_msg = _progress_msg()
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
# 2. Document vectorization and storage
|
||||
total_chunks = len(res)
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n"
|
||||
batch_size = 100
|
||||
total_batches = ceil(total_chunks / batch_size)
|
||||
progress_per_batch = 0.2 / total_batches # Progress of each batch
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# 2.1 Delete document vector index
|
||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||
# 2.2 Vectorize and import batch documents
|
||||
for batch_start in range(0, total_chunks, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds
|
||||
batch = res[batch_start: batch_end] # Retrieve the current batch
|
||||
chunks = []
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.")
|
||||
|
||||
# Process the current batch
|
||||
for idx_in_batch, item in enumerate(batch):
|
||||
global_idx = batch_start + idx_in_batch # Calculate global index
|
||||
metadata = {
|
||||
"doc_id": uuid.uuid4().hex,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": global_idx,
|
||||
"status": 1,
|
||||
}
|
||||
if db_document.parser_config.get("auto_questions", 0):
|
||||
topn = db_document.parser_config["auto_questions"]
|
||||
cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question",
|
||||
{"topn": topn})
|
||||
if total_chunks == 0:
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} No chunks generated, skipping vectorization.")
|
||||
else:
|
||||
total_batches = ceil(total_chunks / EMBEDDING_BATCH_SIZE)
|
||||
progress_per_batch = 0.2 / total_batches
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# 2.1 Delete document vector index
|
||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||
# 2.2 Vectorize and import batch documents
|
||||
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
||||
chat_model = None
|
||||
if auto_questions_topn:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
)
|
||||
|
||||
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
||||
all_batch_chunks: list[list[DocumentChunk]] = []
|
||||
|
||||
if auto_questions_topn:
|
||||
# auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
|
||||
# 构建 (global_idx, item) 列表
|
||||
indexed_items = list(enumerate(res))
|
||||
|
||||
def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
|
||||
"""为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
|
||||
global_idx, item = idx_item
|
||||
content = item["content_with_weight"]
|
||||
cached = get_llm_cache(chat_model.model_name, content, "question",
|
||||
{"topn": auto_questions_topn})
|
||||
if not cached:
|
||||
cached = question_proposal(chat_model, item["content_with_weight"], topn)
|
||||
set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question",
|
||||
{"topn": topn})
|
||||
chunks.append(
|
||||
DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
||||
metadata=metadata))
|
||||
else:
|
||||
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
||||
cached = question_proposal(chat_model, content, auto_questions_topn)
|
||||
set_llm_cache(chat_model.model_name, content, cached, "question",
|
||||
{"topn": auto_questions_topn})
|
||||
return global_idx, cached
|
||||
|
||||
# Bulk segmented vector import
|
||||
vector_service.add_chunks(chunks)
|
||||
# 并发调用 LLM 生成问题
|
||||
question_map: dict[int, str] = {}
|
||||
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
||||
futures = {q_executor.submit(_generate_question, item): item[0]
|
||||
for item in indexed_items}
|
||||
for future in futures:
|
||||
global_idx, cached = future.result()
|
||||
question_map[global_idx] = cached
|
||||
|
||||
# Update progress
|
||||
db_document.progress += progress_per_batch
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n"
|
||||
db_document.progress_msg = progress_msg
|
||||
progress_lines.append(
|
||||
f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
|
||||
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
||||
|
||||
# 按 batch 分组组装 DocumentChunk
|
||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
||||
chunks = []
|
||||
for global_idx in range(batch_start, batch_end):
|
||||
item = res[global_idx]
|
||||
metadata = {
|
||||
"doc_id": uuid.uuid4().hex,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": global_idx,
|
||||
"status": 1,
|
||||
}
|
||||
cached = question_map[global_idx]
|
||||
chunks.append(
|
||||
DocumentChunk(
|
||||
page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
||||
metadata=metadata))
|
||||
all_batch_chunks.append(chunks)
|
||||
else:
|
||||
# 无 auto_questions:直接构建 chunks
|
||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
||||
chunks = []
|
||||
for global_idx in range(batch_start, batch_end):
|
||||
item = res[global_idx]
|
||||
metadata = {
|
||||
"doc_id": uuid.uuid4().hex,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": global_idx,
|
||||
"status": 1,
|
||||
}
|
||||
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
||||
all_batch_chunks.append(chunks)
|
||||
|
||||
# 并发提交 embedding + ES 写入,max_workers 控制模型 API 并发压力
|
||||
batch_errors: dict[int, Exception] = {}
|
||||
|
||||
def _embed_and_store(batch_idx: int, batch_chunks: list[DocumentChunk]):
|
||||
try:
|
||||
vector_service.add_chunks(batch_chunks)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[ParseDoc] batch {batch_idx} failed, retrying: {exc}")
|
||||
try:
|
||||
vector_service.add_chunks(batch_chunks)
|
||||
except Exception as retry_exc:
|
||||
logger.error(f"[ParseDoc] batch {batch_idx} retry failed: {retry_exc}", exc_info=True)
|
||||
batch_errors[batch_idx] = retry_exc
|
||||
|
||||
with ThreadPoolExecutor(max_workers=EMBEDDING_MAX_WORKERS) as executor:
|
||||
futures = {
|
||||
executor.submit(_embed_and_store, i, batch_chunks): i
|
||||
for i, batch_chunks in enumerate(all_batch_chunks)
|
||||
}
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# 如果有 batch 失败,汇总抛出
|
||||
if batch_errors:
|
||||
failed_detail = "; ".join(
|
||||
f"batch {i}: {type(err).__name__}: {err}"
|
||||
for i, err in sorted(batch_errors.items())
|
||||
)
|
||||
raise RuntimeError(f"Embedding failed for {len(batch_errors)}/{total_batches} batch(es). {failed_detail}")
|
||||
|
||||
# 所有 batch 完成后一次性更新进度
|
||||
db_document.progress = 0.8 + 0.2 # 直接到 1.0 前的状态
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} All {total_batches} batches embedded (workers={EMBEDDING_MAX_WORKERS}).")
|
||||
db_document.progress_msg = _progress_msg()
|
||||
db_document.process_duration = time.time() - start_time
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
# Vectorization and data entry completed
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n"
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Indexing done.")
|
||||
db_document.chunk_num = total_chunks
|
||||
db_document.progress = 1.0
|
||||
db_document.process_duration = time.time() - start_time
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n"
|
||||
db_document.progress_msg = progress_msg
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).")
|
||||
db_document.progress_msg = _progress_msg()
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
|
||||
# using graphrag
|
||||
# GraphRAG: 异步派发到独立队列,不阻塞文档解析流程
|
||||
if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||
graphrag_conf = db_knowledge.parser_config.get("graphrag", {})
|
||||
with_resolution = graphrag_conf.get("resolution", False)
|
||||
with_community = graphrag_conf.get("community", False)
|
||||
|
||||
def callback(*args, msg=None, **kwargs):
|
||||
nonlocal progress_msg
|
||||
message = msg or (args[0] if args else "No message")
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n"
|
||||
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to run graphrag.\n"
|
||||
start_time = time.time()
|
||||
db_document.progress_msg = progress_msg
|
||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG enabled, dispatching async task.")
|
||||
db_document.progress_msg = _progress_msg()
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
task = {
|
||||
"id": str(db_document.id),
|
||||
"workspace_id": str(db_knowledge.workspace_id),
|
||||
"kb_id": str(db_knowledge.id),
|
||||
"parser_config": db_knowledge.parser_config,
|
||||
}
|
||||
|
||||
# init_graphrag
|
||||
vts, _ = embedding_model.encode(["ok"])
|
||||
vector_size = len(vts[0])
|
||||
init_graphrag(task, vector_size)
|
||||
|
||||
async def _run(
|
||||
row: dict,
|
||||
document_ids: list[str],
|
||||
language: str,
|
||||
parser_config: dict,
|
||||
vector_service,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
with_resolution: bool = True,
|
||||
with_community: bool = True
|
||||
) -> dict:
|
||||
await trio.sleep(5) # Delay for 10 seconds
|
||||
nonlocal progress_msg # Declare the use of an external progress_msg variable
|
||||
result = await run_graphrag_for_kb(
|
||||
row=row,
|
||||
document_ids=document_ids,
|
||||
language=language,
|
||||
parser_config=parser_config,
|
||||
vector_service=vector_service,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
callback=callback,
|
||||
with_resolution=with_resolution,
|
||||
with_community=with_community,
|
||||
)
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
|
||||
return result
|
||||
|
||||
def sync_task():
|
||||
trio.run(
|
||||
lambda: _run(
|
||||
row=task,
|
||||
document_ids=[str(db_document.id)],
|
||||
language="Chinese",
|
||||
parser_config=db_knowledge.parser_config,
|
||||
vector_service=vector_service,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
callback=callback,
|
||||
with_resolution=with_resolution,
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
future.result() # Blocks until the task completes
|
||||
except Exception as e:
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
|
||||
db_document.progress_msg = progress_msg
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
build_graphrag_for_document.delay(str(document_id), str(db_knowledge.id))
|
||||
|
||||
result = f"parse document '{db_document.file_name}' processed successfully."
|
||||
logger.info(f"[ParseDoc] document={document_id} file='{db_document.file_name}' done in {db_document.process_duration:.1f}s, chunks={total_chunks}")
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'db_document' in locals():
|
||||
db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n"
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
result = f"parse document '{db_document.file_name}' failed."
|
||||
return result
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[ParseDoc] document={document_id} failed: {e}", exc_info=True)
|
||||
if db_document is not None:
|
||||
try:
|
||||
db.rollback()
|
||||
db_document.progress_msg = _progress_msg() + f"Failed to vectorize and import the parsed document:{str(e)}\n"
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning(f"[ParseDoc] document={document_id} failed to update error status in DB", exc_info=True)
|
||||
# db_document 可能处于 detached/expired 状态,用之前缓存的值或 document_id 兜底
|
||||
file_name = getattr(db_document, 'file_name', None) if db_document else None
|
||||
return f"parse document '{file_name or document_id}' failed."
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_kb")
|
||||
@@ -411,51 +442,44 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
"""
|
||||
build knowledge graph
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import importlib
|
||||
|
||||
import trio
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_documents = None
|
||||
db_knowledge = None
|
||||
try:
|
||||
db_documents = db.query(Document).filter(Document.kb_id == kb_id).all()
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||
# 1. Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
vision_model = QWenCV(
|
||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||
lang="Chinese",
|
||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||
)
|
||||
|
||||
# 2. get all document_ids from knowledge base
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.search_by_segment(document_id=None, query=None, pagesize=9999, page=1, asc=True)
|
||||
document_ids = [str(item.id) for item in db_documents]
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if not isinstance(kb_id, uuid.UUID):
|
||||
kb_id = uuid.UUID(str(kb_id))
|
||||
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||
if db_knowledge is None:
|
||||
logger.error(f"[GraphRAG-KB] knowledge={kb_id} not found")
|
||||
return f"build knowledge graph failed: knowledge not found"
|
||||
|
||||
if not (db_knowledge.parser_config and
|
||||
db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False)):
|
||||
return f"build knowledge graph '{db_knowledge.name}' skipped: graphrag not enabled"
|
||||
|
||||
db_documents = db.query(Document).filter(Document.kb_id == kb_id).all()
|
||||
document_ids = [str(doc.id) for doc in db_documents]
|
||||
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||
)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
# 2. using graphrag
|
||||
if db_knowledge.parser_config and db_knowledge.parser_config.get("graphrag", {}).get("use_graphrag", False):
|
||||
graphrag_conf = db_knowledge.parser_config.get("graphrag", {})
|
||||
with_resolution = graphrag_conf.get("resolution", False)
|
||||
with_community = graphrag_conf.get("community", False)
|
||||
|
||||
def callback(*args, msg=None, **kwargs):
|
||||
message = msg or (args[0] if args else "No message")
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} run graphrag msg: {message}.\n")
|
||||
|
||||
start_time = time.time()
|
||||
task = {
|
||||
"id": str(db_knowledge.id),
|
||||
"workspace_id": str(db_knowledge.workspace_id),
|
||||
@@ -468,14 +492,18 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
vector_size = len(vts[0])
|
||||
init_graphrag(task, vector_size)
|
||||
|
||||
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
|
||||
chat_model, embedding_model, callback, with_resolution: bool = True,
|
||||
with_community: bool = True, ) -> dict:
|
||||
result = await run_graphrag_for_kb(
|
||||
row=row,
|
||||
def callback(*args, msg=None, **kwargs):
|
||||
message = msg or (args[0] if args else "No message")
|
||||
logger.info(f"[GraphRAG-KB] kb={kb_id} msg: {message}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> dict:
|
||||
return await run_graphrag_for_kb(
|
||||
row=task,
|
||||
document_ids=document_ids,
|
||||
language=language,
|
||||
parser_config=parser_config,
|
||||
language=DEFAULT_PARSE_LANGUAGE,
|
||||
parser_config=db_knowledge.parser_config,
|
||||
vector_service=vector_service,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
@@ -483,46 +511,97 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
with_resolution=with_resolution,
|
||||
with_community=with_community,
|
||||
)
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n")
|
||||
return result
|
||||
|
||||
def sync_task():
|
||||
trio.run(
|
||||
lambda: _run(
|
||||
row=task,
|
||||
document_ids=document_ids,
|
||||
language="Chinese",
|
||||
parser_config=db_knowledge.parser_config,
|
||||
vector_service=vector_service,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
callback=callback,
|
||||
with_resolution=with_resolution,
|
||||
with_community=with_community,
|
||||
)
|
||||
result = trio.run(_run)
|
||||
duration = time.time() - start_time
|
||||
logger.info(f"[GraphRAG-KB] kb={kb_id} done in {duration:.1f}s, result: {result}")
|
||||
|
||||
return f"build knowledge graph '{db_knowledge.name}' processed successfully."
|
||||
except Exception as e:
|
||||
logger.error(f"[GraphRAG-KB] kb={kb_id} failed: {e}", exc_info=True)
|
||||
return f"build knowledge graph failed: {e}"
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.build_graphrag_for_document")
|
||||
def build_graphrag_for_document(document_id: str, knowledge_id: str):
|
||||
"""
|
||||
为单个文档构建 GraphRAG,由 parse_document 异步派发。
|
||||
"""
|
||||
import importlib
|
||||
|
||||
import trio
|
||||
importlib.reload(trio)
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(knowledge_id)).first()
|
||||
if db_document is None or db_knowledge is None:
|
||||
logger.error(f"[GraphRAG] document={document_id} or knowledge={knowledge_id} not found")
|
||||
return f"build_graphrag_for_document failed: record not found"
|
||||
|
||||
graphrag_conf = db_knowledge.parser_config.get("graphrag", {})
|
||||
with_resolution = graphrag_conf.get("resolution", False)
|
||||
with_community = graphrag_conf.get("community", False)
|
||||
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||
)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
task = {
|
||||
"id": document_id,
|
||||
"workspace_id": str(db_knowledge.workspace_id),
|
||||
"kb_id": str(db_knowledge.id),
|
||||
"parser_config": db_knowledge.parser_config,
|
||||
}
|
||||
|
||||
# init_graphrag
|
||||
vts, _ = embedding_model.encode(["ok"])
|
||||
vector_size = len(vts[0])
|
||||
init_graphrag(task, vector_size)
|
||||
|
||||
def callback(*args, msg=None, **kwargs):
|
||||
message = msg or (args[0] if args else "No message")
|
||||
logger.info(f"[GraphRAG] doc={document_id} msg: {message}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> dict:
|
||||
await trio.sleep(5)
|
||||
return await run_graphrag_for_kb(
|
||||
row=task,
|
||||
document_ids=[document_id],
|
||||
language=DEFAULT_PARSE_LANGUAGE,
|
||||
parser_config=db_knowledge.parser_config,
|
||||
vector_service=vector_service,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
callback=callback,
|
||||
with_resolution=with_resolution,
|
||||
with_community=with_community,
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
future.result() # Blocks until the task completes
|
||||
except Exception as e:
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n")
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)")
|
||||
result = trio.run(_run)
|
||||
duration = time.time() - start_time
|
||||
logger.info(f"[GraphRAG] doc={document_id} done in {duration:.1f}s")
|
||||
|
||||
result = f"build knowledge graph '{db_knowledge.name}' processed successfully."
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'db_knowledge' in locals():
|
||||
print(f"Failed to build knowledge grap:{str(e)}\n")
|
||||
result = f"build knowledge grap '{db_knowledge.name}' failed."
|
||||
return result
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
# 更新文档进度信息
|
||||
db_document.progress_msg = (db_document.progress_msg or "") + \
|
||||
f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({duration:.1f}s)\n"
|
||||
db.commit()
|
||||
|
||||
return f"build_graphrag_for_document '{document_id}' processed successfully."
|
||||
except Exception as e:
|
||||
logger.error(f"[GraphRAG] doc={document_id} failed: {e}", exc_info=True)
|
||||
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||
@@ -530,10 +609,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
"""
|
||||
sync knowledge document and Document parsing, vectorization, and storage
|
||||
"""
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_knowledge = None
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if not isinstance(kb_id, uuid.UUID):
|
||||
kb_id = uuid.UUID(str(kb_id))
|
||||
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first()
|
||||
if db_knowledge is None:
|
||||
logger.error(f"[SyncKB] knowledge={kb_id} not found")
|
||||
return f"sync knowledge failed: knowledge not found"
|
||||
|
||||
# 1. get vector_service
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
@@ -668,7 +753,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nError during crawl: {e}")
|
||||
logger.error(f"[SyncKB] Error during crawl: {e}", exc_info=True)
|
||||
case "Third-party": # Integration of knowledge bases from three parties
|
||||
yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "")
|
||||
feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "")
|
||||
@@ -686,13 +771,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
# Get all files from all repos
|
||||
async def async_get_files(api_client: YuqueAPIClient):
|
||||
async with api_client as client:
|
||||
print("\n=== Fetching repositories ===")
|
||||
repos = await client.get_user_repos()
|
||||
print(f"Found {len(repos)} repositories:")
|
||||
all_files = []
|
||||
for repo in repos:
|
||||
# Get documents from repository
|
||||
print(f"\n=== Fetching documents from '{repo.name}' ===")
|
||||
docs = await client.get_repo_docs(repo.id)
|
||||
all_files.extend(docs)
|
||||
return all_files
|
||||
@@ -838,7 +919,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nError during fetch feishu: {e}")
|
||||
logger.error(f"[SyncKB] Error during fetch yuque: {e}", exc_info=True)
|
||||
if feishu_app_id: # Feishu Knowledge Base
|
||||
feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "")
|
||||
feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "")
|
||||
@@ -1000,19 +1081,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
db.commit()
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nError during fetch feishu: {e}")
|
||||
logger.error(f"[SyncKB] Error during fetch feishu: {e}", exc_info=True)
|
||||
case _: # General
|
||||
print("General: No synchronization needed\n")
|
||||
logger.info(f"[SyncKB] kb={kb_id} type={db_knowledge.type}: no synchronization needed")
|
||||
|
||||
result = f"sync knowledge '{db_knowledge.name}' processed successfully."
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'db_knowledge' in locals():
|
||||
print(f"Failed to sync knowledge:{str(e)}\n")
|
||||
result = f"sync knowledge '{db_knowledge.name}' failed."
|
||||
return result
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[SyncKB] kb={kb_id} failed: {e}", exc_info=True)
|
||||
kb_name = db_knowledge.name if db_knowledge else kb_id
|
||||
return f"sync knowledge '{kb_name}' failed: {e}"
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
||||
|
||||
Reference in New Issue
Block a user