Merge branch 'feature/20251219_lxc' into develop
This commit is contained in:
@@ -26,6 +26,8 @@ from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr
|
||||
from app.core.rag.common.string_utils import remove_redundant_spaces
|
||||
from app.core.rag.common.float_utils import get_float
|
||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
|
||||
|
||||
def knowledge_retrieval(
|
||||
@@ -48,6 +50,7 @@ def knowledge_retrieval(
|
||||
- merge_strategy: "weight" or other strategies
|
||||
- reranker_id: UUID of the reranker to use
|
||||
- reranker_top_k: int
|
||||
- use_graph: bool, whether to use a graph
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
@@ -59,6 +62,7 @@ def knowledge_retrieval(
|
||||
merge_strategy = config.get("merge_strategy", "weight")
|
||||
reranker_id = config.get("reranker_id")
|
||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||
use_graph = config.get("use_graph", "false").lower() == "true"
|
||||
|
||||
file_names_filter = []
|
||||
if user_ids:
|
||||
@@ -67,6 +71,10 @@ def knowledge_retrieval(
|
||||
if not knowledge_bases:
|
||||
return []
|
||||
|
||||
kb_ids = []
|
||||
workspace_ids = []
|
||||
chat_model = None
|
||||
embedding_model = None
|
||||
all_results = []
|
||||
# Search each knowledge base
|
||||
for kb_config in knowledge_bases:
|
||||
@@ -87,6 +95,22 @@ def knowledge_retrieval(
|
||||
else:
|
||||
continue
|
||||
|
||||
if str(db_knowledge.id) not in kb_ids:
|
||||
kb_ids.append(str(db_knowledge.id))
|
||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
if not chat_model:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
if not embedding_model:
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
@@ -136,6 +160,12 @@ def knowledge_retrieval(
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
# use graph
|
||||
if use_graph:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
all_results.insert(0, doc)
|
||||
return all_results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user