feat(knowledge): support graph retrieval type with dynamic API key selection
This commit is contained in:
@@ -23,6 +23,7 @@ from app.models.user_model import User
|
|||||||
from app.schemas import chunk_schema
|
from app.schemas import chunk_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -460,18 +461,20 @@ async def retrieve_chunks(
|
|||||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||||
|
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||||
|
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||||
chat_model = Base(
|
chat_model = Base(
|
||||||
key=db_knowledge.llm.api_keys[0].api_key,
|
key=llm_key.api_key,
|
||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
model_name=llm_key.model_name,
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
base_url=llm_key.api_base
|
||||||
)
|
)
|
||||||
embedding_model = OpenAIEmbed(
|
embedding_model = OpenAIEmbed(
|
||||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
key=emb_key.api_key,
|
||||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
model_name=emb_key.model_name,
|
||||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
base_url=emb_key.api_base
|
||||||
)
|
)
|
||||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||||
if doc:
|
if doc:
|
||||||
rs.insert(0, doc)
|
rs.insert(0, doc)
|
||||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||||
@@ -28,6 +28,7 @@ from app.core.rag.common.float_utils import get_float
|
|||||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -198,16 +199,18 @@ def _retrieve_for_knowledge(
|
|||||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||||
|
|
||||||
if not chat_model:
|
if not chat_model:
|
||||||
|
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||||
chat_model = Base(
|
chat_model = Base(
|
||||||
key=db_knowledge.llm.api_keys[0].api_key,
|
key=llm_key.api_key,
|
||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
model_name=llm_key.model_name,
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
base_url=llm_key.api_base,
|
||||||
)
|
)
|
||||||
if not embedding_model:
|
if not embedding_model:
|
||||||
|
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||||
embedding_model = OpenAIEmbed(
|
embedding_model = OpenAIEmbed(
|
||||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
key=emb_key.api_key,
|
||||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
model_name=emb_key.model_name,
|
||||||
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
base_url=emb_key.api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
@@ -248,6 +251,20 @@ def _retrieve_for_knowledge(
|
|||||||
seen_ids.add(doc.metadata["doc_id"])
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
unique_rs.append(doc)
|
unique_rs.append(doc)
|
||||||
rs = unique_rs
|
rs = unique_rs
|
||||||
|
if kb_config["retrieve_type"] == "graph":
|
||||||
|
try:
|
||||||
|
from app.core.rag.common.settings import kg_retriever
|
||||||
|
graph_doc = kg_retriever.retrieval(
|
||||||
|
question=kb_config["query"],
|
||||||
|
workspace_ids=[str(db_knowledge.workspace_id)],
|
||||||
|
kb_ids=[str(db_knowledge.id)],
|
||||||
|
emb_mdl=embedding_model,
|
||||||
|
llm=chat_model,
|
||||||
|
)
|
||||||
|
if graph_doc:
|
||||||
|
rs.insert(0, graph_doc)
|
||||||
|
except Exception as graph_error:
|
||||||
|
logger.warning(f"Graph retrieval failed for kb {db_knowledge.id}: {graph_error}")
|
||||||
|
|
||||||
results.extend(rs)
|
results.extend(rs)
|
||||||
return results, chat_model, embedding_model
|
return results, chat_model, embedding_model
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from langchain_core.documents import Document
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||||
|
from app.core.rag.llm.chat_model import Base
|
||||||
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
from app.core.rag.models.chunk import DocumentChunk
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
@@ -39,8 +41,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
if isinstance(business_result, dict) and "chunks" in business_result:
|
if isinstance(business_result, dict) and "chunks" in business_result:
|
||||||
return business_result["chunks"]
|
return business_result["chunks"]
|
||||||
return business_result
|
return business_result
|
||||||
|
|
||||||
def _extract_citations(self, business_result: Any) -> list:
|
@staticmethod
|
||||||
|
def _extract_citations(business_result: Any) -> list:
|
||||||
if isinstance(business_result, dict):
|
if isinstance(business_result, dict):
|
||||||
return business_result.get("citations", [])
|
return business_result.get("citations", [])
|
||||||
return []
|
return []
|
||||||
@@ -230,23 +233,23 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
case RetrieveType.HYBRID:
|
case RetrieveType.HYBRID | RetrieveType.Graph:
|
||||||
rs1_task = asyncio.to_thread(
|
rs1_task = asyncio.to_thread(
|
||||||
vector_service.search_by_vector, **{
|
vector_service.search_by_vector, **{
|
||||||
"query": query,
|
"query": query,
|
||||||
"top_k": kb_config.top_k,
|
"top_k": kb_config.top_k,
|
||||||
"indices": indices,
|
"indices": indices,
|
||||||
"score_threshold": kb_config.vector_similarity_weight
|
"score_threshold": kb_config.vector_similarity_weight
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
rs2_task = asyncio.to_thread(
|
rs2_task = asyncio.to_thread(
|
||||||
vector_service.search_by_full_text, **{
|
vector_service.search_by_full_text, **{
|
||||||
"query": query,
|
"query": query,
|
||||||
"top_k": kb_config.top_k,
|
"top_k": kb_config.top_k,
|
||||||
"indices": indices,
|
"indices": indices,
|
||||||
"score_threshold": kb_config.similarity_threshold
|
"score_threshold": kb_config.similarity_threshold
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||||
|
|
||||||
# Deduplicate hybrid retrieval results
|
# Deduplicate hybrid retrieval results
|
||||||
@@ -266,6 +269,33 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
key=lambda d: d.metadata.get("score", 0),
|
key=lambda d: d.metadata.get("score", 0),
|
||||||
reverse=True
|
reverse=True
|
||||||
)[:kb_config.top_k])
|
)[:kb_config.top_k])
|
||||||
|
if kb_config.retrieve_type == RetrieveType.Graph:
|
||||||
|
from app.core.rag.common.settings import kg_retriever
|
||||||
|
llm_key = self.model_balance(db_knowledge.llm)
|
||||||
|
emb_key = self.model_balance(db_knowledge.embedding)
|
||||||
|
chat_model = Base(
|
||||||
|
key=llm_key.api_key,
|
||||||
|
model_name=llm_key.model_name,
|
||||||
|
base_url=llm_key.api_base
|
||||||
|
)
|
||||||
|
embedding_model = OpenAIEmbed(
|
||||||
|
key=emb_key.api_key,
|
||||||
|
model_name=emb_key.model_name,
|
||||||
|
base_url=emb_key.api_base
|
||||||
|
)
|
||||||
|
doc = await asyncio.to_thread(
|
||||||
|
kg_retriever.retrieval,
|
||||||
|
question=query,
|
||||||
|
workspace_ids=[str(db_knowledge.workspace_id)],
|
||||||
|
kb_ids=[str(kb_config.kb_id)],
|
||||||
|
emb_mdl=embedding_model,
|
||||||
|
llm=chat_model
|
||||||
|
)
|
||||||
|
if doc:
|
||||||
|
rs.insert(0, DocumentChunk(
|
||||||
|
page_content=doc.get("page_content", ""),
|
||||||
|
metadata=doc.get("metadata", {})
|
||||||
|
))
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unknown retrieval type")
|
raise RuntimeError("Unknown retrieval type")
|
||||||
return rs
|
return rs
|
||||||
|
|||||||
Reference in New Issue
Block a user