From 87d53fb9b7d1986ae23837d1d679068d9f0da8eb Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 6 Feb 2026 15:17:58 +0800 Subject: [PATCH] perf(workflow): add tests, adapt some LLM node output formats, optimize sandbox return format --- api/app/core/workflow/nodes/__init__.py | 6 +- api/app/core/workflow/nodes/code/__init__.py | 3 +- api/app/core/workflow/nodes/llm/node.py | 2 +- .../nodes/parameter_extractor/node.py | 3 +- .../nodes/question_classifier/node.py | 2 +- api/tests/workflow/__init__.py | 4 + api/tests/workflow/executor/__init__.py | 4 + .../workflow/executor/test_vairable_pool.py | 622 +++++++++ api/tests/workflow/nodes/__init__.py | 4 + api/tests/workflow/nodes/base.py | 77 ++ .../workflow/nodes/test_assigner_node.py | 834 ++++++++++++ api/tests/workflow/nodes/test_breaker_node.py | 23 + api/tests/workflow/nodes/test_code.py | 279 ++++ api/tests/workflow/nodes/test_end_node.py | 42 + api/tests/workflow/nodes/test_ifelse_node.py | 1127 +++++++++++++++++ .../workflow/nodes/test_jinja_render_node.py | 889 +++++++++++++ api/tests/workflow/nodes/test_llm_node.py | 145 +++ .../nodes/test_parameter_extractor_node.py | 504 ++++++++ .../nodes/test_question_classifier_node.py | 647 ++++++++++ api/tests/workflow/nodes/test_start_node.py | 735 +++++++++++ .../nodes/test_variable_aggregator_node.py | 621 +++++++++ sandbox/app/controllers/sandbox_controller.py | 6 +- sandbox/app/models.py | 2 - sandbox/app/services/nodejs_service.py | 8 +- sandbox/app/services/python_service.py | 2 +- 25 files changed, 6576 insertions(+), 15 deletions(-) create mode 100644 api/tests/workflow/__init__.py create mode 100644 api/tests/workflow/executor/__init__.py create mode 100644 api/tests/workflow/executor/test_vairable_pool.py create mode 100644 api/tests/workflow/nodes/__init__.py create mode 100644 api/tests/workflow/nodes/base.py create mode 100644 api/tests/workflow/nodes/test_assigner_node.py create mode 100644 api/tests/workflow/nodes/test_breaker_node.py create mode 100644 api/tests/workflow/nodes/test_code.py create mode 100644 api/tests/workflow/nodes/test_end_node.py create mode 100644 api/tests/workflow/nodes/test_ifelse_node.py create mode 100644 api/tests/workflow/nodes/test_jinja_render_node.py create mode 100644 api/tests/workflow/nodes/test_llm_node.py create mode 100644 api/tests/workflow/nodes/test_parameter_extractor_node.py create mode 100644 api/tests/workflow/nodes/test_question_classifier_node.py create mode 100644 api/tests/workflow/nodes/test_start_node.py create mode 100644 api/tests/workflow/nodes/test_variable_aggregator_node.py diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 1f2eb15b..885dfbc9 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -18,6 +18,8 @@ from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.tool import ToolNode +from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode +from app.core.workflow.nodes.code import CodeNode __all__ = [ "BaseNode", @@ -35,5 +37,7 @@ __all__ = [ "JinjaRenderNode", "ParameterExtractorNode", "QuestionClassifierNode", - "ToolNode" + "ToolNode", + "CodeNode", + "VariableAggregatorNode" ] diff --git a/api/app/core/workflow/nodes/code/__init__.py b/api/app/core/workflow/nodes/code/__init__.py index 758ab3a5..1235db4f 100644 --- a/api/app/core/workflow/nodes/code/__init__.py +++ b/api/app/core/workflow/nodes/code/__init__.py @@ -1,3 +1,4 @@ +from app.core.workflow.nodes.code.config import CodeNodeConfig from app.core.workflow.nodes.code.node import CodeNode -__all__ = ["CodeNode"] +__all__ = ["CodeNode", "CodeNodeConfig"] diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 4393e1ed..14bcb8ed 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -216,7 +216,7 @@ class LLMNode(BaseNode): logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") # 返回 AIMessage(包含响应元数据) - return response if isinstance(response, AIMessage) else AIMessage(content=content) + return AIMessage(content=content, response_metadata=response.response_metadata) def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录)""" diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 31acaafc..7dec03f1 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -193,7 +193,8 @@ class ParameterExtractorNode(BaseNode): model_resp = await llm.ainvoke(messages) self.response_metadata = model_resp.response_metadata - result = json_repair.repair_json(model_resp.content, return_objects=True) + model_message = self.process_model_output(model_resp.content) + result = json_repair.repair_json(model_message, return_objects=True) logger.info(f"node: {self.node_id} get params:{result}") return result diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 38662b64..7f3d4edb 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -131,7 +131,7 @@ class QuestionClassifierNode(BaseNode): ] response = await llm.ainvoke(messages) - result = response.content.strip() + result = self.process_model_output(response.content) self.response_metadata = response.response_metadata if result in category_names: diff --git a/api/tests/workflow/__init__.py b/api/tests/workflow/__init__.py new file mode 100644 index 00000000..29ad3fc2 --- /dev/null +++ b/api/tests/workflow/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/5 15:36 diff --git a/api/tests/workflow/executor/__init__.py b/api/tests/workflow/executor/__init__.py new file mode 100644 index 00000000..fa60d940 --- /dev/null +++ b/api/tests/workflow/executor/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 14:45 diff --git a/api/tests/workflow/executor/test_vairable_pool.py b/api/tests/workflow/executor/test_vairable_pool.py new file mode 100644 index 00000000..6fb91bec --- /dev/null +++ b/api/tests/workflow/executor/test_vairable_pool.py @@ -0,0 +1,622 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 +import pytest + +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool, VariableSelector + + +# ==================== VariableSelector 测试 ==================== +def test_variable_selector_from_string(): + """测试从字符串创建变量选择器""" + selector = VariableSelector.from_string("sys.message") + + assert selector.namespace == "sys" + assert selector.key == "message" + assert selector.path == ["sys", "message"] + + +def test_variable_selector_from_list(): + """测试从列表创建变量选择器""" + selector = VariableSelector(["conv", "username"]) + + assert selector.namespace == "conv" + assert selector.key == "username" + assert str(selector) == "conv.username" + + +def test_variable_selector_empty_path(): + """测试空路径抛出异常""" + with pytest.raises(ValueError) as exc_info: + VariableSelector([]) + + assert "变量路径不能为空" in str(exc_info.value) + + +def test_variable_selector_single_element(): + """测试单元素路径""" + selector = VariableSelector(["sys"]) + + assert selector.namespace == "sys" + assert selector.key is None + + +# ==================== VariablePool 基础测试 ==================== +@pytest.mark.asyncio +async def test_variable_pool_new_variable(): + """测试创建新变量""" + pool = VariablePool() + + await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True) + + assert pool.has("conv.username") + assert pool.get_value("conv.username") == "Alice" + + +@pytest.mark.asyncio +async def test_variable_pool_new_multiple_variables(): + """测试创建多个变量""" + pool = VariablePool() + + await pool.new("conv", "name", "Bob", VariableType.STRING, mut=True) + await pool.new("conv", "age", 25, VariableType.NUMBER, mut=True) + await pool.new("conv", "active", True, VariableType.BOOLEAN, mut=True) + + assert pool.get_value("conv.name") == "Bob" + assert pool.get_value("conv.age") == 25 + assert pool.get_value("conv.active") is True + + +@pytest.mark.asyncio +async def test_variable_pool_different_namespaces(): + """测试不同命名空间的变量""" + pool = VariablePool() + + await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False) + await pool.new("conv", "message", "World", VariableType.STRING, mut=True) + await pool.new("node1", "output", "Result", VariableType.STRING, mut=False) + + assert pool.get_value("sys.message") == "Hello" + assert pool.get_value("conv.message") == "World" + assert pool.get_value("node1.output") == "Result" + + +# ==================== get_value 测试 ==================== +@pytest.mark.asyncio +async def test_get_value_with_template(): + """测试使用模板语法获取值""" + pool = VariablePool() + + await pool.new("conv", "test", "value", VariableType.STRING, mut=True) + + # 支持模板语法 + assert pool.get_value("{{ conv.test }}") == "value" + assert pool.get_value("{{conv.test}}") == "value" + assert pool.get_value("{{ conv.test}}") == "value" + + +@pytest.mark.asyncio +async def test_get_value_not_exist_strict(): + """测试获取不存在的变量(严格模式)""" + pool = VariablePool() + + with pytest.raises(KeyError) as exc_info: + pool.get_value("conv.nonexistent") + + assert "not exist" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_get_value_not_exist_with_default(): + """测试获取不存在的变量(使用默认值)""" + pool = VariablePool() + + result = pool.get_value("conv.nonexistent", default="default_value", strict=False) + + assert result == "default_value" + + +@pytest.mark.asyncio +async def test_get_value_different_types(): + """测试获取不同类型的变量值""" + pool = VariablePool() + + await pool.new("conv", "str", "text", VariableType.STRING, mut=True) + await pool.new("conv", "num", 42, VariableType.NUMBER, mut=True) + await pool.new("conv", "bool", False, VariableType.BOOLEAN, mut=True) + await pool.new("conv", "arr", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True) + await pool.new("conv", "obj", {"key": "value"}, VariableType.OBJECT, mut=True) + + assert pool.get_value("conv.str") == "text" + assert pool.get_value("conv.num") == 42 + assert pool.get_value("conv.bool") is False + assert pool.get_value("conv.arr") == [1, 2, 3] + assert pool.get_value("conv.obj") == {"key": "value"} + + +# ==================== set 测试 ==================== +@pytest.mark.asyncio +async def test_set_mutable_variable(): + """测试设置可变变量""" + pool = VariablePool() + + await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True) + await pool.set("conv.counter", 10) + + assert pool.get_value("conv.counter") == 10 + + +@pytest.mark.asyncio +async def test_set_immutable_variable(): + """测试设置不可变变量(应该失败)""" + pool = VariablePool() + + await pool.new("sys", "message", "original", VariableType.STRING, mut=False) + + with pytest.raises(KeyError) as exc_info: + await pool.set("sys.message", "modified") + + assert "cannot be modified" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_set_nonexistent_variable(): + """测试设置不存在的变量""" + pool = VariablePool() + + with pytest.raises(KeyError) as exc_info: + await pool.set("conv.nonexistent", "value") + + assert "is not defined" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_set_multiple_times(): + """测试多次设置变量""" + pool = VariablePool() + + await pool.new("conv", "value", "first", VariableType.STRING, mut=True) + await pool.set("conv.value", "second") + await pool.set("conv.value", "third") + + assert pool.get_value("conv.value") == "third" + + +# ==================== has 测试 ==================== +@pytest.mark.asyncio +async def test_has_existing_variable(): + """测试检查存在的变量""" + pool = VariablePool() + + await pool.new("conv", "test", "value", VariableType.STRING, mut=True) + + assert pool.has("conv.test") is True + + +@pytest.mark.asyncio +async def test_has_nonexistent_variable(): + """测试检查不存在的变量""" + pool = VariablePool() + + assert pool.has("conv.nonexistent") is False + + +# ==================== get_literal 测试 ==================== +@pytest.mark.asyncio +async def test_get_literal(): + """测试获取变量的字面量表示""" + pool = VariablePool() + + await pool.new("conv", "num", 42, VariableType.NUMBER, mut=True) + + literal = pool.get_literal("conv.num") + + assert isinstance(literal, str) + + +# ==================== 命名空间操作测试 ==================== +@pytest.mark.asyncio +async def test_get_all_system_vars(): + """测试获取所有系统变量""" + pool = VariablePool() + + await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False) + await pool.new("sys", "user_id", "user123", VariableType.STRING, mut=False) + await pool.new("conv", "other", "value", VariableType.STRING, mut=True) + + sys_vars = pool.get_all_system_vars() + + assert "message" in sys_vars + assert "user_id" in sys_vars + assert "other" not in sys_vars + assert sys_vars["message"] == "Hello" + assert sys_vars["user_id"] == "user123" + + +@pytest.mark.asyncio +async def test_get_all_conversation_vars(): + """测试获取所有会话变量""" + pool = VariablePool() + + await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True) + await pool.new("conv", "score", 100, VariableType.NUMBER, mut=True) + await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False) + + conv_vars = pool.get_all_conversation_vars() + + assert "username" in conv_vars + assert "score" in conv_vars + assert "message" not in conv_vars + assert conv_vars["username"] == "Alice" + assert conv_vars["score"] == 100 + + +@pytest.mark.asyncio +async def test_get_all_node_outputs(): + """测试获取所有节点输出""" + pool = VariablePool() + + await pool.new("node1", "output", "result1", VariableType.STRING, mut=False) + await pool.new("node2", "output", "result2", VariableType.STRING, mut=False) + await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False) + await pool.new("conv", "var", "value", VariableType.STRING, mut=True) + + node_outputs = pool.get_all_node_outputs() + + assert "node1" in node_outputs + assert "node2" in node_outputs + assert "sys" not in node_outputs + assert "conv" not in node_outputs + assert node_outputs["node1"]["output"] == "result1" + assert node_outputs["node2"]["output"] == "result2" + + +@pytest.mark.asyncio +async def test_get_node_output(): + """测试获取指定节点的输出""" + pool = VariablePool() + + await pool.new("node1", "output", "result", VariableType.STRING, mut=False) + await pool.new("node1", "status", "success", VariableType.STRING, mut=False) + + node_output = pool.get_node_output("node1") + + assert node_output["output"] == "result" + assert node_output["status"] == "success" + + +@pytest.mark.asyncio +async def test_get_node_output_not_exist_strict(): + """测试获取不存在的节点输出(严格模式)""" + pool = VariablePool() + + with pytest.raises(KeyError) as exc_info: + pool.get_node_output("nonexistent_node") + + assert "output not exist" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_get_node_output_not_exist_with_default(): + """测试获取不存在的节点输出(使用默认值)""" + pool = VariablePool() + + result = pool.get_node_output("nonexistent_node", defalut=None, strict=False) + + assert result is None + + +# ==================== 复杂场景测试 ==================== +@pytest.mark.asyncio +async def test_variable_pool_new_existing_mutable(): + """测试创建已存在的可变变量(应该更新值)""" + pool = VariablePool() + + await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True) + await pool.new("conv", "counter", 10, VariableType.NUMBER, mut=True) + + assert pool.get_value("conv.counter") == 10 + + +@pytest.mark.asyncio +async def test_variable_pool_new_existing_immutable(): + """测试创建已存在的不可变变量(应该为新值)""" + pool = VariablePool() + + await pool.new("sys", "message", "original", VariableType.STRING, mut=False) + await pool.new("sys", "message", "modified", VariableType.STRING, mut=False) + + # 不可变变量被更新 + assert pool.get_value("sys.message") == "modified" + + +@pytest.mark.asyncio +async def test_variable_pool_zero_and_false_values(): + """测试零值和 False 值""" + pool = VariablePool() + + await pool.new("conv", "zero", 0, VariableType.NUMBER, mut=True) + await pool.new("conv", "false", False, VariableType.BOOLEAN, mut=True) + await pool.new("conv", "empty_str", "", VariableType.STRING, mut=True) + await pool.new("conv", "empty_arr", [], VariableType.ARRAY_NUMBER, mut=True) + await pool.new("conv", "empty_obj", {}, VariableType.OBJECT, mut=True) + + assert pool.get_value("conv.zero") == 0 + assert pool.get_value("conv.false") is False + assert pool.get_value("conv.empty_str") == "" + assert pool.get_value("conv.empty_arr") == [] + assert pool.get_value("conv.empty_obj") == {} + + +@pytest.mark.asyncio +async def test_variable_pool_nested_objects(): + """测试嵌套对象""" + pool = VariablePool() + + nested_obj = { + "user": { + "name": "Alice", + "age": 25, + "address": { + "city": "Beijing" + } + }, + "items": [1, 2, 3] + } + + await pool.new("conv", "data", nested_obj, VariableType.OBJECT, mut=True) + + result = pool.get_value("conv.data") + assert result["user"]["name"] == "Alice" + assert result["user"]["address"]["city"] == "Beijing" + assert result["items"] == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_variable_pool_array_of_objects(): + """测试对象数组""" + pool = VariablePool() + + users = [ + {"name": "Alice", "age": 25}, + {"name": "Bob", "age": 30} + ] + + await pool.new("conv", "users", users, VariableType.ARRAY_OBJECT, mut=True) + + result = pool.get_value("conv.users") + assert len(result) == 2 + assert result[0]["name"] == "Alice" + assert result[1]["age"] == 30 + + +@pytest.mark.asyncio +async def test_variable_pool_to_dict(): + """测试导出为字典""" + pool = VariablePool() + + await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False) + await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True) + await pool.new("node1", "output", "result", VariableType.STRING, mut=False) + + result = pool.to_dict() + + assert "system" in result + assert "conversation" in result + assert "nodes" in result + assert result["system"]["message"] == "Hello" + assert result["conversation"]["username"] == "Alice" + assert result["nodes"]["node1"]["output"] == "result" + + +@pytest.mark.asyncio +async def test_variable_pool_copy(): + """测试复制变量池""" + pool1 = VariablePool() + + await pool1.new("conv", "test", "value", VariableType.STRING, mut=True) + + pool2 = VariablePool() + pool2.copy(pool1) + + assert pool2.get_value("conv.test") == "value" + + # 修改 pool2 不应影响 pool1 + await pool2.set("conv.test", "modified") + assert pool2.get_value("conv.test") == "modified" + assert pool1.get_value("conv.test") == "value" + + +@pytest.mark.asyncio +async def test_variable_pool_repr(): + """测试字符串表示""" + pool = VariablePool() + + await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False) + await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True) + await pool.new("node1", "output", "result", VariableType.STRING, mut=False) + + repr_str = repr(pool) + + assert "VariablePool" in repr_str + assert "system_vars=1" in repr_str + assert "conversation_vars=1" in repr_str + assert "runtime_vars=1" in repr_str + + +# ==================== 并发测试 ==================== +@pytest.mark.asyncio +async def test_variable_pool_concurrent_set(): + """测试并发设置变量""" + import asyncio + + pool = VariablePool() + await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True) + + async def increment(): + for _ in range(100): + current = pool.get_value("conv.counter") + await pool.set("conv.counter", current + 1) + + # 并发执行多个增量操作 + await asyncio.gather(increment(), increment()) + + # 由于有锁保护,最终值应该是 200 + assert pool.get_value("conv.counter") == 200 + + +# ==================== 边界情况测试 ==================== +@pytest.mark.asyncio +async def test_variable_pool_empty(): + """测试空变量池""" + pool = VariablePool() + + assert pool.get_all_system_vars() == {} + assert pool.get_all_conversation_vars() == {} + assert pool.get_all_node_outputs() == {} + + +@pytest.mark.asyncio +async def test_variable_selector_invalid(): + """测试无效的变量选择器""" + pool = VariablePool() + + await pool.new("conv", "test", "value", VariableType.STRING, mut=True) + + # 选择器格式错误 + with pytest.raises(ValueError): + pool.get_value("conv.test.extra") + + +@pytest.mark.asyncio +async def test_variable_pool_special_characters(): + """测试包含特殊字符的变量名""" + pool = VariablePool() + + # 变量名可以包含下划线、数字等 + await pool.new("conv", "user_name_123", "Alice", VariableType.STRING, mut=True) + await pool.new("node_1", "output_data", "result", VariableType.STRING, mut=False) + + assert pool.get_value("conv.user_name_123") == "Alice" + assert pool.get_value("node_1.output_data") == "result" + + +@pytest.mark.asyncio +async def test_variable_pool_large_data(): + """测试大数据量""" + pool = VariablePool() + + # 创建大量变量 + for i in range(100): + await pool.new("conv", f"var_{i}", i, VariableType.NUMBER, mut=True) + + # 验证所有变量都存在 + for i in range(100): + assert pool.get_value(f"conv.var_{i}") == i + + conv_vars = pool.get_all_conversation_vars() + assert len(conv_vars) == 100 + + +@pytest.mark.asyncio +async def test_variable_pool_different_types_same_name(): + """测试不同命名空间中相同名称的变量""" + pool = VariablePool() + + await pool.new("sys", "value", "system", VariableType.STRING, mut=False) + await pool.new("conv", "value", "conversation", VariableType.STRING, mut=True) + await pool.new("node1", "value", "node", VariableType.STRING, mut=False) + + assert pool.get_value("sys.value") == "system" + assert pool.get_value("conv.value") == "conversation" + assert pool.get_value("node1.value") == "node" + + +@pytest.mark.asyncio +async def test_variable_pool_update_type(): + """测试更新变量类型""" + pool = VariablePool() + + # 创建字符串变量 + await pool.new("conv", "data", "text", VariableType.STRING, mut=True) + assert pool.get_value("conv.data") == "text" + + # 更新为数字类型变量类型不可变 + with pytest.raises(TypeError): + await pool.new("conv", "data", 123, VariableType.NUMBER, mut=True) + assert pool.get_value("conv.data") == "text" + + +@pytest.mark.asyncio +async def test_variable_pool_array_types(): + """测试不同类型的数组""" + pool = VariablePool() + + await pool.new("conv", "arr_str", ["a", "b", "c"], VariableType.ARRAY_STRING, mut=True) + await pool.new("conv", "arr_num", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True) + await pool.new("conv", "arr_bool", [True, False], VariableType.ARRAY_BOOLEAN, mut=True) + await pool.new("conv", "arr_obj", [{"id": 1}, {"id": 2}], VariableType.ARRAY_OBJECT, mut=True) + + assert pool.get_value("conv.arr_str") == ["a", "b", "c"] + assert pool.get_value("conv.arr_num") == [1, 2, 3] + assert pool.get_value("conv.arr_bool") == [True, False] + assert pool.get_value("conv.arr_obj") == [{"id": 1}, {"id": 2}] + + +@pytest.mark.asyncio +async def test_variable_pool_namespace_isolation(): + """测试命名空间隔离""" + pool = VariablePool() + + # 在不同命名空间创建变量 + await pool.new("sys", "var1", "sys_value", VariableType.STRING, mut=False) + await pool.new("conv", "var2", "conv_value", VariableType.STRING, mut=True) + await pool.new("node1", "var3", "node_value", VariableType.STRING, mut=False) + + # 获取各命名空间的变量 + sys_vars = pool.get_all_system_vars() + conv_vars = pool.get_all_conversation_vars() + node_outputs = pool.get_all_node_outputs() + + # 验证隔离性 + assert "var1" in sys_vars and "var2" not in sys_vars and "var3" not in sys_vars + assert "var2" in conv_vars and "var1" not in conv_vars and "var3" not in conv_vars + assert "node1" in node_outputs and "var3" in node_outputs["node1"] + + +@pytest.mark.asyncio +async def test_variable_pool_mutability_rules(): + """测试可变性规则""" + pool = VariablePool() + + # 系统变量应该是不可变的 + await pool.new("sys", "immutable", "value", VariableType.STRING, mut=False) + with pytest.raises(KeyError): + await pool.set("sys.immutable", "new_value") + + # 会话变量应该是可变的 + await pool.new("conv", "mutable", "value", VariableType.STRING, mut=True) + await pool.set("conv.mutable", "new_value") + assert pool.get_value("conv.mutable") == "new_value" + + # 节点输出应该是不可变的 + await pool.new("node1", "output", "value", VariableType.STRING, mut=False) + with pytest.raises(KeyError): + await pool.set("node1.output", "new_value") + + +@pytest.mark.asyncio +async def test_variable_pool_template_variations(): + """测试模板语法的各种变体""" + pool = VariablePool() + + await pool.new("conv", "test", "value", VariableType.STRING, mut=True) + + # 各种模板格式都应该工作 + assert pool.get_value("{{conv.test}}") == "value" + assert pool.get_value("{{ conv.test }}") == "value" + assert pool.get_value("{{ conv.test }}") == "value" + assert pool.get_value("{{ conv.test}}") == "value" + assert pool.get_value("{{conv.test }}") == "value" diff --git a/api/tests/workflow/nodes/__init__.py b/api/tests/workflow/nodes/__init__.py new file mode 100644 index 00000000..9297d3c1 --- /dev/null +++ b/api/tests/workflow/nodes/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 14:43 diff --git a/api/tests/workflow/nodes/base.py b/api/tests/workflow/nodes/base.py new file mode 100644 index 00000000..4dfc05ae --- /dev/null +++ b/api/tests/workflow/nodes/base.py @@ -0,0 +1,77 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/5 18:19 +import os + +import pytest + +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE +from app.core.workflow.variable_pool import VariablePool + +TEST_WORKSPACE_ID = "test_workspace_id" +TEST_USER_ID = "test_user_id" +TEST_EXECUTION_ID = "test_execution_id" +TEST_CONVERSATION_ID = "test_conversation_id" +TEST_MODEL_ID = "" or os.getenv("TEST_MODEL_ID") +TEST_FILE = { + "type": "image", + "url": "https://inews.gtimg.com/om_bt/Ojy0PdDIWWXRTAMh2QjsiumDZh-D1x7qCkDSmoaaX6INAAA/641", + "__file": True +} +INPUT_DATA = { + "message": "", + "variables": [], + "conversation_id": TEST_CONVERSATION_ID, + "files": [TEST_FILE] +} + + +@pytest.fixture(scope="session", autouse=True) +def global_precheck(): + assert bool(TEST_MODEL_ID) is True, 'PLASE SET TEST_MODEL_ID FIRST' + + +def simple_state(): + return { + "messages": [{"role": "user", "content": "123456"}], + "node_outputs": {}, + "execution_id": TEST_EXECUTION_ID, + "workspace_id": TEST_WORKSPACE_ID, + "user_id": TEST_USER_ID, + "error": None, + "error_node": None, + "cycle_nodes": [], # loop, iteration node id + "looping": 0, # loop runing flag, only use in loop node,not use in main loop + "activate": {} + } + + +async def simple_vairable_pool(message): + # Initialize system variables (sys namespace) + variable_pool = VariablePool() + user_message = message + user_files = INPUT_DATA.get("files") or [] + + # Initialize system variables (sys namespace) + input_variables = INPUT_DATA.get("variables") or {} + sys_vars = { + "message": (user_message, VariableType.STRING), + "conversation_id": (INPUT_DATA.get("conversation_id"), VariableType.STRING), + "execution_id": (TEST_EXECUTION_ID, VariableType.STRING), + "workspace_id": (TEST_WORKSPACE_ID, VariableType.STRING), + "user_id": (TEST_USER_ID, VariableType.STRING), + "input_variables": (input_variables, VariableType.OBJECT), + "files": (user_files, VariableType.ARRAY_FILE) + } + for key, var_def in sys_vars.items(): + value = var_def[0] + var_type = var_def[1] + await variable_pool.new( + namespace='sys', + key=key, + value=value, + var_type=VariableType(var_type), + mut=False + ) + return variable_pool diff --git a/api/tests/workflow/nodes/test_assigner_node.py b/api/tests/workflow/nodes/test_assigner_node.py new file mode 100644 index 00000000..10f1dd40 --- /dev/null +++ b/api/tests/workflow/nodes/test_assigner_node.py @@ -0,0 +1,834 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/5 18:54 +import pytest + +from app.core.workflow.nodes import AssignerNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import simple_state, simple_vairable_pool + + +@pytest.mark.asyncio +async def test_assigner_number_add(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "add", + "value": 3 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == 4 + + +@pytest.mark.asyncio +async def test_assigner_number_subtract(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "subtract", + "value": 3 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == -2 + + +@pytest.mark.asyncio +async def test_assigner_number_multiply(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 2, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "multiply", + "value": 3 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == 6 + + +@pytest.mark.asyncio +async def test_assigner_number_divide(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 6, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "divide", + "value": 2 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == 3 + + +@pytest.mark.asyncio +async def test_assigner_number_assign(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test1", 4, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "assign", + "value": "{{conv.test1}}" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == 4 + + +@pytest.mark.asyncio +async def test_assigner_number_cover(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "cover", + "value": 4 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == 4 + + +@pytest.mark.asyncio +async def test_assigner_number_clear(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "clear", + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == 0 + + +@pytest.mark.asyncio +async def test_assigner_number_append(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + with pytest.raises(AttributeError) as exc_info: + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "append", + "value": 3 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert "'NumberOperator' object has no attribute 'append'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_assigner_number_remove_last(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + with pytest.raises(AttributeError) as exc_info: + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "remove_last" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert "'NumberOperator' object has no attribute 'remove_last'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_assigner_number_remove_first(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + with pytest.raises(AttributeError) as exc_info: + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "remove_first" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert "'NumberOperator' object has no attribute 'remove_first'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_assigner_array_append(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "append", + "value": 3 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_assigner_array_remove_last(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "remove_last" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [1] + + +@pytest.mark.asyncio +async def test_assigner_array_remove_first(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "remove_first" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [2] + + +# String tests +@pytest.mark.asyncio +async def test_assigner_string_assign(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "assign", + "value": "world" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == "world" + + +@pytest.mark.asyncio +async def test_assigner_string_cover(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "cover", + "value": "world" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == "world" + + +@pytest.mark.asyncio +async def test_assigner_string_clear(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "clear" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == "" + + +@pytest.mark.asyncio +async def test_assigner_string_invalid_operation(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "add", + "value": "world" + } + ] + } + } + with pytest.raises(AttributeError) as exc_info: + await AssignerNode(config, {}).execute(state, variable_pool) + assert "'StringOperator' object has no attribute 'add'" in str(exc_info.value) + + +# Boolean tests +@pytest.mark.asyncio +async def test_assigner_boolean_assign(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", True, VariableType.BOOLEAN, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "assign", + "value": False + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") is False + + +@pytest.mark.asyncio +async def test_assigner_boolean_cover(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", False, VariableType.BOOLEAN, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "cover", + "value": True + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") is True + + +@pytest.mark.asyncio +async def test_assigner_boolean_clear(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", True, VariableType.BOOLEAN, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "clear" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") is False + + +# Object tests +@pytest.mark.asyncio +async def test_assigner_object_assign(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "assign", + "value": {"new_key": "new_value"} + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == {"new_key": "new_value"} + + +@pytest.mark.asyncio +async def test_assigner_object_cover(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "cover", + "value": {"new_key": "new_value"} + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == {"new_key": "new_value"} + + +@pytest.mark.asyncio +async def test_assigner_object_clear(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "clear" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == {} + + +# Array string tests +@pytest.mark.asyncio +async def test_assigner_array_string_append(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", ["a", "b"], VariableType.ARRAY_STRING, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "append", + "value": "c" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == ["a", "b", "c"] + + +@pytest.mark.asyncio +async def test_assigner_array_string_clear(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", ["a", "b"], VariableType.ARRAY_STRING, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "clear" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [] + + +@pytest.mark.asyncio +async def test_assigner_array_object_append(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [{"id": 1}], VariableType.ARRAY_OBJECT, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "append", + "value": {"id": 2} + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [{"id": 1}, {"id": 2}] + + +@pytest.mark.asyncio +async def test_assigner_array_assign(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "assign", + "value": [3, 4, 5] + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [3, 4, 5] + + +@pytest.mark.asyncio +async def test_assigner_array_cover(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "cover", + "value": [3, 4, 5] + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [3, 4, 5] + + +# Multiple assignments test +@pytest.mark.asyncio +async def test_assigner_multiple_assignments(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", "hello", VariableType.STRING, mut=True) + await variable_pool.new("conv", "test3", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test1}}", + "operation": "add", + "value": 5 + }, + { + "variable_selector": "{{conv.test2}}", + "operation": "assign", + "value": "world" + }, + { + "variable_selector": "{{conv.test3}}", + "operation": "append", + "value": 3 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test1") == 15 + assert variable_pool.get_value("conv.test2") == "world" + assert variable_pool.get_value("conv.test3") == [1, 2, 3] + + +# Variable reference test +@pytest.mark.asyncio +async def test_assigner_variable_reference(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "source", 100, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "target", 0, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.target}}", + "operation": "assign", + "value": "{{conv.source}}" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.target") == 100 + + +# Edge cases +@pytest.mark.asyncio +async def test_assigner_divide_by_zero(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 10, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "divide", + "value": 0 + } + ] + } + } + with pytest.raises(ZeroDivisionError): + await AssignerNode(config, {}).execute(state, variable_pool) + + +@pytest.mark.asyncio +async def test_assigner_invalid_namespace(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("sys", "test", 10, VariableType.NUMBER, mut=False) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{sys.test}}", + "operation": "add", + "value": 5 + } + ] + } + } + with pytest.raises(ValueError) as exc_info: + await AssignerNode(config, {}).execute(state, variable_pool) + assert "Only conversation or cycle variables can be assigned" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_assigner_empty_array_operations(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [], VariableType.ARRAY_NUMBER, mut=True) + + # Test append on empty array + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "append", + "value": 1 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [1] + + +@pytest.mark.asyncio +async def test_assigner_remove_from_single_element_array(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "remove_last" + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == [] + + +@pytest.mark.asyncio +async def test_assigner_float_operations(): + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 10.5, VariableType.NUMBER, mut=True) + config = { + "id": "assigner_test", + "type": "assigner", + "name": "赋值测试节点", + "config": { + "assignments": [ + { + "variable_selector": "{{conv.test}}", + "operation": "multiply", + "value": 2.0 + } + ] + } + } + await AssignerNode(config, {}).execute(state, variable_pool) + assert variable_pool.get_value("conv.test") == 21.0 diff --git a/api/tests/workflow/nodes/test_breaker_node.py b/api/tests/workflow/nodes/test_breaker_node.py new file mode 100644 index 00000000..913a299f --- /dev/null +++ b/api/tests/workflow/nodes/test_breaker_node.py @@ -0,0 +1,23 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/5 19:15 +import pytest + +from app.core.workflow.nodes.breaker import BreakNode +from tests.workflow.nodes.base import simple_state, simple_vairable_pool + + +@pytest.mark.asyncio +async def test_loop_breaker(): + node_config = { + "id": "breaker_test", + "type": "breaker", + "name": "breaker", + "config": { + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await BreakNode(node_config, {}).execute(state, variable_pool) + assert state["looping"] == 2 diff --git a/api/tests/workflow/nodes/test_code.py b/api/tests/workflow/nodes/test_code.py new file mode 100644 index 00000000..eca6e1ac --- /dev/null +++ b/api/tests/workflow/nodes/test_code.py @@ -0,0 +1,279 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 09:59 +import pytest + +from app.core.workflow.nodes.code import CodeNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import simple_state, simple_vairable_pool + + +@pytest.mark.asyncio +async def test_code_python_complex_output(): + node_config = { + "id": "code_test", + "type": "code", + "name": "代码执行", + "config": { + "code": "ZGVmJTIwbWFpbih4JTJDJTIweSklM0ElMEElMjAlMjAlMjAlMjByZXR1cm4lMjAlN0IlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJudW1iZXIlMjIlM0ElMjB4JTIwJTJCJTIweSUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMnN0cmluZyUyMiUzQSUyMHN0cih4JTIwJTJCJTIweSklMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJib29sZWFuJTIyJTNBJTIwYm9vbCh4JTIwJTJCJTIweSklMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJkaWN0JTIyJTNBJTIwJTdCJTIyc3VtJTIyJTNBJTIweCUyMCUyQiUyMHklN0QlMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJhcnJheV9zdHJpbmclMjIlM0ElMjAlNUJzdHIoeCUyMCUyQiUyMHkpJTVEJTJDJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyYXJyYXlfbnVtYmVyJTIyJTNBJTIwJTVCeCUyMCUyQiUyMHklNUQlMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJhcnJheV9vYmplY3QlMjIlM0ElMjAlNUIlN0IlMjJzdW0lMjIlM0ElMjB4JTIwJTJCJTIweSU3RCU1RCUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMmFycmF5X2Jvb2xlYW4lMjIlM0ElMjAlNUJib29sKHglMjAlMkIlMjB5KSU1RCUwQSUyMCUyMCUyMCUyMCU3RA==", + "language": "python3", + "input_variables": [ + { + "name": "x", + "variable": "{{conv.x}}" + }, + { + "name": "y", + "variable": "{{conv.y}}" + } + ], + "output_variables": [ + { + "name": "number", + "type": VariableType.NUMBER + }, + { + "name": "string", + "type": VariableType.STRING + }, + { + "name": "boolean", + "type": VariableType.BOOLEAN + }, + { + "name": "dict", + "type": VariableType.OBJECT + }, + { + "name": "array_string", + "type": VariableType.ARRAY_STRING + }, + { + "name": "array_number", + "type": VariableType.ARRAY_NUMBER + }, + { + "name": "array_object", + "type": VariableType.ARRAY_OBJECT + }, + { + "name": "array_boolean", + "type": VariableType.ARRAY_BOOLEAN + }, + ] + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True) + result = await CodeNode(node_config, {}).execute(state, variable_pool) + assert result == {'number': 3, 'string': '3', 'boolean': True, 'dict': {'sum': 3}, 'array_string': ['3'], + 'array_number': [3], 'array_object': [{'sum': 3}], 'array_boolean': [True]} + + +@pytest.mark.asyncio +async def test_code_javascript_complex_output(): + node_config = { + "id": "code_test", + "type": "code", + "name": "代码执行", + "config": { + "code": "ZnVuY3Rpb24gbWFpbih7eCwgeX0pIHsKICBjb25zdCBzdW0gPSB4ICsgeTsKCiAgcmV0dXJuIHsKICAgIG51bWJlcjogc3VtLAogICAgc3RyaW5nOiBTdHJpbmcoc3VtKSwKICAgIGJvb2xlYW46IEJvb2xlYW4oc3VtKSwKICAgIGRpY3Q6IHsgc3VtIH0sCiAgICBhcnJheV9zdHJpbmc6IFtTdHJpbmcoc3VtKV0sCiAgICBhcnJheV9udW1iZXI6IFtzdW1dLAogICAgYXJyYXlfb2JqZWN0OiBbeyBzdW0gfV0sCiAgICBhcnJheV9ib29sZWFuOiBbQm9vbGVhbihzdW0pXSwKICB9Owp9", + "language": "javascript", + "input_variables": [ + { + "name": "x", + "variable": "{{conv.x}}" + }, + { + "name": "y", + "variable": "{{conv.y}}" + } + ], + "output_variables": [ + { + "name": "number", + "type": VariableType.NUMBER + }, + { + "name": "string", + "type": VariableType.STRING + }, + { + "name": "boolean", + "type": VariableType.BOOLEAN + }, + { + "name": "dict", + "type": VariableType.OBJECT + }, + { + "name": "array_string", + "type": VariableType.ARRAY_STRING + }, + { + "name": "array_number", + "type": VariableType.ARRAY_NUMBER + }, + { + "name": "array_object", + "type": VariableType.ARRAY_OBJECT + }, + { + "name": "array_boolean", + "type": VariableType.ARRAY_BOOLEAN + }, + ] + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True) + result = await CodeNode(node_config, {}).execute(state, variable_pool) + assert result == {'number': 3, 'string': '3', 'boolean': True, 'dict': {'sum': 3}, 'array_string': ['3'], + 'array_number': [3], 'array_object': [{'sum': 3}], 'array_boolean': [True]} + + +@pytest.mark.asyncio +async def test_code_python_operation_permissions(): + node_config = { + "id": "code_test", + "type": "code", + "name": "代码执行", + "config": { + "code": "ZGVmJTIwbWFpbih4JTJDJTIweSklM0ElMEElMjAlMjAlMjAlMjBpbXBvcnQlMjBvcyUwQSUyMCUyMCUyMCUyMG9zLmdldGN3ZCgpJTBBJTIwJTIwJTIwJTIwcmV0dXJuJTIwJTdCJTIycmVzdWx0JTIyJTNBJTIweCUyMCUyQiUyMHklN0QlMEE=", + "language": "python3", + "input_variables": [ + { + "name": "x", + "variable": "{{conv.x}}" + }, + { + "name": "y", + "variable": "{{conv.y}}" + } + ], + "output_variables": [ + { + "name": "result", + "type": "number" + } + ] + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True) + with pytest.raises(RuntimeError, match="Operation not permitted"): + await CodeNode(node_config, {}).execute(state, variable_pool) + + +@pytest.mark.asyncio +async def test_code_javascript_operation_permissions(): + node_config = { + "id": "code_test", + "type": "code", + "name": "代码执行", + "config": { + "code": "Y29uc29sZS5sb2cocHJvY2Vzcy5nZXRldWlkKCkpOw==", + "language": "javascript", + "input_variables": [ + { + "name": "x", + "variable": "{{conv.x}}" + }, + { + "name": "y", + "variable": "{{conv.y}}" + } + ], + "output_variables": [ + { + "name": "result", + "type": "number" + } + ] + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True) + with pytest.raises(RuntimeError, match="Operation not permitted"): + await CodeNode(node_config, {}).execute(state, variable_pool) + + +@pytest.mark.asyncio +async def test_code_python_run_error(): + node_config = { + "id": "code_test", + "type": "code", + "name": "代码执行", + "config": { + "code": "ZGVmJTIwbWFpbih4JTJDJTIweSUzQSUwQSUyMCUyMCUyMCUyMHJldHVybiUyMCU3QiUyMnJlc3VsdCUyMiUzQSUyMHglMjAlMkIlMjB5JTdEJTBB", + "language": "python3", + "input_variables": [ + { + "name": "x", + "variable": "{{conv.x}}" + }, + { + "name": "y", + "variable": "{{conv.y}}" + } + ], + "output_variables": [ + { + "name": "result", + "type": "number" + } + ] + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True) + with pytest.raises(Exception) as exc_info: + await CodeNode(node_config, {}).execute(state, variable_pool) + assert "'(' was never closed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_code_javascript_run_error(): + node_config = { + "id": "code_test", + "type": "code", + "name": "代码执行", + "config": { + "code": "Y29uc29sZS5sb2co", + "language": "javascript", + "input_variables": [ + { + "name": "x", + "variable": "{{conv.x}}" + }, + { + "name": "y", + "variable": "{{conv.y}}" + } + ], + "output_variables": [ + { + "name": "result", + "type": "number" + } + ] + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True) + with pytest.raises(Exception) as exc_info: + await CodeNode(node_config, {}).execute(state, variable_pool) + assert "SyntaxError" in str(exc_info.value) diff --git a/api/tests/workflow/nodes/test_end_node.py b/api/tests/workflow/nodes/test_end_node.py new file mode 100644 index 00000000..2a5798e1 --- /dev/null +++ b/api/tests/workflow/nodes/test_end_node.py @@ -0,0 +1,42 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 12:22 +import pytest + +from app.core.workflow.nodes import EndNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import simple_state, simple_vairable_pool + + +@pytest.mark.asyncio +async def test_end_output(): + node_config = { + "id": "end_test", + "type": "end", + "name": "end", + "config": { + "output": "{{conv.x}}{{sys.message}}" + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True) + result = await EndNode(node_config, {}).execute(state, variable_pool) + assert result == "1test" + + +@pytest.mark.asyncio +async def test_end_output_miss(): + node_config = { + "id": "end_test", + "type": "end", + "name": "end", + "config": { + "output": "{{conv.x}}{{sys.message}}" + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("test") + result = await EndNode(node_config, {}).execute(state, variable_pool) + assert result == "test" diff --git a/api/tests/workflow/nodes/test_ifelse_node.py b/api/tests/workflow/nodes/test_ifelse_node.py new file mode 100644 index 00000000..9e2eb7f0 --- /dev/null +++ b/api/tests/workflow/nodes/test_ifelse_node.py @@ -0,0 +1,1127 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 +import pytest + +from app.core.workflow.nodes import IfElseNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import simple_state, simple_vairable_pool + + +# 字符串比较测试配置 +STRING_EQ_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "eq", + "right": "hello", + "input_type": "constant" + } + ] + } + ] + } +} + +STRING_CONTAINS_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "contains", + "right": "world", + "input_type": "constant" + } + ] + } + ] + } +} + +STRING_STARTSWITH_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "startwith", + "right": "hello", + "input_type": "constant" + } + ] + } + ] + } +} + +STRING_ENDSWITH_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "endwith", + "right": "world", + "input_type": "constant" + } + ] + } + ] + } +} + +STRING_EMPTY_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "empty", + "right": "", + "input_type": "constant" + } + ] + } + ] + } +} + +STRING_NOT_EMPTY_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "not_empty", + "right": "", + "input_type": "constant" + } + ] + } + ] + } +} + +# 数字比较测试配置 +NUMBER_EQ_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "eq", + "right": 10, + "input_type": "constant" + } + ] + } + ] + } +} + +NUMBER_LT_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "lt", + "right": 10, + "input_type": "constant" + } + ] + } + ] + } +} + +NUMBER_GT_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "gt", + "right": 10, + "input_type": "constant" + } + ] + } + ] + } +} + +NUMBER_LE_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "le", + "right": 10, + "input_type": "constant" + } + ] + } + ] + } +} + +NUMBER_GE_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "ge", + "right": 10, + "input_type": "constant" + } + ] + } + ] + } +} + +# 布尔比较测试配置 +BOOLEAN_EQ_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "eq", + "right": True, + "input_type": "constant" + } + ] + } + ] + } +} + +# 数组比较测试配置 +ARRAY_CONTAINS_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "contains", + "right": 2, + "input_type": "constant" + } + ] + } + ] + } +} + +ARRAY_EMPTY_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "empty", + "right": "", + "input_type": "constant" + } + ] + } + ] + } +} + +# 对象比较测试配置 +OBJECT_EMPTY_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "empty", + "right": "", + "input_type": "constant" + } + ] + } + ] + } +} + +# 多条件测试配置 +MULTI_CONDITION_AND_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test1}}", + "operator": "eq", + "right": 10, + "input_type": "constant" + }, + { + "left": "{{conv.test2}}", + "operator": "eq", + "right": "hello", + "input_type": "constant" + } + ] + } + ] + } +} + +MULTI_CONDITION_OR_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "or", + "expressions": [ + { + "left": "{{conv.test1}}", + "operator": "eq", + "right": 10, + "input_type": "constant" + }, + { + "left": "{{conv.test2}}", + "operator": "eq", + "right": "hello", + "input_type": "constant" + } + ] + } + ] + } +} + +# 多分支测试配置 +MULTI_BRANCH_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "eq", + "right": 1, + "input_type": "constant" + } + ] + }, + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "eq", + "right": 2, + "input_type": "constant" + } + ] + }, + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "eq", + "right": 3, + "input_type": "constant" + } + ] + } + ] + } +} + +# 变量引用测试配置 +VARIABLE_REFERENCE_CONFIG = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test1}}", + "operator": "eq", + "right": "{{conv.test2}}", + "input_type": "variable" + } + ] + } + ] + } +} + + +# ==================== 字符串比较测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_string_eq_true(): + """测试字符串相等条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_EQ_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_eq_false(): + """测试字符串相等条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "world", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_EQ_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_string_contains_true(): + """测试字符串包含条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello world", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_CONTAINS_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_contains_false(): + """测试字符串包含条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_CONTAINS_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_string_startswith_true(): + """测试字符串开头匹配条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello world", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_STARTSWITH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_startswith_false(): + """测试字符串开头匹配条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "world hello", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_STARTSWITH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_string_endswith_true(): + """测试字符串结尾匹配条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello world", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_ENDSWITH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_endswith_false(): + """测试字符串结尾匹配条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "world hello", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_ENDSWITH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_string_empty_true(): + """测试字符串为空条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_empty_false(): + """测试字符串为空条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_string_not_empty_true(): + """测试字符串非空条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_NOT_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_not_empty_false(): + """测试字符串非空条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "", VariableType.STRING, mut=True) + result = await IfElseNode(STRING_NOT_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +# ==================== 数字比较测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_number_eq_true(): + """测试数字相等条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 10, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_EQ_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_number_eq_false(): + """测试数字相等条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 5, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_EQ_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_number_lt_true(): + """测试数字小于条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 5, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_LT_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_number_lt_false(): + """测试数字小于条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 15, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_LT_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_number_gt_true(): + """测试数字大于条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 15, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_GT_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_number_gt_false(): + """测试数字大于条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 5, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_GT_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_number_le_true(): + """测试数字小于等于条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 10, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_LE_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_number_le_false(): + """测试数字小于等于条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 15, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_LE_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_number_ge_true(): + """测试数字大于等于条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 10, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_GE_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_number_ge_false(): + """测试数字大于等于条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 5, VariableType.NUMBER, mut=True) + result = await IfElseNode(NUMBER_GE_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +# ==================== 布尔比较测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_boolean_eq_true(): + """测试布尔值相等条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", True, VariableType.BOOLEAN, mut=True) + result = await IfElseNode(BOOLEAN_EQ_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_boolean_eq_false(): + """测试布尔值相等条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", False, VariableType.BOOLEAN, mut=True) + result = await IfElseNode(BOOLEAN_EQ_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +# ==================== 数组比较测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_array_contains_true(): + """测试数组包含条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True) + result = await IfElseNode(ARRAY_CONTAINS_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_array_contains_false(): + """测试数组包含条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 3, 4], VariableType.ARRAY_NUMBER, mut=True) + result = await IfElseNode(ARRAY_CONTAINS_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_array_empty_true(): + """测试数组为空条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [], VariableType.ARRAY_NUMBER, mut=True) + result = await IfElseNode(ARRAY_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_array_empty_false(): + """测试数组为空条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + result = await IfElseNode(ARRAY_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +# ==================== 对象比较测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_object_empty_true(): + """测试对象为空条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", {}, VariableType.OBJECT, mut=True) + result = await IfElseNode(OBJECT_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_object_empty_false(): + """测试对象为空条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True) + result = await IfElseNode(OBJECT_EMPTY_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +# ==================== 多条件测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_multi_condition_and_all_true(): + """测试多条件AND逻辑,所有条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", "hello", VariableType.STRING, mut=True) + result = await IfElseNode(MULTI_CONDITION_AND_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_multi_condition_and_one_false(): + """测试多条件AND逻辑,一个条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", "world", VariableType.STRING, mut=True) + result = await IfElseNode(MULTI_CONDITION_AND_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_multi_condition_and_all_false(): + """测试多条件AND逻辑,所有条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 5, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", "world", VariableType.STRING, mut=True) + result = await IfElseNode(MULTI_CONDITION_AND_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_multi_condition_or_all_true(): + """测试多条件OR逻辑,所有条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", "hello", VariableType.STRING, mut=True) + result = await IfElseNode(MULTI_CONDITION_OR_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_multi_condition_or_one_true(): + """测试多条件OR逻辑,一个条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", "world", VariableType.STRING, mut=True) + result = await IfElseNode(MULTI_CONDITION_OR_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_multi_condition_or_all_false(): + """测试多条件OR逻辑,所有条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 5, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", "world", VariableType.STRING, mut=True) + result = await IfElseNode(MULTI_CONDITION_OR_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +# ==================== 多分支测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_multi_branch_first(): + """测试多分支,匹配第一个分支""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True) + result = await IfElseNode(MULTI_BRANCH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_multi_branch_second(): + """测试多分支,匹配第二个分支""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 2, VariableType.NUMBER, mut=True) + result = await IfElseNode(MULTI_BRANCH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_multi_branch_third(): + """测试多分支,匹配第三个分支""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 3, VariableType.NUMBER, mut=True) + result = await IfElseNode(MULTI_BRANCH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE3" + + +@pytest.mark.asyncio +async def test_ifelse_multi_branch_default(): + """测试多分支,匹配默认分支""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 4, VariableType.NUMBER, mut=True) + result = await IfElseNode(MULTI_BRANCH_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE4" + + +# ==================== 变量引用测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_variable_reference_true(): + """测试变量引用条件为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", 10, VariableType.NUMBER, mut=True) + result = await IfElseNode(VARIABLE_REFERENCE_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_variable_reference_false(): + """测试变量引用条件为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "test2", 20, VariableType.NUMBER, mut=True) + result = await IfElseNode(VARIABLE_REFERENCE_CONFIG, {}).execute(state, variable_pool) + assert result == "CASE2" + + +# ==================== 边界情况测试 ==================== +@pytest.mark.asyncio +async def test_ifelse_none_variable(): + """测试变量不存在的情况""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.nonexistent}}", + "operator": "eq", + "right": 10, + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE2" + + +@pytest.mark.asyncio +async def test_ifelse_float_comparison(): + """测试浮点数比较""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 10.5, VariableType.NUMBER, mut=True) + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "gt", + "right": 10.0, + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_ne(): + """测试字符串不等于""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "ne", + "right": "world", + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_number_ne(): + """测试数字不等于""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", 10, VariableType.NUMBER, mut=True) + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "ne", + "right": 5, + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_array_not_contains(): + """测试数组不包含""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "not_contains", + "right": 5, + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_string_not_contains(): + """测试字符串不包含""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True) + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "not_contains", + "right": "world", + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_object_not_empty(): + """测试对象非空""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True) + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "not_empty", + "right": "", + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE1" + + +@pytest.mark.asyncio +async def test_ifelse_array_not_empty(): + """测试数组非空""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + config = { + "id": "ifelse_test", + "type": "if-else", + "name": "条件测试节点", + "config": { + "cases": [ + { + "logical_operator": "and", + "expressions": [ + { + "left": "{{conv.test}}", + "operator": "not_empty", + "right": "", + "input_type": "constant" + } + ] + } + ] + } + } + result = await IfElseNode(config, {}).execute(state, variable_pool) + assert result == "CASE1" diff --git a/api/tests/workflow/nodes/test_jinja_render_node.py b/api/tests/workflow/nodes/test_jinja_render_node.py new file mode 100644 index 00000000..e43c2055 --- /dev/null +++ b/api/tests/workflow/nodes/test_jinja_render_node.py @@ -0,0 +1,889 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 +import pytest + +from app.core.workflow.nodes import JinjaRenderNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import simple_state, simple_vairable_pool + + +# 基础模板渲染配置 +SIMPLE_TEMPLATE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Hello, {{ name }}!", + "mapping": [ + { + "name": "name", + "value": "conv.username" + } + ] + } +} + +# 多变量模板配置 +MULTI_VARIABLE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ greeting }}, {{ name }}! You are {{ age }} years old.", + "mapping": [ + { + "name": "greeting", + "value": "conv.greeting" + }, + { + "name": "name", + "value": "conv.name" + }, + { + "name": "age", + "value": "conv.age" + } + ] + } +} + +# 条件渲染配置 +CONDITIONAL_TEMPLATE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{% if is_admin %}Admin{% else %}User{% endif %}", + "mapping": [ + { + "name": "is_admin", + "value": "conv.is_admin" + } + ] + } +} + +# 循环渲染配置 +LOOP_TEMPLATE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}", + "mapping": [ + { + "name": "items", + "value": "conv.items" + } + ] + } +} + +# 过滤器配置 +FILTER_TEMPLATE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ text | upper }}", + "mapping": [ + { + "name": "text", + "value": "conv.text" + } + ] + } +} + +# 对象属性访问配置 +OBJECT_TEMPLATE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Name: {{ user.name }}, Age: {{ user.age }}", + "mapping": [ + { + "name": "user", + "value": "conv.user" + } + ] + } +} + +# 数学运算配置 +MATH_TEMPLATE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ a }} + {{ b }} = {{ a + b }}", + "mapping": [ + { + "name": "a", + "value": "conv.a" + }, + { + "name": "b", + "value": "conv.b" + } + ] + } +} + +# 默认值配置 +DEFAULT_VALUE_CONFIG = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ name | default('Guest') }}", + "mapping": [ + { + "name": "name", + "value": "conv.name" + } + ] + } +} + + +# ==================== 基础模板渲染测试 ==================== +@pytest.mark.asyncio +async def test_jinja_simple_template(): + """测试简单模板渲染""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "username", "Alice", VariableType.STRING, mut=True) + + result = await JinjaRenderNode(SIMPLE_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "Hello, Alice!" + + +@pytest.mark.asyncio +async def test_jinja_multi_variable(): + """测试多变量模板渲染""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "greeting", "Hi", VariableType.STRING, mut=True) + await variable_pool.new("conv", "name", "Bob", VariableType.STRING, mut=True) + await variable_pool.new("conv", "age", 25, VariableType.NUMBER, mut=True) + + result = await JinjaRenderNode(MULTI_VARIABLE_CONFIG, {}).execute(state, variable_pool) + assert result == "Hi, Bob! You are 25 years old." + + +# ==================== 条件渲染测试 ==================== +@pytest.mark.asyncio +async def test_jinja_conditional_true(): + """测试条件渲染为真""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "is_admin", True, VariableType.BOOLEAN, mut=True) + + result = await JinjaRenderNode(CONDITIONAL_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "Admin" + + +@pytest.mark.asyncio +async def test_jinja_conditional_false(): + """测试条件渲染为假""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "is_admin", False, VariableType.BOOLEAN, mut=True) + + result = await JinjaRenderNode(CONDITIONAL_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "User" + + +# ==================== 循环渲染测试 ==================== +@pytest.mark.asyncio +async def test_jinja_loop_array(): + """测试数组循环渲染""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "items", ["apple", "banana", "cherry"], VariableType.ARRAY_STRING, mut=True) + + result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "apple, banana, cherry" + + +@pytest.mark.asyncio +async def test_jinja_loop_empty_array(): + """测试空数组循环渲染""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "items", [], VariableType.ARRAY_STRING, mut=True) + + result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "" + + +@pytest.mark.asyncio +async def test_jinja_loop_single_item(): + """测试单元素数组循环渲染""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "items", ["apple"], VariableType.ARRAY_STRING, mut=True) + + result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "apple" + + +# ==================== 过滤器测试 ==================== +@pytest.mark.asyncio +async def test_jinja_filter_upper(): + """测试大写过滤器""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "text", "hello world", VariableType.STRING, mut=True) + + result = await JinjaRenderNode(FILTER_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "HELLO WORLD" + + +@pytest.mark.asyncio +async def test_jinja_filter_lower(): + """测试小写过滤器""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "text", "HELLO WORLD", VariableType.STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ text | lower }}", + "mapping": [ + { + "name": "text", + "value": "conv.text" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "hello world" + + +@pytest.mark.asyncio +async def test_jinja_filter_title(): + """测试标题化过滤器""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "text", "hello world", VariableType.STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ text | title }}", + "mapping": [ + { + "name": "text", + "value": "conv.text" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Hello World" + + +@pytest.mark.asyncio +async def test_jinja_filter_length(): + """测试长度过滤器""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "items", [1, 2, 3, 4, 5], VariableType.ARRAY_NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Length: {{ items | length }}", + "mapping": [ + { + "name": "items", + "value": "conv.items" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Length: 5" + + +# ==================== 对象属性访问测试 ==================== +@pytest.mark.asyncio +async def test_jinja_object_access(): + """测试对象属性访问""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "user", {"name": "Alice", "age": 30}, VariableType.OBJECT, mut=True) + + result = await JinjaRenderNode(OBJECT_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "Name: Alice, Age: 30" + + +@pytest.mark.asyncio +async def test_jinja_nested_object(): + """测试嵌套对象访问""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "data", { + "user": { + "name": "Bob", + "address": { + "city": "Beijing" + } + } + }, VariableType.OBJECT, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ data.user.name }} lives in {{ data.user.address.city }}", + "mapping": [ + { + "name": "data", + "value": "conv.data" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Bob lives in Beijing" + + +# ==================== 数学运算测试 ==================== +@pytest.mark.asyncio +async def test_jinja_math_addition(): + """测试加法运算""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "a", 10, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "b", 20, VariableType.NUMBER, mut=True) + + result = await JinjaRenderNode(MATH_TEMPLATE_CONFIG, {}).execute(state, variable_pool) + assert result == "10 + 20 = 30" + + +@pytest.mark.asyncio +async def test_jinja_math_subtraction(): + """测试减法运算""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "a", 30, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "b", 10, VariableType.NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ a }} - {{ b }} = {{ a - b }}", + "mapping": [ + { + "name": "a", + "value": "conv.a" + }, + { + "name": "b", + "value": "conv.b" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "30 - 10 = 20" + + +@pytest.mark.asyncio +async def test_jinja_math_multiplication(): + """测试乘法运算""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "a", 5, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "b", 6, VariableType.NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ a }} * {{ b }} = {{ a * b }}", + "mapping": [ + { + "name": "a", + "value": "conv.a" + }, + { + "name": "b", + "value": "conv.b" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "5 * 6 = 30" + + +@pytest.mark.asyncio +async def test_jinja_math_division(): + """测试除法运算""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "a", 20, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "b", 4, VariableType.NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ a }} / {{ b }} = {{ a / b }}", + "mapping": [ + { + "name": "a", + "value": "conv.a" + }, + { + "name": "b", + "value": "conv.b" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "20 / 4 = 5.0" + + +# ==================== 默认值测试 ==================== +@pytest.mark.asyncio +async def test_jinja_default_value_missing(): + """测试变量缺失时使用默认值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + # 不创建 name 变量 + + result = await JinjaRenderNode(DEFAULT_VALUE_CONFIG, {}).execute(state, variable_pool) + assert result == "Guest" + + +@pytest.mark.asyncio +async def test_jinja_default_value_present(): + """测试变量存在时不使用默认值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "name", "Alice", VariableType.STRING, mut=True) + + result = await JinjaRenderNode(DEFAULT_VALUE_CONFIG, {}).execute(state, variable_pool) + assert result == "Alice" + + +# ==================== 字符串拼接测试 ==================== +@pytest.mark.asyncio +async def test_jinja_string_concatenation(): + """测试字符串拼接""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "first", "Hello", VariableType.STRING, mut=True) + await variable_pool.new("conv", "second", "World", VariableType.STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ first ~ ' ' ~ second }}", + "mapping": [ + { + "name": "first", + "value": "conv.first" + }, + { + "name": "second", + "value": "conv.second" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Hello World" + + +# ==================== 比较运算测试 ==================== +@pytest.mark.asyncio +async def test_jinja_comparison(): + """测试比较运算""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "score", 85, VariableType.NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{% if score >= 90 %}A{% elif score >= 80 %}B{% elif score >= 70 %}C{% else %}D{% endif %}", + "mapping": [ + { + "name": "score", + "value": "conv.score" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "B" + + +# ==================== 数组操作测试 ==================== +@pytest.mark.asyncio +async def test_jinja_array_index(): + """测试数组索引访问""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "items", ["first", "second", "third"], VariableType.ARRAY_STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "First: {{ items[0] }}, Last: {{ items[-1] }}", + "mapping": [ + { + "name": "items", + "value": "conv.items" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "First: first, Last: third" + + +@pytest.mark.asyncio +async def test_jinja_array_slice(): + """测试数组切片""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "numbers", [1, 2, 3, 4, 5], VariableType.ARRAY_NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{% for n in numbers[1:4] %}{{ n }}{% endfor %}", + "mapping": [ + { + "name": "numbers", + "value": "conv.numbers" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "234" + + +# ==================== 复杂模板测试 ==================== +@pytest.mark.asyncio +async def test_jinja_complex_template(): + """测试复杂模板""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "users", [ + {"name": "Alice", "age": 25}, + {"name": "Bob", "age": 30}, + {"name": "Charlie", "age": 35} + ], VariableType.ARRAY_OBJECT, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{% for user in users %}{{ user.name }} ({{ user.age }}){% if not loop.last %}, {% endif %}{% endfor %}", + "mapping": [ + { + "name": "users", + "value": "conv.users" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Alice (25), Bob (30), Charlie (35)" + + +# ==================== 空值处理测试 ==================== +@pytest.mark.asyncio +async def test_jinja_empty_string(): + """测试空字符串""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "text", "", VariableType.STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{% if text %}{{ text }}{% else %}Empty{% endif %}", + "mapping": [ + { + "name": "text", + "value": "conv.text" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Empty" + + +@pytest.mark.asyncio +async def test_jinja_zero_value(): + """测试零值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "count", 0, VariableType.NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Count: {{ count }}", + "mapping": [ + { + "name": "count", + "value": "conv.count" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Count: 0" + + +# ==================== 特殊字符测试 ==================== +@pytest.mark.asyncio +async def test_jinja_special_characters(): + """测试特殊字符""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "text", "Hello \"World\"", VariableType.STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ text }}", + "mapping": [ + { + "name": "text", + "value": "conv.text" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Hello \"World\"" + + +@pytest.mark.asyncio +async def test_jinja_newline(): + """测试换行符""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "line1", "First line", VariableType.STRING, mut=True) + await variable_pool.new("conv", "line2", "Second line", VariableType.STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ line1 }}\n{{ line2 }}", + "mapping": [ + { + "name": "line1", + "value": "conv.line1" + }, + { + "name": "line2", + "value": "conv.line2" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "First line\nSecond line" + + +# ==================== 错误处理测试 ==================== +@pytest.mark.asyncio +async def test_jinja_invalid_template(): + """测试无效模板语法""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "name", "Alice", VariableType.STRING, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ name", # 缺少闭合括号 + "mapping": [ + { + "name": "name", + "value": "conv.name" + } + ] + } + } + with pytest.raises(RuntimeError) as exc_info: + await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert "render failed" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_jinja_undefined_variable_strict_false(): + """测试未定义变量(非严格模式)""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + # 不创建任何变量 + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Hello, {{ undefined_var }}!", + "mapping": [ + { + "name": "undefined_var", + "value": "conv.undefined" + } + ] + } + } + # 非严格模式下,未定义变量会被渲染为空字符串 + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Hello, !" + + +# ==================== 布尔值测试 ==================== +@pytest.mark.asyncio +async def test_jinja_boolean_true(): + """测试布尔值 True""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "flag", True, VariableType.BOOLEAN, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Flag is {{ flag }}", + "mapping": [ + { + "name": "flag", + "value": "conv.flag" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Flag is True" + + +@pytest.mark.asyncio +async def test_jinja_boolean_false(): + """测试布尔值 False""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "flag", False, VariableType.BOOLEAN, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Flag is {{ flag }}", + "mapping": [ + { + "name": "flag", + "value": "conv.flag" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Flag is False" + + +# ==================== 浮点数测试 ==================== +@pytest.mark.asyncio +async def test_jinja_float_number(): + """测试浮点数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "price", 19.99, VariableType.NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "Price: ${{ price }}", + "mapping": [ + { + "name": "price", + "value": "conv.price" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "Price: $19.99" + + +@pytest.mark.asyncio +async def test_jinja_float_formatting(): + """测试浮点数格式化""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "value", 3.14159, VariableType.NUMBER, mut=True) + + config = { + "id": "jinja_test", + "type": "jinja-render", + "name": "Jinja渲染测试节点", + "config": { + "template": "{{ '%.2f' | format(value) }}", + "mapping": [ + { + "name": "value", + "value": "conv.value" + } + ] + } + } + result = await JinjaRenderNode(config, {}).execute(state, variable_pool) + assert result == "3.14" diff --git a/api/tests/workflow/nodes/test_llm_node.py b/api/tests/workflow/nodes/test_llm_node.py new file mode 100644 index 00000000..c97cde26 --- /dev/null +++ b/api/tests/workflow/nodes/test_llm_node.py @@ -0,0 +1,145 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/5 15:39 +import pytest + +from app.core.workflow.nodes import LLMNode +from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool + + +@pytest.mark.asyncio +async def test_llm_memory_no_stream(): + node_config = { + "id": "llm_test", + "type": "llm", + "name": "LLM 问答", + "config": { + "messages": [ + { + "role": "system", + "content": "你是一个专业、友好且乐于助人的 AI 助手。" + "你的职责:- " + "准确理解用户的问题并提供有价值的回答" + "- 保持回答的专业性和准确性" + "- 如果不确定答案,诚实地告知用户" + "- 使用清晰、易懂的语言进行交流" + "回答风格:" + "- 简洁明了,直击要点" + "- 必要时提供详细解释和示例" + "- 使用友好、礼貌的语气" + "- 适当使用格式化(如列表、段落)提高可读性" + }, + { + "role": "user", + "content": "{{ sys.message }}" + } + ], + "model_id": TEST_MODEL_ID, + "temperature": 0.7, + "max_tokens": 1000, + "memory": { + "enable": True, + "enable_window": True, + "window_size": 5 + }, + "vision": False, + "vision_input": "{{sys.files}}" + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("输出上一句话") + result = await LLMNode(node_config, {}).execute(state, variable_pool) + assert '123456' in result.content + + +@pytest.mark.asyncio +async def test_llm_memory_stream(): + node_config = { + "id": "llm_test", + "type": "llm", + "name": "LLM 问答", + "config": { + "messages": [ + { + "role": "system", + "content": "你是一个专业、友好且乐于助人的 AI 助手。" + "你的职责:- " + "准确理解用户的问题并提供有价值的回答" + "- 保持回答的专业性和准确性" + "- 如果不确定答案,诚实地告知用户" + "- 使用清晰、易懂的语言进行交流" + "回答风格:" + "- 简洁明了,直击要点" + "- 必要时提供详细解释和示例" + "- 使用友好、礼貌的语气" + "- 适当使用格式化(如列表、段落)提高可读性" + }, + { + "role": "user", + "content": "{{ sys.message }}" + } + ], + "model_id": TEST_MODEL_ID, + "temperature": 0.7, + "max_tokens": 1000, + "memory": { + "enable": True, + "enable_window": True, + "window_size": 5 + }, + "vision": False, + "vision_input": "{{sys.files}}" + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("输出上一句话") + async for event in LLMNode(node_config, {}).execute_stream(state, variable_pool): + if event.get("__final__"): + assert '123456' in event.get("result").content + + +@pytest.mark.asyncio +async def test_llm_vision(): + node_config = { + "id": "llm_test", + "type": "llm", + "name": "LLM 问答", + "config": { + "messages": [ + { + "role": "system", + "content": "你是一个专业、友好且乐于助人的 AI 助手。" + "你的职责:- " + "准确理解用户的问题并提供有价值的回答" + "- 保持回答的专业性和准确性" + "- 如果不确定答案,诚实地告知用户" + "- 使用清晰、易懂的语言进行交流" + "回答风格:" + "- 简洁明了,直击要点" + "- 必要时提供详细解释和示例" + "- 使用友好、礼貌的语气" + "- 适当使用格式化(如列表、段落)提高可读性" + }, + { + "role": "user", + "content": "{{ sys.message }}" + } + ], + "model_id": TEST_MODEL_ID, + "temperature": 0.7, + "max_tokens": 1000, + "memory": { + "enable": True, + "enable_window": True, + "window_size": 5 + }, + "vision": True, + "vision_input": "{{sys.files}}" + } + } + state = simple_state() + variable_pool = await simple_vairable_pool("图片里面有什么") + async for event in LLMNode(node_config, {}).execute_stream(state, variable_pool): + if event.get("__final__"): + assert '花' in event.get("result").content diff --git a/api/tests/workflow/nodes/test_parameter_extractor_node.py b/api/tests/workflow/nodes/test_parameter_extractor_node.py new file mode 100644 index 00000000..c9b775a9 --- /dev/null +++ b/api/tests/workflow/nodes/test_parameter_extractor_node.py @@ -0,0 +1,504 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 14:10 +import pytest + +from app.core.workflow.nodes import ParameterExtractorNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool + + +# 基础参数提取配置 - 单个字符串参数 +SINGLE_STRING_PARAM_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "我的名字是张三,今年25岁", + "params": [ + { + "name": "name", + "type": "string", + "desc": "用户的姓名", + "required": True + } + ], + "prompt": "" + } +} + +# 多参数提取配置 +MULTI_PARAMS_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "我的名字是李四,今年30岁,住在北京", + "params": [ + { + "name": "name", + "type": "string", + "desc": "用户的姓名", + "required": True + }, + { + "name": "age", + "type": "number", + "desc": "用户的年龄", + "required": True + }, + { + "name": "city", + "type": "string", + "desc": "用户所在的城市", + "required": False + } + ], + "prompt": "" + } +} + +# 数字参数提取配置 +NUMBER_PARAM_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "这个产品的价格是99.99元,库存有100件", + "params": [ + { + "name": "price", + "type": "number", + "desc": "产品价格", + "required": True + }, + { + "name": "stock", + "type": "number", + "desc": "库存数量", + "required": True + } + ], + "prompt": "" + } +} + +# 布尔参数提取配置 +BOOLEAN_PARAM_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "这个用户已经完成了实名认证,但还没有绑定手机号", + "params": [ + { + "name": "verified", + "type": "boolean", + "desc": "是否完成实名认证", + "required": True + }, + { + "name": "phone_bound", + "type": "boolean", + "desc": "是否绑定手机号", + "required": True + } + ], + "prompt": "" + } +} + +# 数组参数提取配置 +ARRAY_STRING_PARAM_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "我喜欢的水果有苹果、香蕉、橙子", + "params": [ + { + "name": "fruits", + "type": "array[string]", + "desc": "喜欢的水果列表", + "required": True + } + ], + "prompt": "" + } +} + +# 数字数组参数提取配置 +ARRAY_NUMBER_PARAM_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "这个月的销售额分别是:第一周10000,第二周12000,第三周15000,第四周18000", + "params": [ + { + "name": "weekly_sales", + "type": "array[number]", + "desc": "每周的销售额", + "required": True + } + ], + "prompt": "" + } +} + +# 带自定义提示的配置 +CUSTOM_PROMPT_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "订单号:ORD123456,金额:299元", + "params": [ + { + "name": "order_id", + "type": "string", + "desc": "订单编号", + "required": True + }, + { + "name": "amount", + "type": "number", + "desc": "订单金额", + "required": True + } + ], + "prompt": "请仔细提取订单信息,确保订单号和金额准确无误" + } +} + +# 使用变量的配置 +VARIABLE_INPUT_CONFIG = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "{{ conv.user_input }}", + "params": [ + { + "name": "name", + "type": "string", + "desc": "用户姓名", + "required": True + }, + { + "name": "age", + "type": "number", + "desc": "用户年龄", + "required": True + } + ], + "prompt": "" + } +} + + +# ==================== 基础参数提取测试 ==================== +@pytest.mark.asyncio +async def test_extract_single_string_param(): + """测试提取单个字符串参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await ParameterExtractorNode(SINGLE_STRING_PARAM_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "name" in result + assert isinstance(result["name"], str) + assert "张三" in result["name"] + + +@pytest.mark.asyncio +async def test_extract_multi_params(): + """测试提取多个参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await ParameterExtractorNode(MULTI_PARAMS_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "name" in result + assert "age" in result + assert "city" in result + assert isinstance(result["name"], str) + assert isinstance(result["age"], (int, float)) + assert "李四" in result["name"] + assert result["age"] == 30 + assert "北京" in result["city"] + + +# ==================== 数字参数提取测试 ==================== +@pytest.mark.asyncio +async def test_extract_number_params(): + """测试提取数字参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await ParameterExtractorNode(NUMBER_PARAM_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "price" in result + assert "stock" in result + assert isinstance(result["price"], (int, float)) + assert isinstance(result["stock"], (int, float)) + assert abs(result["price"] - 99.99) < 0.1 + assert result["stock"] == 100 + + +# ==================== 布尔参数提取测试 ==================== +@pytest.mark.asyncio +async def test_extract_boolean_params(): + """测试提取布尔参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await ParameterExtractorNode(BOOLEAN_PARAM_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "verified" in result + assert "phone_bound" in result + assert isinstance(result["verified"], bool) + assert isinstance(result["phone_bound"], bool) + assert result["verified"] is True + assert result["phone_bound"] is False + + +# ==================== 数组参数提取测试 ==================== +@pytest.mark.asyncio +async def test_extract_array_string_param(): + """测试提取字符串数组参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await ParameterExtractorNode(ARRAY_STRING_PARAM_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "fruits" in result + assert isinstance(result["fruits"], list) + assert len(result["fruits"]) >= 3 + assert "苹果" in result["fruits"] + assert "香蕉" in result["fruits"] + assert "橙子" in result["fruits"] + + +@pytest.mark.asyncio +async def test_extract_array_number_param(): + """测试提取数字数组参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await ParameterExtractorNode(ARRAY_NUMBER_PARAM_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "weekly_sales" in result + assert isinstance(result["weekly_sales"], list) + assert len(result["weekly_sales"]) == 4 + assert 10000 in result["weekly_sales"] + assert 12000 in result["weekly_sales"] + assert 15000 in result["weekly_sales"] + assert 18000 in result["weekly_sales"] + + +# ==================== 自定义提示测试 ==================== +@pytest.mark.asyncio +async def test_extract_with_custom_prompt(): + """测试使用自定义提示提取参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await ParameterExtractorNode(CUSTOM_PROMPT_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "order_id" in result + assert "amount" in result + assert "ORD123456" in result["order_id"] + assert isinstance(result["amount"], (int, float)) + assert result["amount"] == 299 + + +# ==================== 变量输入测试 ==================== +@pytest.mark.asyncio +async def test_extract_with_variable_input(): + """测试使用变量作为输入文本""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "user_input", "我叫王五,今年28岁", VariableType.STRING, mut=True) + + result = await ParameterExtractorNode(VARIABLE_INPUT_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "name" in result + assert "age" in result + assert "王五" in result["name"] + assert result["age"] == 28 + + +# ==================== 复杂场景测试 ==================== +@pytest.mark.asyncio +async def test_extract_from_complex_text(): + """测试从复杂文本中提取参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": """ + 客户信息: + 姓名:赵六 + 年龄:35岁 + 职业:软件工程师 + 城市:上海 + 邮箱:zhaoliu@example.com + 是否VIP:是 + """, + "params": [ + { + "name": "name", + "type": "string", + "desc": "客户姓名", + "required": True + }, + { + "name": "age", + "type": "number", + "desc": "客户年龄", + "required": True + }, + { + "name": "occupation", + "type": "string", + "desc": "客户职业", + "required": False + }, + { + "name": "city", + "type": "string", + "desc": "所在城市", + "required": False + }, + { + "name": "is_vip", + "type": "boolean", + "desc": "是否为VIP客户", + "required": False + } + ], + "prompt": "" + } + } + + result = await ParameterExtractorNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "name" in result + assert "age" in result + assert "赵六" in result["name"] + assert result["age"] == 35 + if "occupation" in result: + assert "工程师" in result["occupation"] + if "city" in result: + assert "上海" in result["city"] + if "is_vip" in result: + assert result["is_vip"] is True + + +@pytest.mark.asyncio +async def test_extract_optional_params(): + """测试提取可选参数""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "我叫小明", + "params": [ + { + "name": "name", + "type": "string", + "desc": "用户姓名", + "required": True + }, + { + "name": "age", + "type": "number", + "desc": "用户年龄", + "required": False + }, + { + "name": "city", + "type": "string", + "desc": "所在城市", + "required": False + } + ], + "prompt": "" + } + } + + result = await ParameterExtractorNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "name" in result + assert "小明" in result["name"] + # 可选参数可能不存在或为 None + + +@pytest.mark.asyncio +async def test_extract_with_sys_message(): + """测试使用系统消息变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("我叫小红,今年22岁") + + config = { + "id": "param_extractor_test", + "type": "parameter-extractor", + "name": "参数提取测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "text": "{{ sys.message }}", + "params": [ + { + "name": "name", + "type": "string", + "desc": "用户姓名", + "required": True + }, + { + "name": "age", + "type": "number", + "desc": "用户年龄", + "required": True + } + ], + "prompt": "" + } + } + + result = await ParameterExtractorNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "name" in result + assert "age" in result + assert "小红" in result["name"] + assert result["age"] == 22 diff --git a/api/tests/workflow/nodes/test_question_classifier_node.py b/api/tests/workflow/nodes/test_question_classifier_node.py new file mode 100644 index 00000000..777033ae --- /dev/null +++ b/api/tests/workflow/nodes/test_question_classifier_node.py @@ -0,0 +1,647 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 +import pytest + +from app.core.workflow.nodes import QuestionClassifierNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool + + +# 基础分类配置 - 两个类别 +BASIC_TWO_CATEGORIES_CONFIG = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "我想买一台笔记本电脑", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "售后服务" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } +} + +# 多类别配置 +MULTI_CATEGORIES_CONFIG = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "我的订单什么时候能到?", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "订单查询" + }, + { + "class_name": "售后服务" + }, + { + "class_name": "投诉建议" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } +} + +# 带补充提示的配置 +WITH_SUPPLEMENT_PROMPT_CONFIG = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "这个产品怎么样?", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "用户评价" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": "如果用户在询问产品信息或特性,归类为产品咨询;如果是评价或反馈,归类为用户评价" + } +} + +# 使用变量的配置 +VARIABLE_INPUT_CONFIG = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "{{ conv.user_question }}", + "categories": [ + { + "class_name": "技术支持" + }, + { + "class_name": "账号问题" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } +} + +# 使用系统消息的配置 +SYS_MESSAGE_CONFIG = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "{{ sys.message }}", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "售后服务" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } +} + +# 空问题配置 +EMPTY_QUESTION_CONFIG = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "售后服务" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } +} + + +# ==================== 基础分类测试 ==================== +@pytest.mark.asyncio +async def test_classify_product_inquiry(): + """测试产品咨询分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await QuestionClassifierNode(BASIC_TWO_CATEGORIES_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "class_name" in result + assert "output" in result + assert result["class_name"] == "产品咨询" + assert result["output"] == "CASE1" + + +@pytest.mark.asyncio +async def test_classify_after_sales(): + """测试售后服务分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "我的产品坏了,怎么维修?", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "售后服务" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "售后服务" + assert result["output"] == "CASE2" + + +# ==================== 多类别分类测试 ==================== +@pytest.mark.asyncio +async def test_classify_order_inquiry(): + """测试订单查询分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await QuestionClassifierNode(MULTI_CATEGORIES_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "订单查询" + assert result["output"] == "CASE2" + + +@pytest.mark.asyncio +async def test_classify_complaint(): + """测试投诉建议分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "你们的服务态度太差了!", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "订单查询" + }, + { + "class_name": "售后服务" + }, + { + "class_name": "投诉建议" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "投诉建议" + assert result["output"] == "CASE4" + + +# ==================== 补充提示测试 ==================== +@pytest.mark.asyncio +async def test_classify_with_supplement_prompt(): + """测试使用补充提示进行分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await QuestionClassifierNode(WITH_SUPPLEMENT_PROMPT_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "class_name" in result + assert "output" in result + assert result["class_name"] in ["产品咨询", "用户评价"] + assert result["output"] in ["CASE1", "CASE2"] + + +# ==================== 变量输入测试 ==================== +@pytest.mark.asyncio +async def test_classify_with_conv_variable(): + """测试使用 conv 变量作为输入""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + await variable_pool.new("conv", "user_question", "我忘记密码了", VariableType.STRING, mut=True) + + result = await QuestionClassifierNode(VARIABLE_INPUT_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "账号问题" + assert result["output"] == "CASE2" + + +@pytest.mark.asyncio +async def test_classify_with_sys_message(): + """测试使用系统消息变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("我想了解一下你们的产品功能") + + result = await QuestionClassifierNode(SYS_MESSAGE_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "产品咨询" + assert result["output"] == "CASE1" + + +# ==================== 边界情况测试 ==================== +@pytest.mark.asyncio +async def test_classify_empty_question(): + """测试空问题输入""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await QuestionClassifierNode(EMPTY_QUESTION_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "class_name" in result + assert "output" in result + # 空问题应该返回默认分类(第一个分类) + assert result["class_name"] == "产品咨询" + assert result["output"] == "CASE1" + + +@pytest.mark.asyncio +async def test_classify_single_category(): + """测试只有一个分类的情况""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "任何问题", + "categories": [ + { + "class_name": "通用咨询" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "通用咨询" + assert result["output"] == "CASE1" + + +# ==================== 复杂场景测试 ==================== +@pytest.mark.asyncio +async def test_classify_ambiguous_question(): + """测试模糊问题分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "你好", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "售后服务" + }, + { + "class_name": "闲聊" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] in ["产品咨询", "售后服务", "闲聊"] + assert result["output"] in ["CASE1", "CASE2", "CASE3"] + + +@pytest.mark.asyncio +async def test_classify_long_question(): + """测试长问题分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "我在上个月购买了你们的产品,使用了一段时间后发现有一些问题,想咨询一下售后政策和维修流程,请问应该怎么办?", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "售后服务" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "售后服务" + assert result["output"] == "CASE2" + + +@pytest.mark.asyncio +async def test_classify_technical_support(): + """测试技术支持分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "软件安装失败,报错代码0x80070005", + "categories": [ + { + "class_name": "技术支持" + }, + { + "class_name": "账号问题" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "技术支持" + assert result["output"] == "CASE1" + + +@pytest.mark.asyncio +async def test_classify_multiple_categories(): + """测试多个类别的详细分类""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "我想申请退款", + "categories": [ + { + "class_name": "产品咨询" + }, + { + "class_name": "订单查询" + }, + { + "class_name": "退换货" + }, + { + "class_name": "售后服务" + }, + { + "class_name": "投诉建议" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "退换货" + assert result["output"] == "CASE3" + + +@pytest.mark.asyncio +async def test_classify_with_detailed_supplement(): + """测试使用详细补充提示""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "这个功能怎么用?", + "categories": [ + { + "class_name": "产品使用" + }, + { + "class_name": "产品介绍" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": "如果用户询问如何使用某个功能,归类为产品使用;如果询问功能是什么或有什么功能,归类为产品介绍" + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "产品使用" + assert result["output"] == "CASE1" + + +@pytest.mark.asyncio +async def test_classify_chinese_categories(): + """测试中文类别名称""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "我要投诉", + "categories": [ + { + "class_name": "咨询类" + }, + { + "class_name": "投诉类" + }, + { + "class_name": "建议类" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] == "投诉类" + assert result["output"] == "CASE2" + + +@pytest.mark.asyncio +async def test_classify_case_mapping(): + """测试分类到 CASE 的映射关系""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "测试问题", + "categories": [ + { + "class_name": "类别A" + }, + { + "class_name": "类别B" + }, + { + "class_name": "类别C" + }, + { + "class_name": "类别D" + }, + { + "class_name": "类别E" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "class_name" in result + assert "output" in result + + # 验证 CASE 映射关系 + category_names = ["类别A", "类别B", "类别C", "类别D", "类别E"] + if result["class_name"] in category_names: + expected_case = f"CASE{category_names.index(result['class_name']) + 1}" + assert result["output"] == expected_case + + +@pytest.mark.asyncio +async def test_classify_with_special_characters(): + """测试包含特殊字符的问题""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "classifier_test", + "type": "question-classifier", + "name": "问题分类测试节点", + "config": { + "model_id": TEST_MODEL_ID, + "input_variable": "产品价格是多少?有优惠吗?", + "categories": [ + { + "class_name": "价格咨询" + }, + { + "class_name": "促销活动" + } + ], + "system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + "user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + "user_supplement_prompt": None + } + } + + result = await QuestionClassifierNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["class_name"] in ["价格咨询", "促销活动"] + assert result["output"] in ["CASE1", "CASE2"] diff --git a/api/tests/workflow/nodes/test_start_node.py b/api/tests/workflow/nodes/test_start_node.py new file mode 100644 index 00000000..fb6a3140 --- /dev/null +++ b/api/tests/workflow/nodes/test_start_node.py @@ -0,0 +1,735 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 +import pytest + +from app.core.workflow.nodes import StartNode +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool +from tests.workflow.nodes.base import ( + simple_state, + simple_vairable_pool, + TEST_EXECUTION_ID, + TEST_WORKSPACE_ID, + TEST_USER_ID, + TEST_CONVERSATION_ID, + TEST_FILE +) + + +async def create_variable_pool_with_inputs(message: str, input_variables: dict = None): + """创建带有自定义输入变量的变量池""" + variable_pool = VariablePool() + + sys_vars = { + "message": (message, VariableType.STRING), + "conversation_id": (TEST_CONVERSATION_ID, VariableType.STRING), + "execution_id": (TEST_EXECUTION_ID, VariableType.STRING), + "workspace_id": (TEST_WORKSPACE_ID, VariableType.STRING), + "user_id": (TEST_USER_ID, VariableType.STRING), + "input_variables": (input_variables or {}, VariableType.OBJECT), + "files": ([TEST_FILE], VariableType.ARRAY_FILE) + } + + for key, var_def in sys_vars.items(): + value = var_def[0] + var_type = var_def[1] + await variable_pool.new( + namespace='sys', + key=key, + value=value, + var_type=VariableType(var_type), + mut=False # 系统变量不可变 + ) + + return variable_pool + + +# 基础配置 - 无自定义变量 +BASIC_CONFIG = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [] + } +} + +# 带单个自定义变量的配置 +SINGLE_VARIABLE_CONFIG = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "language", + "type": "string", + "required": False, + "default": "zh-CN", + "description": "语言设置" + } + ] + } +} + +# 带多个自定义变量的配置 +MULTI_VARIABLES_CONFIG = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "language", + "type": "string", + "required": False, + "default": "zh-CN", + "description": "语言设置" + }, + { + "name": "max_length", + "type": "number", + "required": False, + "default": 1000, + "description": "最大长度" + }, + { + "name": "enable_cache", + "type": "boolean", + "required": False, + "default": True, + "description": "是否启用缓存" + } + ] + } +} + +# 带必需变量的配置 +REQUIRED_VARIABLE_CONFIG = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "api_key", + "type": "string", + "required": True, + "description": "API密钥" + } + ] + } +} + +# 混合必需和可选变量的配置 +MIXED_VARIABLES_CONFIG = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "user_id", + "type": "string", + "required": True, + "description": "用户ID" + }, + { + "name": "timeout", + "type": "number", + "required": False, + "default": 30, + "description": "超时时间(秒)" + } + ] + } +} + + +# 不同类型变量的配置 +DIFFERENT_TYPES_CONFIG = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "name", + "type": "string", + "required": False, + "default": "default_name", + "description": "名称" + }, + { + "name": "count", + "type": "number", + "required": False, + "default": 0, + "description": "计数" + }, + { + "name": "enabled", + "type": "boolean", + "required": False, + "default": False, + "description": "是否启用" + }, + { + "name": "tags", + "type": "array[string]", + "required": False, + "default": [], + "description": "标签列表" + }, + { + "name": "config", + "type": "object", + "required": False, + "default": {}, + "description": "配置对象" + } + ] + } +} + + +# ==================== 基础功能测试 ==================== +@pytest.mark.asyncio +async def test_start_node_basic(): + """测试基础 Start 节点(无自定义变量)""" + state = simple_state() + variable_pool = await simple_vairable_pool("test message") + + result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert "message" in result + assert "execution_id" in result + assert "conversation_id" in result + assert "workspace_id" in result + assert "user_id" in result + assert result["message"] == "test message" + + +@pytest.mark.asyncio +async def test_start_node_system_variables(): + """测试系统变量输出""" + state = simple_state() + variable_pool = await simple_vairable_pool("hello world") + + result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool) + + assert result["message"] == "hello world" + assert result["execution_id"] == state["execution_id"] + assert result["workspace_id"] == state["workspace_id"] + assert result["user_id"] == state["user_id"] + + +# ==================== 自定义变量测试 ==================== +@pytest.mark.asyncio +async def test_start_node_single_variable_with_default(): + """测试单个自定义变量使用默认值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool) + + assert "language" in result + assert result["language"] == "zh-CN" + + +@pytest.mark.asyncio +async def test_start_node_single_variable_with_input(): + """测试单个自定义变量使用输入值""" + state = simple_state() + + # 使用带输入变量的变量池 + input_vars = {"language": "en-US"} + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool) + + assert "language" in result + assert result["language"] == "en-US" + + +@pytest.mark.asyncio +async def test_start_node_multi_variables_with_defaults(): + """测试多个自定义变量使用默认值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool) + + assert "language" in result + assert "max_length" in result + assert "enable_cache" in result + assert result["language"] == "zh-CN" + assert result["max_length"] == 1000 + assert result["enable_cache"] is True + + +@pytest.mark.asyncio +async def test_start_node_multi_variables_with_inputs(): + """测试多个自定义变量使用输入值""" + state = simple_state() + + # 使用带输入变量的变量池 + input_vars = { + "language": "ja-JP", + "max_length": 2000, + "enable_cache": False + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool) + + assert result["language"] == "ja-JP" + assert result["max_length"] == 2000 + assert result["enable_cache"] is False + + +@pytest.mark.asyncio +async def test_start_node_partial_inputs(): + """测试部分输入变量,其他使用默认值""" + state = simple_state() + + # 只设置部分输入变量 + input_vars = { + "language": "fr-FR" + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool) + + assert result["language"] == "fr-FR" # 使用输入值 + assert result["max_length"] == 1000 # 使用默认值 + assert result["enable_cache"] is True # 使用默认值 + + +# ==================== 必需变量测试 ==================== +@pytest.mark.asyncio +async def test_start_node_required_variable_provided(): + """测试提供必需变量""" + state = simple_state() + + # 提供必需变量 + input_vars = { + "api_key": "test_api_key_12345" + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool) + + assert "api_key" in result + assert result["api_key"] == "test_api_key_12345" + + +@pytest.mark.asyncio +async def test_start_node_required_variable_missing(): + """测试缺少必需变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 不提供必需变量 + with pytest.raises(ValueError) as exc_info: + await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool) + + assert "缺少必需的输入变量" in str(exc_info.value) + assert "api_key" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_start_node_mixed_variables(): + """测试混合必需和可选变量""" + state = simple_state() + + # 只提供必需变量 + input_vars = { + "user_id": "user_123" + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool) + + assert result["user_id"] == "user_123" # 必需变量 + assert result["timeout"] == 30 # 可选变量使用默认值 + + +@pytest.mark.asyncio +async def test_start_node_mixed_variables_all_provided(): + """测试混合变量全部提供""" + state = simple_state() + + # 提供所有变量 + input_vars = { + "user_id": "user_456", + "timeout": 60 + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool) + + assert result["user_id"] == "user_456" + assert result["timeout"] == 60 + + +# ==================== 不同类型变量测试 ==================== +@pytest.mark.asyncio +async def test_start_node_different_types_defaults(): + """测试不同类型变量的默认值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool) + + assert result["name"] == "default_name" + assert result["count"] == 0 + assert result["enabled"] is False + assert result["tags"] == [] + assert result["config"] == {} + + +@pytest.mark.asyncio +async def test_start_node_different_types_inputs(): + """测试不同类型变量的输入值""" + state = simple_state() + + # 提供不同类型的输入值 + input_vars = { + "name": "custom_name", + "count": 100, + "enabled": True, + "tags": ["tag1", "tag2", "tag3"], + "config": {"key": "value", "nested": {"data": 123}} + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool) + + assert result["name"] == "custom_name" + assert result["count"] == 100 + assert result["enabled"] is True + assert result["tags"] == ["tag1", "tag2", "tag3"] + assert result["config"] == {"key": "value", "nested": {"data": 123}} + + +# ==================== 边界情况测试 ==================== +@pytest.mark.asyncio +async def test_start_node_empty_message(): + """测试空消息""" + state = simple_state() + variable_pool = await simple_vairable_pool("") + + result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool) + + assert result["message"] == "" + + +@pytest.mark.asyncio +async def test_start_node_no_input_variables(): + """测试没有输入变量的情况""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 不设置 input_variables + result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool) + + # 应该使用默认值 + assert result["language"] == "zh-CN" + + +@pytest.mark.asyncio +async def test_start_node_empty_input_variables(): + """测试空的输入变量字典""" + state = simple_state() + + # 设置空的输入变量字典 + variable_pool = await create_variable_pool_with_inputs("test", {}) + + result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool) + + # 应该使用默认值 + assert result["language"] == "zh-CN" + + +@pytest.mark.asyncio +async def test_start_node_extra_input_variables(): + """测试额外的输入变量(未在配置中定义)""" + state = simple_state() + + # 提供额外的未定义变量 + input_vars = { + "language": "de-DE", + "extra_var": "should_be_ignored" + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool) + + assert result["language"] == "de-DE" + assert "extra_var" not in result # 额外变量不应该出现在输出中 + + +# ==================== 数组类型变量测试 ==================== +@pytest.mark.asyncio +async def test_start_node_array_string_variable(): + """测试字符串数组变量""" + state = simple_state() + + config = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "categories", + "type": "array[string]", + "required": False, + "default": ["default1", "default2"], + "description": "分类列表" + } + ] + } + } + + input_vars = { + "categories": ["cat1", "cat2", "cat3"] + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(config, {}).execute(state, variable_pool) + + assert result["categories"] == ["cat1", "cat2", "cat3"] + + +@pytest.mark.asyncio +async def test_start_node_array_number_variable(): + """测试数字数组变量""" + state = simple_state() + + config = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "scores", + "type": "array[number]", + "required": False, + "default": [0, 0, 0], + "description": "分数列表" + } + ] + } + } + + input_vars = { + "scores": [85, 90, 95] + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(config, {}).execute(state, variable_pool) + + assert result["scores"] == [85, 90, 95] + + +@pytest.mark.asyncio +async def test_start_node_array_object_variable(): + """测试对象数组变量""" + state = simple_state() + + config = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "users", + "type": "array[object]", + "required": False, + "default": [], + "description": "用户列表" + } + ] + } + } + + input_vars = { + "users": [ + {"name": "Alice", "age": 25}, + {"name": "Bob", "age": 30} + ] + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(config, {}).execute(state, variable_pool) + + assert len(result["users"]) == 2 + assert result["users"][0]["name"] == "Alice" + assert result["users"][1]["age"] == 30 + + +# ==================== 复杂场景测试 ==================== +@pytest.mark.asyncio +async def test_start_node_complex_object(): + """测试复杂对象变量""" + state = simple_state() + + config = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "settings", + "type": "object", + "required": False, + "default": {"theme": "light"}, + "description": "设置对象" + } + ] + } + } + + input_vars = { + "settings": { + "theme": "dark", + "language": "zh-CN", + "notifications": { + "email": True, + "sms": False + }, + "features": ["feature1", "feature2"] + } + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(config, {}).execute(state, variable_pool) + + assert result["settings"]["theme"] == "dark" + assert result["settings"]["language"] == "zh-CN" + assert result["settings"]["notifications"]["email"] is True + assert result["settings"]["features"] == ["feature1", "feature2"] + + +@pytest.mark.asyncio +async def test_start_node_zero_and_false_values(): + """测试零值和 False 值(确保不被当作空值)""" + state = simple_state() + + config = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "count", + "type": "number", + "required": False, + "default": 10, + "description": "计数" + }, + { + "name": "enabled", + "type": "boolean", + "required": False, + "default": True, + "description": "是否启用" + } + ] + } + } + + input_vars = { + "count": 0, + "enabled": False + } + variable_pool = await create_variable_pool_with_inputs("test", input_vars) + + result = await StartNode(config, {}).execute(state, variable_pool) + + # 0 和 False 应该被正确识别,而不是使用默认值 + assert result["count"] == 0 + assert result["enabled"] is False + + +@pytest.mark.asyncio +async def test_start_node_output_types(): + """测试输出类型定义""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + node = StartNode(MULTI_VARIABLES_CONFIG, {}) + await node.execute(state, variable_pool) + + output_types = node._output_types() + + # 验证系统变量类型 + assert output_types["message"] == VariableType.STRING + assert output_types["execution_id"] == VariableType.STRING + assert output_types["conversation_id"] == VariableType.STRING + assert output_types["workspace_id"] == VariableType.STRING + assert output_types["user_id"] == VariableType.STRING + + # 验证自定义变量类型 + assert output_types["language"] == VariableType.STRING + assert output_types["max_length"] == VariableType.NUMBER + assert output_types["enable_cache"] == VariableType.BOOLEAN + + +@pytest.mark.asyncio +async def test_start_node_multiple_executions(): + """测试多次执行 Start 节点""" + state = simple_state() + + node = StartNode(SINGLE_VARIABLE_CONFIG, {}) + + # 第一次执行 + variable_pool1 = await create_variable_pool_with_inputs("first message", {}) + result1 = await node.execute(state, variable_pool1) + assert result1["message"] == "first message" + assert result1["language"] == "zh-CN" + + # 第二次执行(使用新的变量池) + variable_pool2 = await create_variable_pool_with_inputs("second message", {}) + result2 = await node.execute(state, variable_pool2) + assert result2["message"] == "second message" + assert result2["language"] == "zh-CN" + + +@pytest.mark.asyncio +async def test_start_node_with_description(): + """测试带描述的变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "start_test", + "type": "start", + "name": "开始节点", + "config": { + "variables": [ + { + "name": "api_endpoint", + "type": "string", + "required": True, + "description": "API 端点 URL,用于连接外部服务" + } + ] + } + } + + # 测试缺少必需变量时,错误信息包含描述 + with pytest.raises(ValueError) as exc_info: + await StartNode(config, {}).execute(state, variable_pool) + + assert "api_endpoint" in str(exc_info.value) + assert "API 端点 URL" in str(exc_info.value) diff --git a/api/tests/workflow/nodes/test_variable_aggregator_node.py b/api/tests/workflow/nodes/test_variable_aggregator_node.py new file mode 100644 index 00000000..3086e9eb --- /dev/null +++ b/api/tests/workflow/nodes/test_variable_aggregator_node.py @@ -0,0 +1,621 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/6 +import pytest + +from app.core.workflow.nodes import VariableAggregatorNode +from app.core.workflow.variable.base_variable import VariableType +from tests.workflow.nodes.base import simple_state, simple_vairable_pool + + +# 非分组模式配置 - 返回第一个非空变量 +NON_GROUP_CONFIG = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.var1}}", + "{{conv.var2}}", + "{{conv.var3}}" + ] + } +} + +# 非分组模式配置 - 带类型定义 +NON_GROUP_WITH_TYPE_CONFIG = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.var1}}", + "{{conv.var2}}" + ], + "group_type": { + "output": "string" + } + } +} + +# 分组模式配置 +GROUP_CONFIG = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": True, + "group_variables": { + "user_message": [ + "{{conv.msg1}}", + "{{conv.msg2}}" + ], + "user_name": [ + "{{conv.name1}}", + "{{conv.name2}}" + ] + } + } +} + +# 分组模式配置 - 带类型定义 +GROUP_WITH_TYPE_CONFIG = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": True, + "group_variables": { + "count": [ + "{{conv.count1}}", + "{{conv.count2}}" + ], + "enabled": [ + "{{conv.flag1}}", + "{{conv.flag2}}" + ] + }, + "group_type": { + "count": "number", + "enabled": "boolean" + } + } +} + + +# ==================== 非分组模式测试 ==================== +@pytest.mark.asyncio +async def test_non_group_first_variable(): + """测试非分组模式返回第一个非空变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 设置变量 + await variable_pool.new("conv", "var1", "first_value", VariableType.STRING, mut=True) + await variable_pool.new("conv", "var2", "second_value", VariableType.STRING, mut=True) + await variable_pool.new("conv", "var3", "third_value", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool) + + assert result == "first_value" + + +@pytest.mark.asyncio +async def test_non_group_skip_none(): + """测试非分组模式跳过 None 值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 第一个变量不存在,第二个存在 + await variable_pool.new("conv", "var2", "second_value", VariableType.STRING, mut=True) + await variable_pool.new("conv", "var3", "third_value", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool) + + assert result == "second_value" + + +@pytest.mark.asyncio +async def test_non_group_all_none(): + """测试非分组模式所有变量都不存在""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 不创建任何变量 + result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool) + + assert result == "" + + +@pytest.mark.asyncio +async def test_non_group_with_type_all_none(): + """测试非分组模式带类型定义,所有变量都不存在""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 不创建任何变量 + result = await VariableAggregatorNode(NON_GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool) + + # 应该返回类型的默认值 + assert result == "" + + +@pytest.mark.asyncio +async def test_non_group_different_types(): + """测试非分组模式不同类型的变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.num}}", + "{{conv.str}}", + "{{conv.bool}}" + ] + } + } + + # 设置不同类型的变量 + await variable_pool.new("conv", "num", 123, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "str", "text", VariableType.STRING, mut=True) + await variable_pool.new("conv", "bool", True, VariableType.BOOLEAN, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + assert result == 123 + + +@pytest.mark.asyncio +async def test_non_group_zero_and_false(): + """测试非分组模式零值和 False 值(不应被视为 None)""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.zero}}", + "{{conv.text}}" + ] + } + } + + # 设置零值 + await variable_pool.new("conv", "zero", 0, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + # 0 不应被视为 None,应该返回 0 + assert result == 0 + + +@pytest.mark.asyncio +async def test_non_group_false_value(): + """测试非分组模式 False 值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.flag}}", + "{{conv.text}}" + ] + } + } + + # 设置 False 值 + await variable_pool.new("conv", "flag", False, VariableType.BOOLEAN, mut=True) + await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + # False 不应被视为 None,应该返回 False + assert result is False + + +# ==================== 分组模式测试 ==================== +@pytest.mark.asyncio +async def test_group_mode_all_groups(): + """测试分组模式所有分组都有值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 设置变量 + await variable_pool.new("conv", "msg1", "Hello", VariableType.STRING, mut=True) + await variable_pool.new("conv", "name1", "Alice", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["user_message"] == "Hello" + assert result["user_name"] == "Alice" + + +@pytest.mark.asyncio +async def test_group_mode_fallback(): + """测试分组模式使用备用变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 第一个变量不存在,使用第二个 + await variable_pool.new("conv", "msg2", "Fallback message", VariableType.STRING, mut=True) + await variable_pool.new("conv", "name2", "Bob", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool) + + assert result["user_message"] == "Fallback message" + assert result["user_name"] == "Bob" + + +@pytest.mark.asyncio +async def test_group_mode_partial_none(): + """测试分组模式部分分组没有值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 只设置一个分组的变量 + await variable_pool.new("conv", "msg1", "Hello", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool) + + assert result["user_message"] == "Hello" + assert result["user_name"] == "" # 没有值的分组返回空字符串 + + +@pytest.mark.asyncio +async def test_group_mode_all_none(): + """测试分组模式所有分组都没有值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 不创建任何变量 + result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["user_message"] == "" + assert result["user_name"] == "" + + +@pytest.mark.asyncio +async def test_group_mode_with_type(): + """测试分组模式带类型定义""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 设置变量 + await variable_pool.new("conv", "count1", 100, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "flag1", True, VariableType.BOOLEAN, mut=True) + + result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool) + + assert result["count"] == 100 + assert result["enabled"] is True + + +@pytest.mark.asyncio +async def test_group_mode_with_type_defaults(): + """测试分组模式带类型定义,使用默认值""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 不创建任何变量 + result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool) + + # 应该返回类型的默认值 + assert result["count"] == 0 # number 的默认值 + assert result["enabled"] is False # boolean 的默认值 + + +@pytest.mark.asyncio +async def test_group_mode_mixed_values(): + """测试分组模式混合有值和无值的情况""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + # 只设置 count2 + await variable_pool.new("conv", "count2", 200, VariableType.NUMBER, mut=True) + + result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool) + + assert result["count"] == 200 # 使用第二个变量 + assert result["enabled"] is False # 没有值,使用默认值 + + +@pytest.mark.asyncio +async def test_group_mode_multiple_groups(): + """测试分组模式多个分组""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": True, + "group_variables": { + "group1": ["{{conv.g1_v1}}", "{{conv.g1_v2}}"], + "group2": ["{{conv.g2_v1}}", "{{conv.g2_v2}}"], + "group3": ["{{conv.g3_v1}}", "{{conv.g3_v2}}"] + } + } + } + + # 设置不同分组的变量 + await variable_pool.new("conv", "g1_v1", "value1", VariableType.STRING, mut=True) + await variable_pool.new("conv", "g2_v2", "value2", VariableType.STRING, mut=True) + await variable_pool.new("conv", "g3_v1", "value3", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + assert result["group1"] == "value1" + assert result["group2"] == "value2" + assert result["group3"] == "value3" + + +# ==================== 复杂场景测试 ==================== +@pytest.mark.asyncio +async def test_aggregator_with_array(): + """测试聚合数组变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.arr1}}", + "{{conv.arr2}}" + ] + } + } + + # 设置数组变量 + await variable_pool.new("conv", "arr1", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True) + await variable_pool.new("conv", "arr2", [4, 5, 6], VariableType.ARRAY_NUMBER, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + assert result == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_aggregator_with_object(): + """测试聚合对象变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.obj1}}", + "{{conv.obj2}}" + ] + } + } + + # 设置对象变量 + await variable_pool.new("conv", "obj1", {"key": "value1"}, VariableType.OBJECT, mut=True) + await variable_pool.new("conv", "obj2", {"key": "value2"}, VariableType.OBJECT, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + assert result == {"key": "value1"} + + +@pytest.mark.asyncio +async def test_aggregator_empty_string(): + """测试空字符串不被视为 None""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.empty}}", + "{{conv.text}}" + ] + } + } + + # 设置空字符串 + await variable_pool.new("conv", "empty", "", VariableType.STRING, mut=True) + await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + # 空字符串不应被视为 None,应该返回空字符串 + assert result == "" + + +@pytest.mark.asyncio +async def test_aggregator_empty_array(): + """测试空数组不被视为 None""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.empty_arr}}", + "{{conv.arr}}" + ] + } + } + + # 设置空数组 + await variable_pool.new("conv", "empty_arr", [], VariableType.ARRAY_NUMBER, mut=True) + await variable_pool.new("conv", "arr", [1, 2], VariableType.ARRAY_NUMBER, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + # 空数组不应被视为 None,应该返回空数组 + assert result == [] + + +@pytest.mark.asyncio +async def test_aggregator_empty_object(): + """测试空对象不被视为 None""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.empty_obj}}", + "{{conv.obj}}" + ] + } + } + + # 设置空对象 + await variable_pool.new("conv", "empty_obj", {}, VariableType.OBJECT, mut=True) + await variable_pool.new("conv", "obj", {"key": "value"}, VariableType.OBJECT, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + # 空对象不应被视为 None,应该返回空对象 + assert result == {} + + +@pytest.mark.asyncio +async def test_group_mode_with_different_types(): + """测试分组模式不同类型的变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": True, + "group_variables": { + "text": ["{{conv.str1}}", "{{conv.str2}}"], + "number": ["{{conv.num1}}", "{{conv.num2}}"], + "array": ["{{conv.arr1}}", "{{conv.arr2}}"], + "object": ["{{conv.obj1}}", "{{conv.obj2}}"] + }, + "group_type": { + "text": "string", + "number": "number", + "array": "array[number]", + "object": "object" + } + } + } + + # 设置不同类型的变量 + await variable_pool.new("conv", "str1", "hello", VariableType.STRING, mut=True) + await variable_pool.new("conv", "num1", 42, VariableType.NUMBER, mut=True) + await variable_pool.new("conv", "arr1", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True) + await variable_pool.new("conv", "obj1", {"key": "value"}, VariableType.OBJECT, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + assert result["text"] == "hello" + assert result["number"] == 42 + assert result["array"] == [1, 2, 3] + assert result["object"] == {"key": "value"} + + +@pytest.mark.asyncio +async def test_aggregator_output_types(): + """测试输出类型定义""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + node = VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}) + + output_types = node._output_types() + + assert output_types["count"] == VariableType.NUMBER + assert output_types["enabled"] == VariableType.BOOLEAN + + +@pytest.mark.asyncio +async def test_non_group_single_variable(): + """测试非分组模式只有一个变量""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": False, + "group_variables": [ + "{{conv.only_var}}" + ] + } + } + + await variable_pool.new("conv", "only_var", "single_value", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + assert result == "single_value" + + +@pytest.mark.asyncio +async def test_group_mode_single_group(): + """测试分组模式只有一个分组""" + state = simple_state() + variable_pool = await simple_vairable_pool("test") + + config = { + "id": "aggregator_test", + "type": "var-aggregator", + "name": "变量聚合测试节点", + "config": { + "group": True, + "group_variables": { + "only_group": ["{{conv.var1}}", "{{conv.var2}}"] + } + } + } + + await variable_pool.new("conv", "var1", "value", VariableType.STRING, mut=True) + + result = await VariableAggregatorNode(config, {}).execute(state, variable_pool) + + assert isinstance(result, dict) + assert result["only_group"] == "value" diff --git a/sandbox/app/controllers/sandbox_controller.py b/sandbox/app/controllers/sandbox_controller.py index f9bc3fc0..a006ad20 100644 --- a/sandbox/app/controllers/sandbox_controller.py +++ b/sandbox/app/controllers/sandbox_controller.py @@ -36,7 +36,7 @@ async def run_code(request: RunCodeRequest): elif request.language == "javascript": return await run_nodejs_code(request.code, request.preload, request.options) else: - return error_response(-400, "unsupported language") + return error_response(400, "unsupported language") @router.get("/dependencies", response_model=ApiResponse) @@ -45,7 +45,7 @@ async def get_dependencies(language: str): if language == "python3": return await list_python_dependencies() else: - return error_response(-400, "unsupported language") + return error_response(400, "unsupported language") @router.post("/dependencies/update", response_model=ApiResponse) @@ -54,4 +54,4 @@ async def update_dependencies(request: UpdateDependencyRequest): if request.language == "python3": return await update_python_dependencies() else: - return error_response(-400, "unsupported language") + return error_response(400, "unsupported language") diff --git a/sandbox/app/models.py b/sandbox/app/models.py index e7492b4c..1c157a24 100644 --- a/sandbox/app/models.py +++ b/sandbox/app/models.py @@ -75,6 +75,4 @@ def success_response(data: Any) -> ApiResponse: def error_response(code: int, message: str) -> ApiResponse: """Create error response""" - if code >= 0: - code = -1 return ApiResponse(code=code, message=message, data=None) diff --git a/sandbox/app/services/nodejs_service.py b/sandbox/app/services/nodejs_service.py index ffd6127b..fb5d99cc 100644 --- a/sandbox/app/services/nodejs_service.py +++ b/sandbox/app/services/nodejs_service.py @@ -27,11 +27,11 @@ async def run_nodejs_code(code: str, preload: str, options: RunnerOptions): try: runner = NodejsRunner() result = await runner.run(code, options, preload) - if result.exit_code == signal.SIGSYS + 0x80: + if result.exit_code in [signal.SIGSYS + 0x80, -signal.SIGSYS]: return error_response(31, "sandbox security policy violation") if result.exit_code != 0: - return error_response(500, result.stderr) + return error_response(result.exit_code, result.stderr) return success_response(RunCodeResponse( stdout=result.stdout, @@ -39,5 +39,5 @@ async def run_nodejs_code(code: str, preload: str, options: RunnerOptions): )) except Exception as e: - logger.error(f"Python execution failed: {e}", exc_info=True) - return error_response(-500, str(e)) + logger.error(f"JavaScript execution failed: {e}", exc_info=True) + return error_response(500, str(e)) diff --git a/sandbox/app/services/python_service.py b/sandbox/app/services/python_service.py index 210b2086..ff3bbd04 100644 --- a/sandbox/app/services/python_service.py +++ b/sandbox/app/services/python_service.py @@ -47,7 +47,7 @@ async def run_python_code(code: str, preload: str, options: RunnerOptions): except Exception as e: logger.error(f"Python execution failed: {e}", exc_info=True) - return error_response(-500, str(e)) + return error_response(500, str(e)) async def list_python_dependencies():