Merge branch 'feature/knowledge_lxc' into develop
This commit is contained in:
@@ -94,72 +94,16 @@ def knowledge_retrieval(
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||
# Process shared knowledge base
|
||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
||||
knowledgeshare_id=db_knowledge.id)
|
||||
if knowledgeshare:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
||||
knowledge_id=knowledgeshare.source_kb_id)
|
||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
if str(db_knowledge.id) not in kb_ids:
|
||||
kb_ids.append(str(db_knowledge.id))
|
||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
if not chat_model:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
if not embedding_model:
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
case "participle":
|
||||
rs = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case "semantic":
|
||||
rs = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case _: # hybrid
|
||||
rs1 = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
rs2 = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
|
||||
# Deduplication of merge results
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
rs, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||
db=db,
|
||||
db_knowledge=db_knowledge,
|
||||
kb_config={**kb_config, "query": query}, # 或改为单独参数
|
||||
file_names_filter=file_names_filter,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
kb_ids=kb_ids,
|
||||
workspace_ids=workspace_ids,
|
||||
)
|
||||
|
||||
all_results.extend(rs)
|
||||
except Exception as e:
|
||||
@@ -199,6 +143,115 @@ def knowledge_retrieval(
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _retrieve_for_knowledge(
|
||||
db: Session,
|
||||
db_knowledge,
|
||||
kb_config: Dict[str, Any],
|
||||
file_names_filter: list[str],
|
||||
chat_model: Base | None,
|
||||
embedding_model: OpenAIEmbed | None,
|
||||
kb_ids: list[str],
|
||||
workspace_ids: list[str],
|
||||
) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]:
|
||||
"""
|
||||
对单个知识库进行检索。
|
||||
- 处理共享知识库
|
||||
- 如果是 Folder,则递归检索其子知识库
|
||||
- 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model
|
||||
"""
|
||||
results: list[DocumentChunk] = []
|
||||
|
||||
# 处理共享知识库
|
||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id)
|
||||
if not knowledgeshare:
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id)
|
||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
# Folder 类型:递归处理子知识库
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.Folder:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
# 递归处理子知识库(子库如果还是 Folder,会继续往下)
|
||||
child_results, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||
db=db,
|
||||
db_knowledge=child,
|
||||
kb_config=kb_config,
|
||||
file_names_filter=file_names_filter,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
kb_ids=kb_ids,
|
||||
workspace_ids=workspace_ids,
|
||||
)
|
||||
results.extend(child_results)
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
# 普通知识库,执行一次检索
|
||||
if str(db_knowledge.id) not in kb_ids:
|
||||
kb_ids.append(str(db_knowledge.id))
|
||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
|
||||
if not chat_model:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
)
|
||||
if not embedding_model:
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
match kb_config["retrieve_type"]:
|
||||
case "participle":
|
||||
rs = vector_service.search_by_full_text(
|
||||
query=kb_config["query"], # 或者直接把 query 作为额外参数传进来
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
case "semantic":
|
||||
rs = vector_service.search_by_vector(
|
||||
query=kb_config["query"],
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(
|
||||
query=kb_config["query"],
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
rs2 = vector_service.search_by_full_text(
|
||||
query=kb_config["query"],
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
# 合并去重
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
|
||||
results.extend(rs)
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
|
||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
|
||||
@@ -111,6 +111,20 @@ def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | Non
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledges_by_parent_id(db: Session, parent_id: uuid.UUID) -> list[Knowledge]:
|
||||
db_logger.debug(f"Query knowledge bases based on parent ID: parent_id={parent_id}")
|
||||
try:
|
||||
knowledges = db.query(Knowledge).filter(Knowledge.parent_id == parent_id).all()
|
||||
if knowledges:
|
||||
db_logger.debug(f"Knowledge bases query successful: count={len(knowledges)} (parent_id: {parent_id})")
|
||||
else:
|
||||
db_logger.debug(f"No knowledge bases found for given parent: parent_id={parent_id}")
|
||||
return knowledges
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the knowledge bases based on parent ID: parent_id={parent_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
|
||||
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user