Compare commits
13 Commits
feature/me
...
feature/ra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8d1ed51a7 | ||
|
|
9fa83ed01e | ||
|
|
e222490bce | ||
|
|
ad2e885f72 | ||
|
|
70c6d161c8 | ||
|
|
f85c0594c9 | ||
|
|
5fceba54b4 | ||
|
|
6e89302cb2 | ||
|
|
90aa4cef21 | ||
|
|
6c47bb77ab | ||
|
|
f667936664 | ||
|
|
64e640d882 | ||
|
|
140311048a |
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import csv
|
||||
import io
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -23,6 +25,7 @@ from app.models.user_model import User
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
@@ -271,6 +274,9 @@ async def create_chunk(
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
}
|
||||
# QA chunk: 注入 chunk_type/question/answer 到 metadata
|
||||
if create_data.is_qa:
|
||||
metadata.update(create_data.qa_metadata)
|
||||
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
||||
# 3. Segmented vector storage
|
||||
vector_service.add_chunks([chunk])
|
||||
@@ -282,6 +288,187 @@ async def create_chunk(
|
||||
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
||||
async def create_chunks_batch(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
batch_data: chunk_schema.ChunkBatchCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Batch create chunks (max 8)
|
||||
"""
|
||||
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
|
||||
|
||||
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
|
||||
)
|
||||
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
|
||||
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if not db_document:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
# Get current max sort_id
|
||||
sort_id = 0
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
||||
if items:
|
||||
sort_id = items[0].metadata["sort_id"]
|
||||
|
||||
chunks = []
|
||||
for create_data in batch_data.items:
|
||||
sort_id += 1
|
||||
doc_id = uuid.uuid4().hex
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(document_id),
|
||||
"knowledge_id": str(kb_id),
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
}
|
||||
if create_data.is_qa:
|
||||
metadata.update(create_data.qa_metadata)
|
||||
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
|
||||
|
||||
vector_service.add_chunks(chunks)
|
||||
|
||||
db_document.chunk_num += len(chunks)
|
||||
db.commit()
|
||||
|
||||
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
|
||||
async def import_qa_new_doc(
|
||||
kb_id: uuid.UUID,
|
||||
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
导入 QA 问答对并新建文档(CSV/Excel),异步处理
|
||||
"""
|
||||
from app.schemas import file_schema, document_schema
|
||||
|
||||
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
|
||||
|
||||
# 1. 校验文件格式
|
||||
filename = file.filename or ""
|
||||
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
||||
|
||||
# 2. 校验知识库
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
||||
|
||||
# 3. 读取文件
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
if file_size == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
|
||||
|
||||
_, file_extension = os.path.splitext(filename)
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# 4. 创建 File 记录
|
||||
file_data = file_schema.FileCreate(
|
||||
kb_id=kb_id, created_by=current_user.id,
|
||||
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
|
||||
file_name=filename, file_ext=file_ext, file_size=file_size,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
|
||||
|
||||
# 5. 上传文件到存储后端
|
||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
||||
try:
|
||||
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Storage upload failed: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
|
||||
|
||||
db_file.file_key = file_key
|
||||
db.commit()
|
||||
db.refresh(db_file)
|
||||
|
||||
# 6. 创建 Document 记录(标记为 QA 类型)
|
||||
doc_data = document_schema.DocumentCreate(
|
||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
||||
file_name=filename, file_ext=file_ext, file_size=file_size,
|
||||
file_meta={}, parser_id="qa",
|
||||
parser_config={"doc_type": "qa", "auto_questions": 0}
|
||||
)
|
||||
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
|
||||
|
||||
# 7. 派发异步任务
|
||||
from app.celery_app import celery_app
|
||||
task = celery_app.send_task(
|
||||
"app.core.rag.tasks.import_qa_chunks",
|
||||
args=[str(kb_id), str(db_document.id), filename, contents],
|
||||
queue="qa_import"
|
||||
)
|
||||
|
||||
return success(data={
|
||||
"task_id": task.id,
|
||||
"document_id": str(db_document.id),
|
||||
"file_id": str(db_file.id),
|
||||
}, msg="QA 导入任务已提交,后台处理中")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
|
||||
async def import_qa_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
导入 QA 问答对(CSV/Excel),异步处理
|
||||
"""
|
||||
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
|
||||
|
||||
# 1. 校验文件格式
|
||||
filename = file.filename or ""
|
||||
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
||||
|
||||
# 2. 校验知识库和文档
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
||||
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if not db_document:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
|
||||
|
||||
# 3. 读取文件内容,派发异步任务
|
||||
contents = await file.read()
|
||||
|
||||
from app.celery_app import celery_app
|
||||
task = celery_app.send_task(
|
||||
"app.core.rag.tasks.import_qa_chunks",
|
||||
args=[str(kb_id), str(document_id), filename, contents],
|
||||
queue="qa_import"
|
||||
)
|
||||
|
||||
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
async def get_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
@@ -342,6 +529,9 @@ async def update_chunk(
|
||||
if total:
|
||||
chunk = items[0]
|
||||
chunk.page_content = content
|
||||
# QA chunk: 更新 metadata 中的 question/answer
|
||||
if update_data.is_qa:
|
||||
chunk.metadata.update(update_data.qa_metadata)
|
||||
vector_service.update_by_segment(chunk)
|
||||
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
||||
else:
|
||||
@@ -356,6 +546,7 @@ async def delete_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -373,7 +564,7 @@ async def delete_chunk(
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
if vector_service.text_exists(doc_id):
|
||||
vector_service.delete_by_ids([doc_id])
|
||||
vector_service.delete_by_ids([doc_id], refresh=force_refresh)
|
||||
# 更新 chunk_num
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
db_document.chunk_num -= 1
|
||||
|
||||
@@ -113,6 +113,33 @@ async def create_chunk(
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def create_chunks_batch(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
items: list = Body(..., description="chunk items list"),
|
||||
):
|
||||
"""
|
||||
Batch create chunks (max 8)
|
||||
"""
|
||||
body = await request.json()
|
||||
batch_data = chunk_schema.ChunkBatchCreate(**body)
|
||||
# 0. Obtain the creator of the api key
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await chunk_controller.create_chunks_batch(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
batch_data=batch_data,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def get_chunk(
|
||||
@@ -176,6 +203,7 @@ async def delete_chunk(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
||||
):
|
||||
"""
|
||||
delete document chunk
|
||||
@@ -188,6 +216,7 @@ async def delete_chunk(
|
||||
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
||||
document_id=document_id,
|
||||
doc_id=doc_id,
|
||||
force_refresh=force_refresh,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@@ -98,6 +98,7 @@ class Settings:
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
|
||||
@@ -46,7 +46,10 @@ async def run_graphrag(
|
||||
start = trio.current_time()
|
||||
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
|
||||
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True):
|
||||
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
||||
if d.get("chunk_type") == "qa":
|
||||
continue
|
||||
chunks.append(d["page_content"])
|
||||
|
||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||
@@ -150,6 +153,9 @@ async def run_graphrag_for_kb(
|
||||
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
|
||||
for doc in items:
|
||||
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
||||
if (doc.metadata or {}).get("chunk_type") == "qa":
|
||||
continue
|
||||
content = doc.page_content
|
||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||
current_chunk += content
|
||||
|
||||
@@ -131,18 +131,52 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
||||
|
||||
|
||||
def question_proposal(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
"""生成问题(向后兼容,返回纯文本问题列表)"""
|
||||
pairs = qa_proposal(chat_mdl, content, topn)
|
||||
if not pairs:
|
||||
return ""
|
||||
return kwd
|
||||
return "\n".join([p["question"] for p in pairs])
|
||||
|
||||
|
||||
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
|
||||
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
|
||||
|
||||
Args:
|
||||
chat_mdl: LLM 模型
|
||||
content: 文本内容
|
||||
topn: 生成 QA 对数量
|
||||
custom_prompt: 自定义 prompt 模板(支持 Jinja2,可用变量: content, topn)
|
||||
"""
|
||||
if custom_prompt:
|
||||
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
|
||||
sys_prompt = template.render(topn=topn)
|
||||
else:
|
||||
sys_prompt = QUESTION_PROMPT_TEMPLATE
|
||||
msg = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": content}]
|
||||
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
||||
raw = chat_mdl.chat(sys_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(raw, tuple):
|
||||
raw = raw[0]
|
||||
raw = re.sub(r"^.*</think>", "", raw, flags=re.DOTALL)
|
||||
if raw.find("**ERROR**") >= 0:
|
||||
return []
|
||||
return parse_qa_pairs(raw)
|
||||
|
||||
|
||||
def parse_qa_pairs(text: str) -> list:
|
||||
"""解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx"""
|
||||
pairs = []
|
||||
for line in text.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# 匹配 Q: ... A: ... 格式
|
||||
match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE)
|
||||
if match:
|
||||
q, a = match.group(1).strip(), match.group(2).strip()
|
||||
if q and a:
|
||||
pairs.append({"question": q, "answer": a})
|
||||
return pairs
|
||||
|
||||
|
||||
def graph_entity_types(chat_mdl, scenario):
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
## Role
|
||||
You are a text analyzer.
|
||||
You are a text analyzer and knowledge extraction expert.
|
||||
|
||||
## Task
|
||||
Propose {{ topn }} questions about a given piece of text content.
|
||||
Generate question-answer pairs from the given text content.
|
||||
|
||||
## Requirements
|
||||
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
|
||||
- Understand and summarize the text content, then generate up to {{ topn }} important question-answer pairs.
|
||||
- Each question-answer pair MUST be on a single line, formatted as: Q: <question> A: <answer>
|
||||
- The questions SHOULD NOT have overlapping meanings.
|
||||
- The questions SHOULD cover the main content of the text as much as possible.
|
||||
- The questions MUST be in the same language as the given piece of text content.
|
||||
- One question per line.
|
||||
- Output questions ONLY.
|
||||
|
||||
---
|
||||
|
||||
## Text Content
|
||||
{{ content }}
|
||||
- The answers MUST be concise, accurate, and directly derived from the text content.
|
||||
- The answers SHOULD be self-contained and understandable without additional context.
|
||||
- Both questions and answers MUST be in the same language as the given text content.
|
||||
- If the text is too short or lacks substantive content, generate fewer pairs rather than padding.
|
||||
- Output question-answer pairs ONLY, no extra explanation or commentary.
|
||||
|
||||
## Example Output
|
||||
Q: What is the capital of France? A: The capital of France is Paris.
|
||||
Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889.
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from elasticsearch import Elasticsearch, helpers
|
||||
from elasticsearch import Elasticsearch, helpers, NotFoundError
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from packaging.version import parse as parse_version
|
||||
# langchain-community
|
||||
@@ -53,13 +53,30 @@ class ElasticSearchVector(BaseVector):
|
||||
return "elasticsearch"
|
||||
|
||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||
# 实现 Elasticsearch 保存向量
|
||||
texts = [chunk.page_content for chunk in chunks]
|
||||
# QA chunks: embedding 只对 question 字段做;source chunks: 不做 embedding
|
||||
texts_for_embedding = []
|
||||
for chunk in chunks:
|
||||
chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk")
|
||||
if chunk_type == "source":
|
||||
# source chunk 不需要向量索引
|
||||
texts_for_embedding.append("")
|
||||
elif chunk_type == "qa":
|
||||
# QA chunk: 用 question 字段做 embedding
|
||||
texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content))
|
||||
else:
|
||||
# 普通 chunk: 用 page_content 做 embedding
|
||||
texts_for_embedding.append(chunk.page_content)
|
||||
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = self.embeddings.embed_batch(texts)
|
||||
embeddings = self.embeddings.embed_batch(texts_for_embedding)
|
||||
else:
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
embeddings = self.embeddings.embed_documents(texts_for_embedding)
|
||||
|
||||
# source chunk 的向量置空
|
||||
for i, chunk in enumerate(chunks):
|
||||
if (chunk.metadata or {}).get("chunk_type") == "source":
|
||||
embeddings[i] = None
|
||||
|
||||
self.create(chunks, embeddings, **kwargs)
|
||||
|
||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||
@@ -72,13 +89,25 @@ class ElasticSearchVector(BaseVector):
|
||||
uuids = self._get_uuids(chunks)
|
||||
actions = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
source = {
|
||||
Field.CONTENT_KEY.value: chunk.page_content,
|
||||
Field.METADATA_KEY.value: chunk.metadata or {},
|
||||
Field.VECTOR.value: embeddings[i] or None
|
||||
}
|
||||
# 写入 QA 相关字段
|
||||
meta = chunk.metadata or {}
|
||||
if meta.get("chunk_type"):
|
||||
source[Field.CHUNK_TYPE.value] = meta["chunk_type"]
|
||||
if meta.get("question"):
|
||||
source[Field.QUESTION.value] = meta["question"]
|
||||
if meta.get("answer"):
|
||||
source[Field.ANSWER.value] = meta["answer"]
|
||||
if meta.get("source_chunk_id"):
|
||||
source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"]
|
||||
|
||||
action = {
|
||||
"_index": self._collection_name,
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: chunk.page_content,
|
||||
Field.METADATA_KEY.value: chunk.metadata or {},
|
||||
Field.VECTOR.value: embeddings[i] or None
|
||||
}
|
||||
"_source": source
|
||||
}
|
||||
actions.append(action)
|
||||
# using bulk mode
|
||||
@@ -113,7 +142,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
||||
if not ids:
|
||||
return
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
@@ -134,6 +163,8 @@ class ElasticSearchVector(BaseVector):
|
||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||
try:
|
||||
helpers.bulk(self._client, actions)
|
||||
if refresh:
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
except BulkIndexError as e:
|
||||
for error in e.errors:
|
||||
delete_error = error.get('delete', {})
|
||||
@@ -153,7 +184,7 @@ class ElasticSearchVector(BaseVector):
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
return False
|
||||
actual_ids = self.get_ids_by_metadata_field(key, value)
|
||||
@@ -162,6 +193,8 @@ class ElasticSearchVector(BaseVector):
|
||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||
try:
|
||||
helpers.bulk(self._client, actions)
|
||||
if refresh:
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
except BulkIndexError as e:
|
||||
for error in e.errors:
|
||||
delete_error = error.get('delete', {})
|
||||
@@ -192,6 +225,8 @@ class ElasticSearchVector(BaseVector):
|
||||
List of DocumentChunk objects that match the query.
|
||||
"""
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
|
||||
if not self._client.indices.exists(index=indices):
|
||||
return 0, []
|
||||
|
||||
# Calculate the start position for the current page
|
||||
from_ = pagesize * (page-1)
|
||||
@@ -226,12 +261,15 @@ class ElasticSearchVector(BaseVector):
|
||||
})
|
||||
|
||||
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
|
||||
result = self._client.search(
|
||||
index=indices,
|
||||
from_=from_, # Only use from_ for the first page (simplified)
|
||||
size=pagesize,
|
||||
body=query_str,
|
||||
)
|
||||
try:
|
||||
result = self._client.search(
|
||||
index=indices,
|
||||
from_=from_, # Only use from_ for the first page (simplified)
|
||||
size=pagesize,
|
||||
body=query_str,
|
||||
)
|
||||
except NotFoundError:
|
||||
return 0, []
|
||||
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
@@ -241,10 +279,19 @@ class ElasticSearchVector(BaseVector):
|
||||
for res in result["hits"]["hits"]:
|
||||
source = res["_source"]
|
||||
page_content = source.get(Field.CONTENT_KEY.value)
|
||||
# vector = source.get(Field.VECTOR.value)
|
||||
vector = None
|
||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
||||
score = res["_score"]
|
||||
|
||||
# 将 QA 字段注入 metadata 供前端展示
|
||||
if chunk_type:
|
||||
metadata["chunk_type"] = chunk_type
|
||||
if chunk_type == "qa":
|
||||
metadata["question"] = source.get(Field.QUESTION.value, "")
|
||||
metadata["answer"] = source.get(Field.ANSWER.value, "")
|
||||
page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}"
|
||||
|
||||
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
|
||||
|
||||
docs = []
|
||||
@@ -267,13 +314,18 @@ class ElasticSearchVector(BaseVector):
|
||||
List of DocumentChunk objects that match the query.
|
||||
"""
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
if not self._client.indices.exists(index=indices):
|
||||
return 0, []
|
||||
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
|
||||
result = self._client.search(
|
||||
index=indices,
|
||||
from_=0, # Only use from_ for the first page (simplified)
|
||||
size=1,
|
||||
body=query_str,
|
||||
)
|
||||
try:
|
||||
result = self._client.search(
|
||||
index=indices,
|
||||
from_=0, # Only use from_ for the first page (simplified)
|
||||
size=1,
|
||||
body=query_str,
|
||||
)
|
||||
except NotFoundError:
|
||||
return 0, []
|
||||
# print(result)
|
||||
if "errors" in result:
|
||||
raise ValueError(f"Error during query: {result['errors']}")
|
||||
@@ -308,27 +360,43 @@ class ElasticSearchVector(BaseVector):
|
||||
Returns:
|
||||
updated count.
|
||||
"""
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||
indices = kwargs.get("indices", self._collection_name)
|
||||
chunk_type = (chunk.metadata or {}).get("chunk_type")
|
||||
|
||||
# QA chunk: embedding 基于 question;source chunk: 不更新向量
|
||||
if chunk_type == "source":
|
||||
embed_text = ""
|
||||
elif chunk_type == "qa":
|
||||
embed_text = (chunk.metadata or {}).get("question", chunk.page_content)
|
||||
else:
|
||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||
embed_text = chunk.page_content
|
||||
|
||||
if chunk_type != "source":
|
||||
if self.is_multimodal_embedding:
|
||||
chunk.vector = self.embeddings.embed_text(embed_text)
|
||||
else:
|
||||
chunk.vector = self.embeddings.embed_query(embed_text)
|
||||
|
||||
script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;"
|
||||
params = {
|
||||
"new_content": chunk.page_content,
|
||||
"new_vector": chunk.vector if chunk_type != "source" else None
|
||||
}
|
||||
|
||||
# QA chunk: 同时更新 question/answer 字段
|
||||
if chunk_type == "qa":
|
||||
script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;"
|
||||
params["new_question"] = (chunk.metadata or {}).get("question", "")
|
||||
params["new_answer"] = (chunk.metadata or {}).get("answer", "")
|
||||
|
||||
body = {
|
||||
"script": {
|
||||
"source": """
|
||||
ctx._source.page_content = params.new_content;
|
||||
ctx._source.vector = params.new_vector;
|
||||
""",
|
||||
"params": {
|
||||
"new_content": chunk.page_content,
|
||||
"new_vector": chunk.vector
|
||||
}
|
||||
"source": script_source,
|
||||
"params": params
|
||||
},
|
||||
"query": {
|
||||
"term": {
|
||||
Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id
|
||||
Field.DOC_ID.value: chunk.metadata["doc_id"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -336,9 +404,6 @@ class ElasticSearchVector(BaseVector):
|
||||
index=indices,
|
||||
body=body,
|
||||
)
|
||||
# Remove debug printing and use logging instead
|
||||
# print(result)
|
||||
# print(f"Update successful, number of affected documents: {result['updated']}")
|
||||
return result['updated']
|
||||
|
||||
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
|
||||
@@ -397,11 +462,11 @@ class ElasticSearchVector(BaseVector):
|
||||
}
|
||||
}
|
||||
},
|
||||
"filter": { # Add the filter condition of status=1
|
||||
"term": {
|
||||
"metadata.status": 1
|
||||
}
|
||||
}
|
||||
"filter": [
|
||||
{"term": {"metadata.status": 1}},
|
||||
# 排除 source chunk(仅供 GraphRAG 使用,不参与检索)
|
||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
# If file_names_filter is passed in, merge the filtering conditions
|
||||
@@ -415,22 +480,14 @@ class ElasticSearchVector(BaseVector):
|
||||
},
|
||||
"script": {
|
||||
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
|
||||
# The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1]
|
||||
"params": {"query_vector": query_vector}
|
||||
}
|
||||
}
|
||||
},
|
||||
"filter": [
|
||||
{
|
||||
"term": {
|
||||
"metadata.status": 1
|
||||
}
|
||||
},
|
||||
{
|
||||
"terms": {
|
||||
"metadata.file_name": file_names_filter # Additional file_name filtering
|
||||
}
|
||||
}
|
||||
{"term": {"metadata.status": 1}},
|
||||
{"terms": {"metadata.file_name": file_names_filter}},
|
||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||
],
|
||||
}
|
||||
}
|
||||
@@ -451,8 +508,19 @@ class ElasticSearchVector(BaseVector):
|
||||
source = res["_source"]
|
||||
page_content = source.get(Field.CONTENT_KEY.value)
|
||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
||||
score = res["_score"]
|
||||
score = score / 2 # Normalized [0-1]
|
||||
|
||||
# QA chunk: 返回 Q+A 拼接作为上下文
|
||||
if chunk_type == "qa":
|
||||
question = source.get(Field.QUESTION.value, "")
|
||||
answer = source.get(Field.ANSWER.value, "")
|
||||
page_content = f"Q: {question}\nA: {answer}"
|
||||
metadata["chunk_type"] = "qa"
|
||||
metadata["question"] = question
|
||||
metadata["answer"] = answer
|
||||
|
||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
|
||||
|
||||
docs = []
|
||||
@@ -491,11 +559,10 @@ class ElasticSearchVector(BaseVector):
|
||||
}
|
||||
}
|
||||
},
|
||||
"filter": { # Add the filter condition of status=1
|
||||
"term": {
|
||||
"metadata.status": 1
|
||||
}
|
||||
}
|
||||
"filter": [
|
||||
{"term": {"metadata.status": 1}},
|
||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,16 +579,9 @@ class ElasticSearchVector(BaseVector):
|
||||
}
|
||||
},
|
||||
"filter": [
|
||||
{
|
||||
"term": {
|
||||
"metadata.status": 1
|
||||
}
|
||||
},
|
||||
{
|
||||
"terms": {
|
||||
"metadata.file_name": file_names_filter # Additional file_name filtering
|
||||
}
|
||||
}
|
||||
{"term": {"metadata.status": 1}},
|
||||
{"terms": {"metadata.file_name": file_names_filter}},
|
||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
||||
],
|
||||
}
|
||||
}
|
||||
@@ -543,6 +603,17 @@ class ElasticSearchVector(BaseVector):
|
||||
source = res["_source"]
|
||||
page_content = source.get(Field.CONTENT_KEY.value)
|
||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
||||
|
||||
# QA chunk: 返回 Q+A 拼接作为上下文
|
||||
if chunk_type == "qa":
|
||||
question = source.get(Field.QUESTION.value, "")
|
||||
answer = source.get(Field.ANSWER.value, "")
|
||||
page_content = f"Q: {question}\nA: {answer}"
|
||||
metadata["chunk_type"] = "qa"
|
||||
metadata["question"] = question
|
||||
metadata["answer"] = answer
|
||||
|
||||
# Normalize the score to the [0,1] interval
|
||||
normalized_score = res["_score"] / max_score
|
||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
|
||||
@@ -652,7 +723,7 @@ class ElasticSearchVector(BaseVector):
|
||||
},
|
||||
Field.VECTOR.value: {
|
||||
"type": "dense_vector",
|
||||
"dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency
|
||||
"dims": len(next((e for e in embeddings if e is not None), [0]*768)), # 跳过 None 获取向量维度,fallback 768
|
||||
"index": True,
|
||||
"similarity": "cosine"
|
||||
}
|
||||
|
||||
@@ -14,3 +14,8 @@ class Field(StrEnum):
|
||||
DOCUMENT_ID = "metadata.document_id"
|
||||
KNOWLEDGE_ID = "metadata.knowledge_id"
|
||||
SORT_ID = "metadata.sort_id"
|
||||
# QA fields
|
||||
CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa"
|
||||
QUESTION = "question"
|
||||
ANSWER = "answer"
|
||||
SOURCE_CHUNK_ID = "source_chunk_id"
|
||||
|
||||
@@ -27,14 +27,14 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -20,13 +20,26 @@ class ChunkCreate(BaseModel):
|
||||
|
||||
@property
|
||||
def chunk_content(self) -> str:
|
||||
"""
|
||||
Get the actual content string regardless of input type
|
||||
"""
|
||||
"""Get the actual content string regardless of input type"""
|
||||
if isinstance(self.content, QAChunk):
|
||||
return f"question: {self.content.question} answer: {self.content.answer}"
|
||||
return self.content.question # QA 模式下 page_content 存 question
|
||||
return self.content
|
||||
|
||||
@property
|
||||
def is_qa(self) -> bool:
|
||||
return isinstance(self.content, QAChunk)
|
||||
|
||||
@property
|
||||
def qa_metadata(self) -> dict:
|
||||
"""返回 QA 相关的 metadata 字段"""
|
||||
if isinstance(self.content, QAChunk):
|
||||
return {
|
||||
"chunk_type": "qa",
|
||||
"question": self.content.question,
|
||||
"answer": self.content.answer,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
class ChunkUpdate(BaseModel):
|
||||
content: Union[str, QAChunk] = Field(
|
||||
@@ -35,13 +48,26 @@ class ChunkUpdate(BaseModel):
|
||||
|
||||
@property
|
||||
def chunk_content(self) -> str:
|
||||
"""
|
||||
Get the actual content string regardless of input type
|
||||
"""
|
||||
"""Get the actual content string regardless of input type"""
|
||||
if isinstance(self.content, QAChunk):
|
||||
return f"question: {self.content.question} answer: {self.content.answer}"
|
||||
return self.content.question # QA 模式下 page_content 存 question
|
||||
return self.content
|
||||
|
||||
@property
|
||||
def is_qa(self) -> bool:
|
||||
return isinstance(self.content, QAChunk)
|
||||
|
||||
@property
|
||||
def qa_metadata(self) -> dict:
|
||||
"""返回 QA 相关的 metadata 字段"""
|
||||
if isinstance(self.content, QAChunk):
|
||||
return {
|
||||
"chunk_type": "qa",
|
||||
"question": self.content.question,
|
||||
"answer": self.content.answer,
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
class ChunkRetrieve(BaseModel):
|
||||
query: str
|
||||
@@ -51,3 +77,8 @@ class ChunkRetrieve(BaseModel):
|
||||
vector_similarity_weight: float | None = Field(None)
|
||||
top_k: int | None = Field(None)
|
||||
retrieve_type: RetrieveType | None = Field(None)
|
||||
|
||||
|
||||
class ChunkBatchCreate(BaseModel):
|
||||
"""批量创建 chunk"""
|
||||
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")
|
||||
|
||||
254
api/app/tasks.py
254
api/app/tasks.py
@@ -30,7 +30,7 @@ from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.prompts.generator import question_proposal, qa_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
@@ -311,6 +311,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||
# 2.2 Vectorize and import batch documents
|
||||
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
||||
qa_prompt = db_document.parser_config.get("qa_prompt", None)
|
||||
chat_model = None
|
||||
if auto_questions_topn:
|
||||
chat_model = Base(
|
||||
@@ -318,62 +319,123 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
)
|
||||
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
|
||||
if qa_prompt:
|
||||
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
|
||||
|
||||
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
||||
all_batch_chunks: list[list[DocumentChunk]] = []
|
||||
|
||||
if auto_questions_topn:
|
||||
# auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
|
||||
# 构建 (global_idx, item) 列表
|
||||
# QA 模式(FastGPT 方案):
|
||||
# 1. 原 chunk 标记为 source(保留供 GraphRAG 使用,不参与检索)
|
||||
# 2. LLM 生成 QA 对,每个 QA 对独立存储为 qa chunk
|
||||
indexed_items = list(enumerate(res))
|
||||
|
||||
def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
|
||||
"""为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
|
||||
def _generate_qa(idx_item: tuple[int, dict]) -> tuple[int, list]:
|
||||
"""为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)"""
|
||||
global_idx, item = idx_item
|
||||
content = item["content_with_weight"]
|
||||
cached = get_llm_cache(chat_model.model_name, content, "question",
|
||||
{"topn": auto_questions_topn})
|
||||
cache_params = {"topn": auto_questions_topn}
|
||||
if qa_prompt:
|
||||
import hashlib
|
||||
cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8]
|
||||
cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params)
|
||||
if not cached:
|
||||
cached = question_proposal(chat_model, content, auto_questions_topn)
|
||||
set_llm_cache(chat_model.model_name, content, cached, "question",
|
||||
{"topn": auto_questions_topn})
|
||||
return global_idx, cached
|
||||
logger.info(f"[QA] Cache miss for chunk {global_idx}, calling LLM. cache_params={cache_params}")
|
||||
try:
|
||||
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
|
||||
return global_idx, []
|
||||
logger.info(f"[QA] Chunk {global_idx} generated {len(pairs)} QA pairs")
|
||||
# 缓存存 JSON 字符串
|
||||
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
|
||||
cache_params)
|
||||
return global_idx, pairs
|
||||
logger.info(f"[QA] Cache hit for chunk {global_idx}, cache_params={cache_params}, cached_type={type(cached).__name__}")
|
||||
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
|
||||
if isinstance(cached, str):
|
||||
try:
|
||||
parsed = json.loads(cached)
|
||||
if isinstance(parsed, list):
|
||||
logger.info(f"[QA] Chunk {global_idx} loaded {len(parsed)} QA pairs from cache")
|
||||
return global_idx, parsed
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# 旧缓存格式(纯文本问题),尝试解析
|
||||
from app.core.rag.prompts.generator import parse_qa_pairs
|
||||
return global_idx, parse_qa_pairs(cached) if cached else []
|
||||
return global_idx, cached if isinstance(cached, list) else []
|
||||
|
||||
# 并发调用 LLM 生成问题
|
||||
question_map: dict[int, str] = {}
|
||||
# 并发调用 LLM 生成 QA 对
|
||||
qa_map: dict[int, list] = {}
|
||||
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
||||
futures = {q_executor.submit(_generate_question, item): item[0]
|
||||
futures = {q_executor.submit(_generate_qa, item): item[0]
|
||||
for item in indexed_items}
|
||||
for future in futures:
|
||||
global_idx, cached = future.result()
|
||||
question_map[global_idx] = cached
|
||||
global_idx, pairs = future.result()
|
||||
qa_map[global_idx] = pairs
|
||||
|
||||
progress_lines.append(
|
||||
f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
|
||||
f"{datetime.now().strftime('%H:%M:%S')} QA pairs generated for {total_chunks} chunks "
|
||||
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
||||
|
||||
# 按 batch 分组组装 DocumentChunk
|
||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
||||
chunks = []
|
||||
for global_idx in range(batch_start, batch_end):
|
||||
item = res[global_idx]
|
||||
metadata = {
|
||||
# 组装 chunks:source chunks + qa chunks
|
||||
source_chunks = []
|
||||
qa_chunks = []
|
||||
qa_sort_id = 0
|
||||
|
||||
for global_idx in range(total_chunks):
|
||||
item = res[global_idx]
|
||||
source_chunk_id = uuid.uuid4().hex
|
||||
|
||||
# source chunk:保留原文,供 GraphRAG 使用,不参与向量检索
|
||||
source_meta = {
|
||||
"doc_id": source_chunk_id,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": global_idx,
|
||||
"status": 1,
|
||||
"chunk_type": "source",
|
||||
}
|
||||
source_chunks.append(
|
||||
DocumentChunk(page_content=item["content_with_weight"], metadata=source_meta))
|
||||
|
||||
# qa chunks:每个 QA 对独立存储
|
||||
pairs = qa_map.get(global_idx, [])
|
||||
for pair in pairs:
|
||||
qa_meta = {
|
||||
"doc_id": uuid.uuid4().hex,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": global_idx,
|
||||
"sort_id": qa_sort_id,
|
||||
"status": 1,
|
||||
"chunk_type": "qa",
|
||||
"question": pair["question"],
|
||||
"answer": pair["answer"],
|
||||
"source_chunk_id": source_chunk_id,
|
||||
}
|
||||
cached = question_map[global_idx]
|
||||
chunks.append(
|
||||
DocumentChunk(
|
||||
page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
||||
metadata=metadata))
|
||||
all_batch_chunks.append(chunks)
|
||||
# page_content 存 question,用于向量索引
|
||||
qa_chunks.append(
|
||||
DocumentChunk(page_content=pair["question"], metadata=qa_meta))
|
||||
qa_sort_id += 1
|
||||
|
||||
# 按 batch 分组(source + qa 一起)
|
||||
all_chunks = source_chunks + qa_chunks
|
||||
for batch_start in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, len(all_chunks))
|
||||
all_batch_chunks.append(all_chunks[batch_start:batch_end])
|
||||
|
||||
progress_lines.append(
|
||||
f"{datetime.now().strftime('%H:%M:%S')} QA mode: {len(source_chunks)} source chunks + "
|
||||
f"{len(qa_chunks)} QA chunks prepared.")
|
||||
else:
|
||||
# 无 auto_questions:直接构建 chunks
|
||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||
@@ -635,6 +697,136 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
|
||||
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import")
|
||||
def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes):
|
||||
"""
|
||||
异步导入 QA 问答对(CSV/Excel)
|
||||
|
||||
文件格式:第一行标题(跳过),第一列问题,第二列答案
|
||||
"""
|
||||
import csv as csv_module
|
||||
import io
|
||||
|
||||
db = None
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as db:
|
||||
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
|
||||
if not db_document or not db_knowledge:
|
||||
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
|
||||
return {"error": "document or knowledge not found", "imported": 0}
|
||||
|
||||
# 1. 解析文件
|
||||
qa_pairs = []
|
||||
failed_rows = []
|
||||
|
||||
if filename.endswith(".csv"):
|
||||
try:
|
||||
text = contents.decode("utf-8-sig")
|
||||
except UnicodeDecodeError:
|
||||
text = contents.decode("gbk", errors="ignore")
|
||||
|
||||
sniffer = csv_module.Sniffer()
|
||||
try:
|
||||
dialect = sniffer.sniff(text[:2048])
|
||||
delimiter = dialect.delimiter
|
||||
except csv_module.Error:
|
||||
delimiter = "," if "," in text[:500] else "\t"
|
||||
|
||||
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
|
||||
for i, row in enumerate(reader):
|
||||
if i == 0:
|
||||
continue
|
||||
if len(row) >= 2 and row[0].strip() and row[1].strip():
|
||||
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
|
||||
elif len(row) >= 1 and row[0].strip():
|
||||
failed_rows.append(i + 1)
|
||||
|
||||
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
|
||||
try:
|
||||
import openpyxl
|
||||
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
|
||||
for sheet in wb.worksheets:
|
||||
for i, row in enumerate(sheet.iter_rows(values_only=True)):
|
||||
if i == 0:
|
||||
continue
|
||||
if len(row) >= 2 and row[0] and row[1]:
|
||||
q = str(row[0]).strip()
|
||||
a = str(row[1]).strip()
|
||||
if q and a:
|
||||
qa_pairs.append({"question": q, "answer": a})
|
||||
elif len(row) >= 1 and row[0]:
|
||||
failed_rows.append(i + 1)
|
||||
wb.close()
|
||||
except Exception as e:
|
||||
logger.error(f"[ImportQA] Excel parse failed: {e}")
|
||||
return {"error": f"Excel parse failed: {e}", "imported": 0}
|
||||
|
||||
if not qa_pairs:
|
||||
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
|
||||
return {"error": "No valid QA pairs found", "imported": 0}
|
||||
|
||||
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
|
||||
|
||||
# 2. 写入 ES
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
sort_id = 0
|
||||
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
|
||||
if items:
|
||||
sort_id = items[0].metadata["sort_id"]
|
||||
|
||||
chunks = []
|
||||
for pair in qa_pairs:
|
||||
sort_id += 1
|
||||
doc_id = uuid.uuid4().hex
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": document_id,
|
||||
"knowledge_id": kb_id,
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
"chunk_type": "qa",
|
||||
"question": pair["question"],
|
||||
"answer": pair["answer"],
|
||||
}
|
||||
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
|
||||
|
||||
batch_size = 50
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch = chunks[i:i + batch_size]
|
||||
vector_service.add_chunks(batch)
|
||||
|
||||
# 3. 更新 chunk_num 和 progress
|
||||
db_document.chunk_num += len(chunks)
|
||||
db_document.progress = 1.0
|
||||
db_document.progress_msg = f"QA 导入完成: {len(chunks)} 条"
|
||||
db.commit()
|
||||
|
||||
result = {"imported": len(chunks), "failed_rows": failed_rows}
|
||||
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
|
||||
# 尝试更新文档状态为失败
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as err_db:
|
||||
doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
||||
if doc:
|
||||
doc.progress = -1.0
|
||||
doc.progress_msg = f"QA 导入失败: {str(e)[:200]}"
|
||||
err_db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {"error": str(e), "imported": 0}
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user