Merge branch 'feature/knowledge_lxc' into develop

This commit is contained in:
lixiangcheng1
2026-03-19 08:19:24 +08:00
2 changed files with 133 additions and 66 deletions

View File

@@ -94,72 +94,16 @@ def knowledge_retrieval(
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
# Process shared knowledge base
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
knowledgeshare_id=db_knowledge.id)
if knowledgeshare:
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
knowledge_id=knowledgeshare.source_kb_id)
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
continue
else:
continue
if str(db_knowledge.id) not in kb_ids:
kb_ids.append(str(db_knowledge.id))
if str(db_knowledge.workspace_id) not in workspace_ids:
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
if not embedding_model:
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Retrieve according to the configured retrieval type
match kb_config["retrieve_type"]:
case "participle":
rs = vector_service.search_by_full_text(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter
)
case "semantic":
rs = vector_service.search_by_vector(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter
)
case _: # hybrid
rs1 = vector_service.search_by_vector(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter
)
rs2 = vector_service.search_by_full_text(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter
)
# Deduplication of merge results
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = unique_rs
rs, chat_model, embedding_model = _retrieve_for_knowledge(
db=db,
db_knowledge=db_knowledge,
kb_config={**kb_config, "query": query}, # 或改为单独参数
file_names_filter=file_names_filter,
chat_model=chat_model,
embedding_model=embedding_model,
kb_ids=kb_ids,
workspace_ids=workspace_ids,
)
all_results.extend(rs)
except Exception as e:
@@ -199,6 +143,115 @@ def knowledge_retrieval(
finally:
db.close()
def _retrieve_for_knowledge(
db: Session,
db_knowledge,
kb_config: Dict[str, Any],
file_names_filter: list[str],
chat_model: Base | None,
embedding_model: OpenAIEmbed | None,
kb_ids: list[str],
workspace_ids: list[str],
) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]:
"""
对单个知识库进行检索。
- 处理共享知识库
- 如果是 Folder则递归检索其子知识库
- 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model
"""
results: list[DocumentChunk] = []
# 处理共享知识库
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id)
if not knowledgeshare:
return results, chat_model, embedding_model
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id)
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
return results, chat_model, embedding_model
# Folder 类型:递归处理子知识库
if db_knowledge.type == knowledge_model.KnowledgeType.Folder:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
# 递归处理子知识库(子库如果还是 Folder会继续往下
child_results, chat_model, embedding_model = _retrieve_for_knowledge(
db=db,
db_knowledge=child,
kb_config=kb_config,
file_names_filter=file_names_filter,
chat_model=chat_model,
embedding_model=embedding_model,
kb_ids=kb_ids,
workspace_ids=workspace_ids,
)
results.extend(child_results)
return results, chat_model, embedding_model
# 普通知识库,执行一次检索
if str(db_knowledge.id) not in kb_ids:
kb_ids.append(str(db_knowledge.id))
if str(db_knowledge.workspace_id) not in workspace_ids:
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base,
)
if not embedding_model:
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base,
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match kb_config["retrieve_type"]:
case "participle":
rs = vector_service.search_by_full_text(
query=kb_config["query"], # 或者直接把 query 作为额外参数传进来
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter,
)
case "semantic":
rs = vector_service.search_by_vector(
query=kb_config["query"],
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter,
)
case _:
rs1 = vector_service.search_by_vector(
query=kb_config["query"],
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter,
)
rs2 = vector_service.search_by_full_text(
query=kb_config["query"],
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter,
)
# 合并去重
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = unique_rs
results.extend(rs)
return results, chat_model, embedding_model
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
"""

View File

@@ -111,6 +111,20 @@ def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | Non
raise
def get_knowledges_by_parent_id(db: Session, parent_id: uuid.UUID) -> list[Knowledge]:
db_logger.debug(f"Query knowledge bases based on parent ID: parent_id={parent_id}")
try:
knowledges = db.query(Knowledge).filter(Knowledge.parent_id == parent_id).all()
if knowledges:
db_logger.debug(f"Knowledge bases query successful: count={len(knowledges)} (parent_id: {parent_id})")
else:
db_logger.debug(f"No knowledge bases found for given parent: parent_id={parent_id}")
return knowledges
except Exception as e:
db_logger.error(f"Failed to query the knowledge bases based on parent ID: parent_id={parent_id} - {str(e)}")
raise
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")