Merge #78 into develop from feature/20251219_myh
Merge branch 'develop' into feature/20251219_myh
* feature/20251219_myh: (8 commits squashed)
- feat(workflow): update reranker model configuration for knowledge base retrieval
- fix(workflow): fix output issue in parameter extraction node
- fix(workflow): fix output issue in parameter extraction node
- feat(workflow): add user prompt to parameter extraction node
- perf(workflow): change grouped variable input to key-value format in variable aggregator
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
- Merge branch 'develop' into feature/20251219_myh
# Conflicts:
#	api/app/core/workflow/nodes/configs.py
#	api/app/core/workflow/nodes/node_factory.py
- feat(workflow): Add new cycle node for iterative workflow execution
- Introduce a new Loop/Iteration node in the workflow engine.
- Supports both conditional loops and iteration over lists.
- Allows parallel execution and flattening of iteration outputs.
- Maintains runtime state, node outputs, and loop variables for downstream nodes.
- Enhances workflow flexibility for complex, repeated operations.
Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/78
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Type, Union
|
||||
from copy import deepcopy
|
||||
from urllib.parse import urlparse
|
||||
@@ -8,8 +7,10 @@ from langchain_core.callbacks import Callbacks
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory
|
||||
from app.models import ModelProvider
|
||||
|
||||
|
||||
class RedBearRerank(BaseDocumentCompressor):
|
||||
""" Rerank → 作为 Runnable 插入任意 LCEL 链"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._model = self._create_model(config)
|
||||
self._config = config
|
||||
@@ -22,10 +23,10 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
return model_class(**model_params)
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Compress documents using Jina's Rerank API.
|
||||
@@ -46,17 +47,17 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
compressed.append(doc_copy)
|
||||
return compressed
|
||||
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
documents: Sequence[Union[str, Document, dict]],
|
||||
query: str,
|
||||
*,
|
||||
top_n: Optional[int] = -1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
provider = self._config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
self,
|
||||
documents: Sequence[Union[str, Document, dict]],
|
||||
query: str,
|
||||
*,
|
||||
top_n: Optional[int] = -1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
provider = self._config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
import langchain_community.document_compressors.jina_rerank as jina_mod
|
||||
|
||||
# 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank
|
||||
def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]:
|
||||
if not base_url:
|
||||
@@ -73,8 +74,7 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
# 设置完整的 rerank 端点,例如 http://host:port/v1/rerank
|
||||
jina_mod.JINA_API_URL = jina_base
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
model_instance : JinaRerank = self._model
|
||||
return model_instance.rerank(documents = documents, query = query, top_n=top_n)
|
||||
model_instance: JinaRerank = self._model
|
||||
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型提供商: {provider}")
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
|
||||
import logging
|
||||
# import uuid
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -107,7 +107,13 @@ class WorkflowExecutor:
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"streaming_buffer": {} # 流式缓冲区
|
||||
"streaming_buffer": {}, # 流式缓冲区
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
], # loop, iteration node id
|
||||
"looping": False # loop runing flag, only use in loop node,not use in main loop
|
||||
}
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
@@ -199,6 +205,10 @@ class WorkflowExecutor:
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
cycle_node = node.get("cycle")
|
||||
if cycle_node:
|
||||
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
|
||||
continue
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
if node_type == NodeType.START:
|
||||
@@ -271,7 +281,7 @@ class WorkflowExecutor:
|
||||
workflow.add_edge(START, start_node_id)
|
||||
logger.debug(f"添加边: START -> {start_node_id}")
|
||||
|
||||
for edge in self.edges:
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
@@ -284,12 +294,12 @@ class WorkflowExecutor:
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
# 处理到 end 节点的边
|
||||
if target in end_node_ids:
|
||||
# 连接到 end 节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
# # 处理到 end 节点的边
|
||||
# if target in end_node_ids:
|
||||
# # 连接到 end 节点
|
||||
# workflow.add_edge(source, target)
|
||||
# logger.debug(f"添加边: {source} -> {target}")
|
||||
# continue
|
||||
|
||||
# 跳过错误边(在节点内部处理)
|
||||
if edge_type == "error":
|
||||
|
||||
@@ -74,6 +74,7 @@ class ExpressionEvaluator:
|
||||
# 为了向后兼容,也支持直接访问(但会在日志中警告)
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs
|
||||
context.update(node_outputs)
|
||||
|
||||
try:
|
||||
# simpleeval 只支持安全的操作:
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,8 +40,8 @@ class AssignerNode(BaseNode):
|
||||
variable_selector = expression.split('.')
|
||||
|
||||
# Only conversation variables ('conv') are allowed
|
||||
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
||||
raise ValueError("Only conversation variables can be assigned.")
|
||||
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = assignment.value
|
||||
@@ -55,7 +55,9 @@ class AssignerNode(BaseNode):
|
||||
)
|
||||
|
||||
# Select the appropriate assignment operator instance based on the target variable type
|
||||
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
||||
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
|
||||
pool.get(variable_selector)
|
||||
)(
|
||||
pool, variable_selector, value
|
||||
)
|
||||
|
||||
|
||||
@@ -14,9 +14,13 @@ class VariableType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义
|
||||
|
||||
@@ -20,40 +20,44 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""工作流状态
|
||||
|
||||
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# 消息列表(追加模式)
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
|
||||
# 输入变量(从配置的 variables 传入)
|
||||
# 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx)
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: bool
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
variables: Annotated[dict[str, Any], lambda x, y: {
|
||||
**x,
|
||||
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
||||
for k, v in y.items()}
|
||||
}]
|
||||
|
||||
# 节点输出(存储每个节点的执行结果,用于变量引用)
|
||||
# 使用自定义合并函数,将新的节点输出合并到现有字典中
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
|
||||
# 格式:{node_id: business_result}
|
||||
|
||||
# Runtime node variables (simplified version, stores business data for fast access between nodes)
|
||||
# Format: {node_id: business_result}
|
||||
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 执行上下文
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# 错误信息(用于错误边)
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# 流式缓冲区(存储节点的实时流式输出)
|
||||
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
|
||||
|
||||
# Streaming buffer (stores real-time streaming output of nodes)
|
||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
|
||||
@@ -170,10 +174,10 @@ class BaseNode(ABC):
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
timeout = self.get_timeout()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 调用节点的业务逻辑
|
||||
business_result = await asyncio.wait_for(
|
||||
self.execute(state),
|
||||
@@ -200,7 +204,8 @@ class BaseNode(ABC):
|
||||
**wrapped_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
},
|
||||
"looping": state["looping"]
|
||||
}
|
||||
|
||||
except TimeoutError:
|
||||
@@ -236,10 +241,10 @@ class BaseNode(ABC):
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
timeout = self.get_timeout()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
|
||||
3
api/app/core/workflow/nodes/breaker/__init__.py
Normal file
3
api/app/core/workflow/nodes/breaker/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.core.workflow.nodes.breaker.node import BreakNode
|
||||
|
||||
__all__ = ["BreakNode"]
|
||||
32
api/app/core/workflow/nodes/breaker/node.py
Normal file
32
api/app/core/workflow/nodes/breaker/node.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BreakNode(BaseNode):
|
||||
"""
|
||||
Workflow node that immediately stops loop execution.
|
||||
|
||||
When executed, this node sets the 'looping' flag in the workflow state
|
||||
to False, signaling the outer loop runtime to terminate further iterations.
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the break node.
|
||||
|
||||
Args:
|
||||
state: Current workflow state, including loop control flags.
|
||||
|
||||
Effects:
|
||||
- Sets 'looping' in the state to False to stop the loop.
|
||||
- Logs the action for debugging purposes.
|
||||
|
||||
Returns:
|
||||
Optional dictionary indicating the loop has been stopped.
|
||||
"""
|
||||
state["looping"] = False
|
||||
logger.info(f"run break node, looping={state['looping']}")
|
||||
@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
__all__ = [
|
||||
# 基础类
|
||||
"BaseNodeConfig",
|
||||
@@ -41,5 +42,7 @@ __all__ = [
|
||||
"JinjaRenderNodeConfig",
|
||||
"VariableAggregatorNodeConfig",
|
||||
"ParameterExtractorNodeConfig",
|
||||
"LoopNodeConfig",
|
||||
"IterationNodeConfig",
|
||||
"QuestionClassifierNodeConfig"
|
||||
]
|
||||
|
||||
4
api/app/core/workflow/nodes/cycle_graph/__init__.py
Normal file
4
api/app/core/workflow/nodes/cycle_graph/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
|
||||
|
||||
__all__ = ['CycleGraphNode', 'LoopNodeConfig', 'IterationNodeConfig']
|
||||
96
api/app/core/workflow/nodes/cycle_graph/config.py
Normal file
96
api/app/core/workflow/nodes/cycle_graph/config.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
|
||||
|
||||
class CycleVariable(BaseNodeConfig):
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the loop variable"
|
||||
)
|
||||
type: VariableType = Field(
|
||||
...,
|
||||
description="Data type of the loop variable"
|
||||
)
|
||||
value: str = Field(
|
||||
...,
|
||||
description="Initial or current value of the loop variable"
|
||||
)
|
||||
|
||||
|
||||
class ConditionDetail(BaseModel):
|
||||
comparison_operator: ComparisonOperator = Field(
|
||||
...,
|
||||
description="Operator used to compare the left and right operands"
|
||||
)
|
||||
|
||||
left: str = Field(
|
||||
...,
|
||||
description="Left-hand operand of the comparison expression"
|
||||
)
|
||||
|
||||
right: str = Field(
|
||||
...,
|
||||
description="Right-hand operand of the comparison expression"
|
||||
)
|
||||
|
||||
|
||||
class ConditionsConfig(BaseModel):
|
||||
"""Configuration for loop condition evaluation"""
|
||||
|
||||
logical_operator: LogicOperator = Field(
|
||||
default=LogicOperator.AND.value,
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
)
|
||||
|
||||
expressions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="Collection of condition expressions to be evaluated"
|
||||
)
|
||||
|
||||
|
||||
class LoopNodeConfig(BaseNodeConfig):
|
||||
condition: ConditionsConfig = Field(
|
||||
default_factory=list,
|
||||
description="Conditional configuration that controls loop execution"
|
||||
)
|
||||
|
||||
cycle_vars: list[CycleVariable] = Field(
|
||||
default_factory=list,
|
||||
description="List of variables used and updated during the loop"
|
||||
)
|
||||
|
||||
max_loop: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of loop iterations"
|
||||
)
|
||||
|
||||
|
||||
class IterationNodeConfig(BaseNodeConfig):
|
||||
input: str = Field(
|
||||
...,
|
||||
description="Input of the loop iteration"
|
||||
)
|
||||
|
||||
parallel: bool = Field(
|
||||
default=False,
|
||||
description="Whether to execute loop iterations in parallel"
|
||||
)
|
||||
|
||||
parallel_count: int = Field(
|
||||
default=4,
|
||||
description="Number of iterations to run in parallel"
|
||||
)
|
||||
|
||||
flatten: bool = Field(
|
||||
default=False,
|
||||
description="Whether to flatten the output list from iterations"
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
...,
|
||||
description="Output of the loop iteration"
|
||||
)
|
||||
|
||||
|
||||
146
api/app/core/workflow/nodes/cycle_graph/iteration.py
Normal file
146
api/app/core/workflow/nodes/cycle_graph/iteration.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
|
||||
class IterationRuntime:
|
||||
"""
|
||||
Runtime executor for loop/iteration nodes in a workflow.
|
||||
|
||||
This class handles executing iterations over a list variable, supporting
|
||||
optional parallel execution, flattening of output, and loop control via
|
||||
the workflow state.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
):
|
||||
"""
|
||||
Initialize the iteration runtime.
|
||||
|
||||
Args:
|
||||
graph: Compiled workflow graph capable of async invocation.
|
||||
node_id: Unique identifier of the loop node.
|
||||
config: Dictionary containing iteration node configuration.
|
||||
state: Current workflow state at the point of iteration.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
self.typed_config = IterationNodeConfig(**config)
|
||||
self.looping = True
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
|
||||
def _init_iteration_state(self, item, idx):
|
||||
"""
|
||||
Initialize a per-iteration copy of the workflow state.
|
||||
|
||||
Args:
|
||||
item: Current element from the input array for this iteration.
|
||||
idx: Index of the element in the input array.
|
||||
|
||||
Returns:
|
||||
A deep copy of the workflow state with iteration-specific variables set.
|
||||
"""
|
||||
loopstate = WorkflowState(
|
||||
**copy.deepcopy(self.state)
|
||||
)
|
||||
loopstate["runtime_vars"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
loopstate["looping"] = True
|
||||
return loopstate
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
"""
|
||||
Execute a single iteration asynchronously.
|
||||
|
||||
Args:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
"""
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if not result["looping"]:
|
||||
self.looping = False
|
||||
|
||||
def _create_iteration_tasks(self, array_obj, idx):
|
||||
"""
|
||||
Create async tasks for a batch of iterations based on parallel count.
|
||||
|
||||
Args:
|
||||
array_obj: The input array to iterate over.
|
||||
idx: Starting index for this batch of iterations.
|
||||
|
||||
Returns:
|
||||
List of coroutine tasks ready to be executed in parallel.
|
||||
"""
|
||||
tasks = []
|
||||
for i in range(self.typed_config.parallel_count):
|
||||
if idx + i >= len(array_obj):
|
||||
break
|
||||
item = array_obj[idx + i]
|
||||
tasks.append(self.run_task(item, idx + i))
|
||||
return tasks
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Execute the loop over the input array according to configuration.
|
||||
|
||||
Returns:
|
||||
A list of outputs from all iterations, optionally flattened.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the input variable is not a list.
|
||||
"""
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
|
||||
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
|
||||
|
||||
array_obj = VariablePool(self.state).get(input_expression)
|
||||
if not isinstance(array_obj, list):
|
||||
raise RuntimeError("Cannot iterate over a non-list variable")
|
||||
|
||||
idx = 0
|
||||
if self.typed_config.parallel:
|
||||
# Execute iterations in parallel batches
|
||||
while idx < len(array_obj) and self.looping:
|
||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||
idx += self.typed_config.parallel_count
|
||||
await asyncio.gather(*tasks)
|
||||
return self.result
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if not result["looping"]:
|
||||
self.looping = False
|
||||
idx += 1
|
||||
return self.result
|
||||
124
api/app/core/workflow/nodes/cycle_graph/loop.py
Normal file
124
api/app/core/workflow/nodes/cycle_graph/loop.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
|
||||
class LoopRuntime:
|
||||
"""
|
||||
Runtime executor for loop nodes in a workflow.
|
||||
|
||||
Handles iterative execution of a loop node according to defined loop variables
|
||||
and conditional expressions. Supports maximum loop count and loop control
|
||||
through the workflow state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
):
|
||||
"""
|
||||
Initialize the loop runtime.
|
||||
|
||||
Args:
|
||||
graph: Compiled workflow graph capable of async invocation.
|
||||
node_id: Unique identifier of the loop node.
|
||||
config: Dictionary containing loop node configuration.
|
||||
state: Current workflow state at the point of loop execution.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
self.typed_config = LoopNodeConfig(**config)
|
||||
|
||||
def _init_loop_state(self):
|
||||
"""
|
||||
Initialize workflow state for loop execution.
|
||||
|
||||
- Evaluates initial values of loop variables.
|
||||
- Stores loop variables in runtime_vars and node_outputs.
|
||||
- Marks the loop as active by setting 'looping' to True.
|
||||
|
||||
Returns:
|
||||
A copy of the workflow state prepared for the loop execution.
|
||||
"""
|
||||
pool = VariablePool(self.state)
|
||||
# 循环变量
|
||||
self.state["runtime_vars"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
self.state["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
)
|
||||
loopstate["looping"] = True
|
||||
return loopstate
|
||||
|
||||
def _get_loop_expression(self):
|
||||
"""
|
||||
Build the Python boolean expression for evaluating the loop condition.
|
||||
|
||||
- Converts each condition in the loop configuration into a Python expression string.
|
||||
- Combines multiple conditions with the configured logical operator (AND/OR).
|
||||
|
||||
Returns:
|
||||
A string representing the combined loop condition expression.
|
||||
"""
|
||||
branch_conditions = [
|
||||
ConditionExpressionBuilder(
|
||||
left=condition.left,
|
||||
operator=condition.comparison_operator,
|
||||
right=condition.right
|
||||
).build()
|
||||
for condition in self.typed_config.condition.expressions
|
||||
]
|
||||
if len(branch_conditions) > 1:
|
||||
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
|
||||
else:
|
||||
combined_condition = branch_conditions[0]
|
||||
|
||||
return combined_condition
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Execute the loop node until the condition is no longer met, the loop is
|
||||
manually stopped, or the maximum loop count is reached.
|
||||
|
||||
Returns:
|
||||
The final runtime variables of this loop node after completion.
|
||||
"""
|
||||
loopstate = self._init_loop_state()
|
||||
expression = self._get_loop_expression()
|
||||
loop_variable_pool = VariablePool(loopstate)
|
||||
loop_time = self.typed_config.max_loop
|
||||
while evaluate_condition(
|
||||
expression=expression,
|
||||
variables=loop_variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
||||
system_vars=loop_variable_pool.get_all_system_vars(),
|
||||
) and loopstate["looping"] and loop_time > 0:
|
||||
await self.graph.ainvoke(loopstate)
|
||||
loop_time -= 1
|
||||
return loopstate["runtime_vars"][self.node_id]
|
||||
223
api/app/core/workflow/nodes/cycle_graph/node.py
Normal file
223
api/app/core/workflow/nodes/cycle_graph/node.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CycleGraphNode(BaseNode):
|
||||
"""
|
||||
Node representing a cycle (loop) subgraph within the workflow.
|
||||
|
||||
This node manages internal loop/iteration nodes, builds a subgraph
|
||||
for execution, handles conditional routing, and executes loop
|
||||
or iteration logic based on node type.
|
||||
"""
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
self.end_node_ids = [] # IDs of end nodes within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.build_graph()
|
||||
self.iteration_flag = True
|
||||
|
||||
def pure_cycle_graph(self) -> tuple[list, list]:
|
||||
"""
|
||||
Extract cycle nodes and internal edges from the workflow configuration,
|
||||
removing them from the global workflow.
|
||||
|
||||
Raises:
|
||||
ValueError: If cycle nodes are connected to external nodes improperly.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- cycle_nodes: List of removed nodes
|
||||
- cycle_edges: List of removed edges
|
||||
"""
|
||||
nodes = self.workflow_config.get("nodes", [])
|
||||
edges = self.workflow_config.get("edges", [])
|
||||
|
||||
# Select all nodes that belong to the current cycle
|
||||
cycle_nodes = [node for node in nodes if node.get("cycle") == self.node_id]
|
||||
cycle_node_ids = {node.get("id") for node in cycle_nodes}
|
||||
|
||||
cycle_edges = []
|
||||
remain_edges = []
|
||||
|
||||
for edge in edges:
|
||||
source_in = edge.get("source") in cycle_node_ids
|
||||
target_in = edge.get("target") in cycle_node_ids
|
||||
|
||||
# Raise error if cycle nodes are connected with external nodes
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(f"循环节点与外部节点存在连接,soruce: {edge.get("source")}, target:{edge.get("target")}")
|
||||
|
||||
if source_in and target_in:
|
||||
cycle_edges.append(edge)
|
||||
else:
|
||||
remain_edges.append(edge)
|
||||
|
||||
# Update workflow_config by removing cycle nodes and internal edges
|
||||
self.workflow_config["nodes"] = [
|
||||
node for node in nodes if node.get("cycle") != self.node_id
|
||||
]
|
||||
self.workflow_config["edges"] = remain_edges
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
def create_node(self):
|
||||
"""
|
||||
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
|
||||
|
||||
Special handling is applied for conditional nodes to generate
|
||||
edge conditions based on node outputs.
|
||||
"""
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
for node in self.cycle_nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
|
||||
if node_type == NodeType.CYCLE_START:
|
||||
self.start_node_id = node_id
|
||||
continue
|
||||
elif node_type == NodeType.END:
|
||||
self.end_node_ids.append(node_id)
|
||||
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
|
||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
|
||||
expressions = node_instance.build_conditional_edge_expressions()
|
||||
|
||||
# Number of branches, usually matches the number of conditional expressions
|
||||
branch_number = len(expressions)
|
||||
|
||||
# Find all edges whose source is the current node
|
||||
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
|
||||
|
||||
# Iterate over each branch
|
||||
for idx in range(branch_number):
|
||||
# Generate a condition expression for each edge
|
||||
# Used later to determine which branch to take based on the node's output
|
||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
def make_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
|
||||
return node_func
|
||||
|
||||
self.graph.add_node(node_id, make_func(node_instance))
|
||||
|
||||
def create_edge(self):
|
||||
"""
|
||||
Connect nodes within the cycle subgraph by adding edges to the internal graph.
|
||||
|
||||
Conditional edges are routed based on evaluated expressions.
|
||||
Start and end nodes are connected to global START and END nodes.
|
||||
"""
|
||||
for edge in self.cycle_edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
condition = edge.get("condition")
|
||||
|
||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||
if source == self.start_node_id:
|
||||
# 但要连接 start 到下一个节点
|
||||
self.graph.add_edge(START, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
"""条件路由函数"""
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
|
||||
self.graph.add_conditional_edges(source, router)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
else:
|
||||
# 普通边
|
||||
self.graph.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
|
||||
# 从 end 节点连接到 END
|
||||
for end_node_id in self.end_node_ids:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"添加边: {end_node_id} -> END")
|
||||
|
||||
def build_graph(self):
|
||||
"""
|
||||
Build the internal subgraph for the cycle node.
|
||||
|
||||
Steps:
|
||||
1. Extract cycle nodes and edges.
|
||||
2. Create node instances and add them to the graph.
|
||||
3. Connect edges and conditional routes.
|
||||
4. Compile the graph for execution.
|
||||
"""
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.create_node()
|
||||
self.create_edge()
|
||||
self.graph = self.graph.compile()
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the cycle node at runtime.
|
||||
|
||||
Depending on the node type, runs either a loop (LoopRuntime)
|
||||
or an iteration (IterationRuntime) over the internal subgraph.
|
||||
|
||||
Args:
|
||||
state: Current workflow state.
|
||||
|
||||
Returns:
|
||||
Runtime result of the cycle, typically the final loop/iteration variables.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If node type is unrecognized.
|
||||
"""
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
).run()
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
).run()
|
||||
raise RuntimeError("未知循环节点类型")
|
||||
@@ -1,14 +1,5 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.operators import (
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
AssignmentOperatorType,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
)
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
@@ -27,6 +18,10 @@ class NodeType(StrEnum):
|
||||
JINJARENDER = "jinja-render"
|
||||
VAR_AGGREGATOR = "var-aggregator"
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
LOOP = "loop"
|
||||
ITERATION = "iteration"
|
||||
CYCLE_START = "cycle-start"
|
||||
BREAK = "break"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
@@ -62,21 +57,6 @@ class AssignmentOperator(StrEnum):
|
||||
REMOVE_LAST = "remove_last"
|
||||
REMOVE_FIRST = "remove_first"
|
||||
|
||||
@classmethod
|
||||
def get_operator(cls, obj) -> AssignmentOperatorType:
|
||||
if isinstance(obj, str):
|
||||
return StringOperator
|
||||
elif isinstance(obj, bool):
|
||||
return BooleanOperator
|
||||
elif isinstance(obj, (int, float)):
|
||||
return NumberOperator
|
||||
elif isinstance(obj, list):
|
||||
return ArrayOperator
|
||||
elif isinstance(obj, dict):
|
||||
return ObjectOperator
|
||||
|
||||
raise TypeError(f"Unsupported variable type ({type(obj)})")
|
||||
|
||||
|
||||
class HttpRequestMethod(StrEnum):
|
||||
GET = "GET"
|
||||
|
||||
@@ -30,7 +30,7 @@ class ConditionBranchConfig(BaseModel):
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
)
|
||||
|
||||
conditions: list[ConditionDetail] = Field(
|
||||
expressions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="List of condition expressions within this branch"
|
||||
)
|
||||
@@ -57,7 +57,7 @@ class IfElseNodeConfig(BaseNodeConfig):
|
||||
# CASE1 / IF Branch
|
||||
{
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
"expressions": [
|
||||
[
|
||||
{
|
||||
"left": "node.userinput.message",
|
||||
@@ -75,7 +75,7 @@ class IfElseNodeConfig(BaseNodeConfig):
|
||||
# CASE1 / ELIF Branch
|
||||
{
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
"expressions": [
|
||||
[
|
||||
{
|
||||
"left": "node.userinput.test",
|
||||
|
||||
@@ -2,93 +2,13 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionExpressionBuilder:
|
||||
"""
|
||||
Build a Python boolean expression string based on a comparison operator.
|
||||
|
||||
This class does not evaluate the expression.
|
||||
It only generates a valid Python expression string
|
||||
that can be evaluated later in a workflow context.
|
||||
"""
|
||||
|
||||
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
||||
self.left = left
|
||||
self.operator = operator
|
||||
self.right = right
|
||||
|
||||
def _empty(self):
|
||||
return f"{self.left} == ''"
|
||||
|
||||
def _not_empty(self):
|
||||
return f"{self.left} != ''"
|
||||
|
||||
def _contains(self):
|
||||
return f"{self.right} in {self.left}"
|
||||
|
||||
def _not_contains(self):
|
||||
return f"{self.right} not in {self.left}"
|
||||
|
||||
def _startwith(self):
|
||||
return f'{self.left}.startswith({self.right})'
|
||||
|
||||
def _endwith(self):
|
||||
return f'{self.left}.endswith({self.right})'
|
||||
|
||||
def _eq(self):
|
||||
return f"{self.left} == {self.right}"
|
||||
|
||||
def _ne(self):
|
||||
return f"{self.left} != {self.right}"
|
||||
|
||||
def _lt(self):
|
||||
return f"{self.left} < {self.right}"
|
||||
|
||||
def _le(self):
|
||||
return f"{self.left} <= {self.right}"
|
||||
|
||||
def _gt(self):
|
||||
return f"{self.left} > {self.right}"
|
||||
|
||||
def _ge(self):
|
||||
return f"{self.left} >= {self.right}"
|
||||
|
||||
def build(self):
|
||||
match self.operator:
|
||||
case ComparisonOperator.EMPTY:
|
||||
return self._empty()
|
||||
case ComparisonOperator.NOT_EMPTY:
|
||||
return self._not_empty()
|
||||
case ComparisonOperator.CONTAINS:
|
||||
return self._contains()
|
||||
case ComparisonOperator.NOT_CONTAINS:
|
||||
return self._not_contains()
|
||||
case ComparisonOperator.START_WITH:
|
||||
return self._startwith()
|
||||
case ComparisonOperator.END_WITH:
|
||||
return self._endwith()
|
||||
case ComparisonOperator.EQ:
|
||||
return self._eq()
|
||||
case ComparisonOperator.NE:
|
||||
return self._ne()
|
||||
case ComparisonOperator.LT:
|
||||
return self._lt()
|
||||
case ComparisonOperator.LE:
|
||||
return self._le()
|
||||
case ComparisonOperator.GT:
|
||||
return self._gt()
|
||||
case ComparisonOperator.GE:
|
||||
return self._ge()
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {self.operator}")
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
@@ -143,7 +63,7 @@ class IfElseNode(BaseNode):
|
||||
|
||||
branch_conditions = [
|
||||
self._build_condition_expression(condition)
|
||||
for condition in case_branch.conditions
|
||||
for condition in case_branch.expressions
|
||||
]
|
||||
if len(branch_conditions) > 1:
|
||||
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query string"
|
||||
)
|
||||
|
||||
kb_ids: list[UUID] = Field(
|
||||
class KnowledgeBaseConfig(BaseModel):
|
||||
kb_id: UUID = Field(
|
||||
...,
|
||||
description="Knowledge base IDs"
|
||||
)
|
||||
@@ -37,18 +32,42 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
description="Retrieve type"
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query string"
|
||||
)
|
||||
|
||||
knowledge_bases: list[KnowledgeBaseConfig] = Field(
|
||||
...,
|
||||
description="Knowledge base config"
|
||||
)
|
||||
|
||||
reranker_id: UUID = Field(
|
||||
...,
|
||||
description="Reranker top k"
|
||||
)
|
||||
|
||||
reranker_top_k: int = Field(
|
||||
default=4,
|
||||
description="Knowledge base top k"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"query": "{{sys.message}}",
|
||||
"kb_ids": [
|
||||
"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
|
||||
],
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.3,
|
||||
"top_k": 1,
|
||||
"retrieve_type": "hybrid"
|
||||
"knowledge_bases": [{
|
||||
"kb_id": "xxxxxxxx-xxxx-xxxx-xxxxxxxxxxxxxxxxx",
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.3,
|
||||
"top_k": 4,
|
||||
"retrieve_type": "hybrid"
|
||||
}],
|
||||
"reranker_top_k": 1,
|
||||
"reranker_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2,14 +2,18 @@ import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services import knowledge_service, knowledgeshare_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -108,6 +112,44 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
existing_ids.extend(items)
|
||||
return existing_ids
|
||||
|
||||
def get_reranker_model(self) -> RedBearRerank:
|
||||
"""
|
||||
Retrieve and initialize a RedBear reranker model based on configuration.
|
||||
|
||||
Raises:
|
||||
BusinessException: If configuration is missing or API keys are not set.
|
||||
RuntimeError: If the configured model is not of type RERANK.
|
||||
"""
|
||||
with get_db_read() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.reranker_id)
|
||||
|
||||
if not config:
|
||||
raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND)
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
if model_type != ModelType.RERANK:
|
||||
raise RuntimeError("Model is not a reranker")
|
||||
|
||||
reranker = RedBearRerank(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
)
|
||||
return reranker
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the knowledge retrieval workflow node.
|
||||
@@ -131,38 +173,41 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"""
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
with get_db_read() as db:
|
||||
existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids)
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
||||
|
||||
if not existing_ids:
|
||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
rs = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
match self.typed_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.similarity_threshold)
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.vector_similarity_weight)
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.similarity_threshold)
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
|
||||
return [chunk.model_dump() for chunk in rs]
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=kb_config.top_k)
|
||||
return [chunk.model_dump() for chunk in final_rs]
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Union
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNode
|
||||
@@ -22,6 +23,7 @@ from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,6 +41,9 @@ WorkflowNode = Union[
|
||||
JinjaRenderNode,
|
||||
VariableAggregatorNode,
|
||||
ParameterExtractorNode,
|
||||
CycleGraphNode,
|
||||
BreakNode,
|
||||
ParameterExtractorNode,
|
||||
QuestionClassifierNode
|
||||
]
|
||||
|
||||
@@ -64,6 +69,9 @@ class NodeFactory:
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.LOOP: CycleGraphNode,
|
||||
NodeType.ITERATION: CycleGraphNode,
|
||||
NodeType.BREAK: BreakNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from typing import Union, Type
|
||||
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
|
||||
@@ -136,6 +137,23 @@ class ObjectOperator(OperatorBase):
|
||||
self.pool.set(self.left_selector, dict())
|
||||
|
||||
|
||||
class AssignmentOperatorResolver:
|
||||
@classmethod
|
||||
def resolve_by_value(cls, value):
|
||||
if isinstance(value, str):
|
||||
return StringOperator
|
||||
elif isinstance(value, bool):
|
||||
return BooleanOperator
|
||||
elif isinstance(value, (int, float)):
|
||||
return NumberOperator
|
||||
elif isinstance(value, list):
|
||||
return ArrayOperator
|
||||
elif isinstance(value, dict):
|
||||
return ObjectOperator
|
||||
else:
|
||||
raise TypeError(f"Unsupported variable type: {type(value)}")
|
||||
|
||||
|
||||
AssignmentOperatorInstance = Union[
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
@@ -144,3 +162,83 @@ AssignmentOperatorInstance = Union[
|
||||
ObjectOperator
|
||||
]
|
||||
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
||||
|
||||
|
||||
class ConditionExpressionBuilder:
|
||||
"""
|
||||
Build a Python boolean expression string based on a comparison operator.
|
||||
|
||||
This class does not evaluate the expression.
|
||||
It only generates a valid Python expression string
|
||||
that can be evaluated later in a workflow context.
|
||||
"""
|
||||
|
||||
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
||||
self.left = left
|
||||
self.operator = operator
|
||||
self.right = right
|
||||
|
||||
def _empty(self):
|
||||
return f"{self.left} == ''"
|
||||
|
||||
def _not_empty(self):
|
||||
return f"{self.left} != ''"
|
||||
|
||||
def _contains(self):
|
||||
return f"{self.right} in {self.left}"
|
||||
|
||||
def _not_contains(self):
|
||||
return f"{self.right} not in {self.left}"
|
||||
|
||||
def _startswith(self):
|
||||
return f'{self.left}.startswith({self.right})'
|
||||
|
||||
def _endswith(self):
|
||||
return f'{self.left}.endswith({self.right})'
|
||||
|
||||
def _eq(self):
|
||||
return f"{self.left} == {self.right}"
|
||||
|
||||
def _ne(self):
|
||||
return f"{self.left} != {self.right}"
|
||||
|
||||
def _lt(self):
|
||||
return f"{self.left} < {self.right}"
|
||||
|
||||
def _le(self):
|
||||
return f"{self.left} <= {self.right}"
|
||||
|
||||
def _gt(self):
|
||||
return f"{self.left} > {self.right}"
|
||||
|
||||
def _ge(self):
|
||||
return f"{self.left} >= {self.right}"
|
||||
|
||||
def build(self):
|
||||
match self.operator:
|
||||
case ComparisonOperator.EMPTY:
|
||||
return self._empty()
|
||||
case ComparisonOperator.NOT_EMPTY:
|
||||
return self._not_empty()
|
||||
case ComparisonOperator.CONTAINS:
|
||||
return self._contains()
|
||||
case ComparisonOperator.NOT_CONTAINS:
|
||||
return self._not_contains()
|
||||
case ComparisonOperator.START_WITH:
|
||||
return self._startswith()
|
||||
case ComparisonOperator.END_WITH:
|
||||
return self._endswith()
|
||||
case ComparisonOperator.EQ:
|
||||
return self._eq()
|
||||
case ComparisonOperator.NE:
|
||||
return self._ne()
|
||||
case ComparisonOperator.LT:
|
||||
return self._lt()
|
||||
case ComparisonOperator.LE:
|
||||
return self._le()
|
||||
case ComparisonOperator.GT:
|
||||
return self._gt()
|
||||
case ComparisonOperator.GE:
|
||||
return self._ge()
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {self.operator}")
|
||||
|
||||
@@ -36,6 +36,11 @@ class ParamsConfig(BaseModel):
|
||||
description="Description of the parameter"
|
||||
)
|
||||
|
||||
required: bool = Field(
|
||||
...,
|
||||
description="Whether the parameter is required"
|
||||
)
|
||||
|
||||
|
||||
class ParameterExtractorNodeConfig(BaseNodeConfig):
|
||||
model_id: uuid.UUID = Field(
|
||||
@@ -52,3 +57,8 @@ class ParameterExtractorNodeConfig(BaseNodeConfig):
|
||||
...,
|
||||
description="List of parameters"
|
||||
)
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description="User-provided supplemental prompt"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import json_repair
|
||||
from typing import Any
|
||||
@@ -15,6 +16,8 @@ from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
@@ -114,7 +117,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
"""
|
||||
field_type = {}
|
||||
for param in self.typed_config.params:
|
||||
field_type[param.name] = param.type
|
||||
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
|
||||
return field_type
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
@@ -154,12 +157,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", self._render_template(self.typed_config.prompt, state)),
|
||||
("user", rendered_user_prompt),
|
||||
]
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
result = json_repair.repair_json(model_resp.content)
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
logger.info(f"get prarms:{result}")
|
||||
|
||||
return {
|
||||
"output": result,
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -9,43 +9,27 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
|
||||
description="输出变量是否需要分组",
|
||||
)
|
||||
|
||||
group_names: list[str] = Field(
|
||||
default_factory=lambda: ["output"],
|
||||
description="各个分组的名称"
|
||||
)
|
||||
|
||||
group_variables: list[str] | list[list[str]] = Field(
|
||||
group_variables: list[str] | dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="需要被聚合的变量"
|
||||
)
|
||||
|
||||
@field_validator("group_names", mode="before")
|
||||
@classmethod
|
||||
def group_names_validator(cls, v, info):
|
||||
group_status = info.data.get("group")
|
||||
if not group_status or not v:
|
||||
return ["output"]
|
||||
return v
|
||||
|
||||
@field_validator("group_variables")
|
||||
@classmethod
|
||||
def group_variables_validator(cls, v, info):
|
||||
group_status = info.data.get("group")
|
||||
group_names = info.data.get("group_names")
|
||||
if not isinstance(v, list):
|
||||
raise ValueError("group_variables must be a list")
|
||||
|
||||
if not group_status:
|
||||
for variable in v:
|
||||
if not isinstance(variable, str):
|
||||
raise ValueError("When group=False, group_variables must be a list of strings")
|
||||
else:
|
||||
if len(group_names) != len(v):
|
||||
raise ValueError("group_names and group_variables length mismatch")
|
||||
for group in v:
|
||||
if not isinstance(group, list):
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("When group=True, group_variables must be a dict")
|
||||
for group_name, group_values in v.items():
|
||||
if not isinstance(group_name, str):
|
||||
raise ValueError("When group=True, each element of group_variables must be a list")
|
||||
for variable in group:
|
||||
for variable in group_values:
|
||||
if not isinstance(variable, str):
|
||||
raise ValueError("Each element inside group_variables lists must be a string")
|
||||
return v
|
||||
|
||||
@@ -59,7 +59,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
# Group mode
|
||||
# --------------------------
|
||||
result = {}
|
||||
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables):
|
||||
for group_name, variables in self.typed_config.group_variables.items():
|
||||
for variable in variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
|
||||
@@ -198,19 +198,22 @@ class VariablePool:
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
if namespace != "conv":
|
||||
raise ValueError("只能设置会话变量 (conv.*)")
|
||||
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
|
||||
key = selector[1]
|
||||
|
||||
# 确保 variables 结构存在
|
||||
if "variables" not in self.state:
|
||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
if namespace == "conv":
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
elif namespace in self.state["cycle_nodes"]:
|
||||
self.state["runtime_vars"][namespace][key] = value
|
||||
|
||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ class NodeDefinition(BaseModel):
|
||||
id: str = Field(..., description="节点唯一标识")
|
||||
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
|
||||
name: str | None = Field(None, description="节点名称")
|
||||
cycle: str | None = Field(None, description="父循环节点id")
|
||||
description: str | None = Field(None, description="节点描述")
|
||||
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
|
||||
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")
|
||||
|
||||
Reference in New Issue
Block a user