diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py index cdb83131..9d307216 100644 --- a/api/app/core/workflow/nodes/knowledge/config.py +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -45,7 +45,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig): ) reranker_id: UUID = Field( - ..., + default="", description="Reranker top k" ) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 5a6b2a7f..061328e1 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -203,19 +203,34 @@ class KnowledgeRetrievalNode(BaseNode): rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, indices=indices, score_threshold=kb_config.similarity_threshold) + # Deduplicate hy brid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) if not unique_rs: continue - vector_service.reranker = self.get_reranker_model() - rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + if self.typed_config.reranker_id: + vector_service.reranker = self.get_reranker_model() + rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + else: + rs.extend(sorted( + unique_rs, + key=lambda d: d.metadata.get("score", 0), + reverse=True + )[:kb_config.top_k]) case _: raise RuntimeError("Unknown retrieval type") if not rs: return [] - vector_service.reranker = self.get_reranker_model() - # TODO:其他重排序方式支持 - final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + if self.typed_config.reranker_id: + vector_service.reranker = self.get_reranker_model() + final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + else: + final_rs = sorted( + rs, + key=lambda d: d.metadata.get("score", 0), + reverse=True + )[:self.typed_config.reranker_top_k] + logger.info( f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" ) diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index 25caec07..ad38284a 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -386,7 +386,10 @@ class ArrayComparisonOperator(ConditionBase): return self.right_value not in self.left_value -class NoneObjectComparisonOperator(ConditionBase): +class NoneObjectComparisonOperator: + def __init__(self, *arg, **kwargs): + pass + def __getattr__(self, name): return lambda *args, **kwargs: False