Merge #51 into develop from feature/20251219_myh

feat(workflow): add HTTP request node

* feature/20251219_myh: (3 commits)
  feat(workflow): add HTTP request node
  refactor(workflow): organize knowledge base code structure and add comments
  fix(workflow): correct property usage in HTTP request node

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/51
This commit is contained in:
朱文辉
2025-12-25 13:52:35 +08:00
9 changed files with 660 additions and 85 deletions

View File

@@ -16,6 +16,7 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry
# from app.core.tools.executor import ToolExecutor
# from app.core.tools.langchain_adapter import LangchainAdapter
@@ -69,7 +70,7 @@ class WorkflowExecutor:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
# 会话变量处理从配置文件获取变量定义列表转换为字典name -> default value
config_variables_list = self.workflow_config.get("variables") or []
conversation_vars = {}
@@ -78,19 +79,20 @@ class WorkflowExecutor:
var_name = var_def.get("name")
var_default = var_def.get("default")
if var_name:
# TODO: 入参类型校验
conversation_vars[var_name] = var_default
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构
variables = {
"sys": {
"message": user_message, # 用户消息
"message": user_message, # 用户消息
"conversation_id": input_data.get("conversation_id"), # 会话 ID
"execution_id": self.execution_id, # 执行 ID
"workspace_id": self.workspace_id, # 工作空间 ID
"user_id": self.user_id, # 用户 ID
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
"execution_id": self.execution_id, # 执行 ID
"workspace_id": self.workspace_id, # 工作空间 ID
"user_id": self.user_id, # 用户 ID
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
},
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
}
@@ -108,8 +110,6 @@ class WorkflowExecutor:
"streaming_buffer": {} # 流式缓冲区
}
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
@@ -120,72 +120,72 @@ class WorkflowExecutor:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def build_graph(self,stream=False) -> CompiledStateGraph:
def build_graph(self, stream=False) -> CompiledStateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置和相邻且被引用的节点
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
@@ -209,7 +209,7 @@ class WorkflowExecutor:
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE]:
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
@@ -232,13 +232,13 @@ class WorkflowExecutor:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if stream:
@@ -249,14 +249,18 @@ class WorkflowExecutor:
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
@@ -426,7 +430,7 @@ class WorkflowExecutor:
# 记录开始时间
start_time = datetime.datetime.now()
# 发送 workflow_start 事件
yield {
"event": "workflow_start",
@@ -447,7 +451,7 @@ class WorkflowExecutor:
try:
chunk_count = 0
final_state = None
async for event in graph.astream(
initial_state,
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
@@ -466,7 +470,7 @@ class WorkflowExecutor:
chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
@@ -478,7 +482,7 @@ class WorkflowExecutor:
"is_suffix": data.get("is_suffix")
}
}
elif mode == "debug":
# Handle debug information (node execution status)
event_type = data.get("type")
@@ -493,7 +497,7 @@ class WorkflowExecutor:
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node starts execution: {node_name}")
yield {
"event": "node_start",
"data": {
@@ -512,7 +516,7 @@ class WorkflowExecutor:
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node execution completed: {node_name}")
yield {
"event": "node_end",
"data": {
@@ -527,13 +531,16 @@ class WorkflowExecutor:
# Handle state updates - store final state
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
final_state = data
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s")
logger.info(
f"Workflow execution completed (streaming), "
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
)
# 发送 workflow_end 事件
yield {
"event": "workflow_end",
@@ -551,7 +558,7 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
# 发送 workflow_end 事件(失败)
yield {
"event": "workflow_end",
@@ -678,7 +685,6 @@ async def execute_workflow_stream(
async for event in executor.execute_stream(input_data):
yield event
# ==================== 工具管理系统集成 ====================
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
@@ -852,4 +858,4 @@ async def execute_workflow_stream(
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
# logger.info("工具节点已注册到工作流系统")
# except Exception as e:
# logger.warning(f"注册工具节点失败: {e}")
# logger.warning(f"注册工具节点失败: {e}")

View File

@@ -11,6 +11,7 @@ from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.http_request import HttpRequestNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
@@ -28,4 +29,5 @@ __all__ = [
"WorkflowNode",
"KnowledgeRetrievalNode",
"AssignerNode",
"HttpRequestNode"
]

View File

@@ -15,6 +15,7 @@ from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [
@@ -32,4 +33,5 @@ __all__ = [
"IfElseNodeConfig",
"KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",
"HttpRequestNodeConfig"
]

View File

@@ -73,3 +73,34 @@ class AssignmentOperator(StrEnum):
return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})")
class HttpRequestMethod(StrEnum):
GET = "GET"
POST = "POST"
HEAD = "HEAD"
PUT = "PUT"
PATCH = "PATCH"
DELETE = "DELETE"
class HttpAuthType(StrEnum):
NONE = "none"
BASIC = "basic"
BEARER = "bearer"
CUSTOM = "custom"
class HttpContentType(StrEnum):
NONE = "none"
FROM_DATA = "form-data"
WWW_FORM = "x-www-form-urlencoded"
JSON = "json"
RAW = "raw"
BINARY = "binary"
class HttpErrorHandle(StrEnum):
NONE = "none"
DEFAULT = "default"
BRANCH = "branch"

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig
from app.core.workflow.nodes.http_request.node import HttpRequestNode
__all__ = ['HttpRequestNode', 'HttpRequestNodeConfig']

View File

@@ -0,0 +1,215 @@
from typing import Literal
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
class HttpAuthConfig(BaseModel):
auth_type: HttpAuthType = Field(
default=HttpAuthType.NONE,
description="Type of HTTP authentication to use",
)
header: str = Field(
default="",
description="Custom HTTP Authorization header (used if auth_type is CUSTOM)",
)
api_key: str = Field(
default="",
description="API key for authentication (used if auth_type is not NONE)",
)
@field_validator("header")
@classmethod
def validate_header(cls, v, info):
auth_type = info.data.get("auth_type")
if auth_type == HttpAuthType.CUSTOM and not v:
raise ValueError("Custom auth header not specified")
return v
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v, info):
auth_type = info.data.get("auth_type")
if auth_type != HttpAuthType.NONE and not v:
raise ValueError("API key for authentication not specified")
return v
class HttpFormData(BaseModel):
key: str = Field(
...,
description="Form-data field name",
)
type: Literal["text", "file"] = Field(
...,
description="Form-data type: 'text' or 'file'"
)
value: str = Field(
...,
description="Form-data field value",
)
class HttpContentTypeConfig(BaseModel):
content_type: HttpContentType = Field(
...,
description="HTTP content type of the request body",
)
data: list[HttpFormData] | dict | str = Field(
...,
description="Data of the HTTP request body; type depends on content_type",
)
@field_validator("data")
@classmethod
def validate_data(cls, v, info):
content_type = info.data.get("content_type")
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict):
raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object")
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
return v
class HttpTimeOutConfig(BaseModel):
connect_timeout: int = Field(
default=5,
description="Connection timeout in seconds",
)
read_timeout: int = Field(
default=5,
description="Read timeout in seconds",
)
write_timeout: int = Field(
default=5,
description="Write timeout in seconds",
)
class HttpRetryConfig(BaseModel):
max_attempts: int = Field(
default=1,
description="Maximum number of retry attempts for failed requests",
)
retry_interval: int = Field(
default=100,
description="Interval between retries in milliseconds",
)
class HttpErrorDefaultTamplete(BaseModel):
body: str = Field(
default="",
description="Default body returned on HTTP error",
)
status_code: int = Field(
default=400,
description="Default HTTP status code returned on error",
)
headers: dict = Field(
default_factory=dict,
description="Default HTTP headers returned on error",
)
class HttpErrorHandleConfig(BaseModel):
method: HttpErrorHandle = Field(
default=HttpErrorHandle.NONE,
description="Error handling strategy: 'none', 'default', or 'branch'",
)
default: HttpErrorDefaultTamplete = Field(
...,
description="Default response template for error handling",
)
class HttpRequestNodeConfig(BaseNodeConfig):
method: HttpRequestMethod = Field(
...,
description="HTTP method for the request (GET, POST, etc.)",
)
url: str = Field(
...,
description="URL of the HTTP request",
)
auth: HttpAuthConfig = Field(
...,
description="HTTP authentication configuration",
)
headers: dict = Field(
default_factory=dict,
description="HTTP request headers",
)
params: dict = Field(
default_factory=dict,
description="Query parameters for the HTTP request",
)
body: HttpContentTypeConfig = Field(
...,
description="HTTP request body configuration",
)
verify_ssl: bool = Field(
...,
description="Whether to verify SSL certificates",
)
timeouts: HttpTimeOutConfig = Field(
...,
description="Timeout settings for the request",
)
retry: HttpRetryConfig = Field(
...,
description="Retry configuration for failed requests",
)
error_handle: HttpErrorHandleConfig = Field(
...,
description="Configuration for handling HTTP request errors",
)
class HttpRequestNodeOutput(BaseModel):
body: str = Field(
...,
description="Body of the HTTP response",
)
status_code: int = Field(
...,
description="HTTP response status code",
)
headers: dict = Field(
...,
description="Http response headers"
)
output: str = Field(
default="SUCCESS",
description="HTTP response body",
)
# files: list[File] = Field(
# ...
# )

View File

@@ -0,0 +1,238 @@
import asyncio
import json
import logging
from typing import Any, Callable, Coroutine
import httpx
# import filetypes # TODO: File support (Feature)
from httpx import AsyncClient, Response, Timeout
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
logger = logging.getLogger(__file__)
class HttpRequestNode(BaseNode):
"""
HTTP Request Workflow Node.
This node executes an HTTP request as part of a workflow execution.
It supports:
- Multiple HTTP methods (GET, POST, PUT, DELETE, PATCH, HEAD)
- Multiple authentication strategies
- Multiple request body content types
- Retry mechanism with configurable interval
- Flexible error handling strategies
The execution result is returned as a serialized HttpRequestNodeOutput,
or a branch identifier string when error branching is enabled.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = HttpRequestNodeConfig(**self.config)
def _build_timeout(self) -> Timeout:
"""
Build httpx Timeout configuration.
All four timeout dimensions are explicitly defined to avoid
implicit defaults that may lead to unpredictable behavior
in production environments.
"""
timeout = httpx.Timeout(
connect=self.typed_config.timeouts.connect_timeout,
read=self.typed_config.timeouts.read_timeout,
write=self.typed_config.timeouts.write_timeout,
pool=5
)
return timeout
def _build_auth(self, state: WorkflowState) -> dict[str, str]:
"""
Build authentication-related HTTP headers.
Authentication values support template rendering based on
the current workflow runtime state.
Args:
state: Current workflow runtime state.
Returns:
A dictionary of HTTP headers used for authentication.
"""
api_key = self._render_template(self.typed_config.auth.api_key, state)
match self.typed_config.auth.auth_type:
case HttpAuthType.NONE:
return {}
case HttpAuthType.BASIC:
return {
"Authorization": f"Basic {api_key}",
}
case HttpAuthType.BEARER:
return {
"Authorization": f"Bearer {api_key}",
}
case HttpAuthType.CUSTOM:
return {
self.typed_config.auth.header: api_key
}
case _:
raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}")
def _build_header(self, state: WorkflowState) -> dict[str, str]:
"""
Build HTTP request headers.
Both header keys and values support runtime template rendering.
"""
headers = {}
for key, value in self.typed_config.headers.items():
headers[self._render_template(key, state)] = self._render_template(value, state)
return headers
def _build_params(self, state: WorkflowState) -> dict[str, str]:
"""
Build URL query parameters.
Parameter keys and values support runtime template rendering.
"""
params = {}
for key, value in self.typed_config.params.items():
params[self._render_template(key, state)] = self._render_template(value, state)
return params
def _build_content(self, state) -> dict[str, Any]:
"""
Build HTTP request body arguments for httpx request methods.
The returned dictionary is directly unpacked into the httpx
request call (e.g., json=, data=, content=).
Returns:
A dictionary containing httpx-compatible request body arguments.
"""
content = {}
match self.typed_config.body.content_type:
case HttpContentType.NONE:
return {}
case HttpContentType.JSON:
content["json"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state
))
case HttpContentType.FROM_DATA:
data = {}
for item in self.typed_config.body.data:
if item.type == "text":
data[self._render_template(item.key, state)] = self._render_template(item.value, state)
elif item.type == "file":
# TODO: File support (Feature)
pass
content["data"] = data
case HttpContentType.BINARY:
# TODO: File support (Feature)
pass
case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state
))
case HttpContentType.RAW:
content["content"] = self._render_template(self.typed_config.body.data, state)
case _:
raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}")
return content
def _get_client_method(self, client: AsyncClient) -> Callable[..., Coroutine[Any, Any, Response]]:
"""
Resolve the httpx AsyncClient method based on configured HTTP method.
"""
match self.typed_config.method:
case HttpRequestMethod.GET:
return client.get
case HttpRequestMethod.POST:
return client.post
case HttpRequestMethod.PUT:
return client.put
case HttpRequestMethod.DELETE:
return client.delete
case HttpRequestMethod.PATCH:
return client.patch
case HttpRequestMethod.HEAD:
return client.head
case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
def build_conditional_edge_expressions(self):
"""
Build conditional edge expressions for workflow branching.
When the HTTP error handling strategy is set to `BRANCH`,
this node exposes a single conditional output labeled "ERROR".
The workflow engine uses this output to create an explicit
error-handling branch for downstream nodes.
Returns:
list[str]:
- ["ERROR"] if error handling strategy is BRANCH
- An empty list if no conditional branching is required
"""
if self.typed_config.error_handle.method == HttpErrorHandle.BRANCH:
return ["ERROR"]
return []
async def execute(self, state: WorkflowState) -> dict | str:
"""
Execute the HTTP request node.
Execution flow:
1. Initialize AsyncClient with configured options
2. Perform HTTP request with retry mechanism
3. Apply configured error handling strategy on failure
Args:
state: Current workflow runtime state.
Returns:
- dict: Serialized HttpRequestNodeOutput on success
- str: Branch identifier (e.g. "ERROR") when branching is enabled
"""
async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(),
headers=self._build_header(state) | self._build_auth(state),
params=self._build_params(state),
) as client:
retries = self.typed_config.retry.max_attempts
while retries > 0:
try:
request_func = self._get_client_method(client)
resp = await request_func(
url=self._render_template(self.typed_config.url, state),
**self._build_content(state)
)
resp.raise_for_status()
return HttpRequestNodeOutput(
body=resp.text,
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
except (httpx.HTTPStatusError, httpx.RequestError) as e:
logger.error(f"HTTP request node exception: {e}")
retries -= 1
if retries > 0:
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
else:
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT:
return self.typed_config.error_handle.default.model_dump()
case HttpErrorHandle.BRANCH:
return "ERROR"

View File

@@ -1,10 +1,11 @@
import logging
import uuid
from typing import Any
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_context
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
@@ -18,38 +19,119 @@ class KnowledgeRetrievalNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
"""
Build SQLAlchemy filter conditions for querying valid knowledge bases.
Filters ensure:
- Knowledge base ID is in the provided list
- Permission type matches (Private / Share)
- Knowledge base has indexed chunks
- Knowledge base is in active status
Args:
kb_ids (list[UUID]): Candidate knowledge base IDs.
permission (PermissionType): Required permission type.
Returns:
list: SQLAlchemy filter expressions.
"""
return [
knowledge_model.Knowledge.id.in_(kb_ids),
knowledge_model.Knowledge.permission_id == permission,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
@staticmethod
def _deduplicate_docs(*doc_lists):
"""
Deduplicate documents from multiple retrieval result lists
while preserving original order.
Deduplication is based on `doc.metadata["doc_id"]`.
Args:
*doc_lists: Multiple lists of retrieved documents.
Returns:
list: Deduplicated document list.
"""
seen = set()
unique = []
for doc in (doc for lst in doc_lists for doc in lst):
doc_id = doc.metadata["doc_id"]
if doc_id not in seen:
seen.add(doc_id)
unique.append(doc)
return unique
def _get_existing_kb_ids(self, db, kb_ids):
"""
Resolve all accessible and valid knowledge base IDs for retrieval.
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Args:
db: Database session.
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
Returns:
list[UUID]: Final list of valid knowledge base IDs.
"""
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
return existing_ids
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the knowledge retrieval workflow node.
Steps:
1. Render query template using workflow state
2. Resolve accessible knowledge bases
3. Initialize Elasticsearch vector service
4. Perform retrieval based on configured retrieve type
5. Deduplicate results if necessary
6. Serialize and return retrieved chunks
Args:
state (WorkflowState): Current workflow execution state.
Returns:
Any: List of retrieved knowledge chunks (dict format).
Raises:
RuntimeError: If no valid knowledge base is found or access is denied.
"""
query = self._render_template(self.typed_config.query, state)
with get_db_context() as db:
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
with get_db_read() as db:
existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids)
if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
@@ -69,12 +151,10 @@ class KnowledgeRetrievalNode(BaseNode):
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
return [chunk.model_dump() for chunk in rs]
case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
return [chunk.model_dump() for chunk in rs]
case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
@@ -82,12 +162,6 @@ class KnowledgeRetrievalNode(BaseNode):
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]
# Deduplicate hybrid retrieval results
rs = self._deduplicate_docs(rs1, rs2)
return [chunk.model_dump() for chunk in rs]

View File

@@ -8,6 +8,7 @@ import logging
from typing import Any, Union
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.http_request import HttpRequestNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
@@ -29,6 +30,7 @@ WorkflowNode = Union[
AgentNode,
TransformNode,
AssignerNode,
HttpRequestNode,
KnowledgeRetrievalNode,
]
@@ -49,6 +51,7 @@ class NodeFactory:
NodeType.IF_ELSE: IfElseNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
}
@classmethod