Files
MemoryBear/api/tests/workflow/nodes/test_start_node.py

736 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6
import pytest
from app.core.workflow.nodes import StartNode
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from tests.workflow.nodes.base import (
simple_state,
simple_vairable_pool,
TEST_EXECUTION_ID,
TEST_WORKSPACE_ID,
TEST_USER_ID,
TEST_CONVERSATION_ID,
TEST_FILE
)
async def create_variable_pool_with_inputs(message: str, input_variables: dict = None):
"""创建带有自定义输入变量的变量池"""
variable_pool = VariablePool()
sys_vars = {
"message": (message, VariableType.STRING),
"conversation_id": (TEST_CONVERSATION_ID, VariableType.STRING),
"execution_id": (TEST_EXECUTION_ID, VariableType.STRING),
"workspace_id": (TEST_WORKSPACE_ID, VariableType.STRING),
"user_id": (TEST_USER_ID, VariableType.STRING),
"input_variables": (input_variables or {}, VariableType.OBJECT),
"files": ([TEST_FILE], VariableType.ARRAY_FILE)
}
for key, var_def in sys_vars.items():
value = var_def[0]
var_type = var_def[1]
await variable_pool.new(
namespace='sys',
key=key,
value=value,
var_type=VariableType(var_type),
mut=False # 系统变量不可变
)
return variable_pool
# 基础配置 - 无自定义变量
BASIC_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": []
}
}
# 带单个自定义变量的配置
SINGLE_VARIABLE_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
}
]
}
}
# 带多个自定义变量的配置
MULTI_VARIABLES_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
},
{
"name": "enable_cache",
"type": "boolean",
"required": False,
"default": True,
"description": "是否启用缓存"
}
]
}
}
# 带必需变量的配置
REQUIRED_VARIABLE_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "api_key",
"type": "string",
"required": True,
"description": "API密钥"
}
]
}
}
# 混合必需和可选变量的配置
MIXED_VARIABLES_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "user_id",
"type": "string",
"required": True,
"description": "用户ID"
},
{
"name": "timeout",
"type": "number",
"required": False,
"default": 30,
"description": "超时时间(秒)"
}
]
}
}
# 不同类型变量的配置
DIFFERENT_TYPES_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "name",
"type": "string",
"required": False,
"default": "default_name",
"description": "名称"
},
{
"name": "count",
"type": "number",
"required": False,
"default": 0,
"description": "计数"
},
{
"name": "enabled",
"type": "boolean",
"required": False,
"default": False,
"description": "是否启用"
},
{
"name": "tags",
"type": "array[string]",
"required": False,
"default": [],
"description": "标签列表"
},
{
"name": "config",
"type": "object",
"required": False,
"default": {},
"description": "配置对象"
}
]
}
}
# ==================== 基础功能测试 ====================
@pytest.mark.asyncio
async def test_start_node_basic():
"""测试基础 Start 节点(无自定义变量)"""
state = simple_state()
variable_pool = await simple_vairable_pool("test message")
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "message" in result
assert "execution_id" in result
assert "conversation_id" in result
assert "workspace_id" in result
assert "user_id" in result
assert result["message"] == "test message"
@pytest.mark.asyncio
async def test_start_node_system_variables():
"""测试系统变量输出"""
state = simple_state()
variable_pool = await simple_vairable_pool("hello world")
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
assert result["message"] == "hello world"
assert result["execution_id"] == state["execution_id"]
assert result["workspace_id"] == state["workspace_id"]
assert result["user_id"] == state["user_id"]
# ==================== 自定义变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_single_variable_with_default():
"""测试单个自定义变量使用默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "language" in result
assert result["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_single_variable_with_input():
"""测试单个自定义变量使用输入值"""
state = simple_state()
# 使用带输入变量的变量池
input_vars = {"language": "en-US"}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "language" in result
assert result["language"] == "en-US"
@pytest.mark.asyncio
async def test_start_node_multi_variables_with_defaults():
"""测试多个自定义变量使用默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert "language" in result
assert "max_length" in result
assert "enable_cache" in result
assert result["language"] == "zh-CN"
assert result["max_length"] == 1000
assert result["enable_cache"] is True
@pytest.mark.asyncio
async def test_start_node_multi_variables_with_inputs():
"""测试多个自定义变量使用输入值"""
state = simple_state()
# 使用带输入变量的变量池
input_vars = {
"language": "ja-JP",
"max_length": 2000,
"enable_cache": False
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["language"] == "ja-JP"
assert result["max_length"] == 2000
assert result["enable_cache"] is False
@pytest.mark.asyncio
async def test_start_node_partial_inputs():
"""测试部分输入变量,其他使用默认值"""
state = simple_state()
# 只设置部分输入变量
input_vars = {
"language": "fr-FR"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["language"] == "fr-FR" # 使用输入值
assert result["max_length"] == 1000 # 使用默认值
assert result["enable_cache"] is True # 使用默认值
# ==================== 必需变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_required_variable_provided():
"""测试提供必需变量"""
state = simple_state()
# 提供必需变量
input_vars = {
"api_key": "test_api_key_12345"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "api_key" in result
assert result["api_key"] == "test_api_key_12345"
@pytest.mark.asyncio
async def test_start_node_required_variable_missing():
"""测试缺少必需变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不提供必需变量
with pytest.raises(ValueError) as exc_info:
await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "缺少必需的输入变量" in str(exc_info.value)
assert "api_key" in str(exc_info.value)
@pytest.mark.asyncio
async def test_start_node_mixed_variables():
"""测试混合必需和可选变量"""
state = simple_state()
# 只提供必需变量
input_vars = {
"user_id": "user_123"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["user_id"] == "user_123" # 必需变量
assert result["timeout"] == 30 # 可选变量使用默认值
@pytest.mark.asyncio
async def test_start_node_mixed_variables_all_provided():
"""测试混合变量全部提供"""
state = simple_state()
# 提供所有变量
input_vars = {
"user_id": "user_456",
"timeout": 60
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["user_id"] == "user_456"
assert result["timeout"] == 60
# ==================== 不同类型变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_different_types_defaults():
"""测试不同类型变量的默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool)
assert result["name"] == "default_name"
assert result["count"] == 0
assert result["enabled"] is False
assert result["tags"] == []
assert result["config"] == {}
@pytest.mark.asyncio
async def test_start_node_different_types_inputs():
"""测试不同类型变量的输入值"""
state = simple_state()
# 提供不同类型的输入值
input_vars = {
"name": "custom_name",
"count": 100,
"enabled": True,
"tags": ["tag1", "tag2", "tag3"],
"config": {"key": "value", "nested": {"data": 123}}
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool)
assert result["name"] == "custom_name"
assert result["count"] == 100
assert result["enabled"] is True
assert result["tags"] == ["tag1", "tag2", "tag3"]
assert result["config"] == {"key": "value", "nested": {"data": 123}}
# ==================== 边界情况测试 ====================
@pytest.mark.asyncio
async def test_start_node_empty_message():
"""测试空消息"""
state = simple_state()
variable_pool = await simple_vairable_pool("")
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
assert result["message"] == ""
@pytest.mark.asyncio
async def test_start_node_no_input_variables():
"""测试没有输入变量的情况"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不设置 input_variables
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
# 应该使用默认值
assert result["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_empty_input_variables():
"""测试空的输入变量字典"""
state = simple_state()
# 设置空的输入变量字典
variable_pool = await create_variable_pool_with_inputs("test", {})
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
# 应该使用默认值
assert result["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_extra_input_variables():
"""测试额外的输入变量(未在配置中定义)"""
state = simple_state()
# 提供额外的未定义变量
input_vars = {
"language": "de-DE",
"extra_var": "should_be_ignored"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert result["language"] == "de-DE"
assert "extra_var" not in result # 额外变量不应该出现在输出中
# ==================== 数组类型变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_array_string_variable():
"""测试字符串数组变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "categories",
"type": "array[string]",
"required": False,
"default": ["default1", "default2"],
"description": "分类列表"
}
]
}
}
input_vars = {
"categories": ["cat1", "cat2", "cat3"]
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert result["categories"] == ["cat1", "cat2", "cat3"]
@pytest.mark.asyncio
async def test_start_node_array_number_variable():
"""测试数字数组变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "scores",
"type": "array[number]",
"required": False,
"default": [0, 0, 0],
"description": "分数列表"
}
]
}
}
input_vars = {
"scores": [85, 90, 95]
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert result["scores"] == [85, 90, 95]
@pytest.mark.asyncio
async def test_start_node_array_object_variable():
"""测试对象数组变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "users",
"type": "array[object]",
"required": False,
"default": [],
"description": "用户列表"
}
]
}
}
input_vars = {
"users": [
{"name": "Alice", "age": 25},
{"name": "Bob", "age": 30}
]
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert len(result["users"]) == 2
assert result["users"][0]["name"] == "Alice"
assert result["users"][1]["age"] == 30
# ==================== 复杂场景测试 ====================
@pytest.mark.asyncio
async def test_start_node_complex_object():
"""测试复杂对象变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "settings",
"type": "object",
"required": False,
"default": {"theme": "light"},
"description": "设置对象"
}
]
}
}
input_vars = {
"settings": {
"theme": "dark",
"language": "zh-CN",
"notifications": {
"email": True,
"sms": False
},
"features": ["feature1", "feature2"]
}
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert result["settings"]["theme"] == "dark"
assert result["settings"]["language"] == "zh-CN"
assert result["settings"]["notifications"]["email"] is True
assert result["settings"]["features"] == ["feature1", "feature2"]
@pytest.mark.asyncio
async def test_start_node_zero_and_false_values():
"""测试零值和 False 值(确保不被当作空值)"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "count",
"type": "number",
"required": False,
"default": 10,
"description": "计数"
},
{
"name": "enabled",
"type": "boolean",
"required": False,
"default": True,
"description": "是否启用"
}
]
}
}
input_vars = {
"count": 0,
"enabled": False
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
# 0 和 False 应该被正确识别,而不是使用默认值
assert result["count"] == 0
assert result["enabled"] is False
@pytest.mark.asyncio
async def test_start_node_output_types():
"""测试输出类型定义"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
node = StartNode(MULTI_VARIABLES_CONFIG, {})
await node.execute(state, variable_pool)
output_types = node._output_types()
# 验证系统变量类型
assert output_types["message"] == VariableType.STRING
assert output_types["execution_id"] == VariableType.STRING
assert output_types["conversation_id"] == VariableType.STRING
assert output_types["workspace_id"] == VariableType.STRING
assert output_types["user_id"] == VariableType.STRING
# 验证自定义变量类型
assert output_types["language"] == VariableType.STRING
assert output_types["max_length"] == VariableType.NUMBER
assert output_types["enable_cache"] == VariableType.BOOLEAN
@pytest.mark.asyncio
async def test_start_node_multiple_executions():
"""测试多次执行 Start 节点"""
state = simple_state()
node = StartNode(SINGLE_VARIABLE_CONFIG, {})
# 第一次执行
variable_pool1 = await create_variable_pool_with_inputs("first message", {})
result1 = await node.execute(state, variable_pool1)
assert result1["message"] == "first message"
assert result1["language"] == "zh-CN"
# 第二次执行(使用新的变量池)
variable_pool2 = await create_variable_pool_with_inputs("second message", {})
result2 = await node.execute(state, variable_pool2)
assert result2["message"] == "second message"
assert result2["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_with_description():
"""测试带描述的变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "api_endpoint",
"type": "string",
"required": True,
"description": "API 端点 URL用于连接外部服务"
}
]
}
}
# 测试缺少必需变量时,错误信息包含描述
with pytest.raises(ValueError) as exc_info:
await StartNode(config, {}).execute(state, variable_pool)
assert "api_endpoint" in str(exc_info.value)
assert "API 端点 URL" in str(exc_info.value)