feat(knowledge): support graph retrieval type with dynamic API key selection
This commit is contained in:
@@ -8,6 +8,8 @@ from langchain_core.documents import Document
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
@@ -39,8 +41,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
if isinstance(business_result, dict) and "chunks" in business_result:
|
||||
return business_result["chunks"]
|
||||
return business_result
|
||||
|
||||
def _extract_citations(self, business_result: Any) -> list:
|
||||
|
||||
@staticmethod
|
||||
def _extract_citations(business_result: Any) -> list:
|
||||
if isinstance(business_result, dict):
|
||||
return business_result.get("citations", [])
|
||||
return []
|
||||
@@ -230,23 +233,23 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.HYBRID:
|
||||
case RetrieveType.HYBRID | RetrieveType.Graph:
|
||||
rs1_task = asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
rs2_task = asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
@@ -266,6 +269,33 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
if kb_config.retrieve_type == RetrieveType.Graph:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
llm_key = self.model_balance(db_knowledge.llm)
|
||||
emb_key = self.model_balance(db_knowledge.embedding)
|
||||
chat_model = Base(
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base
|
||||
)
|
||||
doc = await asyncio.to_thread(
|
||||
kg_retriever.retrieval,
|
||||
question=query,
|
||||
workspace_ids=[str(db_knowledge.workspace_id)],
|
||||
kb_ids=[str(kb_config.kb_id)],
|
||||
emb_mdl=embedding_model,
|
||||
llm=chat_model
|
||||
)
|
||||
if doc:
|
||||
rs.insert(0, DocumentChunk(
|
||||
page_content=doc.get("page_content", ""),
|
||||
metadata=doc.get("metadata", {})
|
||||
))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
return rs
|
||||
|
||||
Reference in New Issue
Block a user