Initial commit
This commit is contained in:
192
app/core/rag/nlp/search.py
Normal file
192
app/core/rag/nlp/search.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import uuid
|
||||
from typing import Dict, List, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from app.db import get_db
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.models import knowledge_model
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
|
||||
def knowledge_retrieval(
|
||||
query: str,
|
||||
config: Dict[str, Any],
|
||||
user_ids: List[str] = None,
|
||||
) -> list[DocumentChunk]:
|
||||
"""
|
||||
Knowledge retrieval with multiple knowledge bases and reranking
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
config: Configuration dictionary containing:
|
||||
- knowledge_bases: List of knowledge base configs with:
|
||||
- kb_id: Knowledge base ID
|
||||
- similarity_threshold: float
|
||||
- vector_similarity_weight: float
|
||||
- top_k: int
|
||||
- retrieve_type: "participle" or "semantic" or "hybrid"
|
||||
- merge_strategy: "weight" or other strategies
|
||||
- reranker_id: UUID of the reranker to use
|
||||
- reranker_top_k: int
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
"""
|
||||
db = next(get_db()) # Manually call the generator
|
||||
try:
|
||||
# parse configuration
|
||||
knowledge_bases = config.get("knowledge_bases", [])
|
||||
merge_strategy = config.get("merge_strategy", "weight")
|
||||
reranker_id = config.get("reranker_id")
|
||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||
|
||||
file_names_filter=[]
|
||||
if user_ids:
|
||||
file_names_filter.extend([f"{user_id}.txt" for user_id in user_ids])
|
||||
|
||||
if not knowledge_bases:
|
||||
return []
|
||||
|
||||
all_results = []
|
||||
# Search each knowledge base
|
||||
for kb_config in knowledge_bases:
|
||||
kb_id = kb_config["kb_id"]
|
||||
try:
|
||||
# Check whether the knowledge base exists and is available
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||
# Process shared knowledge base
|
||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
||||
knowledgeshare_id=db_knowledge.id)
|
||||
if knowledgeshare:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
||||
knowledge_id=knowledgeshare.source_kb_id)
|
||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
case "participle":
|
||||
rs = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case "semantic":
|
||||
rs = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case _: # hybrid
|
||||
rs1 = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
rs2 = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
|
||||
# Deduplication of merge results
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
|
||||
all_results.extend(rs)
|
||||
except Exception as e:
|
||||
# Failure of retrieval in a single knowledge base does not affect other knowledge bases
|
||||
print(f"retrieval knowledge({kb_id}) failed: {str(e)}")
|
||||
continue
|
||||
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
return all_results
|
||||
|
||||
except Exception as e:
|
||||
print(f"retrieval knowledge failed: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||
Args:
|
||||
reranker_id: reranker model id
|
||||
query: query string
|
||||
docs: List of document blocks to be rearranged
|
||||
top_k: Number of top-level documents returned
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input document list is empty or top_k is invalid
|
||||
"""
|
||||
# 参数校验
|
||||
if not reranker_id:
|
||||
raise ValueError("reranker_id be empty")
|
||||
if not docs:
|
||||
raise ValueError("retrieval chunks be empty")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
try:
|
||||
# initialize reranker
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=reranker_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
reranker = RedBearRerank(RedBearModelConfig(
|
||||
model_name=apiConfig.model_name,
|
||||
provider=apiConfig.provider,
|
||||
api_key=apiConfig.api_key,
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
# Convert to LangChain Document object
|
||||
documents = [
|
||||
Document(
|
||||
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||
print(reranked_docs)
|
||||
|
||||
# Sort in descending order based on relevance score
|
||||
reranked_docs.sort(
|
||||
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||
reverse=True
|
||||
)
|
||||
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||
result = []
|
||||
for item in reranked_docs[:top_k]:
|
||||
for doc in docs:
|
||||
if doc.page_content == item.page_content:
|
||||
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||
result.append(doc)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||
Reference in New Issue
Block a user