diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 988aa706..b5c0a5ae 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -23,6 +23,7 @@ from app.models.user_model import User from app.schemas import chunk_schema from app.schemas.response_schema import ApiResponse from app.services import knowledge_service, document_service, file_service, knowledgeshare_service +from app.services.model_service import ModelApiKeyService # Obtain a dedicated API logger api_logger = get_api_logger() @@ -460,18 +461,20 @@ async def retrieve_chunks( if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph: kb_ids = [str(kb_id) for kb_id in private_kb_ids] workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids] + llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id) + emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id) # Prepare to configure chat_mdl、embedding_model、vision_model information chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base + key=llm_key.api_key, + model_name=llm_key.model_name, + base_url=llm_key.api_base ) embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base + key=emb_key.api_key, + model_name=emb_key.model_name, + base_url=emb_key.api_base ) - doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model) + doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) if doc: rs.insert(0, doc) return success(data=jsonable_encoder(rs), msg="retrieval successful") \ No newline at end of file diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index db93bc48..1a84b8a7 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -28,6 +28,7 @@ from app.core.rag.common.float_utils import get_float from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.llm.chat_model import Base from app.core.rag.llm.embedding_model import OpenAIEmbed +from app.services.model_service import ModelApiKeyService import logging logger = logging.getLogger(__name__) @@ -198,16 +199,18 @@ def _retrieve_for_knowledge( workspace_ids.append(str(db_knowledge.workspace_id)) if not chat_model: + llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id) chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base, + key=llm_key.api_key, + model_name=llm_key.model_name, + base_url=llm_key.api_base, ) if not embedding_model: + emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id) embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base, + key=emb_key.api_key, + model_name=emb_key.model_name, + base_url=emb_key.api_base, ) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) @@ -248,6 +251,20 @@ def _retrieve_for_knowledge( seen_ids.add(doc.metadata["doc_id"]) unique_rs.append(doc) rs = unique_rs + if kb_config["retrieve_type"] == "graph": + try: + from app.core.rag.common.settings import kg_retriever + graph_doc = kg_retriever.retrieval( + question=kb_config["query"], + workspace_ids=[str(db_knowledge.workspace_id)], + kb_ids=[str(db_knowledge.id)], + emb_mdl=embedding_model, + llm=chat_model, + ) + if graph_doc: + rs.insert(0, graph_doc) + except Exception as graph_error: + logger.warning(f"Graph retrieval failed for kb {db_knowledge.id}: {graph_error}") results.extend(rs) return results, chat_model, embedding_model diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 97fa86cb..29e46902 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -8,6 +8,8 @@ from langchain_core.documents import Document from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig +from app.core.rag.llm.chat_model import Base +from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.models.chunk import DocumentChunk from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.engine.state_manager import WorkflowState @@ -39,8 +41,9 @@ class KnowledgeRetrievalNode(BaseNode): if isinstance(business_result, dict) and "chunks" in business_result: return business_result["chunks"] return business_result - - def _extract_citations(self, business_result: Any) -> list: + + @staticmethod + def _extract_citations(business_result: Any) -> list: if isinstance(business_result, dict): return business_result.get("citations", []) return [] @@ -230,23 +233,23 @@ class KnowledgeRetrievalNode(BaseNode): } ) ) - case RetrieveType.HYBRID: + case RetrieveType.HYBRID | RetrieveType.Graph: rs1_task = asyncio.to_thread( - vector_service.search_by_vector, **{ - "query": query, - "top_k": kb_config.top_k, - "indices": indices, - "score_threshold": kb_config.vector_similarity_weight - } - ) + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) rs2_task = asyncio.to_thread( - vector_service.search_by_full_text, **{ - "query": query, - "top_k": kb_config.top_k, - "indices": indices, - "score_threshold": kb_config.similarity_threshold - } - ) + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) rs1, rs2 = await asyncio.gather(rs1_task, rs2_task) # Deduplicate hybrid retrieval results @@ -266,6 +269,33 @@ class KnowledgeRetrievalNode(BaseNode): key=lambda d: d.metadata.get("score", 0), reverse=True )[:kb_config.top_k]) + if kb_config.retrieve_type == RetrieveType.Graph: + from app.core.rag.common.settings import kg_retriever + llm_key = self.model_balance(db_knowledge.llm) + emb_key = self.model_balance(db_knowledge.embedding) + chat_model = Base( + key=llm_key.api_key, + model_name=llm_key.model_name, + base_url=llm_key.api_base + ) + embedding_model = OpenAIEmbed( + key=emb_key.api_key, + model_name=emb_key.model_name, + base_url=emb_key.api_base + ) + doc = await asyncio.to_thread( + kg_retriever.retrieval, + question=query, + workspace_ids=[str(db_knowledge.workspace_id)], + kb_ids=[str(kb_config.kb_id)], + emb_mdl=embedding_model, + llm=chat_model + ) + if doc: + rs.insert(0, DocumentChunk( + page_content=doc.get("page_content", ""), + metadata=doc.get("metadata", {}) + )) case _: raise RuntimeError("Unknown retrieval type") return rs