feat(workflow): implement a workflow node for knowledge base retrieval
This commit is contained in:
@@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector):
|
||||
self._client.indices.create(index=self._collection_name, body=index_mapping)
|
||||
|
||||
|
||||
class ElasticSearchVectorFactory(ABC):
|
||||
def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector:
|
||||
class ElasticSearchVectorFactory:
|
||||
@staticmethod
|
||||
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
|
||||
collection_name = f"Vector_index_{knowledge.id}_Node"
|
||||
|
||||
# Use regular Elasticsearch with config values
|
||||
@@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC):
|
||||
}
|
||||
)
|
||||
|
||||
if knowledge.embedding and knowledge.reranker:
|
||||
if knowledge.embedding is None:
|
||||
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
||||
if knowledge.reranker is None:
|
||||
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
|
||||
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(**config_dict),
|
||||
embedding_config=knowledge.embedding.api_keys[0],
|
||||
reranker_config=knowledge.reranker.api_keys[0]
|
||||
)
|
||||
else:
|
||||
if knowledge.embedding is None:
|
||||
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
||||
if knowledge.reranker is None:
|
||||
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
@@ -26,6 +26,6 @@ __all__ = [
|
||||
"EndNode",
|
||||
"NodeFactory",
|
||||
"WorkflowNode",
|
||||
# "KnowledgeRetrievalNode",
|
||||
"KnowledgeRetrievalNode",
|
||||
"AssignerNode",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
|
||||
__all__ = [
|
||||
@@ -30,6 +30,6 @@ __all__ = [
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
"IfElseNodeConfig",
|
||||
# "KnowledgeRetrievalNodeConfig",
|
||||
"KnowledgeRetrievalNodeConfig",
|
||||
"AssignerNodeConfig",
|
||||
]
|
||||
|
||||
4
api/app/core/workflow/nodes/knowledge/__init__.py
Normal file
4
api/app/core/workflow/nodes/knowledge/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode
|
||||
|
||||
__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"]
|
||||
38
api/app/core/workflow/nodes/knowledge/config.py
Normal file
38
api/app/core/workflow/nodes/knowledge/config.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query string"
|
||||
)
|
||||
|
||||
kb_ids: list[UUID] = Field(
|
||||
...,
|
||||
description="Knowledge base IDs"
|
||||
)
|
||||
|
||||
similarity_threshold: float = Field(
|
||||
default=0.2,
|
||||
description="Knowledge base similarity threshold"
|
||||
)
|
||||
|
||||
vector_similarity_weight: float = Field(
|
||||
default=0.3,
|
||||
description="Knowledge base vector similarity weight"
|
||||
)
|
||||
|
||||
top_k: int = Field(
|
||||
default=4,
|
||||
description="Knowledge base top k"
|
||||
)
|
||||
|
||||
retrieve_type: RetrieveType = Field(
|
||||
default=RetrieveType.PARTICIPLE,
|
||||
description="Retrieve type"
|
||||
)
|
||||
97
api/app/core/workflow/nodes/knowledge/node.py
Normal file
97
api/app/core/workflow/nodes/knowledge/node.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.db import get_db
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services import knowledge_service, knowledgeshare_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
if share_ids:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
|
||||
]
|
||||
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
|
||||
if not existing_ids:
|
||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
match self.typed_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.similarity_threshold)
|
||||
return [chunk.model_dump() for chunk in rs]
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.vector_similarity_weight)
|
||||
return [chunk.model_dump() for chunk in rs]
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.similarity_threshold)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
|
||||
return [chunk.model_dump() for chunk in rs]
|
||||
finally:
|
||||
next(db_gen)
|
||||
@@ -7,7 +7,7 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
@@ -29,7 +29,7 @@ WorkflowNode = Union[
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
AssignerNode,
|
||||
# KnowledgeRetrievalNode,
|
||||
KnowledgeRetrievalNode,
|
||||
]
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class NodeFactory:
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
}
|
||||
|
||||
|
||||
@@ -116,7 +116,6 @@ class StartNode(BaseNode):
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def get_knowledges_paginated(
|
||||
raise
|
||||
|
||||
|
||||
def get_chunded_knowledgeids(
|
||||
def get_chunked_knowledgeids(
|
||||
db: Session,
|
||||
filters: list
|
||||
) -> list:
|
||||
|
||||
@@ -45,7 +45,7 @@ def get_chunded_knowledgeids(
|
||||
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
|
||||
|
||||
try:
|
||||
items = knowledge_repository.get_chunded_knowledgeids(
|
||||
items = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user