[add] batch chunk. qa_prompt set

This commit is contained in:
Mark
2026-04-28 15:33:44 +08:00
parent 140311048a
commit 64e640d882
5 changed files with 103 additions and 10 deletions

View File

@@ -285,6 +285,67 @@ async def create_chunk(
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
async def create_chunks_batch(
kb_id: uuid.UUID,
document_id: uuid.UUID,
batch_data: chunk_schema.ChunkBatchCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Batch create chunks (max 8)
"""
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Get current max sort_id
sort_id = 0
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
chunks = []
for create_data in batch_data.items:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(document_id),
"knowledge_id": str(kb_id),
"sort_id": sort_id,
"status": 1,
}
if create_data.is_qa:
metadata.update(create_data.qa_metadata)
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
vector_service.add_chunks(chunks)
db_document.chunk_num += len(chunks)
db.commit()
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def get_chunk(
kb_id: uuid.UUID,

View File

@@ -98,6 +98,7 @@ class Settings:
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))

View File

@@ -138,9 +138,19 @@ def question_proposal(chat_mdl, content, topn=3):
return "\n".join([p["question"] for p in pairs])
def qa_proposal(chat_mdl, content, topn=3):
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]"""
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
Args:
chat_mdl: LLM 模型
content: 文本内容
topn: 生成 QA 对数量
custom_prompt: 自定义 prompt 模板(支持 Jinja2可用变量: content, topn
"""
if custom_prompt:
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
else:
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]

View File

@@ -77,3 +77,8 @@ class ChunkRetrieve(BaseModel):
vector_similarity_weight: float | None = Field(None)
top_k: int | None = Field(None)
retrieve_type: RetrieveType | None = Field(None)
class ChunkBatchCreate(BaseModel):
"""批量创建 chunk"""
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")

View File

@@ -311,6 +311,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
# 2.2 Vectorize and import batch documents
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
qa_prompt = db_document.parser_config.get("qa_prompt", None)
chat_model = None
if auto_questions_topn:
chat_model = Base(
@@ -318,6 +319,9 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base,
)
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
if qa_prompt:
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
# 预先构建所有 batch 的 chunks保证 sort_id 全局有序
all_batch_chunks: list[list[DocumentChunk]] = []
@@ -335,15 +339,27 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
cached = get_llm_cache(chat_model.model_name, content, "qa",
{"topn": auto_questions_topn})
if not cached:
pairs = qa_proposal(chat_model, content, auto_questions_topn)
cached = pairs
set_llm_cache(chat_model.model_name, content, cached, "qa",
try:
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
except Exception as e:
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
return global_idx, []
# 缓存存 JSON 字符串
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
{"topn": auto_questions_topn})
elif isinstance(cached, str):
# 兼容旧缓存格式纯文本问题)
return global_idx, pairs
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
if isinstance(cached, str):
try:
parsed = json.loads(cached)
if isinstance(parsed, list):
return global_idx, parsed
except (json.JSONDecodeError, TypeError):
pass
# 旧缓存格式(纯文本问题),尝试解析
from app.core.rag.prompts.generator import parse_qa_pairs
cached = parse_qa_pairs(cached) if cached else []
return global_idx, cached
return global_idx, parse_qa_pairs(cached) if cached else []
return global_idx, cached if isinstance(cached, list) else []
# 并发调用 LLM 生成 QA 对
qa_map: dict[int, list] = {}