fix(workflow): support reordering without a rerank model in knowledge base
This commit is contained in:
@@ -45,7 +45,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
reranker_id: UUID = Field(
|
reranker_id: UUID = Field(
|
||||||
...,
|
default="",
|
||||||
description="Reranker top k"
|
description="Reranker top k"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -203,19 +203,34 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=kb_config.similarity_threshold)
|
score_threshold=kb_config.similarity_threshold)
|
||||||
|
|
||||||
# Deduplicate hy brid retrieval results
|
# Deduplicate hy brid retrieval results
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
if not unique_rs:
|
if not unique_rs:
|
||||||
continue
|
continue
|
||||||
vector_service.reranker = self.get_reranker_model()
|
if self.typed_config.reranker_id:
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
vector_service.reranker = self.get_reranker_model()
|
||||||
|
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
|
else:
|
||||||
|
rs.extend(sorted(
|
||||||
|
unique_rs,
|
||||||
|
key=lambda d: d.metadata.get("score", 0),
|
||||||
|
reverse=True
|
||||||
|
)[:kb_config.top_k])
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unknown retrieval type")
|
raise RuntimeError("Unknown retrieval type")
|
||||||
if not rs:
|
if not rs:
|
||||||
return []
|
return []
|
||||||
vector_service.reranker = self.get_reranker_model()
|
if self.typed_config.reranker_id:
|
||||||
# TODO:其他重排序方式支持
|
vector_service.reranker = self.get_reranker_model()
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
|
else:
|
||||||
|
final_rs = sorted(
|
||||||
|
rs,
|
||||||
|
key=lambda d: d.metadata.get("score", 0),
|
||||||
|
reverse=True
|
||||||
|
)[:self.typed_config.reranker_top_k]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -386,7 +386,10 @@ class ArrayComparisonOperator(ConditionBase):
|
|||||||
return self.right_value not in self.left_value
|
return self.right_value not in self.left_value
|
||||||
|
|
||||||
|
|
||||||
class NoneObjectComparisonOperator(ConditionBase):
|
class NoneObjectComparisonOperator:
|
||||||
|
def __init__(self, *arg, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return lambda *args, **kwargs: False
|
return lambda *args, **kwargs: False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user