【fix]Nested query of folder knowledge base
This commit is contained in:
@@ -84,72 +84,16 @@ def knowledge_retrieval(
|
|||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||||
# Process shared knowledge base
|
# Process shared knowledge base
|
||||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
rs, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
db=db,
|
||||||
knowledgeshare_id=db_knowledge.id)
|
db_knowledge=db_knowledge,
|
||||||
if knowledgeshare:
|
kb_config={**kb_config, "query": query}, # 或改为单独参数
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
file_names_filter=file_names_filter,
|
||||||
knowledge_id=knowledgeshare.source_kb_id)
|
chat_model=chat_model,
|
||||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
embedding_model=embedding_model,
|
||||||
continue
|
kb_ids=kb_ids,
|
||||||
else:
|
workspace_ids=workspace_ids,
|
||||||
continue
|
)
|
||||||
|
|
||||||
if str(db_knowledge.id) not in kb_ids:
|
|
||||||
kb_ids.append(str(db_knowledge.id))
|
|
||||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
|
||||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
|
||||||
if not chat_model:
|
|
||||||
chat_model = Base(
|
|
||||||
key=db_knowledge.llm.api_keys[0].api_key,
|
|
||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
|
||||||
)
|
|
||||||
if not embedding_model:
|
|
||||||
embedding_model = OpenAIEmbed(
|
|
||||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
|
||||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
|
||||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
|
||||||
)
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
# Retrieve according to the configured retrieval type
|
|
||||||
match kb_config["retrieve_type"]:
|
|
||||||
case "participle":
|
|
||||||
rs = vector_service.search_by_full_text(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["similarity_threshold"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
case "semantic":
|
|
||||||
rs = vector_service.search_by_vector(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["vector_similarity_weight"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
case _: # hybrid
|
|
||||||
rs1 = vector_service.search_by_vector(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["vector_similarity_weight"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
rs2 = vector_service.search_by_full_text(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["similarity_threshold"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deduplication of merge results
|
|
||||||
seen_ids = set()
|
|
||||||
unique_rs = []
|
|
||||||
for doc in rs1 + rs2:
|
|
||||||
if doc.metadata["doc_id"] not in seen_ids:
|
|
||||||
seen_ids.add(doc.metadata["doc_id"])
|
|
||||||
unique_rs.append(doc)
|
|
||||||
rs = unique_rs
|
|
||||||
|
|
||||||
all_results.extend(rs)
|
all_results.extend(rs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -173,6 +117,115 @@ def knowledge_retrieval(
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
def _retrieve_for_knowledge(
|
||||||
|
db: Session,
|
||||||
|
db_knowledge,
|
||||||
|
kb_config: Dict[str, Any],
|
||||||
|
file_names_filter: list[str],
|
||||||
|
chat_model: Base | None,
|
||||||
|
embedding_model: OpenAIEmbed | None,
|
||||||
|
kb_ids: list[str],
|
||||||
|
workspace_ids: list[str],
|
||||||
|
) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]:
|
||||||
|
"""
|
||||||
|
对单个知识库进行检索。
|
||||||
|
- 处理共享知识库
|
||||||
|
- 如果是 Folder,则递归检索其子知识库
|
||||||
|
- 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model
|
||||||
|
"""
|
||||||
|
results: list[DocumentChunk] = []
|
||||||
|
|
||||||
|
# 处理共享知识库
|
||||||
|
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||||
|
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id)
|
||||||
|
if not knowledgeshare:
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id)
|
||||||
|
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
# Folder 类型:递归处理子知识库
|
||||||
|
if db_knowledge.type == knowledge_model.KnowledgeType.Folder:
|
||||||
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
for child in children:
|
||||||
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
|
continue
|
||||||
|
# 递归处理子知识库(子库如果还是 Folder,会继续往下)
|
||||||
|
child_results, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||||
|
db=db,
|
||||||
|
db_knowledge=child,
|
||||||
|
kb_config=kb_config,
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
chat_model=chat_model,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
workspace_ids=workspace_ids,
|
||||||
|
)
|
||||||
|
results.extend(child_results)
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
# 普通知识库,执行一次检索
|
||||||
|
if str(db_knowledge.id) not in kb_ids:
|
||||||
|
kb_ids.append(str(db_knowledge.id))
|
||||||
|
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||||
|
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||||
|
|
||||||
|
if not chat_model:
|
||||||
|
chat_model = Base(
|
||||||
|
key=db_knowledge.llm.api_keys[0].api_key,
|
||||||
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
|
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||||
|
)
|
||||||
|
if not embedding_model:
|
||||||
|
embedding_model = OpenAIEmbed(
|
||||||
|
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||||
|
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||||
|
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
match kb_config["retrieve_type"]:
|
||||||
|
case "participle":
|
||||||
|
rs = vector_service.search_by_full_text(
|
||||||
|
query=kb_config["query"], # 或者直接把 query 作为额外参数传进来
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["similarity_threshold"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
case "semantic":
|
||||||
|
rs = vector_service.search_by_vector(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["vector_similarity_weight"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
rs1 = vector_service.search_by_vector(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["vector_similarity_weight"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
rs2 = vector_service.search_by_full_text(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["similarity_threshold"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
# 合并去重
|
||||||
|
seen_ids = set()
|
||||||
|
unique_rs = []
|
||||||
|
for doc in rs1 + rs2:
|
||||||
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
|
unique_rs.append(doc)
|
||||||
|
rs = unique_rs
|
||||||
|
|
||||||
|
results.extend(rs)
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
|
||||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -111,6 +111,20 @@ def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | Non
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_knowledges_by_parent_id(db: Session, parent_id: uuid.UUID) -> list[Knowledge]:
|
||||||
|
db_logger.debug(f"Query knowledge bases based on parent ID: parent_id={parent_id}")
|
||||||
|
try:
|
||||||
|
knowledges = db.query(Knowledge).filter(Knowledge.parent_id == parent_id).all()
|
||||||
|
if knowledges:
|
||||||
|
db_logger.debug(f"Knowledge bases query successful: count={len(knowledges)} (parent_id: {parent_id})")
|
||||||
|
else:
|
||||||
|
db_logger.debug(f"No knowledge bases found for given parent: parent_id={parent_id}")
|
||||||
|
return knowledges
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query the knowledge bases based on parent ID: parent_id={parent_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
|
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
|
||||||
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user