feat(knowledge): support graph retrieval type with dynamic API key selection

This commit is contained in:
Timebomb2018
2026-04-09 15:00:49 +08:00
parent edac6a164e
commit 70aab94fc3
3 changed files with 80 additions and 30 deletions

View File

@@ -23,6 +23,7 @@ from app.models.user_model import User
from app.schemas import chunk_schema from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.services.model_service import ModelApiKeyService
# Obtain a dedicated API logger # Obtain a dedicated API logger
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -460,18 +461,20 @@ async def retrieve_chunks(
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph: if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
kb_ids = [str(kb_id) for kb_id in private_kb_ids] kb_ids = [str(kb_id) for kb_id in private_kb_ids]
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids] workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
# Prepare to configure chat_mdl、embedding_model、vision_model information # Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base( chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key, key=llm_key.api_key,
model_name=db_knowledge.llm.api_keys[0].model_name, model_name=llm_key.model_name,
base_url=db_knowledge.llm.api_keys[0].api_base base_url=llm_key.api_base
) )
embedding_model = OpenAIEmbed( embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key, key=emb_key.api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name, model_name=emb_key.model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base base_url=emb_key.api_base
) )
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model) doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc: if doc:
rs.insert(0, doc) rs.insert(0, doc)
return success(data=jsonable_encoder(rs), msg="retrieval successful") return success(data=jsonable_encoder(rs), msg="retrieval successful")

View File

@@ -28,6 +28,7 @@ from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.services.model_service import ModelApiKeyService
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -198,16 +199,18 @@ def _retrieve_for_knowledge(
workspace_ids.append(str(db_knowledge.workspace_id)) workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model: if not chat_model:
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
chat_model = Base( chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key, key=llm_key.api_key,
model_name=db_knowledge.llm.api_keys[0].model_name, model_name=llm_key.model_name,
base_url=db_knowledge.llm.api_keys[0].api_base, base_url=llm_key.api_base,
) )
if not embedding_model: if not embedding_model:
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
embedding_model = OpenAIEmbed( embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key, key=emb_key.api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name, model_name=emb_key.model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base, base_url=emb_key.api_base,
) )
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
@@ -248,6 +251,20 @@ def _retrieve_for_knowledge(
seen_ids.add(doc.metadata["doc_id"]) seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc) unique_rs.append(doc)
rs = unique_rs rs = unique_rs
if kb_config["retrieve_type"] == "graph":
try:
from app.core.rag.common.settings import kg_retriever
graph_doc = kg_retriever.retrieval(
question=kb_config["query"],
workspace_ids=[str(db_knowledge.workspace_id)],
kb_ids=[str(db_knowledge.id)],
emb_mdl=embedding_model,
llm=chat_model,
)
if graph_doc:
rs.insert(0, graph_doc)
except Exception as graph_error:
logger.warning(f"Graph retrieval failed for kb {db_knowledge.id}: {graph_error}")
results.extend(rs) results.extend(rs)
return results, chat_model, embedding_model return results, chat_model, embedding_model

View File

@@ -8,6 +8,8 @@ from langchain_core.documents import Document
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.models.chunk import DocumentChunk from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
@@ -40,7 +42,8 @@ class KnowledgeRetrievalNode(BaseNode):
return business_result["chunks"] return business_result["chunks"]
return business_result return business_result
def _extract_citations(self, business_result: Any) -> list: @staticmethod
def _extract_citations(business_result: Any) -> list:
if isinstance(business_result, dict): if isinstance(business_result, dict):
return business_result.get("citations", []) return business_result.get("citations", [])
return [] return []
@@ -230,23 +233,23 @@ class KnowledgeRetrievalNode(BaseNode):
} }
) )
) )
case RetrieveType.HYBRID: case RetrieveType.HYBRID | RetrieveType.Graph:
rs1_task = asyncio.to_thread( rs1_task = asyncio.to_thread(
vector_service.search_by_vector, **{ vector_service.search_by_vector, **{
"query": query, "query": query,
"top_k": kb_config.top_k, "top_k": kb_config.top_k,
"indices": indices, "indices": indices,
"score_threshold": kb_config.vector_similarity_weight "score_threshold": kb_config.vector_similarity_weight
} }
) )
rs2_task = asyncio.to_thread( rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{ vector_service.search_by_full_text, **{
"query": query, "query": query,
"top_k": kb_config.top_k, "top_k": kb_config.top_k,
"indices": indices, "indices": indices,
"score_threshold": kb_config.similarity_threshold "score_threshold": kb_config.similarity_threshold
} }
) )
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task) rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results # Deduplicate hybrid retrieval results
@@ -266,6 +269,33 @@ class KnowledgeRetrievalNode(BaseNode):
key=lambda d: d.metadata.get("score", 0), key=lambda d: d.metadata.get("score", 0),
reverse=True reverse=True
)[:kb_config.top_k]) )[:kb_config.top_k])
if kb_config.retrieve_type == RetrieveType.Graph:
from app.core.rag.common.settings import kg_retriever
llm_key = self.model_balance(db_knowledge.llm)
emb_key = self.model_balance(db_knowledge.embedding)
chat_model = Base(
key=llm_key.api_key,
model_name=llm_key.model_name,
base_url=llm_key.api_base
)
embedding_model = OpenAIEmbed(
key=emb_key.api_key,
model_name=emb_key.model_name,
base_url=emb_key.api_base
)
doc = await asyncio.to_thread(
kg_retriever.retrieval,
question=query,
workspace_ids=[str(db_knowledge.workspace_id)],
kb_ids=[str(kb_config.kb_id)],
emb_mdl=embedding_model,
llm=chat_model
)
if doc:
rs.insert(0, DocumentChunk(
page_content=doc.get("page_content", ""),
metadata=doc.get("metadata", {})
))
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
return rs return rs