Merge #104 into develop from feature/20251219_myh

feat(workflow): add support for question classifier in graph construction

* feature/20251219_myh: (11 commits)
  feat(workflow): support variable types(TODO)
  fix(workflow): fix passing of loop variable termination condition
  feat(workflow): add support for passing workspace ID
  feat(workflow): support retrieving variables wrapped in {{}} from variable pool
  feat(prompt_opt): support streaming output for prompt optimization API
  feat(workflow): update workflow conditional logic
  feat(workflow): enable front-end to cover pre-rendered non-variable values
  fix(workflow): ensure default values are properly retrieved in HTTP nodes
  refactor(workflow): refactor graph construction to support subgraph building
  Merge branch 'develop' into feature/20251219_myh
  feat(workflow): add support for question classifier in graph construction

Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/104
This commit is contained in:
朱文辉
2026-01-05 11:50:19 +08:00
26 changed files with 956 additions and 576 deletions

View File

@@ -597,7 +597,8 @@ async def draft_run(
async for event in workflow_service.run_stream( async for event in workflow_service.run_stream(
app_id=app_id, app_id=app_id,
payload=payload, payload=payload,
config=config config=config,
workspace_id=current_user.current_workspace_id
): ):
# 提取事件类型和数据 # 提取事件类型和数据
event_type = event.get("event", "message") event_type = event.get("event", "message")
@@ -627,7 +628,7 @@ async def draft_run(
} }
) )
result = await workflow_service.run(app_id, payload,config) result = await workflow_service.run(app_id, payload, config, current_user.current_workspace_id)
logger.debug( logger.debug(
"工作流试运行返回结果", "工作流试运行返回结果",

View File

@@ -1,7 +1,9 @@
import uuid import uuid
import json
from fastapi import APIRouter, Depends, Path from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.response_utils import success from app.core.response_utils import success
@@ -104,35 +106,25 @@ async def get_prompt_opt(
ApiResponse: Contains the optimized prompt, description, and a list of variables. ApiResponse: Contains the optimized prompt, description, and a list of variables.
""" """
service = PromptOptimizerService(db) service = PromptOptimizerService(db)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.USER,
content=data.message
)
opt_result = await service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.ASSISTANT,
content=opt_result.desc
)
variables = service.parser_prompt_variables(opt_result.prompt)
result = {
"prompt": opt_result.prompt,
"desc": opt_result.desc,
"variables": variables
}
result_schema = OptimizePromptResponse.model_validate(result)
return success(data=result_schema)
async def event_generator():
async for chunk in service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)

View File

@@ -10,11 +10,10 @@ import logging
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.graph_builder import GraphBuilder
from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry # from app.core.tools.registry import ToolRegistry
@@ -191,155 +190,10 @@ class WorkflowExecutor:
编译后的状态图 编译后的状态图
""" """
logger.info(f"开始构建工作流图: execution_id={self.execution_id}") logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
graph = GraphBuilder(
# 分析 End 节点的前缀配置和相邻且被引用的节点 self.workflow_config,
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set()) stream=stream,
).build()
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
continue
# 记录 start 和 end 节点 ID
if node_type == NodeType.START:
start_node_id = node_id
elif node_type == NodeType.END:
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(len(related_edge)):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.workflow_config.get("edges", []):
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{tgt}"
return router_fn
router_fn = make_router(condition, target)
workflow.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}") logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph return graph

View File

@@ -0,0 +1,253 @@
import logging
import uuid
from typing import Any
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph import START, END
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
# TODO: 子图拆解支持
class GraphBuilder:
def __init__(
self,
workflow_config: dict[str, Any],
stream: bool = False,
subgraph: bool = False,
):
self.workflow_config = workflow_config
self.stream = stream
self.subgraph = subgraph
self.start_node_id = None
self.end_node_ids = []
self.graph: StateGraph | CompiledStateGraph | None = None
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def add_nodes(self):
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
if not self.subgraph:
continue
# 记录 start 和 end 节点 ID
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(len(related_edge)):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if self.stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if self.stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if self.stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
self.graph.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
def add_edges(self):
if self.start_node_id:
self.graph.add_edge(START, self.start_node_id)
logger.debug(f"添加边: START -> {self.start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("runtime_vars", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
return router_fn
router_fn = make_router(condition, target)
self.graph.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
return
def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges() # 添加边必须在添加节点之后
return self.graph.compile()

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import Field, BaseModel from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
@@ -19,7 +21,7 @@ class AssignmentItem(BaseModel):
description="Assignment operator", description="Assignment operator",
) )
value: str | list[str] = Field( value: Any = Field(
..., ...,
description="Value(s) to assign to the variable(s)", description="Value(s) to assign to the variable(s)",
) )

View File

@@ -2,7 +2,6 @@ import logging
import re import re
from typing import Any from typing import Any
from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator from app.core.workflow.nodes.enums import AssignmentOperator
@@ -29,6 +28,7 @@ class AssignerNode(BaseNode):
None or the result of the assignment operation. None or the result of the assignment operation.
""" """
# Initialize a variable pool for accessing conversation, node, and system variables # Initialize a variable pool for accessing conversation, node, and system variables
logger.info(f"节点 {self.node_id} 开始执行")
pool = VariablePool(state) pool = VariablePool(state)
for assignment in self.typed_config.assignments: for assignment in self.typed_config.assignments:
# Get the target variable selector (e.g., "conv.test") # Get the target variable selector (e.g., "conv.test")
@@ -45,14 +45,13 @@ class AssignerNode(BaseNode):
# Get the value or expression to assign # Get the value or expression to assign
value = assignment.value value = assignment.value
if isinstance(value, list): pattern = r"\{\{\s*(.*?)\s*\}\}"
value = '.'.join(value) if isinstance(value, str):
value = ExpressionEvaluator.evaluate( expression = re.match(pattern, value)
expression=value, if expression:
variables=pool.get_all_conversation_vars(), expression = expression.group(1)
node_outputs=pool.get_all_node_outputs(), expression = re.sub(pattern, r"\1", expression).strip()
system_vars=pool.get_all_system_vars(), value = self.get_variable(expression, state)
)
# Select the appropriate assignment operator instance based on the target variable type # Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value( operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
@@ -63,6 +62,8 @@ class AssignerNode(BaseNode):
# Execute the configured assignment operation # Execute the configured assignment operation
match assignment.operation: match assignment.operation:
case AssignmentOperator.COVER:
operator.assign()
case AssignmentOperator.ASSIGN: case AssignmentOperator.ASSIGN:
operator.assign() operator.assign()
case AssignmentOperator.CLEAR: case AssignmentOperator.CLEAR:

View File

@@ -4,8 +4,11 @@
""" """
from enum import StrEnum from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, ConfigDict
VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}"
class VariableType(StrEnum): class VariableType(StrEnum):
@@ -22,6 +25,57 @@ class VariableType(StrEnum):
ARRAY_OBJECT = "array[object]" ARRAY_OBJECT = "array[object]"
class TypedVariable(BaseModel):
"""
TODO: 强类型限制
Strongly typed variable that validates value on assignment.
"""
value: Any = Field(..., description="Variable value")
type: VariableType = Field(..., description="Declared type of the variable")
model_config = ConfigDict(
validate_assignment=True
)
def __setattr__(self, name, value):
if name == "value":
self._validate_value(value)
if name == "type":
raise RuntimeError("Cannot modify variable type at runtime")
super().__setattr__(name, value)
def _validate_value(self, v: Any):
t = self.type
match t:
case VariableType.STRING:
if not isinstance(v, str):
raise TypeError("Variable value does not match type STRING")
case VariableType.BOOLEAN:
if not isinstance(v, bool):
raise TypeError("Variable value does not match type BOOLEAN")
case VariableType.NUMBER:
if not isinstance(v, (int, float)):
raise TypeError("Variable value does not match type NUMBER")
case VariableType.OBJECT:
if not isinstance(v, dict):
raise TypeError("Variable value does not match type OBJECT")
case VariableType.ARRAY_STRING:
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
raise TypeError("Variable value does not match type ARRAY_STRING")
case VariableType.ARRAY_NUMBER:
if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v):
raise TypeError("Variable value does not match type ARRAY_NUMBER")
case VariableType.ARRAY_BOOLEAN:
if not isinstance(v, list) or not all(isinstance(i, bool) for i in v):
raise TypeError("Variable value does not match type ARRAY_BOOLEAN")
case VariableType.ARRAY_OBJECT:
if not isinstance(v, list) or not all(isinstance(i, dict) for i in v):
raise TypeError("Variable value does not match type ARRAY_OBJECT")
case _:
raise TypeError(f"Unknown variable type: {t}")
class VariableDefinition(BaseModel): class VariableDefinition(BaseModel):
"""变量定义 """变量定义

View File

@@ -356,7 +356,8 @@ class BaseNode(ABC):
**final_output, **final_output,
"runtime_vars": { "runtime_vars": {
self.node_id: runtime_var self.node_id: runtime_var
} },
"looping": state["looping"]
} }
# Add streaming buffer for non-End nodes # Add streaming buffer for non-End nodes

View File

@@ -1,7 +1,9 @@
from typing import Any
from pydantic import Field, BaseModel from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class CycleVariable(BaseNodeConfig): class CycleVariable(BaseNodeConfig):
@@ -9,18 +11,25 @@ class CycleVariable(BaseNodeConfig):
..., ...,
description="Name of the loop variable" description="Name of the loop variable"
) )
type: VariableType = Field( type: VariableType = Field(
..., ...,
description="Data type of the loop variable" description="Data type of the loop variable"
) )
value: str = Field(
input_type: ValueInputType = Field(
...,
description="Input type of the loop variable"
)
value: Any = Field(
..., ...,
description="Initial or current value of the loop variable" description="Initial or current value of the loop variable"
) )
class ConditionDetail(BaseModel): class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field( operator: ComparisonOperator = Field(
..., ...,
description="Operator used to compare the left and right operands" description="Operator used to compare the left and right operands"
) )
@@ -30,11 +39,16 @@ class ConditionDetail(BaseModel):
description="Left-hand operand of the comparison expression" description="Left-hand operand of the comparison expression"
) )
right: str = Field( right: Any = Field(
..., ...,
description="Right-hand operand of the comparison expression" description="Right-hand operand of the comparison expression"
) )
input_type: ValueInputType = Field(
...,
description="Input type of the loop variable"
)
class ConditionsConfig(BaseModel): class ConditionsConfig(BaseModel):
"""Configuration for loop condition evaluation""" """Configuration for loop condition evaluation"""

View File

@@ -3,10 +3,11 @@ from typing import Any
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression from app.core.workflow.expression_evaluator import evaluate_expression
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -14,11 +15,13 @@ logger = logging.getLogger(__name__)
class LoopRuntime: class LoopRuntime:
""" """
Runtime executor for loop nodes in a workflow. Runtime executor for a loop node in a workflow graph.
Handles iterative execution of a loop node according to defined loop variables This class is responsible for executing a loop node at runtime:
and conditional expressions. Supports maximum loop count and loop control - Initializing loop-scoped variables
through the workflow state. - Evaluating loop continuation conditions
- Repeatedly invoking a compiled sub-graph
- Enforcing maximum loop count and external stop signals
""" """
def __init__( def __init__(
@@ -29,13 +32,13 @@ class LoopRuntime:
state: WorkflowState, state: WorkflowState,
): ):
""" """
Initialize the loop runtime. Initialize the loop runtime executor.
Args: Args:
graph: Compiled workflow graph capable of async invocation. graph: A compiled LangGraph state graph representing the loop body.
node_id: Unique identifier of the loop node. node_id: The unique identifier of the loop node in the workflow.
config: Dictionary containing loop node configuration. config: Raw configuration dictionary for the loop node.
state: Current workflow state at the point of loop execution. state: The current workflow state before entering the loop.
""" """
self.graph = graph self.graph = graph
self.state = state self.state = state
@@ -46,12 +49,15 @@ class LoopRuntime:
""" """
Initialize workflow state for loop execution. Initialize workflow state for loop execution.
- Evaluates initial values of loop variables. This method:
- Stores loop variables in runtime_vars and node_outputs. - Evaluates initial values of loop variables
- Marks the loop as active by setting 'looping' to True. - Stores loop variables into both `runtime_vars` and `node_outputs`
under the current loop node's scope
- Creates a shallow copy of the workflow state
- Marks the loop as active by setting `looping = True`
Returns: Returns:
A copy of the workflow state prepared for the loop execution. WorkflowState: A prepared workflow state used for loop execution.
""" """
pool = VariablePool(self.state) pool = VariablePool(self.state)
# 循环变量 # 循环变量
@@ -61,7 +67,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(), variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(), node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(), system_vars=pool.get_all_system_vars(),
) ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars for variable in self.typed_config.cycle_vars
} }
self.state["node_outputs"][self.node_id] = { self.state["node_outputs"][self.node_id] = {
@@ -70,7 +76,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(), variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(), node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(), system_vars=pool.get_all_system_vars(),
) ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars for variable in self.typed_config.cycle_vars
} }
loopstate = WorkflowState( loopstate = WorkflowState(
@@ -79,49 +85,93 @@ class LoopRuntime:
loopstate["looping"] = True loopstate["looping"] = True
return loopstate return loopstate
def _get_loop_expression(self): @staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
""" """
Build the Python boolean expression for evaluating the loop condition. Dispatch and execute a comparison operator against a resolved
CompareOperatorInstance.
- Converts each condition in the loop configuration into a Python expression string. Args:
- Combines multiple conditions with the configured logical operator (AND/OR). operator: A ComparisonOperator enum value.
instance: A CompareOperatorInstance bound to concrete operands.
Returns: Returns:
A string representing the combined loop condition expression. Any: The evaluation result, typically a boolean.
""" """
branch_conditions = [ match operator:
ConditionExpressionBuilder( case ComparisonOperator.EMPTY:
left=condition.left, return instance.empty()
operator=condition.comparison_operator, case ComparisonOperator.NOT_EMPTY:
right=condition.right return instance.not_empty()
).build() case ComparisonOperator.CONTAINS:
for condition in self.typed_config.condition.expressions return instance.contains()
] case ComparisonOperator.NOT_CONTAINS:
if len(branch_conditions) > 1: return instance.not_contains()
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions) case ComparisonOperator.START_WITH:
else: return instance.startswith()
combined_condition = branch_conditions[0] case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
return combined_condition def evaluate_conditional(self, state) -> bool:
"""
Evaluate the loop continuation condition at runtime.
This method:
- Resolves all condition expressions against the current workflow state
- Evaluates each comparison expression immediately
- Combines results using the configured logical operator (AND / OR)
Args:
state: The current workflow state during loop execution.
Returns:
bool: True if the loop should continue, False otherwise.
"""
conditions = []
for expression in self.typed_config.condition.expressions:
left_value = VariablePool(state).get(expression.left)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
VariablePool(state),
expression.left,
expression.right,
expression.input_type
)
conditions.append(self._evaluate(expression.operator, evaluator))
if self.typed_config.condition.logical_operator == LogicOperator.AND:
return all(conditions)
else:
return any(conditions)
async def run(self): async def run(self):
""" """
Execute the loop node until the condition is no longer met, the loop is Execute the loop node until termination conditions are met.
manually stopped, or the maximum loop count is reached.
The loop terminates when any of the following occurs:
- The loop condition evaluates to False
- The `looping` flag in the workflow state is set to False
- The maximum loop count is reached
Returns: Returns:
The final runtime variables of this loop node after completion. dict[str, Any]: The final runtime variables of this loop node.
""" """
loopstate = self._init_loop_state() loopstate = self._init_loop_state()
expression = self._get_loop_expression()
loop_variable_pool = VariablePool(loopstate)
loop_time = self.typed_config.max_loop loop_time = self.typed_config.max_loop
while evaluate_condition( while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
expression=expression,
variables=loop_variable_pool.get_all_conversation_vars(),
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running") logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate) await self.graph.ainvoke(loopstate)
loop_time -= 1 loop_time -= 1

View File

@@ -1,10 +1,9 @@
import logging import logging
from typing import Any from typing import Any
from langgraph.graph import StateGraph, START, END from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
@@ -17,12 +16,18 @@ logger = logging.getLogger(__name__)
class CycleGraphNode(BaseNode): class CycleGraphNode(BaseNode):
""" """
Node representing a cycle (loop) subgraph within the workflow. Node representing a cyclic (loop or iteration) subgraph within the workflow.
This node manages internal loop/iteration nodes, builds a subgraph A CycleGraphNode is a structural node that:
for execution, handles conditional routing, and executes loop - Extracts a group of nodes marked as belonging to the same cycle
or iteration logic based on node type. - Builds an isolated internal StateGraph (subgraph)
- Delegates runtime execution to LoopRuntime or IterationRuntime
depending on the node type
This node itself does NOT execute business logic directly.
It acts as a container and execution controller for a subgraph.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
@@ -38,16 +43,23 @@ class CycleGraphNode(BaseNode):
def pure_cycle_graph(self) -> tuple[list, list]: def pure_cycle_graph(self) -> tuple[list, list]:
""" """
Extract cycle nodes and internal edges from the workflow configuration, Extract cycle-scoped nodes and internal edges from the workflow configuration.
removing them from the global workflow.
Raises: This method:
ValueError: If cycle nodes are connected to external nodes improperly. - Identifies all nodes marked with `cycle == self.node_id`
- Collects edges that fully connect cycle nodes
- Removes extracted nodes and edges from the global workflow configuration
Safety check:
- Raises an error if a cycle node is connected to an external node
Returns: Returns:
Tuple containing: tuple[list, list]:
- cycle_nodes: List of removed nodes - cycle_nodes: Nodes belonging to this cycle
- cycle_edges: List of removed edges - cycle_edges: Edges connecting nodes within the cycle
Raises:
ValueError: If a cycle node is improperly connected to an external node.
""" """
nodes = self.workflow_config.get("nodes", []) nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", []) edges = self.workflow_config.get("edges", [])
@@ -83,131 +95,41 @@ class CycleGraphNode(BaseNode):
return cycle_nodes, cycle_edges return cycle_nodes, cycle_edges
def create_node(self):
"""
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
Special handling is applied for conditional nodes to generate
edge conditions based on node outputs.
"""
from app.core.workflow.nodes import NodeFactory
for node in self.cycle_nodes:
node_type = node.get("type")
node_id = node.get("id")
if node_type == NodeType.CYCLE_START:
self.start_node_id = node_id
continue
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
def create_edge(self):
"""
Connect nodes within the cycle subgraph by adding edges to the internal graph.
Conditional edges are routed based on evaluated expressions.
Start and end nodes are connected to global START and END nodes.
"""
for edge in self.cycle_edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(START, target)
logger.debug(f"添加边: {source} -> {target}")
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
self.graph.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
def build_graph(self): def build_graph(self):
""" """
Build the internal subgraph for the cycle node. Build and compile the internal subgraph for this cycle node.
Steps: Steps:
1. Extract cycle nodes and edges. 1. Extract cycle nodes and internal edges from the workflow
2. Create node instances and add them to the graph. 2. Construct a StateGraph using GraphBuilder in subgraph mode
3. Connect edges and conditional routes. 3. Compile the graph for runtime execution
4. Compile the graph for execution.
""" """
self.graph = StateGraph(WorkflowState) from app.core.workflow.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.create_node() self.graph = GraphBuilder(
self.create_edge() {
self.graph = self.graph.compile() "nodes": self.cycle_nodes,
"edges": self.cycle_edges,
},
subgraph=True
).build()
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
""" """
Execute the cycle node at runtime. Execute the cycle node at runtime.
Depending on the node type, runs either a loop (LoopRuntime) Based on the node type:
or an iteration (IterationRuntime) over the internal subgraph. - LOOP: Executes LoopRuntime, repeatedly invoking the subgraph
- ITERATION: Executes IterationRuntime, iterating over a collection
Args: Args:
state: Current workflow state. state: The current workflow state when entering the cycle node.
Returns: Returns:
Runtime result of the cycle, typically the final loop/iteration variables. Any: The runtime result produced by the loop or iteration executor.
Raises: Raises:
RuntimeError: If node type is unrecognized. RuntimeError: If the node type is unsupported.
""" """
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
return await LoopRuntime( return await LoopRuntime(

View File

@@ -61,7 +61,7 @@ class EndNode(BaseNode):
引用的节点 ID 列表 引用的节点 ID 列表
""" """
# 匹配 {{node_id.xxx}} 格式 # 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}' pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template) matches = re.findall(pattern, template)
return list(set(matches)) # 去重 return list(set(matches)) # 去重

View File

@@ -45,7 +45,8 @@ class LogicOperator(StrEnum):
class AssignmentOperator(StrEnum): class AssignmentOperator(StrEnum):
ASSIGN = "assign" COVER = "cover" # 覆盖
ASSIGN = "assign" # 设置
CLEAR = "clear" CLEAR = "clear"
ADD = "add" # += ADD = "add" # +=
@@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum):
NONE = "none" NONE = "none"
DEFAULT = "default" DEFAULT = "default"
BRANCH = "branch" BRANCH = "branch"
class ValueInputType(StrEnum):
VARIABLE = "Variable"
CONSTANT = "Constant"

View File

@@ -63,7 +63,7 @@ class HttpContentTypeConfig(BaseModel):
) )
data: list[HttpFormData] | dict | str = Field( data: list[HttpFormData] | dict | str = Field(
..., default="",
description="Data of the HTTP request body; type depends on content_type", description="Data of the HTTP request body; type depends on content_type",
) )
@@ -98,6 +98,10 @@ class HttpTimeOutConfig(BaseModel):
class HttpRetryConfig(BaseModel): class HttpRetryConfig(BaseModel):
enable: bool = Field(
...,
description="Enable/disable retry logic",
)
max_attempts: int = Field( max_attempts: int = Field(
default=1, default=1,
description="Maximum number of retry attempts for failed requests", description="Maximum number of retry attempts for failed requests",
@@ -124,6 +128,11 @@ class HttpErrorDefaultTamplete(BaseModel):
description="Default HTTP headers returned on error", description="Default HTTP headers returned on error",
) )
output: str = Field(
default="SUCCESS",
description="HTTP response body",
)
class HttpErrorHandleConfig(BaseModel): class HttpErrorHandleConfig(BaseModel):
method: HttpErrorHandle = Field( method: HttpErrorHandle = Field(
@@ -131,8 +140,8 @@ class HttpErrorHandleConfig(BaseModel):
description="Error handling strategy: 'none', 'default', or 'branch'", description="Error handling strategy: 'none', 'default', or 'branch'",
) )
default: HttpErrorDefaultTamplete = Field( default: HttpErrorDefaultTamplete | None = Field(
..., default=None,
description="Default response template for error handling", description="Default response template for error handling",
) )

View File

@@ -165,24 +165,6 @@ class HttpRequestNode(BaseNode):
case _: case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}") raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
def build_conditional_edge_expressions(self):
"""
Build conditional edge expressions for workflow branching.
When the HTTP error handling strategy is set to `BRANCH`,
this node exposes a single conditional output labeled "ERROR".
The workflow engine uses this output to create an explicit
error-handling branch for downstream nodes.
Returns:
list[str]:
- ["ERROR"] if error handling strategy is BRANCH
- An empty list if no conditional branching is required
"""
if self.typed_config.error_handle.method == HttpErrorHandle.BRANCH:
return ["ERROR"]
return []
async def execute(self, state: WorkflowState) -> dict | str: async def execute(self, state: WorkflowState) -> dict | str:
""" """
Execute the HTTP request node. Execute the HTTP request node.

View File

@@ -1,12 +1,13 @@
"""Condition Configuration""" """Condition Configuration"""
from typing import Any
from pydantic import Field, BaseModel, field_validator from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class ConditionDetail(BaseModel): class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field( operator: ComparisonOperator = Field(
..., ...,
description="Comparison operator used to evaluate the condition" description="Comparison operator used to evaluate the condition"
) )
@@ -16,17 +17,22 @@ class ConditionDetail(BaseModel):
description="Value to compare against" description="Value to compare against"
) )
right: str = Field( right: Any = Field(
..., ...,
description="Value to compare with" description="Value to compare with"
) )
input_type: ValueInputType = Field(
...,
description="Value input type for comparison"
)
class ConditionBranchConfig(BaseModel): class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch""" """Configuration for a conditional branch"""
logical_operator: LogicOperator = Field( logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value, default=LogicOperator.AND,
description="Logical operator used to combine multiple condition expressions" description="Logical operator used to combine multiple condition expressions"
) )

View File

@@ -1,10 +1,11 @@
import logging import logging
import re
from typing import Any from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,30 +16,36 @@ class IfElseNode(BaseNode):
self.typed_config = IfElseNodeConfig(**self.config) self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod @staticmethod
def _build_condition_expression( def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
condition: ConditionDetail, match operator:
) -> str: case ComparisonOperator.EMPTY:
""" return instance.empty()
Build a single boolean condition expression string. case ComparisonOperator.NOT_EMPTY:
return instance.not_empty()
case ComparisonOperator.CONTAINS:
return instance.contains()
case ComparisonOperator.NOT_CONTAINS:
return instance.not_contains()
case ComparisonOperator.START_WITH:
return instance.startswith()
case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
This method does NOT evaluate the condition. def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
""" """
Build conditional edge expressions for the If-Else node. Build conditional edge expressions for the If-Else node.
@@ -60,19 +67,28 @@ class IfElseNode(BaseNode):
for case_branch in self.typed_config.cases: for case_branch in self.typed_config.cases:
branch_index += 1 branch_index += 1
branch_result = []
branch_conditions = [ for expression in case_branch.expressions:
self._build_condition_expression(condition) pattern = r"\{\{\s*(.*?)\s*\}\}"
for condition in case_branch.expressions left_string = re.sub(pattern, r"\1", expression.left).strip()
] left_value = self.get_variable(left_string, state)
if len(branch_conditions) > 1: evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) self.get_variable_pool(state),
expression.left,
expression.right,
expression.input_type
)
branch_result.append(self._evaluate(expression.operator, evaluator))
if case_branch.logical_operator == LogicOperator.AND:
conditions.append(all(branch_result))
else: else:
combined_condition = branch_conditions[0] condition_res = any(branch_result)
conditions.append(combined_condition) conditions.append(condition_res)
if condition_res:
return conditions
# Default fallback branch # Default fallback branch
conditions.append("True") conditions.append(True)
return conditions return conditions
@@ -90,10 +106,10 @@ class IfElseNode(BaseNode):
Returns: Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
""" """
expressions = self.build_conditional_edge_expressions() expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
for i in range(len(expressions)): for i in range(len(expressions)):
logger.info(expressions[i]) if expressions[i]:
if self._evaluate_condition(expressions[i], state):
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}") logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}' return f'CASE{i + 1}'
return f'CASE{len(expressions)}' return f'CASE{len(expressions)}'

View File

@@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)

View File

@@ -74,6 +74,7 @@ class NodeFactory:
NodeType.LOOP: CycleGraphNode, NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode, NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode, NodeType.BREAK: BreakNode,
NodeType.CYCLE_START: StartNode,
NodeType.TOOL: ToolNode, NodeType.TOOL: ToolNode,
} }

View File

@@ -1,10 +1,73 @@
import json
import re
from abc import ABC from abc import ABC
from typing import Union, Type from typing import Union, Type, NoReturn
from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import ValueInputType
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
class TypeTransformer:
@classmethod
def _fail(cls, value, target) -> NoReturn:
raise TypeError(f"Cannot convert {value!r} to {target} type")
@classmethod
def _json_load(cls, value, target):
try:
return json.loads(value)
except Exception:
cls._fail(value, target)
@classmethod
def transform(cls, variable_literal: str | bool, target_type: VariableType):
match target_type:
case VariableType.STRING:
return str(variable_literal)
case VariableType.NUMBER:
for caster in (int, float):
try:
return caster(variable_literal)
except Exception:
pass
cls._fail(variable_literal, target_type)
case VariableType.BOOLEAN:
if isinstance(variable_literal, bool):
return variable_literal
cls._fail(variable_literal, target_type)
case VariableType.OBJECT:
obj = cls._json_load(variable_literal, target_type)
if isinstance(obj, dict):
return obj
cls._fail(variable_literal, target_type)
case VariableType.ARRAY_BOOLEAN:
return cls._parse_list(variable_literal, bool, target_type)
case VariableType.ARRAY_NUMBER:
return cls._parse_list(variable_literal, (int, float), target_type)
case VariableType.ARRAY_STRING:
return cls._parse_list(variable_literal, str, target_type)
case VariableType.ARRAY_OBJECT:
return cls._parse_list(variable_literal, dict, target_type)
case _:
raise TypeError("Invalid type")
@classmethod
def _parse_list(cls, value, item_type, target):
arr = cls._json_load(value, target)
if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr):
return arr
cls._fail(value, target)
class OperatorBase(ABC): class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right): def __init__(self, pool: VariablePool, left_selector, right):
self.pool = pool self.pool = pool
@@ -19,7 +82,9 @@ class OperatorBase(ABC):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
if not no_right and not isinstance(self.right, self.type_limit): if not no_right and not isinstance(self.right, self.type_limit):
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type") raise TypeError(
f"The value assigned must be of {self.type_limit} type"
)
class StringOperator(OperatorBase): class StringOperator(OperatorBase):
@@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase):
class ObjectOperator(OperatorBase): class ObjectOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right): def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right) super().__init__(pool, left_selector, right)
self.type_limit = object self.type_limit = dict
def assign(self) -> None: def assign(self) -> None:
self.check() self.check()
@@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase):
class AssignmentOperatorResolver: class AssignmentOperatorResolver:
OPERATOR_MAP = {
str: StringOperator,
bool: BooleanOperator,
int: NumberOperator,
float: NumberOperator,
list: ArrayOperator,
dict: ObjectOperator,
}
@classmethod @classmethod
def resolve_by_value(cls, value): def resolve_by_value(cls, value):
if isinstance(value, str): for t, op in cls.OPERATOR_MAP.items():
return StringOperator if isinstance(value, t):
elif isinstance(value, bool): return op
return BooleanOperator raise TypeError(f"Unsupported variable type: {type(value)}")
elif isinstance(value, (int, float)):
return NumberOperator
elif isinstance(value, list):
return ArrayOperator
elif isinstance(value, dict):
return ObjectOperator
else:
raise TypeError(f"Unsupported variable type: {type(value)}")
AssignmentOperatorInstance = Union[ AssignmentOperatorInstance = Union[
@@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[
AssignmentOperatorType = Type[AssignmentOperatorInstance] AssignmentOperatorType = Type[AssignmentOperatorInstance]
class ConditionExpressionBuilder: class ConditionBase(ABC):
""" type_limit: type[str, int, dict, list] = None
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression. def __init__(
It only generates a valid Python expression string self,
that can be evaluated later in a workflow context. pool: VariablePool,
""" left_selector,
right_selector: str,
input_type: ValueInputType
):
self.pool = pool
self.left_selector = left_selector
self.right_selector = right_selector
self.input_type = input_type
def __init__(self, left: str, operator: ComparisonOperator, right: str): self.left_value = self.pool.get(self.left_selector)
self.left = left self.right_value = self.resolve_right_literal_value()
self.operator = operator
self.right = right
def _empty(self): self.type_limit = getattr(self, "type_limit", None)
return f"{self.left} == ''"
def _not_empty(self): def resolve_right_literal_value(self):
return f"{self.left} != ''" if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
return self.pool.get(right_expression)
elif self.input_type == ValueInputType.CONSTANT:
return self.right_selector
raise RuntimeError("Unsupported variable type")
def _contains(self): def check(self, no_right=False):
return f"{self.right} in {self.left}" left = self.pool.get(self.left_selector.variable_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
if not no_right:
right = self.resolve_right_literal_value()
if not isinstance(right, self.type_limit):
raise TypeError(
f"The compared variable must be of {self.type_limit} type"
)
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startswith(self): class StringComparisonOperator(ConditionBase):
return f'{self.left}.startswith({self.right})' type_limit = str
def _endswith(self): def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
return f'{self.left}.endswith({self.right})' super().__init__(pool, left_selector, right_selector, input_type)
def _eq(self): def empty(self):
return f"{self.left} == {self.right}" self.check(no_right=True)
return self.left_value == ""
def _ne(self): def not_empty(self):
return f"{self.left} != {self.right}" return not self.empty()
def _lt(self): def contains(self):
return f"{self.left} < {self.right}" self.check()
return self.right_value in self.left_value
def _le(self): def not_contains(self):
return f"{self.left} <= {self.right}" return self.right_value not in self.left_value
def _gt(self): def startswith(self):
return f"{self.left} > {self.right}" self.check()
return self.left_value.startswith(self.right_value)
def _ge(self): def endswith(self):
return f"{self.left} >= {self.right}" return self.left_value.endswith(self.right_value)
def build(self): def eq(self):
match self.operator: return self.left_value == self.right_value
case ComparisonOperator.EMPTY:
return self._empty() def ne(self):
case ComparisonOperator.NOT_EMPTY: return self.left_value != self.right_value
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains() class NumberComparisonOperator(ConditionBase):
case ComparisonOperator.NOT_CONTAINS: type_limit = (int, float)
return self._not_contains()
case ComparisonOperator.START_WITH: def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
return self._startswith() super().__init__(pool, left_selector, right_selector, input_type)
case ComparisonOperator.END_WITH:
return self._endswith() def empty(self):
case ComparisonOperator.EQ: return self.left_value == 0
return self._eq()
case ComparisonOperator.NE: def not_empty(self):
return self._ne() return self.left_value != 0
case ComparisonOperator.LT:
return self._lt() def eq(self):
case ComparisonOperator.LE: return self.left_value == self.right_value
return self._le()
case ComparisonOperator.GT: def ne(self):
return self._gt() return self.left_value != self.right_value
case ComparisonOperator.GE:
return self._ge() def lt(self):
case _: return self.left_value < self.right_value
raise ValueError(f"Invalid condition: {self.operator}")
def le(self):
return self.left_value <= self.right_value
def gt(self):
return self.left_value > self.right_value
def ge(self):
return self.left_value >= self.right_value
class BooleanComparisonOperator(ConditionBase):
type_limit = bool
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
class ObjectComparisonOperator(ConditionBase):
type_limit = dict
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
class ArrayComparisonOperator(ConditionBase):
type_limit = list
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
def contains(self):
return self.right_value in self.left_value
def not_contains(self):
return self.right_value not in self.left_value
CompareOperatorInstance = Union[
StringComparisonOperator,
NumberComparisonOperator,
BooleanComparisonOperator,
ArrayComparisonOperator,
ObjectComparisonOperator
]
CompareOperatorType = Type[CompareOperatorInstance]
class ConditionExpressionResolver:
CONDITION_OPERATOR_MAP = {
str: StringComparisonOperator,
bool: BooleanComparisonOperator,
int: NumberComparisonOperator,
float: NumberComparisonOperator,
list: ArrayComparisonOperator,
dict: ObjectComparisonOperator,
}
@classmethod
def resolve_by_value(cls, value) -> CompareOperatorType:
for t, op in cls.CONDITION_OPERATOR_MAP.items():
if isinstance(value, t):
return op
raise TypeError(f"Unsupported variable type: {type(value)}")

View File

@@ -4,11 +4,11 @@
从文件系统加载预定义的工作流模板 从文件系统加载预定义的工作流模板
""" """
import os
import yaml
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import yaml
class TemplateLoader: class TemplateLoader:
"""工作流模板加载器""" """工作流模板加载器"""

View File

@@ -10,6 +10,7 @@
""" """
import logging import logging
import re
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -114,7 +115,9 @@ class VariablePool:
""" """
# 转换为 VariableSelector # 转换为 VariableSelector
if isinstance(selector, str): if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if not selector or len(selector) < 1: if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空") raise ValueError("变量选择器不能为空")

View File

@@ -1,5 +1,6 @@
import re import re
import uuid import uuid
from typing import Any, AsyncGenerator
import json_repair import json_repair
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
@@ -123,7 +124,7 @@ class PromptOptimizerService:
user_id: uuid.UUID, user_id: uuid.UUID,
current_prompt: str, current_prompt: str,
user_require: str user_require: str
) -> OptimizePromptResult: ) -> AsyncGenerator[dict[str, str | Any], Any]:
""" """
Optimize a user-provided prompt using a configured prompt optimizer LLM. Optimize a user-provided prompt using a configured prompt optimizer LLM.
@@ -161,6 +162,7 @@ class PromptOptimizerService:
BusinessException: If the LLM response cannot be parsed as valid JSON BusinessException: If the LLM response cannot be parsed as valid JSON
or does not conform to the expected output format. or does not conform to the expected output format.
""" """
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
model_config = self.get_model_config(tenant_id, model_id) model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id) session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
@@ -202,17 +204,54 @@ class PromptOptimizerService:
messages.extend(session_history[:-1]) # last message is current message messages.extend(session_history[:-1]) # last message is current message
messages.extend([(RoleType.USER.value, rendered_user_message)]) messages.extend([(RoleType.USER.value, rendered_user_message)])
logger.info(f"Prompt optimization message: {messages}") logger.info(f"Prompt optimization message: {messages}")
optim_resp = await llm.ainvoke(messages) buffer = ""
logger.info(optim_resp.content) prompt_started = False
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True) prompt_finished = False
prompt = optim_result.get("prompt") idx = 0
desc = optim_result.get("desc")
return OptimizePromptResult( async for chunk in llm.astream(messages):
prompt=prompt, content = getattr(chunk, "content", chunk)
desc=desc if not content:
continue
buffer += content
cache = buffer[:-20]
# 尝试找到 "prompt": " 开始位置
if prompt_finished:
continue
if not prompt_started:
m = re.search(r'"prompt"\s*:\s*"', cache)
if m:
prompt_started = True
prompt_index = m.end()
idx = prompt_index
else:
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
if m:
prompt_index = m.start()
prompt_finished = True
yield {"type": "delta", "content": buffer[idx:prompt_index]}
else:
yield {"type": "delta", "content": cache[idx:]}
if len(cache) != 0:
idx = len(cache)
# optim_resp = await llm.astream(messages)
logger.info(buffer)
optim_result = json_repair.repair_json(buffer, return_objects=True)
# prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
self.create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=RoleType.ASSISTANT,
content=desc
) )
yield {"type": "done", "desc": optim_result.get("desc")}
@staticmethod @staticmethod
def parser_prompt_variables(prompt: str): def parser_prompt_variables(prompt: str):
try: try:

View File

@@ -410,7 +410,8 @@ class WorkflowService:
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
payload: DraftRunRequest, payload: DraftRunRequest,
config: WorkflowConfig config: WorkflowConfig,
workspace_id: uuid.UUID,
): ):
"""运行工作流 """运行工作流
@@ -484,7 +485,7 @@ class WorkflowService:
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id="", workspace_id=str(workspace_id),
user_id=payload.user_id user_id=payload.user_id
) )
@@ -530,7 +531,8 @@ class WorkflowService:
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
payload: DraftRunRequest, payload: DraftRunRequest,
config: WorkflowConfig config: WorkflowConfig,
workspace_id: uuid.UUID,
): ):
"""运行工作流(流式) """运行工作流(流式)
@@ -603,7 +605,7 @@ class WorkflowService:
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id="", workspace_id=str(workspace_id),
user_id=payload.user_id user_id=payload.user_id
): ):
# 直接转发 executor 的事件(已经是正确的格式) # 直接转发 executor 的事件(已经是正确的格式)

View File

@@ -25,7 +25,7 @@ Rules
Basic Principles Basic Principles
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements. Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements. Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
Structure Rule: Use a clear block structure including [Role], [Task], [Requirements], [Input], [Output], [Constraints] labels. Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
Language Rule: All label languages must fully match the user input language. Language Rule: All label languages must fully match the user input language.
Behavior Guidelines Behavior Guidelines