Merge #104 into develop from feature/20251219_myh

feat(workflow): add support for question classifier in graph construction

* feature/20251219_myh: (11 commits)
  feat(workflow): support variable types(TODO)
  fix(workflow): fix passing of loop variable termination condition
  feat(workflow): add support for passing workspace ID
  feat(workflow): support retrieving variables wrapped in {{}} from variable pool
  feat(prompt_opt): support streaming output for prompt optimization API
  feat(workflow): update workflow conditional logic
  feat(workflow): enable front-end to cover pre-rendered non-variable values
  fix(workflow): ensure default values are properly retrieved in HTTP nodes
  refactor(workflow): refactor graph construction to support subgraph building
  Merge branch 'develop' into feature/20251219_myh
  feat(workflow): add support for question classifier in graph construction

Signed-off-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/104
This commit is contained in:
朱文辉
2026-01-05 11:50:19 +08:00
26 changed files with 956 additions and 576 deletions

View File

@@ -597,7 +597,8 @@ async def draft_run(
async for event in workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config
config=config,
workspace_id=current_user.current_workspace_id
):
# 提取事件类型和数据
event_type = event.get("event", "message")
@@ -627,7 +628,7 @@ async def draft_run(
}
)
result = await workflow_service.run(app_id, payload,config)
result = await workflow_service.run(app_id, payload, config, current_user.current_workspace_id)
logger.debug(
"工作流试运行返回结果",

View File

@@ -1,7 +1,9 @@
import uuid
import json
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
@@ -70,12 +72,12 @@ def get_prompt_session(
SessionMessage(role=role, content=content)
for role, content in history
]
result = SessionHistoryResponse(
session_id=session_id,
messages=messages
)
return success(data=result)
@@ -104,35 +106,25 @@ async def get_prompt_opt(
ApiResponse: Contains the optimized prompt, description, and a list of variables.
"""
service = PromptOptimizerService(db)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.USER,
content=data.message
)
opt_result = await service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.ASSISTANT,
content=opt_result.desc
)
variables = service.parser_prompt_variables(opt_result.prompt)
result = {
"prompt": opt_result.prompt,
"desc": opt_result.desc,
"variables": variables
}
result_schema = OptimizePromptResponse.model_validate(result)
return success(data=result_schema)
async def event_generator():
async for chunk in service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)

View File

@@ -10,11 +10,10 @@ import logging
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.graph_builder import GraphBuilder
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry
@@ -191,155 +190,10 @@ class WorkflowExecutor:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置和相邻且被引用的节点
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
continue
# 记录 start 和 end 节点 ID
if node_type == NodeType.START:
start_node_id = node_id
elif node_type == NodeType.END:
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(len(related_edge)):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.workflow_config.get("edges", []):
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{tgt}"
return router_fn
router_fn = make_router(condition, target)
workflow.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
graph = GraphBuilder(
self.workflow_config,
stream=stream,
).build()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph

View File

@@ -0,0 +1,253 @@
import logging
import uuid
from typing import Any
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph import START, END
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
# TODO: 子图拆解支持
class GraphBuilder:
def __init__(
self,
workflow_config: dict[str, Any],
stream: bool = False,
subgraph: bool = False,
):
self.workflow_config = workflow_config
self.stream = stream
self.subgraph = subgraph
self.start_node_id = None
self.end_node_ids = []
self.graph: StateGraph | CompiledStateGraph | None = None
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def add_nodes(self):
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
if not self.subgraph:
continue
# 记录 start 和 end 节点 ID
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(len(related_edge)):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if self.stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if self.stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if self.stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
self.graph.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
def add_edges(self):
if self.start_node_id:
self.graph.add_edge(START, self.start_node_id)
logger.debug(f"添加边: START -> {self.start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("runtime_vars", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
return router_fn
router_fn = make_router(condition, target)
self.graph.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
return
def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges() # 添加边必须在添加节点之后
return self.graph.compile()

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig
@@ -19,7 +21,7 @@ class AssignmentItem(BaseModel):
description="Assignment operator",
)
value: str | list[str] = Field(
value: Any = Field(
...,
description="Value(s) to assign to the variable(s)",
)

View File

@@ -2,7 +2,6 @@ import logging
import re
from typing import Any
from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
@@ -29,6 +28,7 @@ class AssignerNode(BaseNode):
None or the result of the assignment operation.
"""
# Initialize a variable pool for accessing conversation, node, and system variables
logger.info(f"节点 {self.node_id} 开始执行")
pool = VariablePool(state)
for assignment in self.typed_config.assignments:
# Get the target variable selector (e.g., "conv.test")
@@ -45,14 +45,13 @@ class AssignerNode(BaseNode):
# Get the value or expression to assign
value = assignment.value
if isinstance(value, list):
value = '.'.join(value)
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
pattern = r"\{\{\s*(.*?)\s*\}\}"
if isinstance(value, str):
expression = re.match(pattern, value)
if expression:
expression = expression.group(1)
expression = re.sub(pattern, r"\1", expression).strip()
value = self.get_variable(expression, state)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
@@ -63,6 +62,8 @@ class AssignerNode(BaseNode):
# Execute the configured assignment operation
match assignment.operation:
case AssignmentOperator.COVER:
operator.assign()
case AssignmentOperator.ASSIGN:
operator.assign()
case AssignmentOperator.CLEAR:

View File

@@ -4,13 +4,16 @@
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}"
class VariableType(StrEnum):
"""变量类型枚举"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
@@ -22,43 +25,94 @@ class VariableType(StrEnum):
ARRAY_OBJECT = "array[object]"
class TypedVariable(BaseModel):
"""
TODO: 强类型限制
Strongly typed variable that validates value on assignment.
"""
value: Any = Field(..., description="Variable value")
type: VariableType = Field(..., description="Declared type of the variable")
model_config = ConfigDict(
validate_assignment=True
)
def __setattr__(self, name, value):
if name == "value":
self._validate_value(value)
if name == "type":
raise RuntimeError("Cannot modify variable type at runtime")
super().__setattr__(name, value)
def _validate_value(self, v: Any):
t = self.type
match t:
case VariableType.STRING:
if not isinstance(v, str):
raise TypeError("Variable value does not match type STRING")
case VariableType.BOOLEAN:
if not isinstance(v, bool):
raise TypeError("Variable value does not match type BOOLEAN")
case VariableType.NUMBER:
if not isinstance(v, (int, float)):
raise TypeError("Variable value does not match type NUMBER")
case VariableType.OBJECT:
if not isinstance(v, dict):
raise TypeError("Variable value does not match type OBJECT")
case VariableType.ARRAY_STRING:
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
raise TypeError("Variable value does not match type ARRAY_STRING")
case VariableType.ARRAY_NUMBER:
if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v):
raise TypeError("Variable value does not match type ARRAY_NUMBER")
case VariableType.ARRAY_BOOLEAN:
if not isinstance(v, list) or not all(isinstance(i, bool) for i in v):
raise TypeError("Variable value does not match type ARRAY_BOOLEAN")
case VariableType.ARRAY_OBJECT:
if not isinstance(v, list) or not all(isinstance(i, dict) for i in v):
raise TypeError("Variable value does not match type ARRAY_OBJECT")
case _:
raise TypeError(f"Unknown variable type: {t}")
class VariableDefinition(BaseModel):
"""变量定义
定义工作流或节点的输入/输出变量。
这是一个通用的数据结构,可以在多个地方使用。
"""
name: str = Field(
...,
description="变量名称"
)
type: VariableType = Field(
default=VariableType.STRING,
description="变量类型"
)
required: bool = Field(
default=False,
description="是否必需"
)
default: str | int | float | bool | list | dict | None = Field(
default=None,
description="默认值"
)
description: str | None = Field(
default=None,
description="变量描述"
)
max_length: int = Field(
default=200,
description="只对字符串类型生效"
)
class Config:
json_schema_extra = {
"examples": [
@@ -96,22 +150,22 @@ class BaseNodeConfig(BaseModel):
- description: 节点描述
- tags: 节点标签(用于分类和搜索)
"""
name: str | None = Field(
default=None,
description="节点名称(显示名称),如果不设置则使用节点 ID"
)
description: str | None = Field(
default=None,
description="节点描述,说明节点的作用"
)
tags: list[str] = Field(
default_factory=list,
description="节点标签,用于分类和搜索"
)
class Config:
"""Pydantic 配置"""
# 允许额外字段(向后兼容)

View File

@@ -356,7 +356,8 @@ class BaseNode(ABC):
**final_output,
"runtime_vars": {
self.node_id: runtime_var
}
},
"looping": state["looping"]
}
# Add streaming buffer for non-End nodes

View File

@@ -1,7 +1,9 @@
from typing import Any
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class CycleVariable(BaseNodeConfig):
@@ -9,18 +11,25 @@ class CycleVariable(BaseNodeConfig):
...,
description="Name of the loop variable"
)
type: VariableType = Field(
...,
description="Data type of the loop variable"
)
value: str = Field(
input_type: ValueInputType = Field(
...,
description="Input type of the loop variable"
)
value: Any = Field(
...,
description="Initial or current value of the loop variable"
)
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
operator: ComparisonOperator = Field(
...,
description="Operator used to compare the left and right operands"
)
@@ -30,11 +39,16 @@ class ConditionDetail(BaseModel):
description="Left-hand operand of the comparison expression"
)
right: str = Field(
right: Any = Field(
...,
description="Right-hand operand of the comparison expression"
)
input_type: ValueInputType = Field(
...,
description="Input type of the loop variable"
)
class ConditionsConfig(BaseModel):
"""Configuration for loop condition evaluation"""

View File

@@ -3,10 +3,11 @@ from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
from app.core.workflow.expression_evaluator import evaluate_expression
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,11 +15,13 @@ logger = logging.getLogger(__name__)
class LoopRuntime:
"""
Runtime executor for loop nodes in a workflow.
Runtime executor for a loop node in a workflow graph.
Handles iterative execution of a loop node according to defined loop variables
and conditional expressions. Supports maximum loop count and loop control
through the workflow state.
This class is responsible for executing a loop node at runtime:
- Initializing loop-scoped variables
- Evaluating loop continuation conditions
- Repeatedly invoking a compiled sub-graph
- Enforcing maximum loop count and external stop signals
"""
def __init__(
@@ -29,13 +32,13 @@ class LoopRuntime:
state: WorkflowState,
):
"""
Initialize the loop runtime.
Initialize the loop runtime executor.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing loop node configuration.
state: Current workflow state at the point of loop execution.
graph: A compiled LangGraph state graph representing the loop body.
node_id: The unique identifier of the loop node in the workflow.
config: Raw configuration dictionary for the loop node.
state: The current workflow state before entering the loop.
"""
self.graph = graph
self.state = state
@@ -46,12 +49,15 @@ class LoopRuntime:
"""
Initialize workflow state for loop execution.
- Evaluates initial values of loop variables.
- Stores loop variables in runtime_vars and node_outputs.
- Marks the loop as active by setting 'looping' to True.
This method:
- Evaluates initial values of loop variables
- Stores loop variables into both `runtime_vars` and `node_outputs`
under the current loop node's scope
- Creates a shallow copy of the workflow state
- Marks the loop as active by setting `looping = True`
Returns:
A copy of the workflow state prepared for the loop execution.
WorkflowState: A prepared workflow state used for loop execution.
"""
pool = VariablePool(self.state)
# 循环变量
@@ -61,7 +67,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
@@ -70,7 +76,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
loopstate = WorkflowState(
@@ -79,49 +85,93 @@ class LoopRuntime:
loopstate["looping"] = True
return loopstate
def _get_loop_expression(self):
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
"""
Build the Python boolean expression for evaluating the loop condition.
Dispatch and execute a comparison operator against a resolved
CompareOperatorInstance.
- Converts each condition in the loop configuration into a Python expression string.
- Combines multiple conditions with the configured logical operator (AND/OR).
Args:
operator: A ComparisonOperator enum value.
instance: A CompareOperatorInstance bound to concrete operands.
Returns:
A string representing the combined loop condition expression.
Any: The evaluation result, typically a boolean.
"""
branch_conditions = [
ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
for condition in self.typed_config.condition.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
match operator:
case ComparisonOperator.EMPTY:
return instance.empty()
case ComparisonOperator.NOT_EMPTY:
return instance.not_empty()
case ComparisonOperator.CONTAINS:
return instance.contains()
case ComparisonOperator.NOT_CONTAINS:
return instance.not_contains()
case ComparisonOperator.START_WITH:
return instance.startswith()
case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
return combined_condition
def evaluate_conditional(self, state) -> bool:
"""
Evaluate the loop continuation condition at runtime.
This method:
- Resolves all condition expressions against the current workflow state
- Evaluates each comparison expression immediately
- Combines results using the configured logical operator (AND / OR)
Args:
state: The current workflow state during loop execution.
Returns:
bool: True if the loop should continue, False otherwise.
"""
conditions = []
for expression in self.typed_config.condition.expressions:
left_value = VariablePool(state).get(expression.left)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
VariablePool(state),
expression.left,
expression.right,
expression.input_type
)
conditions.append(self._evaluate(expression.operator, evaluator))
if self.typed_config.condition.logical_operator == LogicOperator.AND:
return all(conditions)
else:
return any(conditions)
async def run(self):
"""
Execute the loop node until the condition is no longer met, the loop is
manually stopped, or the maximum loop count is reached.
Execute the loop node until termination conditions are met.
The loop terminates when any of the following occurs:
- The loop condition evaluates to False
- The `looping` flag in the workflow state is set to False
- The maximum loop count is reached
Returns:
The final runtime variables of this loop node after completion.
dict[str, Any]: The final runtime variables of this loop node.
"""
loopstate = self._init_loop_state()
expression = self._get_loop_expression()
loop_variable_pool = VariablePool(loopstate)
loop_time = self.typed_config.max_loop
while evaluate_condition(
expression=expression,
variables=loop_variable_pool.get_all_conversation_vars(),
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate)
loop_time -= 1

View File

@@ -1,10 +1,9 @@
import logging
from typing import Any
from langgraph.graph import StateGraph, START, END
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
@@ -17,12 +16,18 @@ logger = logging.getLogger(__name__)
class CycleGraphNode(BaseNode):
"""
Node representing a cycle (loop) subgraph within the workflow.
Node representing a cyclic (loop or iteration) subgraph within the workflow.
This node manages internal loop/iteration nodes, builds a subgraph
for execution, handles conditional routing, and executes loop
or iteration logic based on node type.
A CycleGraphNode is a structural node that:
- Extracts a group of nodes marked as belonging to the same cycle
- Builds an isolated internal StateGraph (subgraph)
- Delegates runtime execution to LoopRuntime or IterationRuntime
depending on the node type
This node itself does NOT execute business logic directly.
It acts as a container and execution controller for a subgraph.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
@@ -38,16 +43,23 @@ class CycleGraphNode(BaseNode):
def pure_cycle_graph(self) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Extract cycle-scoped nodes and internal edges from the workflow configuration.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
This method:
- Identifies all nodes marked with `cycle == self.node_id`
- Collects edges that fully connect cycle nodes
- Removes extracted nodes and edges from the global workflow configuration
Safety check:
- Raises an error if a cycle node is connected to an external node
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
tuple[list, list]:
- cycle_nodes: Nodes belonging to this cycle
- cycle_edges: Edges connecting nodes within the cycle
Raises:
ValueError: If a cycle node is improperly connected to an external node.
"""
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
@@ -83,131 +95,41 @@ class CycleGraphNode(BaseNode):
return cycle_nodes, cycle_edges
def create_node(self):
"""
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
Special handling is applied for conditional nodes to generate
edge conditions based on node outputs.
"""
from app.core.workflow.nodes import NodeFactory
for node in self.cycle_nodes:
node_type = node.get("type")
node_id = node.get("id")
if node_type == NodeType.CYCLE_START:
self.start_node_id = node_id
continue
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
def create_edge(self):
"""
Connect nodes within the cycle subgraph by adding edges to the internal graph.
Conditional edges are routed based on evaluated expressions.
Start and end nodes are connected to global START and END nodes.
"""
for edge in self.cycle_edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(START, target)
logger.debug(f"添加边: {source} -> {target}")
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
self.graph.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
def build_graph(self):
"""
Build the internal subgraph for the cycle node.
Build and compile the internal subgraph for this cycle node.
Steps:
1. Extract cycle nodes and edges.
2. Create node instances and add them to the graph.
3. Connect edges and conditional routes.
4. Compile the graph for execution.
1. Extract cycle nodes and internal edges from the workflow
2. Construct a StateGraph using GraphBuilder in subgraph mode
3. Compile the graph for runtime execution
"""
self.graph = StateGraph(WorkflowState)
from app.core.workflow.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.create_node()
self.create_edge()
self.graph = self.graph.compile()
self.graph = GraphBuilder(
{
"nodes": self.cycle_nodes,
"edges": self.cycle_edges,
},
subgraph=True
).build()
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the cycle node at runtime.
Depending on the node type, runs either a loop (LoopRuntime)
or an iteration (IterationRuntime) over the internal subgraph.
Based on the node type:
- LOOP: Executes LoopRuntime, repeatedly invoking the subgraph
- ITERATION: Executes IterationRuntime, iterating over a collection
Args:
state: Current workflow state.
state: The current workflow state when entering the cycle node.
Returns:
Runtime result of the cycle, typically the final loop/iteration variables.
Any: The runtime result produced by the loop or iteration executor.
Raises:
RuntimeError: If node type is unrecognized.
RuntimeError: If the node type is unsupported.
"""
if self.node_type == NodeType.LOOP:
return await LoopRuntime(

View File

@@ -61,7 +61,7 @@ class EndNode(BaseNode):
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}'
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重

View File

@@ -45,7 +45,8 @@ class LogicOperator(StrEnum):
class AssignmentOperator(StrEnum):
ASSIGN = "assign"
COVER = "cover" # 覆盖
ASSIGN = "assign" # 设置
CLEAR = "clear"
ADD = "add" # +=
@@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum):
NONE = "none"
DEFAULT = "default"
BRANCH = "branch"
class ValueInputType(StrEnum):
VARIABLE = "Variable"
CONSTANT = "Constant"

View File

@@ -63,7 +63,7 @@ class HttpContentTypeConfig(BaseModel):
)
data: list[HttpFormData] | dict | str = Field(
...,
default="",
description="Data of the HTTP request body; type depends on content_type",
)
@@ -98,6 +98,10 @@ class HttpTimeOutConfig(BaseModel):
class HttpRetryConfig(BaseModel):
enable: bool = Field(
...,
description="Enable/disable retry logic",
)
max_attempts: int = Field(
default=1,
description="Maximum number of retry attempts for failed requests",
@@ -124,6 +128,11 @@ class HttpErrorDefaultTamplete(BaseModel):
description="Default HTTP headers returned on error",
)
output: str = Field(
default="SUCCESS",
description="HTTP response body",
)
class HttpErrorHandleConfig(BaseModel):
method: HttpErrorHandle = Field(
@@ -131,8 +140,8 @@ class HttpErrorHandleConfig(BaseModel):
description="Error handling strategy: 'none', 'default', or 'branch'",
)
default: HttpErrorDefaultTamplete = Field(
...,
default: HttpErrorDefaultTamplete | None = Field(
default=None,
description="Default response template for error handling",
)

View File

@@ -165,24 +165,6 @@ class HttpRequestNode(BaseNode):
case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
def build_conditional_edge_expressions(self):
"""
Build conditional edge expressions for workflow branching.
When the HTTP error handling strategy is set to `BRANCH`,
this node exposes a single conditional output labeled "ERROR".
The workflow engine uses this output to create an explicit
error-handling branch for downstream nodes.
Returns:
list[str]:
- ["ERROR"] if error handling strategy is BRANCH
- An empty list if no conditional branching is required
"""
if self.typed_config.error_handle.method == HttpErrorHandle.BRANCH:
return ["ERROR"]
return []
async def execute(self, state: WorkflowState) -> dict | str:
"""
Execute the HTTP request node.

View File

@@ -1,12 +1,13 @@
"""Condition Configuration"""
from typing import Any
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
operator: ComparisonOperator = Field(
...,
description="Comparison operator used to evaluate the condition"
)
@@ -16,17 +17,22 @@ class ConditionDetail(BaseModel):
description="Value to compare against"
)
right: str = Field(
right: Any = Field(
...,
description="Value to compare with"
)
input_type: ValueInputType = Field(
...,
description="Value input type for comparison"
)
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
default=LogicOperator.AND,
description="Logical operator used to combine multiple condition expressions"
)

View File

@@ -1,10 +1,11 @@
import logging
import re
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
logger = logging.getLogger(__name__)
@@ -15,30 +16,36 @@ class IfElseNode(BaseNode):
self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod
def _build_condition_expression(
condition: ConditionDetail,
) -> str:
"""
Build a single boolean condition expression string.
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator:
case ComparisonOperator.EMPTY:
return instance.empty()
case ComparisonOperator.NOT_EMPTY:
return instance.not_empty()
case ComparisonOperator.CONTAINS:
return instance.contains()
case ComparisonOperator.NOT_CONTAINS:
return instance.not_contains()
case ComparisonOperator.START_WITH:
return instance.startswith()
case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
This method does NOT evaluate the condition.
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
"""
Build conditional edge expressions for the If-Else node.
@@ -60,19 +67,28 @@ class IfElseNode(BaseNode):
for case_branch in self.typed_config.cases:
branch_index += 1
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
branch_result = []
for expression in case_branch.expressions:
pattern = r"\{\{\s*(.*?)\s*\}\}"
left_string = re.sub(pattern, r"\1", expression.left).strip()
left_value = self.get_variable(left_string, state)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
self.get_variable_pool(state),
expression.left,
expression.right,
expression.input_type
)
branch_result.append(self._evaluate(expression.operator, evaluator))
if case_branch.logical_operator == LogicOperator.AND:
conditions.append(all(branch_result))
else:
combined_condition = branch_conditions[0]
conditions.append(combined_condition)
condition_res = any(branch_result)
conditions.append(condition_res)
if condition_res:
return conditions
# Default fallback branch
conditions.append("True")
conditions.append(True)
return conditions
@@ -90,10 +106,10 @@ class IfElseNode(BaseNode):
Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
"""
expressions = self.build_conditional_edge_expressions()
expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
if expressions[i]:
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}'
return f'CASE{len(expressions)}'

View File

@@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)

View File

@@ -74,6 +74,7 @@ class NodeFactory:
NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
NodeType.CYCLE_START: StartNode,
NodeType.TOOL: ToolNode,
}

View File

@@ -1,10 +1,73 @@
import json
import re
from abc import ABC
from typing import Union, Type
from typing import Union, Type, NoReturn
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import ValueInputType
from app.core.workflow.variable_pool import VariablePool
class TypeTransformer:
@classmethod
def _fail(cls, value, target) -> NoReturn:
raise TypeError(f"Cannot convert {value!r} to {target} type")
@classmethod
def _json_load(cls, value, target):
try:
return json.loads(value)
except Exception:
cls._fail(value, target)
@classmethod
def transform(cls, variable_literal: str | bool, target_type: VariableType):
match target_type:
case VariableType.STRING:
return str(variable_literal)
case VariableType.NUMBER:
for caster in (int, float):
try:
return caster(variable_literal)
except Exception:
pass
cls._fail(variable_literal, target_type)
case VariableType.BOOLEAN:
if isinstance(variable_literal, bool):
return variable_literal
cls._fail(variable_literal, target_type)
case VariableType.OBJECT:
obj = cls._json_load(variable_literal, target_type)
if isinstance(obj, dict):
return obj
cls._fail(variable_literal, target_type)
case VariableType.ARRAY_BOOLEAN:
return cls._parse_list(variable_literal, bool, target_type)
case VariableType.ARRAY_NUMBER:
return cls._parse_list(variable_literal, (int, float), target_type)
case VariableType.ARRAY_STRING:
return cls._parse_list(variable_literal, str, target_type)
case VariableType.ARRAY_OBJECT:
return cls._parse_list(variable_literal, dict, target_type)
case _:
raise TypeError("Invalid type")
@classmethod
def _parse_list(cls, value, item_type, target):
arr = cls._json_load(value, target)
if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr):
return arr
cls._fail(value, target)
class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right):
self.pool = pool
@@ -19,7 +82,9 @@ class OperatorBase(ABC):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
if not no_right and not isinstance(self.right, self.type_limit):
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
raise TypeError(
f"The value assigned must be of {self.type_limit} type"
)
class StringOperator(OperatorBase):
@@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase):
class ObjectOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = object
self.type_limit = dict
def assign(self) -> None:
self.check()
@@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase):
class AssignmentOperatorResolver:
OPERATOR_MAP = {
str: StringOperator,
bool: BooleanOperator,
int: NumberOperator,
float: NumberOperator,
list: ArrayOperator,
dict: ObjectOperator,
}
@classmethod
def resolve_by_value(cls, value):
if isinstance(value, str):
return StringOperator
elif isinstance(value, bool):
return BooleanOperator
elif isinstance(value, (int, float)):
return NumberOperator
elif isinstance(value, list):
return ArrayOperator
elif isinstance(value, dict):
return ObjectOperator
else:
raise TypeError(f"Unsupported variable type: {type(value)}")
for t, op in cls.OPERATOR_MAP.items():
if isinstance(value, t):
return op
raise TypeError(f"Unsupported variable type: {type(value)}")
AssignmentOperatorInstance = Union[
@@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[
AssignmentOperatorType = Type[AssignmentOperatorInstance]
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
class ConditionBase(ABC):
type_limit: type[str, int, dict, list] = None
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(
self,
pool: VariablePool,
left_selector,
right_selector: str,
input_type: ValueInputType
):
self.pool = pool
self.left_selector = left_selector
self.right_selector = right_selector
self.input_type = input_type
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
self.left_value = self.pool.get(self.left_selector)
self.right_value = self.resolve_right_literal_value()
def _empty(self):
return f"{self.left} == ''"
self.type_limit = getattr(self, "type_limit", None)
def _not_empty(self):
return f"{self.left} != ''"
def resolve_right_literal_value(self):
if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
return self.pool.get(right_expression)
elif self.input_type == ValueInputType.CONSTANT:
return self.right_selector
raise RuntimeError("Unsupported variable type")
def _contains(self):
return f"{self.right} in {self.left}"
def check(self, no_right=False):
left = self.pool.get(self.left_selector.variable_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
if not no_right:
right = self.resolve_right_literal_value()
if not isinstance(right, self.type_limit):
raise TypeError(
f"The compared variable must be of {self.type_limit} type"
)
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startswith(self):
return f'{self.left}.startswith({self.right})'
class StringComparisonOperator(ConditionBase):
type_limit = str
def _endswith(self):
return f'{self.left}.endswith({self.right})'
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def _eq(self):
return f"{self.left} == {self.right}"
def empty(self):
self.check(no_right=True)
return self.left_value == ""
def _ne(self):
return f"{self.left} != {self.right}"
def not_empty(self):
return not self.empty()
def _lt(self):
return f"{self.left} < {self.right}"
def contains(self):
self.check()
return self.right_value in self.left_value
def _le(self):
return f"{self.left} <= {self.right}"
def not_contains(self):
return self.right_value not in self.left_value
def _gt(self):
return f"{self.left} > {self.right}"
def startswith(self):
self.check()
return self.left_value.startswith(self.right_value)
def _ge(self):
return f"{self.left} >= {self.right}"
def endswith(self):
return self.left_value.endswith(self.right_value)
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startswith()
case ComparisonOperator.END_WITH:
return self._endswith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
class NumberComparisonOperator(ConditionBase):
type_limit = (int, float)
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def empty(self):
return self.left_value == 0
def not_empty(self):
return self.left_value != 0
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
def lt(self):
return self.left_value < self.right_value
def le(self):
return self.left_value <= self.right_value
def gt(self):
return self.left_value > self.right_value
def ge(self):
return self.left_value >= self.right_value
class BooleanComparisonOperator(ConditionBase):
type_limit = bool
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
class ObjectComparisonOperator(ConditionBase):
type_limit = dict
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
class ArrayComparisonOperator(ConditionBase):
type_limit = list
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
def contains(self):
return self.right_value in self.left_value
def not_contains(self):
return self.right_value not in self.left_value
CompareOperatorInstance = Union[
StringComparisonOperator,
NumberComparisonOperator,
BooleanComparisonOperator,
ArrayComparisonOperator,
ObjectComparisonOperator
]
CompareOperatorType = Type[CompareOperatorInstance]
class ConditionExpressionResolver:
CONDITION_OPERATOR_MAP = {
str: StringComparisonOperator,
bool: BooleanComparisonOperator,
int: NumberComparisonOperator,
float: NumberComparisonOperator,
list: ArrayComparisonOperator,
dict: ObjectComparisonOperator,
}
@classmethod
def resolve_by_value(cls, value) -> CompareOperatorType:
for t, op in cls.CONDITION_OPERATOR_MAP.items():
if isinstance(value, t):
return op
raise TypeError(f"Unsupported variable type: {type(value)}")

View File

@@ -1,4 +1,4 @@
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]

View File

@@ -4,11 +4,11 @@
从文件系统加载预定义的工作流模板
"""
import os
import yaml
from pathlib import Path
from typing import Optional
import yaml
class TemplateLoader:
"""工作流模板加载器"""

View File

@@ -10,6 +10,7 @@
"""
import logging
import re
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
@@ -28,7 +29,7 @@ class VariableSelector:
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
"""初始化变量选择器
@@ -37,11 +38,11 @@ class VariableSelector:
"""
if not path or len(path) < 1:
raise ValueError("变量路径不能为空")
self.path = path
self.namespace = path[0] # sys, var, 或 node_id
self.key = path[1] if len(path) > 1 else None
@classmethod
def from_string(cls, selector_str: str) -> "VariableSelector":
"""从字符串创建选择器
@@ -58,10 +59,10 @@ class VariableSelector:
"""
path = selector_str.split(".")
return cls(path)
def __str__(self) -> str:
return ".".join(self.path)
def __repr__(self) -> str:
return f"VariableSelector({self.path})"
@@ -84,7 +85,7 @@ class VariablePool:
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""
def __init__(self, state: "WorkflowState"):
"""初始化变量池
@@ -92,7 +93,7 @@ class VariablePool:
state: 工作流状态LangGraph State
"""
self.state = state
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
@@ -114,13 +115,15 @@ class VariablePool:
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
namespace = selector[0]
try:
# 系统变量
if namespace == "sys":
@@ -128,30 +131,30 @@ class VariablePool:
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
node_var = runtime_vars[node_id]
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
@@ -166,14 +169,14 @@ class VariablePool:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return result
except KeyError:
if default is not None:
return default
raise
def set(self, selector: list[str] | str, value: Any):
"""设置变量值
@@ -192,17 +195,17 @@ class VariablePool:
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
namespace = selector[0]
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
@@ -214,9 +217,9 @@ class VariablePool:
self.state["variables"]["conv"][key] = value
elif namespace in self.state["cycle_nodes"]:
self.state["runtime_vars"][namespace][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
"""检查变量是否存在
@@ -237,7 +240,7 @@ class VariablePool:
return True
except KeyError:
return False
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
@@ -245,7 +248,7 @@ class VariablePool:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
@@ -253,7 +256,7 @@ class VariablePool:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
@@ -261,7 +264,7 @@ class VariablePool:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
@@ -272,7 +275,7 @@ class VariablePool:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
def to_dict(self) -> dict[str, Any]:
"""导出为字典
@@ -284,12 +287,12 @@ class VariablePool:
"conversation": self.get_all_conversation_vars(),
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
}
def __repr__(self) -> str:
sys_vars = self.get_all_system_vars()
conv_vars = self.get_all_conversation_vars()
runtime_vars = self.get_all_node_outputs()
return (
f"VariablePool(\n"
f" system_vars={len(sys_vars)},\n"

View File

@@ -1,5 +1,6 @@
import re
import uuid
from typing import Any, AsyncGenerator
import json_repair
from langchain_core.prompts import ChatPromptTemplate
@@ -123,7 +124,7 @@ class PromptOptimizerService:
user_id: uuid.UUID,
current_prompt: str,
user_require: str
) -> OptimizePromptResult:
) -> AsyncGenerator[dict[str, str | Any], Any]:
"""
Optimize a user-provided prompt using a configured prompt optimizer LLM.
@@ -161,6 +162,7 @@ class PromptOptimizerService:
BusinessException: If the LLM response cannot be parsed as valid JSON
or does not conform to the expected output format.
"""
self.create_message(tenant_id, session_id, user_id, role=RoleType.USER, content=user_require)
model_config = self.get_model_config(tenant_id, model_id)
session_history = self.get_session_message_history(session_id=session_id, user_id=user_id)
@@ -202,17 +204,54 @@ class PromptOptimizerService:
messages.extend(session_history[:-1]) # last message is current message
messages.extend([(RoleType.USER.value, rendered_user_message)])
logger.info(f"Prompt optimization message: {messages}")
optim_resp = await llm.ainvoke(messages)
logger.info(optim_resp.content)
optim_result = json_repair.repair_json(optim_resp.content, return_objects=True)
prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
buffer = ""
prompt_started = False
prompt_finished = False
idx = 0
return OptimizePromptResult(
prompt=prompt,
desc=desc
async for chunk in llm.astream(messages):
content = getattr(chunk, "content", chunk)
if not content:
continue
buffer += content
cache = buffer[:-20]
# 尝试找到 "prompt": " 开始位置
if prompt_finished:
continue
if not prompt_started:
m = re.search(r'"prompt"\s*:\s*"', cache)
if m:
prompt_started = True
prompt_index = m.end()
idx = prompt_index
else:
m = re.search(r'"\s*,\s*\\?n?\s*"desc"\s*:\s*"', buffer)
if m:
prompt_index = m.start()
prompt_finished = True
yield {"type": "delta", "content": buffer[idx:prompt_index]}
else:
yield {"type": "delta", "content": cache[idx:]}
if len(cache) != 0:
idx = len(cache)
# optim_resp = await llm.astream(messages)
logger.info(buffer)
optim_result = json_repair.repair_json(buffer, return_objects=True)
# prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
self.create_message(
tenant_id=tenant_id,
session_id=session_id,
user_id=user_id,
role=RoleType.ASSISTANT,
content=desc
)
yield {"type": "done", "desc": optim_result.get("desc")}
@staticmethod
def parser_prompt_variables(prompt: str):
try:

View File

@@ -410,7 +410,8 @@ class WorkflowService:
self,
app_id: uuid.UUID,
payload: DraftRunRequest,
config: WorkflowConfig
config: WorkflowConfig,
workspace_id: uuid.UUID,
):
"""运行工作流
@@ -484,7 +485,7 @@ class WorkflowService:
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id="",
workspace_id=str(workspace_id),
user_id=payload.user_id
)
@@ -530,7 +531,8 @@ class WorkflowService:
self,
app_id: uuid.UUID,
payload: DraftRunRequest,
config: WorkflowConfig
config: WorkflowConfig,
workspace_id: uuid.UUID,
):
"""运行工作流(流式)
@@ -603,7 +605,7 @@ class WorkflowService:
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id="",
workspace_id=str(workspace_id),
user_id=payload.user_id
):
# 直接转发 executor 的事件(已经是正确的格式)

View File

@@ -25,7 +25,7 @@ Rules
Basic Principles
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
Structure Rule: Use a clear block structure including [Role], [Task], [Requirements], [Input], [Output], [Constraints] labels.
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
Language Rule: All label languages must fully match the user input language.
Behavior Guidelines