[add] batch chunk. qa_prompt set
This commit is contained in:
@@ -285,6 +285,67 @@ async def create_chunk(
|
||||
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
||||
async def create_chunks_batch(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
batch_data: chunk_schema.ChunkBatchCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Batch create chunks (max 8)
|
||||
"""
|
||||
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
|
||||
|
||||
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
|
||||
)
|
||||
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
|
||||
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if not db_document:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
# Get current max sort_id
|
||||
sort_id = 0
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
||||
if items:
|
||||
sort_id = items[0].metadata["sort_id"]
|
||||
|
||||
chunks = []
|
||||
for create_data in batch_data.items:
|
||||
sort_id += 1
|
||||
doc_id = uuid.uuid4().hex
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(document_id),
|
||||
"knowledge_id": str(kb_id),
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
}
|
||||
if create_data.is_qa:
|
||||
metadata.update(create_data.qa_metadata)
|
||||
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
|
||||
|
||||
vector_service.add_chunks(chunks)
|
||||
|
||||
db_document.chunk_num += len(chunks)
|
||||
db.commit()
|
||||
|
||||
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
async def get_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
|
||||
@@ -98,6 +98,7 @@ class Settings:
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
|
||||
@@ -138,9 +138,19 @@ def question_proposal(chat_mdl, content, topn=3):
|
||||
return "\n".join([p["question"] for p in pairs])
|
||||
|
||||
|
||||
def qa_proposal(chat_mdl, content, topn=3):
|
||||
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]"""
|
||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
|
||||
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
|
||||
|
||||
Args:
|
||||
chat_mdl: LLM 模型
|
||||
content: 文本内容
|
||||
topn: 生成 QA 对数量
|
||||
custom_prompt: 自定义 prompt 模板(支持 Jinja2,可用变量: content, topn)
|
||||
"""
|
||||
if custom_prompt:
|
||||
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
|
||||
else:
|
||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
|
||||
@@ -77,3 +77,8 @@ class ChunkRetrieve(BaseModel):
|
||||
vector_similarity_weight: float | None = Field(None)
|
||||
top_k: int | None = Field(None)
|
||||
retrieve_type: RetrieveType | None = Field(None)
|
||||
|
||||
|
||||
class ChunkBatchCreate(BaseModel):
|
||||
"""批量创建 chunk"""
|
||||
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")
|
||||
|
||||
@@ -311,6 +311,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||
# 2.2 Vectorize and import batch documents
|
||||
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
||||
qa_prompt = db_document.parser_config.get("qa_prompt", None)
|
||||
chat_model = None
|
||||
if auto_questions_topn:
|
||||
chat_model = Base(
|
||||
@@ -318,6 +319,9 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
)
|
||||
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
|
||||
if qa_prompt:
|
||||
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
|
||||
|
||||
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
||||
all_batch_chunks: list[list[DocumentChunk]] = []
|
||||
@@ -335,15 +339,27 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
||||
cached = get_llm_cache(chat_model.model_name, content, "qa",
|
||||
{"topn": auto_questions_topn})
|
||||
if not cached:
|
||||
pairs = qa_proposal(chat_model, content, auto_questions_topn)
|
||||
cached = pairs
|
||||
set_llm_cache(chat_model.model_name, content, cached, "qa",
|
||||
try:
|
||||
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
|
||||
return global_idx, []
|
||||
# 缓存存 JSON 字符串
|
||||
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
|
||||
{"topn": auto_questions_topn})
|
||||
elif isinstance(cached, str):
|
||||
# 兼容旧缓存格式(纯文本问题)
|
||||
return global_idx, pairs
|
||||
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
|
||||
if isinstance(cached, str):
|
||||
try:
|
||||
parsed = json.loads(cached)
|
||||
if isinstance(parsed, list):
|
||||
return global_idx, parsed
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# 旧缓存格式(纯文本问题),尝试解析
|
||||
from app.core.rag.prompts.generator import parse_qa_pairs
|
||||
cached = parse_qa_pairs(cached) if cached else []
|
||||
return global_idx, cached
|
||||
return global_idx, parse_qa_pairs(cached) if cached else []
|
||||
return global_idx, cached if isinstance(cached, list) else []
|
||||
|
||||
# 并发调用 LLM 生成 QA 对
|
||||
qa_map: dict[int, list] = {}
|
||||
|
||||
Reference in New Issue
Block a user