[add] batch chunk. qa_prompt set

This commit is contained in:
Mark
2026-04-28 15:33:44 +08:00
parent 140311048a
commit 64e640d882
5 changed files with 103 additions and 10 deletions

View File

@@ -311,6 +311,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
# 2.2 Vectorize and import batch documents
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
qa_prompt = db_document.parser_config.get("qa_prompt", None)
chat_model = None
if auto_questions_topn:
chat_model = Base(
@@ -318,6 +319,9 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base,
)
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
if qa_prompt:
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
# 预先构建所有 batch 的 chunks保证 sort_id 全局有序
all_batch_chunks: list[list[DocumentChunk]] = []
@@ -335,15 +339,27 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
cached = get_llm_cache(chat_model.model_name, content, "qa",
{"topn": auto_questions_topn})
if not cached:
pairs = qa_proposal(chat_model, content, auto_questions_topn)
cached = pairs
set_llm_cache(chat_model.model_name, content, cached, "qa",
try:
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
except Exception as e:
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
return global_idx, []
# 缓存存 JSON 字符串
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
{"topn": auto_questions_topn})
elif isinstance(cached, str):
# 兼容旧缓存格式纯文本问题)
return global_idx, pairs
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
if isinstance(cached, str):
try:
parsed = json.loads(cached)
if isinstance(parsed, list):
return global_idx, parsed
except (json.JSONDecodeError, TypeError):
pass
# 旧缓存格式(纯文本问题),尝试解析
from app.core.rag.prompts.generator import parse_qa_pairs
cached = parse_qa_pairs(cached) if cached else []
return global_idx, cached
return global_idx, parse_qa_pairs(cached) if cached else []
return global_idx, cached if isinstance(cached, list) else []
# 并发调用 LLM 生成 QA 对
qa_map: dict[int, list] = {}