feat(workflow): Add a new node for executing code
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
from app.core.workflow.nodes.code.node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
50
api/app/core/workflow/nodes/code/config.py
Normal file
50
api/app/core/workflow/nodes/code/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Literal
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
|
||||
|
||||
class InputVariable(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description="variable name"
|
||||
)
|
||||
|
||||
variable: str = Field(
|
||||
...,
|
||||
description="variable selector"
|
||||
)
|
||||
|
||||
|
||||
class OutputVariable(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description="variable name"
|
||||
)
|
||||
|
||||
type: VariableType = Field(
|
||||
...,
|
||||
description="variable selector"
|
||||
)
|
||||
|
||||
|
||||
class CodeNodeConfig(BaseNodeConfig):
|
||||
input_variables: list[InputVariable] = Field(
|
||||
default_factory=list,
|
||||
description="input variables"
|
||||
)
|
||||
|
||||
output_variables: list[OutputVariable] = Field(
|
||||
default_factory=list,
|
||||
description="output variables"
|
||||
)
|
||||
|
||||
code_content: str = Field(
|
||||
default="",
|
||||
description="code content"
|
||||
)
|
||||
|
||||
language: Literal['python3', 'nodejs'] = Field(
|
||||
...,
|
||||
description="language"
|
||||
)
|
||||
122
api/app/core/workflow/nodes/code/node.py
Normal file
122
api/app/core/workflow/nodes/code/node.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sympy.physics.vector import vlatex
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCRIPT_TEMPLATE = Template(dedent("""
|
||||
$code
|
||||
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
# decode and prepare input dict
|
||||
inputs_obj = json.loads(b64decode('$inputs_variable').decode('utf-8'))
|
||||
|
||||
# execute main function
|
||||
output_obj = main(**inputs_obj)
|
||||
|
||||
# convert output to json and print
|
||||
output_json = json.dumps(output_obj, indent=4)
|
||||
result = "<<RESULT>>" + output_json + "<<RESULT>>"
|
||||
print(result)
|
||||
"""))
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: CodeNodeConfig | None = None
|
||||
|
||||
def extract_result(self, content: str):
|
||||
match = re.search(r'<<RESULT>>(.*?)<<RESULT>>', content, re.DOTALL)
|
||||
if match:
|
||||
extracted = match.group(1)
|
||||
exec_result = json.loads(extracted)
|
||||
result = {}
|
||||
for output in self.typed_config.output_variables:
|
||||
value = exec_result.get(output.name)
|
||||
if not value:
|
||||
raise RuntimeError(f"Return value {output.name} does not exist")
|
||||
match output.type:
|
||||
case VariableType.STRING:
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError(f"Return value {output.name} should be a string")
|
||||
case VariableType.BOOLEAN:
|
||||
if not isinstance(value, bool):
|
||||
raise RuntimeError(f"Return value {output.name} should be a boolean")
|
||||
case VariableType.NUMBER:
|
||||
if not isinstance(value, (int, float)):
|
||||
raise RuntimeError(f"Return value {output.name} should be a number")
|
||||
case VariableType.OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
raise RuntimeError(f"Return value {output.name} should be a dictionary")
|
||||
case VariableType.ARRAY_STRING:
|
||||
if not isinstance(value, list) or not all(isinstance(v, str) for v in value):
|
||||
raise RuntimeError(f"Return value {output.name} should be a list of strings")
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
if not isinstance(value, list) or not all(isinstance(v, (int, float)) for v in value):
|
||||
raise RuntimeError(f"Return value {output.name} should be a list of numbers")
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
if not isinstance(value, list) or not all(isinstance(v, dict) for v in value):
|
||||
raise RuntimeError(f"Return value {output.name} should be a list of dictionaries")
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
if not isinstance(value, list) or not all(isinstance(v, bool) for v in value):
|
||||
raise RuntimeError(f"Return value {output.name} should be a list of booleans")
|
||||
result[output.name] = value
|
||||
return result
|
||||
else:
|
||||
raise RuntimeError("The output of main must be a dictionary")
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = CodeNodeConfig(**self.config)
|
||||
input_variable_dict = {}
|
||||
for input_variable in self.typed_config.input_variables:
|
||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
||||
code = base64.b64decode(
|
||||
self.typed_config.code
|
||||
).decode("utf-8")
|
||||
|
||||
input_variable_dict = base64.b64encode(
|
||||
json.dumps(input_variable_dict).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
final_script = SCRIPT_TEMPLATE.substitute(
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"http://sandbox:8194/v1/sandbox/run",
|
||||
headers={
|
||||
"x-api-key": 'redbear-sandbox'
|
||||
},
|
||||
json={
|
||||
"language": "python3",
|
||||
"code": base64.b64encode(final_script.encode("utf-8")).decode("utf-8"),
|
||||
"options": {
|
||||
"enable_network": True
|
||||
}
|
||||
}
|
||||
)
|
||||
resp = response.json()
|
||||
|
||||
match resp['code']:
|
||||
case 31:
|
||||
raise RuntimeError("Operation not permitted")
|
||||
case 0:
|
||||
return self.extract_result(resp["data"]["stdout"])
|
||||
case _:
|
||||
raise Exception(resp["message"])
|
||||
@@ -10,21 +10,22 @@ from app.core.workflow.nodes.base_config import (
|
||||
VariableDefinition,
|
||||
VariableType,
|
||||
)
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
__all__ = [
|
||||
# 基础类
|
||||
"BaseNodeConfig",
|
||||
@@ -49,5 +50,6 @@ __all__ = [
|
||||
"QuestionClassifierNodeConfig",
|
||||
"ToolNodeConfig",
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig"
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig"
|
||||
]
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Union
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
@@ -49,7 +50,8 @@ WorkflowNode = Union[
|
||||
QuestionClassifierNode,
|
||||
ToolNode,
|
||||
MemoryReadNode,
|
||||
MemoryWriteNode
|
||||
MemoryWriteNode,
|
||||
CodeNode
|
||||
]
|
||||
|
||||
|
||||
@@ -81,6 +83,7 @@ class NodeFactory:
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.MEMORY_READ: MemoryReadNode,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -15,7 +15,6 @@ class ExecutionResult:
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
self.exit_code = exit_code
|
||||
self.error = error
|
||||
|
||||
|
||||
class CodeExecutor(ABC):
|
||||
|
||||
@@ -9,12 +9,15 @@ from app.config import SANDBOX_USER_ID, SANDBOX_GROUP_ID, get_config
|
||||
from app.core.encryption import generate_key, encrypt_code
|
||||
from app.core.executor import CodeExecutor, ExecutionResult
|
||||
from app.core.runners.python.settings import check_lib_avaiable, release_lib_binary, LIB_PATH
|
||||
from app.logger import get_logger
|
||||
from app.models import RunnerOptions
|
||||
|
||||
# Python sandbox prescript template
|
||||
with open("app/core/runners/python/prescript.py") as f:
|
||||
PYTHON_PRESCRIPT = f.read()
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class PythonRunner(CodeExecutor):
|
||||
"""Python code runner with security isolation"""
|
||||
@@ -106,6 +109,7 @@ class PythonRunner(CodeExecutor):
|
||||
env["ALLOWED_SYSCALLS"] = ",".join(map(str, config.allowed_syscalls))
|
||||
|
||||
# Execute with Python interpreter
|
||||
logger.info(encoded_key)
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
config.python_path,
|
||||
@@ -143,7 +147,6 @@ class PythonRunner(CodeExecutor):
|
||||
stdout="",
|
||||
stderr="Execution timeout",
|
||||
exit_code=-1,
|
||||
error="Execution timeout"
|
||||
)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -37,8 +37,8 @@ async def run_python_code(code: str, preload: str, options: RunnerOptions):
|
||||
if result.exit_code == -signal.SIGSYS:
|
||||
return error_response(31, "sandbox security policy violation")
|
||||
|
||||
if result.error:
|
||||
return error_response(-500, result.error)
|
||||
if result.stderr:
|
||||
return error_response(500, result.stderr)
|
||||
|
||||
return success_response(RunCodeResponse(
|
||||
stdout=result.stdout,
|
||||
|
||||
Reference in New Issue
Block a user