494 lines
20 KiB
Python
494 lines
20 KiB
Python
import os
|
|
from typing import Any, Optional
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
from fastapi.encoders import jsonable_encoder
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging_config import get_api_logger
|
|
from app.core.rag.common.settings import kg_retriever
|
|
from app.core.rag.llm.chat_model import Base
|
|
from app.core.rag.llm.cv_model import QWenCV
|
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
|
from app.core.rag.models.chunk import DocumentChunk
|
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
|
from app.core.response_utils import success
|
|
from app.db import get_db
|
|
from app.dependencies import get_current_user
|
|
from app.models import knowledge_model, knowledgeshare_model
|
|
from app.models.document_model import Document
|
|
from app.models.user_model import User
|
|
from app.schemas import chunk_schema
|
|
from app.schemas.response_schema import ApiResponse
|
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
|
from app.services.model_service import ModelApiKeyService
|
|
|
|
# Obtain a dedicated API logger
|
|
api_logger = get_api_logger()
|
|
|
|
router = APIRouter(
|
|
prefix="/chunks",
|
|
tags=["chunks"],
|
|
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
|
)
|
|
|
|
|
|
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
|
|
async def get_preview_chunks(
|
|
kb_id: uuid.UUID,
|
|
document_id: uuid.UUID,
|
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
|
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Paged query document block preview list
|
|
- Support filtering by document_id
|
|
- Support keyword search for segmented content
|
|
- Return paging metadata + file list
|
|
"""
|
|
api_logger.info(f"Paged query document block preview list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
|
# 1. parameter validation
|
|
if page < 1 or pagesize < 1:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="The paging parameter must be greater than 0"
|
|
)
|
|
|
|
# 2. Obtain knowledge base information
|
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
if not db_knowledge:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The knowledge base does not exist or access is denied"
|
|
)
|
|
# 3. Check if the document exists
|
|
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
|
if not db_document:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The document does not exist or you do not have permission to access it"
|
|
)
|
|
|
|
# 4. Check if the file exists
|
|
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
|
|
if not db_file:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The file does not exist or you do not have permission to access it"
|
|
)
|
|
|
|
# 5. Get file content from storage backend
|
|
if not db_file.file_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="File has no storage key (legacy data not migrated)"
|
|
)
|
|
|
|
from app.services.file_storage_service import FileStorageService
|
|
import asyncio
|
|
storage_service = FileStorageService()
|
|
|
|
async def _download():
|
|
return await storage_service.download_file(db_file.file_key)
|
|
|
|
try:
|
|
file_binary = asyncio.run(_download())
|
|
except RuntimeError:
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
file_binary = loop.run_until_complete(_download())
|
|
finally:
|
|
loop.close()
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"File not found in storage: {e}"
|
|
)
|
|
|
|
# 7. Document parsing & segmentation
|
|
def progress_callback(prog=None, msg=None):
|
|
print(f"prog: {prog} msg: {msg}\n")
|
|
# Prepare to configure vision_model information
|
|
vision_model = QWenCV(
|
|
key=db_knowledge.image2text.api_keys[0].api_key,
|
|
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
|
lang="Chinese",
|
|
base_url=db_knowledge.image2text.api_keys[0].api_base
|
|
)
|
|
from app.core.rag.app.naive import chunk
|
|
res = chunk(filename=db_file.file_name,
|
|
binary=file_binary,
|
|
from_page=0,
|
|
to_page=5,
|
|
callback=progress_callback,
|
|
vision_model=vision_model,
|
|
parser_config=db_document.parser_config,
|
|
is_root=False)
|
|
|
|
start_index = (page - 1) * pagesize
|
|
end_index = start_index + pagesize
|
|
# Use slicing to obtain the data of the current page
|
|
paginated_chunk_str_list = res[start_index:end_index]
|
|
chunks = []
|
|
for idx, item in enumerate(paginated_chunk_str_list):
|
|
metadata = {
|
|
"doc_id": uuid.uuid4().hex,
|
|
"file_id": str(db_document.file_id),
|
|
"file_name": db_document.file_name,
|
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
"document_id": str(db_document.id),
|
|
"knowledge_id": str(db_document.kb_id),
|
|
"sort_id": idx,
|
|
"status": 1,
|
|
}
|
|
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
|
|
|
# 8. Return structured response
|
|
total = len(res)
|
|
result = {
|
|
"items": chunks,
|
|
"page": {
|
|
"page": page,
|
|
"pagesize": pagesize,
|
|
"total": total,
|
|
"has_next": True if page * pagesize < total else False
|
|
}
|
|
}
|
|
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
|
|
return success(data=jsonable_encoder(result), msg="Querying the document block preview list succeeded")
|
|
|
|
|
|
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
|
|
async def get_chunks(
|
|
kb_id: uuid.UUID,
|
|
document_id: uuid.UUID,
|
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
|
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Paged query document chunk list
|
|
- Support filtering by document_id
|
|
- Support keyword search for segmented content
|
|
- Return paging metadata + file list
|
|
"""
|
|
api_logger.info(f"Paged query document chunk list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
|
# 1. parameter validation
|
|
if page < 1 or pagesize < 1:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="The paging parameter must be greater than 0"
|
|
)
|
|
|
|
# 2. Obtain knowledge base information
|
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
if not db_knowledge:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The knowledge base does not exist or access is denied"
|
|
)
|
|
|
|
# 3. Execute paged query
|
|
try:
|
|
api_logger.debug("Start executing document chunk query")
|
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True)
|
|
api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records")
|
|
except Exception as e:
|
|
api_logger.error(f"Document chunk query failed: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Query failed: {str(e)}"
|
|
)
|
|
|
|
# 4. Return structured response
|
|
result = {
|
|
"items": items,
|
|
"page": {
|
|
"page": page,
|
|
"pagesize": pagesize,
|
|
"total": total,
|
|
"has_next": True if page * pagesize < total else False
|
|
}
|
|
}
|
|
return success(data=jsonable_encoder(result), msg="Query of document chunk list succeeded")
|
|
|
|
|
|
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
|
|
async def create_chunk(
|
|
kb_id: uuid.UUID,
|
|
document_id: uuid.UUID,
|
|
create_data: chunk_schema.ChunkCreate,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
create chunk
|
|
"""
|
|
# Obtain the actual content
|
|
content = create_data.chunk_content
|
|
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={content}, username: {current_user.username}")
|
|
|
|
# 1. Obtain knowledge base information
|
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
if not db_knowledge:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The knowledge base does not exist or access is denied"
|
|
)
|
|
# 1. Obtain document information
|
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
|
if not db_document:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The document does not exist or you do not have permission to access it"
|
|
)
|
|
|
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
|
|
# 2. Get the sort ID
|
|
sort_id = 0
|
|
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
|
if items:
|
|
sort_id = items[0].metadata["sort_id"]
|
|
sort_id = sort_id + 1
|
|
|
|
doc_id = uuid.uuid4().hex
|
|
metadata = {
|
|
"doc_id": doc_id,
|
|
"file_id": str(db_document.file_id),
|
|
"file_name": db_document.file_name,
|
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
"document_id": str(document_id),
|
|
"knowledge_id": str(kb_id),
|
|
"sort_id": sort_id,
|
|
"status": 1,
|
|
}
|
|
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
|
# 3. Segmented vector storage
|
|
vector_service.add_chunks([chunk])
|
|
|
|
# 4.update chunk_num
|
|
db_document.chunk_num += 1
|
|
db.commit()
|
|
|
|
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
|
|
|
|
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
|
async def get_chunk(
|
|
kb_id: uuid.UUID,
|
|
document_id: uuid.UUID,
|
|
doc_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Retrieve document chunk information based on doc_id
|
|
"""
|
|
api_logger.info(f"Obtain document chunk information: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
|
|
|
|
# 1. Obtain knowledge base information
|
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
if not db_knowledge:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The knowledge base does not exist or access is denied"
|
|
)
|
|
|
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
|
if total:
|
|
return success(data=jsonable_encoder(items[0]), msg="Document chunk query successful")
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The document chunk does not exist or you do not have access"
|
|
)
|
|
|
|
|
|
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
|
async def update_chunk(
|
|
kb_id: uuid.UUID,
|
|
document_id: uuid.UUID,
|
|
doc_id: str,
|
|
update_data: chunk_schema.ChunkUpdate,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Update document chunk content
|
|
"""
|
|
# Obtain the actual content
|
|
content = update_data.chunk_content
|
|
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={content}, username: {current_user.username}")
|
|
|
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
if not db_knowledge:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The knowledge base does not exist or access is denied"
|
|
)
|
|
|
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
|
if total:
|
|
chunk = items[0]
|
|
chunk.page_content = content
|
|
vector_service.update_by_segment(chunk)
|
|
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The document chunk does not exist or you do not have access to it"
|
|
)
|
|
|
|
|
|
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
|
async def delete_chunk(
|
|
kb_id: uuid.UUID,
|
|
document_id: uuid.UUID,
|
|
doc_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
delete document chunk
|
|
"""
|
|
api_logger.info(f"Request to delete document chunk: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
|
|
|
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
if not db_knowledge:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The knowledge base does not exist or access is denied"
|
|
)
|
|
|
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
if vector_service.text_exists(doc_id):
|
|
vector_service.delete_by_ids([doc_id])
|
|
# 更新 chunk_num
|
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
|
db_document.chunk_num -= 1
|
|
db.commit()
|
|
return success(msg="The document chunk has been successfully deleted")
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The document chunk does not exist or you do not have access to it"
|
|
)
|
|
|
|
|
|
@router.get("/retrieve_type", response_model=ApiResponse)
|
|
def get_retrieve_types():
|
|
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
|
|
|
|
|
|
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
|
|
async def retrieve_chunks(
|
|
retrieve_data: chunk_schema.ChunkRetrieve,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
retrieve chunk
|
|
"""
|
|
api_logger.info(f"retrieve chunk: query={retrieve_data.query}, username: {current_user.username}")
|
|
|
|
filters = [
|
|
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
|
knowledge_model.Knowledge.chunk_num > 0,
|
|
knowledge_model.Knowledge.status == 1
|
|
]
|
|
private_items = knowledge_service.get_chunked_knowledgeids(
|
|
db=db,
|
|
filters=filters,
|
|
current_user=current_user
|
|
)
|
|
private_kb_ids = [item[0] for item in private_items]
|
|
private_workspace_ids = [item[1] for item in private_items]
|
|
filters = [
|
|
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
|
knowledge_model.Knowledge.chunk_num > 0,
|
|
knowledge_model.Knowledge.status == 1
|
|
]
|
|
items = knowledge_service.get_chunked_knowledgeids(
|
|
db=db,
|
|
filters=filters,
|
|
current_user=current_user
|
|
)
|
|
if items:
|
|
filters = [
|
|
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
|
|
]
|
|
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
|
db=db,
|
|
filters=filters,
|
|
current_user=current_user
|
|
)
|
|
share_kb_ids = [item[0] for item in share_items]
|
|
share_workspace_ids = [item[1] for item in share_items]
|
|
private_kb_ids.extend(share_kb_ids)
|
|
private_workspace_ids.extend(share_workspace_ids)
|
|
if not private_kb_ids:
|
|
return success(data=[], msg="retrieval successful")
|
|
kb_id = private_kb_ids[0]
|
|
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
|
|
indices = ",".join(uuid_strs)
|
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
if not db_knowledge:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="The knowledge base does not exist or access is denied"
|
|
)
|
|
|
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
|
|
# 1 participle search, 2 semantic search, 3 hybrid search
|
|
match retrieve_data.retrieve_type:
|
|
case chunk_schema.RetrieveType.PARTICIPLE:
|
|
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
|
return success(data=rs, msg="retrieval successful")
|
|
case chunk_schema.RetrieveType.SEMANTIC:
|
|
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
|
return success(data=rs, msg="retrieval successful")
|
|
case _:
|
|
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
|
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
|
# Efficient deduplication
|
|
seen_ids = set()
|
|
unique_rs = []
|
|
for doc in rs1 + rs2:
|
|
if doc.metadata["doc_id"] not in seen_ids:
|
|
seen_ids.add(doc.metadata["doc_id"])
|
|
unique_rs.append(doc)
|
|
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else []
|
|
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
|
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
|
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
|
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
|
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
|
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
|
chat_model = Base(
|
|
key=llm_key.api_key,
|
|
model_name=llm_key.model_name,
|
|
base_url=llm_key.api_base
|
|
)
|
|
embedding_model = OpenAIEmbed(
|
|
key=emb_key.api_key,
|
|
model_name=emb_key.model_name,
|
|
base_url=emb_key.api_base
|
|
)
|
|
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
|
if doc:
|
|
rs.insert(0, doc)
|
|
return success(data=jsonable_encoder(rs), msg="retrieval successful") |