Merge #51 into develop from feature/20251219_myh

feat(workflow): add HTTP request node

* feature/20251219_myh: (3 commits)
  feat(workflow): add HTTP request node
  refactor(workflow): organize knowledge base code structure and add comments
  fix(workflow): correct property usage in HTTP request node

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/51
This commit is contained in:
朱文辉
2025-12-25 13:52:35 +08:00
9 changed files with 660 additions and 85 deletions

View File

@@ -16,6 +16,7 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry # from app.core.tools.registry import ToolRegistry
# from app.core.tools.executor import ToolExecutor # from app.core.tools.executor import ToolExecutor
# from app.core.tools.langchain_adapter import LangchainAdapter # from app.core.tools.langchain_adapter import LangchainAdapter
@@ -78,6 +79,7 @@ class WorkflowExecutor:
var_name = var_def.get("name") var_name = var_def.get("name")
var_default = var_def.get("default") var_default = var_def.get("default")
if var_name: if var_name:
# TODO: 入参类型校验
conversation_vars[var_name] = var_default conversation_vars[var_name] = var_default
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
@@ -85,12 +87,12 @@ class WorkflowExecutor:
# 构建分层的变量结构 # 构建分层的变量结构
variables = { variables = {
"sys": { "sys": {
"message": user_message, # 用户消息 "message": user_message, # 用户消息
"conversation_id": input_data.get("conversation_id"), # 会话 ID "conversation_id": input_data.get("conversation_id"), # 会话 ID
"execution_id": self.execution_id, # 执行 ID "execution_id": self.execution_id, # 执行 ID
"workspace_id": self.workspace_id, # 工作空间 ID "workspace_id": self.workspace_id, # 工作空间 ID
"user_id": self.user_id, # 用户 ID "user_id": self.user_id, # 用户 ID
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用) "input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
}, },
"conv": conversation_vars # 会话级变量(跨多轮对话保持) "conv": conversation_vars # 会话级变量(跨多轮对话保持)
} }
@@ -108,8 +110,6 @@ class WorkflowExecutor:
"streaming_buffer": {} # 流式缓冲区 "streaming_buffer": {} # 流式缓冲区
} }
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置 """分析 End 节点的前缀配置
@@ -178,7 +178,7 @@ class WorkflowExecutor:
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced return prefixes, adjacent_and_referenced
def build_graph(self,stream=False) -> CompiledStateGraph: def build_graph(self, stream=False) -> CompiledStateGraph:
"""构建 LangGraph """构建 LangGraph
Returns: Returns:
@@ -209,7 +209,7 @@ class WorkflowExecutor:
# 创建节点实例(现在 start 和 end 也会被创建) # 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE]: if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions() expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions # Number of branches, usually matches the number of conditional expressions
@@ -249,14 +249,18 @@ class WorkflowExecutor:
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state): async for item in inst.run_stream(state):
yield item yield item
return node_func return node_func
workflow.add_node(node_id, make_stream_func(node_instance)) workflow.add_node(node_id, make_stream_func(node_instance))
else: else:
# 非流式模式:创建 async function # 非流式模式:创建 async function
def make_func(inst): def make_func(inst):
async def node_func(state: WorkflowState): async def node_func(state: WorkflowState):
return await inst.run(state) return await inst.run(state)
return node_func return node_func
workflow.add_node(node_id, make_func(node_instance)) workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})") logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
@@ -532,7 +536,10 @@ class WorkflowExecutor:
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds() elapsed_time = (end_time - start_time).total_seconds()
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s") logger.info(
f"Workflow execution completed (streaming), "
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
)
# 发送 workflow_end 事件 # 发送 workflow_end 事件
yield { yield {
@@ -678,7 +685,6 @@ async def execute_workflow_stream(
async for event in executor.execute_stream(input_data): async for event in executor.execute_stream(input_data):
yield event yield event
# ==================== 工具管理系统集成 ==================== # ==================== 工具管理系统集成 ====================
# def get_workflow_tools(workspace_id: str, user_id: str) -> list: # def get_workflow_tools(workspace_id: str, user_id: str) -> list:

View File

@@ -11,6 +11,7 @@ from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.http_request import HttpRequestNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
@@ -28,4 +29,5 @@ __all__ = [
"WorkflowNode", "WorkflowNode",
"KnowledgeRetrievalNode", "KnowledgeRetrievalNode",
"AssignerNode", "AssignerNode",
"HttpRequestNode"
] ]

View File

@@ -15,6 +15,7 @@ from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [ __all__ = [
@@ -32,4 +33,5 @@ __all__ = [
"IfElseNodeConfig", "IfElseNodeConfig",
"KnowledgeRetrievalNodeConfig", "KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig", "AssignerNodeConfig",
"HttpRequestNodeConfig"
] ]

View File

@@ -73,3 +73,34 @@ class AssignmentOperator(StrEnum):
return ObjectOperator return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})") raise TypeError(f"Unsupported variable type ({type(obj)})")
class HttpRequestMethod(StrEnum):
GET = "GET"
POST = "POST"
HEAD = "HEAD"
PUT = "PUT"
PATCH = "PATCH"
DELETE = "DELETE"
class HttpAuthType(StrEnum):
NONE = "none"
BASIC = "basic"
BEARER = "bearer"
CUSTOM = "custom"
class HttpContentType(StrEnum):
NONE = "none"
FROM_DATA = "form-data"
WWW_FORM = "x-www-form-urlencoded"
JSON = "json"
RAW = "raw"
BINARY = "binary"
class HttpErrorHandle(StrEnum):
NONE = "none"
DEFAULT = "default"
BRANCH = "branch"

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig
from app.core.workflow.nodes.http_request.node import HttpRequestNode
__all__ = ['HttpRequestNode', 'HttpRequestNodeConfig']

View File

@@ -0,0 +1,215 @@
from typing import Literal
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
class HttpAuthConfig(BaseModel):
auth_type: HttpAuthType = Field(
default=HttpAuthType.NONE,
description="Type of HTTP authentication to use",
)
header: str = Field(
default="",
description="Custom HTTP Authorization header (used if auth_type is CUSTOM)",
)
api_key: str = Field(
default="",
description="API key for authentication (used if auth_type is not NONE)",
)
@field_validator("header")
@classmethod
def validate_header(cls, v, info):
auth_type = info.data.get("auth_type")
if auth_type == HttpAuthType.CUSTOM and not v:
raise ValueError("Custom auth header not specified")
return v
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v, info):
auth_type = info.data.get("auth_type")
if auth_type != HttpAuthType.NONE and not v:
raise ValueError("API key for authentication not specified")
return v
class HttpFormData(BaseModel):
key: str = Field(
...,
description="Form-data field name",
)
type: Literal["text", "file"] = Field(
...,
description="Form-data type: 'text' or 'file'"
)
value: str = Field(
...,
description="Form-data field value",
)
class HttpContentTypeConfig(BaseModel):
content_type: HttpContentType = Field(
...,
description="HTTP content type of the request body",
)
data: list[HttpFormData] | dict | str = Field(
...,
description="Data of the HTTP request body; type depends on content_type",
)
@field_validator("data")
@classmethod
def validate_data(cls, v, info):
content_type = info.data.get("content_type")
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict):
raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object")
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
return v
class HttpTimeOutConfig(BaseModel):
connect_timeout: int = Field(
default=5,
description="Connection timeout in seconds",
)
read_timeout: int = Field(
default=5,
description="Read timeout in seconds",
)
write_timeout: int = Field(
default=5,
description="Write timeout in seconds",
)
class HttpRetryConfig(BaseModel):
max_attempts: int = Field(
default=1,
description="Maximum number of retry attempts for failed requests",
)
retry_interval: int = Field(
default=100,
description="Interval between retries in milliseconds",
)
class HttpErrorDefaultTamplete(BaseModel):
body: str = Field(
default="",
description="Default body returned on HTTP error",
)
status_code: int = Field(
default=400,
description="Default HTTP status code returned on error",
)
headers: dict = Field(
default_factory=dict,
description="Default HTTP headers returned on error",
)
class HttpErrorHandleConfig(BaseModel):
method: HttpErrorHandle = Field(
default=HttpErrorHandle.NONE,
description="Error handling strategy: 'none', 'default', or 'branch'",
)
default: HttpErrorDefaultTamplete = Field(
...,
description="Default response template for error handling",
)
class HttpRequestNodeConfig(BaseNodeConfig):
method: HttpRequestMethod = Field(
...,
description="HTTP method for the request (GET, POST, etc.)",
)
url: str = Field(
...,
description="URL of the HTTP request",
)
auth: HttpAuthConfig = Field(
...,
description="HTTP authentication configuration",
)
headers: dict = Field(
default_factory=dict,
description="HTTP request headers",
)
params: dict = Field(
default_factory=dict,
description="Query parameters for the HTTP request",
)
body: HttpContentTypeConfig = Field(
...,
description="HTTP request body configuration",
)
verify_ssl: bool = Field(
...,
description="Whether to verify SSL certificates",
)
timeouts: HttpTimeOutConfig = Field(
...,
description="Timeout settings for the request",
)
retry: HttpRetryConfig = Field(
...,
description="Retry configuration for failed requests",
)
error_handle: HttpErrorHandleConfig = Field(
...,
description="Configuration for handling HTTP request errors",
)
class HttpRequestNodeOutput(BaseModel):
body: str = Field(
...,
description="Body of the HTTP response",
)
status_code: int = Field(
...,
description="HTTP response status code",
)
headers: dict = Field(
...,
description="Http response headers"
)
output: str = Field(
default="SUCCESS",
description="HTTP response body",
)
# files: list[File] = Field(
# ...
# )

View File

@@ -0,0 +1,238 @@
import asyncio
import json
import logging
from typing import Any, Callable, Coroutine
import httpx
# import filetypes # TODO: File support (Feature)
from httpx import AsyncClient, Response, Timeout
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
logger = logging.getLogger(__file__)
class HttpRequestNode(BaseNode):
"""
HTTP Request Workflow Node.
This node executes an HTTP request as part of a workflow execution.
It supports:
- Multiple HTTP methods (GET, POST, PUT, DELETE, PATCH, HEAD)
- Multiple authentication strategies
- Multiple request body content types
- Retry mechanism with configurable interval
- Flexible error handling strategies
The execution result is returned as a serialized HttpRequestNodeOutput,
or a branch identifier string when error branching is enabled.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = HttpRequestNodeConfig(**self.config)
def _build_timeout(self) -> Timeout:
"""
Build httpx Timeout configuration.
All four timeout dimensions are explicitly defined to avoid
implicit defaults that may lead to unpredictable behavior
in production environments.
"""
timeout = httpx.Timeout(
connect=self.typed_config.timeouts.connect_timeout,
read=self.typed_config.timeouts.read_timeout,
write=self.typed_config.timeouts.write_timeout,
pool=5
)
return timeout
def _build_auth(self, state: WorkflowState) -> dict[str, str]:
"""
Build authentication-related HTTP headers.
Authentication values support template rendering based on
the current workflow runtime state.
Args:
state: Current workflow runtime state.
Returns:
A dictionary of HTTP headers used for authentication.
"""
api_key = self._render_template(self.typed_config.auth.api_key, state)
match self.typed_config.auth.auth_type:
case HttpAuthType.NONE:
return {}
case HttpAuthType.BASIC:
return {
"Authorization": f"Basic {api_key}",
}
case HttpAuthType.BEARER:
return {
"Authorization": f"Bearer {api_key}",
}
case HttpAuthType.CUSTOM:
return {
self.typed_config.auth.header: api_key
}
case _:
raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}")
def _build_header(self, state: WorkflowState) -> dict[str, str]:
"""
Build HTTP request headers.
Both header keys and values support runtime template rendering.
"""
headers = {}
for key, value in self.typed_config.headers.items():
headers[self._render_template(key, state)] = self._render_template(value, state)
return headers
def _build_params(self, state: WorkflowState) -> dict[str, str]:
"""
Build URL query parameters.
Parameter keys and values support runtime template rendering.
"""
params = {}
for key, value in self.typed_config.params.items():
params[self._render_template(key, state)] = self._render_template(value, state)
return params
def _build_content(self, state) -> dict[str, Any]:
"""
Build HTTP request body arguments for httpx request methods.
The returned dictionary is directly unpacked into the httpx
request call (e.g., json=, data=, content=).
Returns:
A dictionary containing httpx-compatible request body arguments.
"""
content = {}
match self.typed_config.body.content_type:
case HttpContentType.NONE:
return {}
case HttpContentType.JSON:
content["json"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state
))
case HttpContentType.FROM_DATA:
data = {}
for item in self.typed_config.body.data:
if item.type == "text":
data[self._render_template(item.key, state)] = self._render_template(item.value, state)
elif item.type == "file":
# TODO: File support (Feature)
pass
content["data"] = data
case HttpContentType.BINARY:
# TODO: File support (Feature)
pass
case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state
))
case HttpContentType.RAW:
content["content"] = self._render_template(self.typed_config.body.data, state)
case _:
raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}")
return content
def _get_client_method(self, client: AsyncClient) -> Callable[..., Coroutine[Any, Any, Response]]:
"""
Resolve the httpx AsyncClient method based on configured HTTP method.
"""
match self.typed_config.method:
case HttpRequestMethod.GET:
return client.get
case HttpRequestMethod.POST:
return client.post
case HttpRequestMethod.PUT:
return client.put
case HttpRequestMethod.DELETE:
return client.delete
case HttpRequestMethod.PATCH:
return client.patch
case HttpRequestMethod.HEAD:
return client.head
case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
def build_conditional_edge_expressions(self):
"""
Build conditional edge expressions for workflow branching.
When the HTTP error handling strategy is set to `BRANCH`,
this node exposes a single conditional output labeled "ERROR".
The workflow engine uses this output to create an explicit
error-handling branch for downstream nodes.
Returns:
list[str]:
- ["ERROR"] if error handling strategy is BRANCH
- An empty list if no conditional branching is required
"""
if self.typed_config.error_handle.method == HttpErrorHandle.BRANCH:
return ["ERROR"]
return []
async def execute(self, state: WorkflowState) -> dict | str:
"""
Execute the HTTP request node.
Execution flow:
1. Initialize AsyncClient with configured options
2. Perform HTTP request with retry mechanism
3. Apply configured error handling strategy on failure
Args:
state: Current workflow runtime state.
Returns:
- dict: Serialized HttpRequestNodeOutput on success
- str: Branch identifier (e.g. "ERROR") when branching is enabled
"""
async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(),
headers=self._build_header(state) | self._build_auth(state),
params=self._build_params(state),
) as client:
retries = self.typed_config.retry.max_attempts
while retries > 0:
try:
request_func = self._get_client_method(client)
resp = await request_func(
url=self._render_template(self.typed_config.url, state),
**self._build_content(state)
)
resp.raise_for_status()
return HttpRequestNodeOutput(
body=resp.text,
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
except (httpx.HTTPStatusError, httpx.RequestError) as e:
logger.error(f"HTTP request node exception: {e}")
retries -= 1
if retries > 0:
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
else:
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT:
return self.typed_config.error_handle.default.model_dump()
case HttpErrorHandle.BRANCH:
return "ERROR"

View File

@@ -1,10 +1,11 @@
import logging import logging
import uuid
from typing import Any from typing import Any
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_context from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model from app.models import knowledge_model, knowledgeshare_model
from app.repositories import knowledge_repository from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType from app.schemas.chunk_schema import RetrieveType
@@ -18,38 +19,119 @@ class KnowledgeRetrievalNode(BaseNode):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
"""
Build SQLAlchemy filter conditions for querying valid knowledge bases.
Filters ensure:
- Knowledge base ID is in the provided list
- Permission type matches (Private / Share)
- Knowledge base has indexed chunks
- Knowledge base is in active status
Args:
kb_ids (list[UUID]): Candidate knowledge base IDs.
permission (PermissionType): Required permission type.
Returns:
list: SQLAlchemy filter expressions.
"""
return [
knowledge_model.Knowledge.id.in_(kb_ids),
knowledge_model.Knowledge.permission_id == permission,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
@staticmethod
def _deduplicate_docs(*doc_lists):
"""
Deduplicate documents from multiple retrieval result lists
while preserving original order.
Deduplication is based on `doc.metadata["doc_id"]`.
Args:
*doc_lists: Multiple lists of retrieved documents.
Returns:
list: Deduplicated document list.
"""
seen = set()
unique = []
for doc in (doc for lst in doc_lists for doc in lst):
doc_id = doc.metadata["doc_id"]
if doc_id not in seen:
seen.add(doc_id)
unique.append(doc)
return unique
def _get_existing_kb_ids(self, db, kb_ids):
"""
Resolve all accessible and valid knowledge base IDs for retrieval.
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Args:
db: Database session.
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
Returns:
list[UUID]: Final list of valid knowledge base IDs.
"""
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
return existing_ids
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
"""
Execute the knowledge retrieval workflow node.
Steps:
1. Render query template using workflow state
2. Resolve accessible knowledge bases
3. Initialize Elasticsearch vector service
4. Perform retrieval based on configured retrieve type
5. Deduplicate results if necessary
6. Serialize and return retrieved chunks
Args:
state (WorkflowState): Current workflow execution state.
Returns:
Any: List of retrieved knowledge chunks (dict format).
Raises:
RuntimeError: If no valid knowledge base is found or access is denied.
"""
query = self._render_template(self.typed_config.query, state) query = self._render_template(self.typed_config.query, state)
with get_db_context() as db: with get_db_read() as db:
filters = [ existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids)
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
if not existing_ids: if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
@@ -69,12 +151,10 @@ class KnowledgeRetrievalNode(BaseNode):
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices, indices=indices,
score_threshold=self.typed_config.similarity_threshold) score_threshold=self.typed_config.similarity_threshold)
return [chunk.model_dump() for chunk in rs]
case RetrieveType.SEMANTIC: case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices, indices=indices,
score_threshold=self.typed_config.vector_similarity_weight) score_threshold=self.typed_config.vector_similarity_weight)
return [chunk.model_dump() for chunk in rs]
case _: case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices, indices=indices,
@@ -82,12 +162,6 @@ class KnowledgeRetrievalNode(BaseNode):
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices, indices=indices,
score_threshold=self.typed_config.similarity_threshold) score_threshold=self.typed_config.similarity_threshold)
# Efficient deduplication # Deduplicate hybrid retrieval results
seen_ids = set() rs = self._deduplicate_docs(rs1, rs2)
unique_rs = [] return [chunk.model_dump() for chunk in rs]
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]

View File

@@ -8,6 +8,7 @@ import logging
from typing import Any, Union from typing import Any, Union
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.http_request import HttpRequestNode
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
@@ -29,6 +30,7 @@ WorkflowNode = Union[
AgentNode, AgentNode,
TransformNode, TransformNode,
AssignerNode, AssignerNode,
HttpRequestNode,
KnowledgeRetrievalNode, KnowledgeRetrievalNode,
] ]
@@ -49,6 +51,7 @@ class NodeFactory:
NodeType.IF_ELSE: IfElseNode, NodeType.IF_ELSE: IfElseNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode, NodeType.ASSIGNER: AssignerNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
} }
@classmethod @classmethod