fix(workflow): support reordering without a rerank model in knowledge base
This commit is contained in:
@@ -45,7 +45,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
)
|
||||
|
||||
reranker_id: UUID = Field(
|
||||
...,
|
||||
default="",
|
||||
description="Reranker top k"
|
||||
)
|
||||
|
||||
|
||||
@@ -203,19 +203,34 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
|
||||
# Deduplicate hy brid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
continue
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
if not rs:
|
||||
return []
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
# TODO:其他重排序方式支持
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:self.typed_config.reranker_top_k]
|
||||
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
|
||||
@@ -386,7 +386,10 @@ class ArrayComparisonOperator(ConditionBase):
|
||||
return self.right_value not in self.left_value
|
||||
|
||||
|
||||
class NoneObjectComparisonOperator(ConditionBase):
|
||||
class NoneObjectComparisonOperator:
|
||||
def __init__(self, *arg, **kwargs):
|
||||
pass
|
||||
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user