Compare commits

...

13 Commits

Author SHA1 Message Date
Mark
f8d1ed51a7 [fix] system prompt fit error 2026-05-07 19:37:34 +08:00
Mark
9fa83ed01e [modify] QA pair 2026-05-07 19:04:19 +08:00
Mark
e222490bce [add] batch add chunk for v1 2026-05-07 18:45:36 +08:00
Mark
ad2e885f72 [fix] index_not_found_exception 2026-05-06 18:34:07 +08:00
Mark
70c6d161c8 [fix] delete chunk refresh index 2026-05-06 15:19:46 +08:00
Mark
f85c0594c9 [fix] es vector 2026-04-29 15:24:25 +08:00
Mark
5fceba54b4 [fix] file upload 2026-04-29 13:41:14 +08:00
Mark
6e89302cb2 no message 2026-04-29 11:44:03 +08:00
Mark
90aa4cef21 [add] import qa chunks 2026-04-28 16:38:14 +08:00
Mark
6c47bb77ab [add] task log 2026-04-28 16:13:26 +08:00
Mark
f667936664 [fix] qa cache 2026-04-28 15:53:07 +08:00
Mark
64e640d882 [add] batch chunk. qa_prompt set 2026-04-28 15:33:44 +08:00
Mark
140311048a [modify] rag qa chunk 2026-04-28 14:04:36 +08:00
11 changed files with 702 additions and 141 deletions

View File

@@ -1,8 +1,10 @@
import os
import csv
import io
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session
@@ -23,6 +25,7 @@ from app.models.user_model import User
from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
from app.services.model_service import ModelApiKeyService
# Obtain a dedicated API logger
@@ -271,6 +274,9 @@ async def create_chunk(
"sort_id": sort_id,
"status": 1,
}
# QA chunk: 注入 chunk_type/question/answer 到 metadata
if create_data.is_qa:
metadata.update(create_data.qa_metadata)
chunk = DocumentChunk(page_content=content, metadata=metadata)
# 3. Segmented vector storage
vector_service.add_chunks([chunk])
@@ -282,6 +288,187 @@ async def create_chunk(
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
async def create_chunks_batch(
kb_id: uuid.UUID,
document_id: uuid.UUID,
batch_data: chunk_schema.ChunkBatchCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Batch create chunks (max 8)
"""
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Get current max sort_id
sort_id = 0
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
chunks = []
for create_data in batch_data.items:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(document_id),
"knowledge_id": str(kb_id),
"sort_id": sort_id,
"status": 1,
}
if create_data.is_qa:
metadata.update(create_data.qa_metadata)
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
vector_service.add_chunks(chunks)
db_document.chunk_num += len(chunks)
db.commit()
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
async def import_qa_new_doc(
kb_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
导入 QA 问答对并新建文档CSV/Excel异步处理
"""
from app.schemas import file_schema, document_schema
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
filename = file.filename or ""
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
# 2. 校验知识库
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
# 3. 读取文件
contents = await file.read()
file_size = len(contents)
if file_size == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
_, file_extension = os.path.splitext(filename)
file_ext = file_extension.lower()
# 4. 创建 File 记录
file_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id,
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
file_name=filename, file_ext=file_ext, file_size=file_size,
)
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
# 5. 上传文件到存储后端
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
try:
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
db_file.file_key = file_key
db.commit()
db.refresh(db_file)
# 6. 创建 Document 记录(标记为 QA 类型)
doc_data = document_schema.DocumentCreate(
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
file_name=filename, file_ext=file_ext, file_size=file_size,
file_meta={}, parser_id="qa",
parser_config={"doc_type": "qa", "auto_questions": 0}
)
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
# 7. 派发异步任务
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",
args=[str(kb_id), str(db_document.id), filename, contents],
queue="qa_import"
)
return success(data={
"task_id": task.id,
"document_id": str(db_document.id),
"file_id": str(db_file.id),
}, msg="QA 导入任务已提交,后台处理中")
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
async def import_qa_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
导入 QA 问答对CSV/Excel异步处理
"""
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
# 1. 校验文件格式
filename = file.filename or ""
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
# 2. 校验知识库和文档
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
# 3. 读取文件内容,派发异步任务
contents = await file.read()
from app.celery_app import celery_app
task = celery_app.send_task(
"app.core.rag.tasks.import_qa_chunks",
args=[str(kb_id), str(document_id), filename, contents],
queue="qa_import"
)
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def get_chunk(
kb_id: uuid.UUID,
@@ -342,6 +529,9 @@ async def update_chunk(
if total:
chunk = items[0]
chunk.page_content = content
# QA chunk: 更新 metadata 中的 question/answer
if update_data.is_qa:
chunk.metadata.update(update_data.qa_metadata)
vector_service.update_by_segment(chunk)
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
else:
@@ -356,6 +546,7 @@ async def delete_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -373,7 +564,7 @@ async def delete_chunk(
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
if vector_service.text_exists(doc_id):
vector_service.delete_by_ids([doc_id])
vector_service.delete_by_ids([doc_id], refresh=force_refresh)
# 更新 chunk_num
db_document = db.query(Document).filter(Document.id == document_id).first()
db_document.chunk_num -= 1

View File

@@ -113,6 +113,33 @@ async def create_chunk(
current_user=current_user)
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def create_chunks_batch(
kb_id: uuid.UUID,
document_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
items: list = Body(..., description="chunk items list"),
):
"""
Batch create chunks (max 8)
"""
body = await request.json()
batch_data = chunk_schema.ChunkBatchCreate(**body)
# 0. Obtain the creator of the api key
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await chunk_controller.create_chunks_batch(kb_id=kb_id,
document_id=document_id,
batch_data=batch_data,
db=db,
current_user=current_user)
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def get_chunk(
@@ -176,6 +203,7 @@ async def delete_chunk(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
):
"""
delete document chunk
@@ -188,6 +216,7 @@ async def delete_chunk(
return await chunk_controller.delete_chunk(kb_id=kb_id,
document_id=document_id,
doc_id=doc_id,
force_refresh=force_refresh,
db=db,
current_user=current_user)

View File

@@ -98,6 +98,7 @@ class Settings:
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))

View File

@@ -46,7 +46,10 @@ async def run_graphrag(
start = trio.current_time()
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
chunks = []
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True):
# 跳过 QA chunks只用原文 chunks 构建图谱
if d.get("chunk_type") == "qa":
continue
chunks.append(d["page_content"])
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
@@ -150,6 +153,9 @@ async def run_graphrag_for_kb(
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
for doc in items:
# 跳过 QA chunks只用原文 chunks 构建图谱
if (doc.metadata or {}).get("chunk_type") == "qa":
continue
content = doc.page_content
if num_tokens_from_string(current_chunk + content) < 1024:
current_chunk += content

View File

@@ -131,18 +131,52 @@ def keyword_extraction(chat_mdl, content, topn=3):
def question_proposal(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
"""生成问题(向后兼容,返回纯文本问题列表)"""
pairs = qa_proposal(chat_mdl, content, topn)
if not pairs:
return ""
return kwd
return "\n".join([p["question"] for p in pairs])
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
Args:
chat_mdl: LLM 模型
content: 文本内容
topn: 生成 QA 对数量
custom_prompt: 自定义 prompt 模板(支持 Jinja2可用变量: content, topn
"""
if custom_prompt:
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
sys_prompt = template.render(topn=topn)
else:
sys_prompt = QUESTION_PROMPT_TEMPLATE
msg = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": content}]
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
raw = chat_mdl.chat(sys_prompt, msg[1:], {"temperature": 0.2})
if isinstance(raw, tuple):
raw = raw[0]
raw = re.sub(r"^.*</think>", "", raw, flags=re.DOTALL)
if raw.find("**ERROR**") >= 0:
return []
return parse_qa_pairs(raw)
def parse_qa_pairs(text: str) -> list:
"""解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx"""
pairs = []
for line in text.strip().split("\n"):
line = line.strip()
if not line:
continue
# 匹配 Q: ... A: ... 格式
match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE)
if match:
q, a = match.group(1).strip(), match.group(2).strip()
if q and a:
pairs.append({"question": q, "answer": a})
return pairs
def graph_entity_types(chat_mdl, scenario):

View File

@@ -1,19 +1,20 @@
## Role
You are a text analyzer.
You are a text analyzer and knowledge extraction expert.
## Task
Propose {{ topn }} questions about a given piece of text content.
Generate question-answer pairs from the given text content.
## Requirements
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
- Understand and summarize the text content, then generate up to {{ topn }} important question-answer pairs.
- Each question-answer pair MUST be on a single line, formatted as: Q: <question> A: <answer>
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
- The questions MUST be in the same language as the given piece of text content.
- One question per line.
- Output questions ONLY.
---
## Text Content
{{ content }}
- The answers MUST be concise, accurate, and directly derived from the text content.
- The answers SHOULD be self-contained and understandable without additional context.
- Both questions and answers MUST be in the same language as the given text content.
- If the text is too short or lacks substantive content, generate fewer pairs rather than padding.
- Output question-answer pairs ONLY, no extra explanation or commentary.
## Example Output
Q: What is the capital of France? A: The capital of France is Paris.
Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889.

View File

@@ -5,7 +5,7 @@ from typing import Any
from urllib.parse import urlparse
import requests
from elasticsearch import Elasticsearch, helpers
from elasticsearch import Elasticsearch, helpers, NotFoundError
from elasticsearch.helpers import BulkIndexError
from packaging.version import parse as parse_version
# langchain-community
@@ -53,13 +53,30 @@ class ElasticSearchVector(BaseVector):
return "elasticsearch"
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# 实现 Elasticsearch 保存向量
texts = [chunk.page_content for chunk in chunks]
# QA chunks: embedding 只对 question 字段做source chunks: 不做 embedding
texts_for_embedding = []
for chunk in chunks:
chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk")
if chunk_type == "source":
# source chunk 不需要向量索引
texts_for_embedding.append("")
elif chunk_type == "qa":
# QA chunk: 用 question 字段做 embedding
texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content))
else:
# 普通 chunk: 用 page_content 做 embedding
texts_for_embedding.append(chunk.page_content)
if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
embeddings = self.embeddings.embed_batch(texts)
embeddings = self.embeddings.embed_batch(texts_for_embedding)
else:
embeddings = self.embeddings.embed_documents(list(texts))
embeddings = self.embeddings.embed_documents(texts_for_embedding)
# source chunk 的向量置空
for i, chunk in enumerate(chunks):
if (chunk.metadata or {}).get("chunk_type") == "source":
embeddings[i] = None
self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -72,13 +89,25 @@ class ElasticSearchVector(BaseVector):
uuids = self._get_uuids(chunks)
actions = []
for i, chunk in enumerate(chunks):
source = {
Field.CONTENT_KEY.value: chunk.page_content,
Field.METADATA_KEY.value: chunk.metadata or {},
Field.VECTOR.value: embeddings[i] or None
}
# 写入 QA 相关字段
meta = chunk.metadata or {}
if meta.get("chunk_type"):
source[Field.CHUNK_TYPE.value] = meta["chunk_type"]
if meta.get("question"):
source[Field.QUESTION.value] = meta["question"]
if meta.get("answer"):
source[Field.ANSWER.value] = meta["answer"]
if meta.get("source_chunk_id"):
source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"]
action = {
"_index": self._collection_name,
"_source": {
Field.CONTENT_KEY.value: chunk.page_content,
Field.METADATA_KEY.value: chunk.metadata or {},
Field.VECTOR.value: embeddings[i] or None
}
"_source": source
}
actions.append(action)
# using bulk mode
@@ -113,7 +142,7 @@ class ElasticSearchVector(BaseVector):
return True
def delete_by_ids(self, ids: list[str]):
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
if not ids:
return
if not self._client.indices.exists(index=self._collection_name):
@@ -134,6 +163,8 @@ class ElasticSearchVector(BaseVector):
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
try:
helpers.bulk(self._client, actions)
if refresh:
self._client.indices.refresh(index=self._collection_name)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get('delete', {})
@@ -153,7 +184,7 @@ class ElasticSearchVector(BaseVector):
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
if not self._client.indices.exists(index=self._collection_name):
return False
actual_ids = self.get_ids_by_metadata_field(key, value)
@@ -162,6 +193,8 @@ class ElasticSearchVector(BaseVector):
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
try:
helpers.bulk(self._client, actions)
if refresh:
self._client.indices.refresh(index=self._collection_name)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get('delete', {})
@@ -192,6 +225,8 @@ class ElasticSearchVector(BaseVector):
List of DocumentChunk objects that match the query.
"""
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
if not self._client.indices.exists(index=indices):
return 0, []
# Calculate the start position for the current page
from_ = pagesize * (page-1)
@@ -226,12 +261,15 @@ class ElasticSearchVector(BaseVector):
})
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
result = self._client.search(
index=indices,
from_=from_, # Only use from_ for the first page (simplified)
size=pagesize,
body=query_str,
)
try:
result = self._client.search(
index=indices,
from_=from_, # Only use from_ for the first page (simplified)
size=pagesize,
body=query_str,
)
except NotFoundError:
return 0, []
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
@@ -241,10 +279,19 @@ class ElasticSearchVector(BaseVector):
for res in result["hits"]["hits"]:
source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value)
# vector = source.get(Field.VECTOR.value)
vector = None
metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
score = res["_score"]
# 将 QA 字段注入 metadata 供前端展示
if chunk_type:
metadata["chunk_type"] = chunk_type
if chunk_type == "qa":
metadata["question"] = source.get(Field.QUESTION.value, "")
metadata["answer"] = source.get(Field.ANSWER.value, "")
page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}"
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
docs = []
@@ -267,13 +314,18 @@ class ElasticSearchVector(BaseVector):
List of DocumentChunk objects that match the query.
"""
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
if not self._client.indices.exists(index=indices):
return 0, []
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
result = self._client.search(
index=indices,
from_=0, # Only use from_ for the first page (simplified)
size=1,
body=query_str,
)
try:
result = self._client.search(
index=indices,
from_=0, # Only use from_ for the first page (simplified)
size=1,
body=query_str,
)
except NotFoundError:
return 0, []
# print(result)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
@@ -308,27 +360,43 @@ class ElasticSearchVector(BaseVector):
Returns:
updated count.
"""
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
chunk.vector = self.embeddings.embed_text(chunk.page_content)
indices = kwargs.get("indices", self._collection_name)
chunk_type = (chunk.metadata or {}).get("chunk_type")
# QA chunk: embedding 基于 questionsource chunk: 不更新向量
if chunk_type == "source":
embed_text = ""
elif chunk_type == "qa":
embed_text = (chunk.metadata or {}).get("question", chunk.page_content)
else:
chunk.vector = self.embeddings.embed_query(chunk.page_content)
embed_text = chunk.page_content
if chunk_type != "source":
if self.is_multimodal_embedding:
chunk.vector = self.embeddings.embed_text(embed_text)
else:
chunk.vector = self.embeddings.embed_query(embed_text)
script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;"
params = {
"new_content": chunk.page_content,
"new_vector": chunk.vector if chunk_type != "source" else None
}
# QA chunk: 同时更新 question/answer 字段
if chunk_type == "qa":
script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;"
params["new_question"] = (chunk.metadata or {}).get("question", "")
params["new_answer"] = (chunk.metadata or {}).get("answer", "")
body = {
"script": {
"source": """
ctx._source.page_content = params.new_content;
ctx._source.vector = params.new_vector;
""",
"params": {
"new_content": chunk.page_content,
"new_vector": chunk.vector
}
"source": script_source,
"params": params
},
"query": {
"term": {
Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id
Field.DOC_ID.value: chunk.metadata["doc_id"]
}
}
}
@@ -336,9 +404,6 @@ class ElasticSearchVector(BaseVector):
index=indices,
body=body,
)
# Remove debug printing and use logging instead
# print(result)
# print(f"Update successful, number of affected documents: {result['updated']}")
return result['updated']
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
@@ -397,11 +462,11 @@ class ElasticSearchVector(BaseVector):
}
}
},
"filter": { # Add the filter condition of status=1
"term": {
"metadata.status": 1
}
}
"filter": [
{"term": {"metadata.status": 1}},
# 排除 source chunk仅供 GraphRAG 使用,不参与检索)
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
]
}
}
# If file_names_filter is passed in, merge the filtering conditions
@@ -415,22 +480,14 @@ class ElasticSearchVector(BaseVector):
},
"script": {
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
# The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1]
"params": {"query_vector": query_vector}
}
}
},
"filter": [
{
"term": {
"metadata.status": 1
}
},
{
"terms": {
"metadata.file_name": file_names_filter # Additional file_name filtering
}
}
{"term": {"metadata.status": 1}},
{"terms": {"metadata.file_name": file_names_filter}},
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
],
}
}
@@ -451,8 +508,19 @@ class ElasticSearchVector(BaseVector):
source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value)
metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
score = res["_score"]
score = score / 2 # Normalized [0-1]
# QA chunk: 返回 Q+A 拼接作为上下文
if chunk_type == "qa":
question = source.get(Field.QUESTION.value, "")
answer = source.get(Field.ANSWER.value, "")
page_content = f"Q: {question}\nA: {answer}"
metadata["chunk_type"] = "qa"
metadata["question"] = question
metadata["answer"] = answer
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
docs = []
@@ -491,11 +559,10 @@ class ElasticSearchVector(BaseVector):
}
}
},
"filter": { # Add the filter condition of status=1
"term": {
"metadata.status": 1
}
}
"filter": [
{"term": {"metadata.status": 1}},
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
]
}
}
@@ -512,16 +579,9 @@ class ElasticSearchVector(BaseVector):
}
},
"filter": [
{
"term": {
"metadata.status": 1
}
},
{
"terms": {
"metadata.file_name": file_names_filter # Additional file_name filtering
}
}
{"term": {"metadata.status": 1}},
{"terms": {"metadata.file_name": file_names_filter}},
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
],
}
}
@@ -543,6 +603,17 @@ class ElasticSearchVector(BaseVector):
source = res["_source"]
page_content = source.get(Field.CONTENT_KEY.value)
metadata = source.get(Field.METADATA_KEY.value, {})
chunk_type = source.get(Field.CHUNK_TYPE.value)
# QA chunk: 返回 Q+A 拼接作为上下文
if chunk_type == "qa":
question = source.get(Field.QUESTION.value, "")
answer = source.get(Field.ANSWER.value, "")
page_content = f"Q: {question}\nA: {answer}"
metadata["chunk_type"] = "qa"
metadata["question"] = question
metadata["answer"] = answer
# Normalize the score to the [0,1] interval
normalized_score = res["_score"] / max_score
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
@@ -652,7 +723,7 @@ class ElasticSearchVector(BaseVector):
},
Field.VECTOR.value: {
"type": "dense_vector",
"dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency
"dims": len(next((e for e in embeddings if e is not None), [0]*768)), # 跳过 None 获取向量维度fallback 768
"index": True,
"similarity": "cosine"
}

View File

@@ -14,3 +14,8 @@ class Field(StrEnum):
DOCUMENT_ID = "metadata.document_id"
KNOWLEDGE_ID = "metadata.knowledge_id"
SORT_ID = "metadata.sort_id"
# QA fields
CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa"
QUESTION = "question"
ANSWER = "answer"
SOURCE_CHUNK_ID = "source_chunk_id"

View File

@@ -27,14 +27,14 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]):
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str):
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
raise NotImplementedError
@abstractmethod

View File

@@ -20,13 +20,26 @@ class ChunkCreate(BaseModel):
@property
def chunk_content(self) -> str:
"""
Get the actual content string regardless of input type
"""
"""Get the actual content string regardless of input type"""
if isinstance(self.content, QAChunk):
return f"question: {self.content.question} answer: {self.content.answer}"
return self.content.question # QA 模式下 page_content 存 question
return self.content
@property
def is_qa(self) -> bool:
return isinstance(self.content, QAChunk)
@property
def qa_metadata(self) -> dict:
"""返回 QA 相关的 metadata 字段"""
if isinstance(self.content, QAChunk):
return {
"chunk_type": "qa",
"question": self.content.question,
"answer": self.content.answer,
}
return {}
class ChunkUpdate(BaseModel):
content: Union[str, QAChunk] = Field(
@@ -35,13 +48,26 @@ class ChunkUpdate(BaseModel):
@property
def chunk_content(self) -> str:
"""
Get the actual content string regardless of input type
"""
"""Get the actual content string regardless of input type"""
if isinstance(self.content, QAChunk):
return f"question: {self.content.question} answer: {self.content.answer}"
return self.content.question # QA 模式下 page_content 存 question
return self.content
@property
def is_qa(self) -> bool:
return isinstance(self.content, QAChunk)
@property
def qa_metadata(self) -> dict:
"""返回 QA 相关的 metadata 字段"""
if isinstance(self.content, QAChunk):
return {
"chunk_type": "qa",
"question": self.content.question,
"answer": self.content.answer,
}
return {}
class ChunkRetrieve(BaseModel):
query: str
@@ -51,3 +77,8 @@ class ChunkRetrieve(BaseModel):
vector_similarity_weight: float | None = Field(None)
top_k: int | None = Field(None)
retrieve_type: RetrieveType | None = Field(None)
class ChunkBatchCreate(BaseModel):
"""批量创建 chunk"""
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")

View File

@@ -30,7 +30,7 @@ from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.prompts.generator import question_proposal, qa_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
@@ -311,6 +311,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
# 2.2 Vectorize and import batch documents
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
qa_prompt = db_document.parser_config.get("qa_prompt", None)
chat_model = None
if auto_questions_topn:
chat_model = Base(
@@ -318,62 +319,123 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base,
)
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
if qa_prompt:
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
# 预先构建所有 batch 的 chunks保证 sort_id 全局有序
all_batch_chunks: list[list[DocumentChunk]] = []
if auto_questions_topn:
# auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
# 构建 (global_idx, item) 列表
# QA 模式FastGPT 方案):
# 1. 原 chunk 标记为 source保留供 GraphRAG 使用,不参与检索)
# 2. LLM 生成 QA 对,每个 QA 对独立存储为 qa chunk
indexed_items = list(enumerate(res))
def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
"""为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
def _generate_qa(idx_item: tuple[int, dict]) -> tuple[int, list]:
"""为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)"""
global_idx, item = idx_item
content = item["content_with_weight"]
cached = get_llm_cache(chat_model.model_name, content, "question",
{"topn": auto_questions_topn})
cache_params = {"topn": auto_questions_topn}
if qa_prompt:
import hashlib
cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8]
cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params)
if not cached:
cached = question_proposal(chat_model, content, auto_questions_topn)
set_llm_cache(chat_model.model_name, content, cached, "question",
{"topn": auto_questions_topn})
return global_idx, cached
logger.info(f"[QA] Cache miss for chunk {global_idx}, calling LLM. cache_params={cache_params}")
try:
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
except Exception as e:
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
return global_idx, []
logger.info(f"[QA] Chunk {global_idx} generated {len(pairs)} QA pairs")
# 缓存存 JSON 字符串
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
cache_params)
return global_idx, pairs
logger.info(f"[QA] Cache hit for chunk {global_idx}, cache_params={cache_params}, cached_type={type(cached).__name__}")
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
if isinstance(cached, str):
try:
parsed = json.loads(cached)
if isinstance(parsed, list):
logger.info(f"[QA] Chunk {global_idx} loaded {len(parsed)} QA pairs from cache")
return global_idx, parsed
except (json.JSONDecodeError, TypeError):
pass
# 旧缓存格式(纯文本问题),尝试解析
from app.core.rag.prompts.generator import parse_qa_pairs
return global_idx, parse_qa_pairs(cached) if cached else []
return global_idx, cached if isinstance(cached, list) else []
# 并发调用 LLM 生成问题
question_map: dict[int, str] = {}
# 并发调用 LLM 生成 QA 对
qa_map: dict[int, list] = {}
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
futures = {q_executor.submit(_generate_question, item): item[0]
futures = {q_executor.submit(_generate_qa, item): item[0]
for item in indexed_items}
for future in futures:
global_idx, cached = future.result()
question_map[global_idx] = cached
global_idx, pairs = future.result()
qa_map[global_idx] = pairs
progress_lines.append(
f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
f"{datetime.now().strftime('%H:%M:%S')} QA pairs generated for {total_chunks} chunks "
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
# 按 batch 分组组装 DocumentChunk
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
chunks = []
for global_idx in range(batch_start, batch_end):
item = res[global_idx]
metadata = {
# 组装 chunkssource chunks + qa chunks
source_chunks = []
qa_chunks = []
qa_sort_id = 0
for global_idx in range(total_chunks):
item = res[global_idx]
source_chunk_id = uuid.uuid4().hex
# source chunk保留原文供 GraphRAG 使用,不参与向量检索
source_meta = {
"doc_id": source_chunk_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": global_idx,
"status": 1,
"chunk_type": "source",
}
source_chunks.append(
DocumentChunk(page_content=item["content_with_weight"], metadata=source_meta))
# qa chunks每个 QA 对独立存储
pairs = qa_map.get(global_idx, [])
for pair in pairs:
qa_meta = {
"doc_id": uuid.uuid4().hex,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": global_idx,
"sort_id": qa_sort_id,
"status": 1,
"chunk_type": "qa",
"question": pair["question"],
"answer": pair["answer"],
"source_chunk_id": source_chunk_id,
}
cached = question_map[global_idx]
chunks.append(
DocumentChunk(
page_content=f"question: {cached} answer: {item['content_with_weight']}",
metadata=metadata))
all_batch_chunks.append(chunks)
# page_content 存 question用于向量索引
qa_chunks.append(
DocumentChunk(page_content=pair["question"], metadata=qa_meta))
qa_sort_id += 1
# 按 batch 分组source + qa 一起)
all_chunks = source_chunks + qa_chunks
for batch_start in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE):
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, len(all_chunks))
all_batch_chunks.append(all_chunks[batch_start:batch_end])
progress_lines.append(
f"{datetime.now().strftime('%H:%M:%S')} QA mode: {len(source_chunks)} source chunks + "
f"{len(qa_chunks)} QA chunks prepared.")
else:
# 无 auto_questions直接构建 chunks
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
@@ -635,6 +697,136 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
return f"build_graphrag_for_document '{document_id}' failed: {e}"
@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import")
def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes):
"""
异步导入 QA 问答对CSV/Excel
文件格式:第一行标题(跳过),第一列问题,第二列答案
"""
import csv as csv_module
import io
db = None
try:
from app.db import get_db_context
with get_db_context() as db:
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
if not db_document or not db_knowledge:
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
return {"error": "document or knowledge not found", "imported": 0}
# 1. 解析文件
qa_pairs = []
failed_rows = []
if filename.endswith(".csv"):
try:
text = contents.decode("utf-8-sig")
except UnicodeDecodeError:
text = contents.decode("gbk", errors="ignore")
sniffer = csv_module.Sniffer()
try:
dialect = sniffer.sniff(text[:2048])
delimiter = dialect.delimiter
except csv_module.Error:
delimiter = "," if "," in text[:500] else "\t"
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
for i, row in enumerate(reader):
if i == 0:
continue
if len(row) >= 2 and row[0].strip() and row[1].strip():
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
elif len(row) >= 1 and row[0].strip():
failed_rows.append(i + 1)
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
try:
import openpyxl
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
for sheet in wb.worksheets:
for i, row in enumerate(sheet.iter_rows(values_only=True)):
if i == 0:
continue
if len(row) >= 2 and row[0] and row[1]:
q = str(row[0]).strip()
a = str(row[1]).strip()
if q and a:
qa_pairs.append({"question": q, "answer": a})
elif len(row) >= 1 and row[0]:
failed_rows.append(i + 1)
wb.close()
except Exception as e:
logger.error(f"[ImportQA] Excel parse failed: {e}")
return {"error": f"Excel parse failed: {e}", "imported": 0}
if not qa_pairs:
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
return {"error": "No valid QA pairs found", "imported": 0}
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
# 2. 写入 ES
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
sort_id = 0
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
chunks = []
for pair in qa_pairs:
sort_id += 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": document_id,
"knowledge_id": kb_id,
"sort_id": sort_id,
"status": 1,
"chunk_type": "qa",
"question": pair["question"],
"answer": pair["answer"],
}
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
batch_size = 50
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
vector_service.add_chunks(batch)
# 3. 更新 chunk_num 和 progress
db_document.chunk_num += len(chunks)
db_document.progress = 1.0
db_document.progress_msg = f"QA 导入完成: {len(chunks)}"
db.commit()
result = {"imported": len(chunks), "failed_rows": failed_rows}
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
return result
except Exception as e:
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
# 尝试更新文档状态为失败
try:
from app.db import get_db_context
with get_db_context() as err_db:
doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
if doc:
doc.progress = -1.0
doc.progress_msg = f"QA 导入失败: {str(e)[:200]}"
err_db.commit()
except Exception:
pass
return {"error": str(e), "imported": 0}
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
def sync_knowledge_for_kb(kb_id: uuid.UUID):
"""