From f52b68113348c8dc0d028e431b89f5f33e5ae22b Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Thu, 19 Mar 2026 08:17:58 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90fix]Nested=20query=20of=20folder=20kno?= =?UTF-8?q?wledge=20base?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/core/rag/nlp/search.py | 185 ++++++++++++------- api/app/repositories/knowledge_repository.py | 14 ++ 2 files changed, 133 insertions(+), 66 deletions(-) diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 1f696c98..56f6ba47 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -84,72 +84,16 @@ def knowledge_retrieval( db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id) if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1: # Process shared knowledge base - if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share: - knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, - knowledgeshare_id=db_knowledge.id) - if knowledgeshare: - db_knowledge = knowledge_repository.get_knowledge_by_id(db, - knowledge_id=knowledgeshare.source_kb_id) - if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1): - continue - else: - continue - - if str(db_knowledge.id) not in kb_ids: - kb_ids.append(str(db_knowledge.id)) - if str(db_knowledge.workspace_id) not in workspace_ids: - workspace_ids.append(str(db_knowledge.workspace_id)) - if not chat_model: - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - if not embedding_model: - embedding_model = OpenAIEmbed( - key=db_knowledge.embedding.api_keys[0].api_key, - model_name=db_knowledge.embedding.api_keys[0].model_name, - base_url=db_knowledge.embedding.api_keys[0].api_base - ) - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # Retrieve according to the configured retrieval type - match kb_config["retrieve_type"]: - case "participle": - rs = vector_service.search_by_full_text( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["similarity_threshold"], - file_names_filter=file_names_filter - ) - case "semantic": - rs = vector_service.search_by_vector( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["vector_similarity_weight"], - file_names_filter=file_names_filter - ) - case _: # hybrid - rs1 = vector_service.search_by_vector( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["vector_similarity_weight"], - file_names_filter=file_names_filter - ) - rs2 = vector_service.search_by_full_text( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["similarity_threshold"], - file_names_filter=file_names_filter - ) - - # Deduplication of merge results - seen_ids = set() - unique_rs = [] - for doc in rs1 + rs2: - if doc.metadata["doc_id"] not in seen_ids: - seen_ids.add(doc.metadata["doc_id"]) - unique_rs.append(doc) - rs = unique_rs + rs, chat_model, embedding_model = _retrieve_for_knowledge( + db=db, + db_knowledge=db_knowledge, + kb_config={**kb_config, "query": query}, # 或改为单独参数 + file_names_filter=file_names_filter, + chat_model=chat_model, + embedding_model=embedding_model, + kb_ids=kb_ids, + workspace_ids=workspace_ids, + ) all_results.extend(rs) except Exception as e: @@ -173,6 +117,115 @@ def knowledge_retrieval( finally: db.close() +def _retrieve_for_knowledge( + db: Session, + db_knowledge, + kb_config: Dict[str, Any], + file_names_filter: list[str], + chat_model: Base | None, + embedding_model: OpenAIEmbed | None, + kb_ids: list[str], + workspace_ids: list[str], +) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]: + """ + 对单个知识库进行检索。 + - 处理共享知识库 + - 如果是 Folder,则递归检索其子知识库 + - 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model + """ + results: list[DocumentChunk] = [] + + # 处理共享知识库 + if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share: + knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id) + if not knowledgeshare: + return results, chat_model, embedding_model + + db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id) + if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1): + return results, chat_model, embedding_model + + # Folder 类型:递归处理子知识库 + if db_knowledge.type == knowledge_model.KnowledgeType.Folder: + children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) + for child in children: + if not (child and child.chunk_num > 0 and child.status == 1): + continue + # 递归处理子知识库(子库如果还是 Folder,会继续往下) + child_results, chat_model, embedding_model = _retrieve_for_knowledge( + db=db, + db_knowledge=child, + kb_config=kb_config, + file_names_filter=file_names_filter, + chat_model=chat_model, + embedding_model=embedding_model, + kb_ids=kb_ids, + workspace_ids=workspace_ids, + ) + results.extend(child_results) + return results, chat_model, embedding_model + + # 普通知识库,执行一次检索 + if str(db_knowledge.id) not in kb_ids: + kb_ids.append(str(db_knowledge.id)) + if str(db_knowledge.workspace_id) not in workspace_ids: + workspace_ids.append(str(db_knowledge.workspace_id)) + + if not chat_model: + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base, + ) + if not embedding_model: + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base, + ) + + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + match kb_config["retrieve_type"]: + case "participle": + rs = vector_service.search_by_full_text( + query=kb_config["query"], # 或者直接把 query 作为额外参数传进来 + top_k=kb_config["top_k"], + score_threshold=kb_config["similarity_threshold"], + file_names_filter=file_names_filter, + ) + case "semantic": + rs = vector_service.search_by_vector( + query=kb_config["query"], + top_k=kb_config["top_k"], + score_threshold=kb_config["vector_similarity_weight"], + file_names_filter=file_names_filter, + ) + case _: + rs1 = vector_service.search_by_vector( + query=kb_config["query"], + top_k=kb_config["top_k"], + score_threshold=kb_config["vector_similarity_weight"], + file_names_filter=file_names_filter, + ) + rs2 = vector_service.search_by_full_text( + query=kb_config["query"], + top_k=kb_config["top_k"], + score_threshold=kb_config["similarity_threshold"], + file_names_filter=file_names_filter, + ) + # 合并去重 + seen_ids = set() + unique_rs = [] + for doc in rs1 + rs2: + if doc.metadata["doc_id"] not in seen_ids: + seen_ids.add(doc.metadata["doc_id"]) + unique_rs.append(doc) + rs = unique_rs + + results.extend(rs) + return results, chat_model, embedding_model + def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: """ diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 681d1c10..bc9681e1 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -111,6 +111,20 @@ def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | Non raise +def get_knowledges_by_parent_id(db: Session, parent_id: uuid.UUID) -> list[Knowledge]: + db_logger.debug(f"Query knowledge bases based on parent ID: parent_id={parent_id}") + try: + knowledges = db.query(Knowledge).filter(Knowledge.parent_id == parent_id).all() + if knowledges: + db_logger.debug(f"Knowledge bases query successful: count={len(knowledges)} (parent_id: {parent_id})") + else: + db_logger.debug(f"No knowledge bases found for given parent: parent_id={parent_id}") + return knowledges + except Exception as e: + db_logger.error(f"Failed to query the knowledge bases based on parent ID: parent_id={parent_id} - {str(e)}") + raise + + def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None: db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")