Merge #34 into develop from feature/20251219_myh

feat(workflow): add assigner node and fix circular imports with minor code style cleanup

* feature/20251219_myh: (7 commits)
  style(service): workflow
  style(workflow): remove unnecessary indentation
  revert(workflow): read conversation variables from database instead of API input
  feat(workflow): add assigner node and fix circular imports with minor code style cleanup
  fix(workflow): fix incorrect list append/pop logic in assigner node
  fix(workflow): fix incorrect list extend logic in assigner node
  fix(workflow): fix incorrect list append logic in assigner node

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/34
This commit is contained in:
朱文辉
2025-12-23 17:06:43 +08:00
16 changed files with 466 additions and 181 deletions

View File

@@ -5,9 +5,11 @@
""" """
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.if_else import IfElseNode
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
@@ -23,5 +25,7 @@ __all__ = [
"StartNode", "StartNode",
"EndNode", "EndNode",
"NodeFactory", "NodeFactory",
"WorkflowNode" "WorkflowNode",
# "KnowledgeRetrievalNode",
"AssignerNode",
] ]

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.assigner.node import AssignerNode
__all__ = ["AssignerNode", "AssignerNodeConfig"]

View File

@@ -0,0 +1,21 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import AssignmentOperator
class AssignerNodeConfig(BaseNodeConfig):
variable_selector: str | list[str] = Field(
...,
description="Variables to be assigned",
)
operation: AssignmentOperator = Field(
...,
description="Operator to assign",
)
value: str | list[str] = Field(
...,
description="Values to assign",
)

View File

@@ -0,0 +1,80 @@
import logging
from typing import Any
from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = AssignerNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the assignment operation defined by this node.
Args:
state: The current workflow state, including conversation variables,
node outputs, and system variables.
Returns:
None or the result of the assignment operation.
"""
# Initialize a variable pool for accessing conversation, node, and system variables
pool = VariablePool(state)
# Get the target variable selector (e.g., "conv.test")
variable_selector = self.typed_config.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
variable_selector = variable_selector.split('.')
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
raise ValueError("Only conversation variables can be assigned.")
# Get the value or expression to assign
value = self.typed_config.value
if isinstance(value, list):
value = '.'.join(value)
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
pool, variable_selector, value
)
# Execute the configured assignment operation
match self.typed_config.operation:
case AssignmentOperator.ASSIGN:
operator.assign()
case AssignmentOperator.CLEAR:
operator.clear()
case AssignmentOperator.ADD:
operator.add()
case AssignmentOperator.SUBTRACT:
operator.subtract()
case AssignmentOperator.MULTIPLY:
operator.multiply()
case AssignmentOperator.DIVIDE:
operator.divide()
case AssignmentOperator.APPEND:
operator.append()
case AssignmentOperator.REMOVE_FIRST:
operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")

View File

@@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [ __all__ = [
# 基础类 # 基础类
@@ -28,4 +30,6 @@ __all__ = [
"AgentNodeConfig", "AgentNodeConfig",
"TransformNodeConfig", "TransformNodeConfig",
"IfElseNodeConfig", "IfElseNodeConfig",
# "KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",
] ]

View File

@@ -33,7 +33,7 @@ class EndNode(BaseNode):
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
# 如果配置了输出模板,使用模板渲染;否则使用默认输出 # 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template: if output_template:
output = self._render_template(output_template, state) output = self._render_template(output_template, state)
@@ -45,17 +45,17 @@ class EndNode(BaseNode):
total_nodes = len(node_outputs) total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output return output
def _extract_referenced_nodes(self, template: str) -> list[str]: def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID """从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa'] 例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args: Args:
template: 模板字符串 template: 模板字符串
Returns: Returns:
引用的节点 ID 列表 引用的节点 ID 列表
""" """
@@ -63,44 +63,44 @@ class EndNode(BaseNode):
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template) matches = re.findall(pattern, template)
return list(set(matches)) # 去重 return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]: def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用 """解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀' 例如:'你好 {{llm.output}}, 这是后缀'
返回:[ 返回:[
{"type": "static", "content": "你好 "}, {"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"}, {"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"} {"type": "static", "content": ", 这是后缀"}
] ]
Args: Args:
template: 模板字符串 template: 模板字符串
state: 工作流状态 state: 工作流状态
Returns: Returns:
模板部分列表 模板部分列表
""" """
import re import re
parts = [] parts = []
last_end = 0 last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格) # 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}' pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template): for match in re.finditer(pattern, template):
start, end = match.span() start, end = match.span()
# 添加前面的静态文本 # 添加前面的静态文本
if start > last_end: if start > last_end:
static_text = template[last_end:start] static_text = template[last_end:start]
if static_text: if static_text:
parts.append({"type": "static", "content": static_text}) parts.append({"type": "static", "content": static_text})
# 解析动态引用 # 解析动态引用
ref = match.group(1).strip() ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output # 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref: if '.' in ref:
node_id, field = ref.split('.', 1) node_id, field = ref.split('.', 1)
@@ -115,62 +115,62 @@ class EndNode(BaseNode):
# 直接渲染这部分 # 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state) rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered}) parts.append({"type": "static", "content": rendered})
last_end = end last_end = end
# 添加最后的静态文本 # 添加最后的静态文本
if last_end < len(template): if last_end < len(template):
static_text = template[last_end:] static_text = template[last_end:]
if static_text: if static_text:
parts.append({"type": "static", "content": static_text}) parts.append({"type": "static", "content": static_text})
return parts return parts
async def execute_stream(self, state: WorkflowState): async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑 """流式执行 end 节点业务逻辑
智能输出策略: 智能输出策略:
1. 检测模板中是否引用了直接上游节点 1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀) 2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了 3. 前缀和引用内容已经在上游节点流式输出时发送了
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa - 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出 - LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出) - End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Args: Args:
state: 工作流状态 state: 工作流状态
Yields: Yields:
完成标记 完成标记
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
if not output_template: if not output_template:
output = "工作流已完成" output = "工作流已完成"
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 找到直接上游节点 # 找到直接上游节点
direct_upstream_nodes = [] direct_upstream_nodes = []
for edge in self.workflow_config.get("edges", []): for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id: if edge.get("target") == self.node_id:
source_node_id = edge.get("source") source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id) direct_upstream_nodes.append(source_node_id)
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
# 解析模板部分 # 解析模板部分
parts = self._parse_template_parts(output_template, state) parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts): for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}") logger.info(f"[模板解析] part[{i}]: {part}")
# 找到第一个引用直接上游节点的动态引用 # 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None upstream_ref_index = None
for i, part in enumerate(parts): for i, part in enumerate(parts):
@@ -178,12 +178,12 @@ class EndNode(BaseNode):
upstream_ref_index = i upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
break break
if upstream_ref_index is None: if upstream_ref_index is None:
# 没有引用直接上游节点,输出完整模板内容 # 没有引用直接上游节点,输出完整模板内容
output = self._render_template(output_template, state) output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
# 通过 writer 发送完整内容(作为一个 message chunk # 通过 writer 发送完整内容(作为一个 message chunk
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
writer = get_stream_writer() writer = get_stream_writer()
@@ -196,14 +196,14 @@ class EndNode(BaseNode):
"is_suffix": False "is_suffix": False
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
# yield 完成标记 # yield 完成标记
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 有引用直接上游节点,只输出该引用之后的部分(后缀) # 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# 收集后缀部分 # 收集后缀部分
suffix_parts = [] suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}") logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}")
@@ -214,7 +214,7 @@ class EndNode(BaseNode):
# 静态文本 # 静态文本
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'") logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
suffix_parts.append(part["content"]) suffix_parts.append(part["content"])
elif part["type"] == "dynamic": elif part["type"] == "dynamic":
# Other dynamic references (if there are multiple references) # Other dynamic references (if there are multiple references)
node_id = part["node_id"] node_id = part["node_id"]
@@ -229,21 +229,21 @@ class EndNode(BaseNode):
except Exception as e: except Exception as e:
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}") logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
content = "" content = ""
# Convert to string if not None # Convert to string if not None
suffix_parts.append(str(content) if content is not None else "") suffix_parts.append(str(content) if content is not None else "")
# 拼接后缀 # 拼接后缀
suffix = "".join(suffix_parts) suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state) full_output = self._render_template(output_template, state)
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}") logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}") logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
if suffix: if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})") logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk # 一次性输出后缀(作为单个 chunk
@@ -266,8 +266,8 @@ class EndNode(BaseNode):
# 统计信息 # 统计信息
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs) total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点") logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
# yield 完成标记(包含完整输出) # yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output} yield {"__final__": True, "result": full_output}

View File

@@ -1,5 +1,14 @@
from enum import StrEnum from enum import StrEnum
from app.core.workflow.nodes.operators import (
StringOperator,
NumberOperator,
AssignmentOperatorType,
BooleanOperator,
ArrayOperator,
ObjectOperator
)
class NodeType(StrEnum): class NodeType(StrEnum):
START = "start" START = "start"
@@ -14,6 +23,7 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request" HTTP_REQUEST = "http-request"
TOOL = "tool" TOOL = "tool"
AGENT = "agent" AGENT = "agent"
ASSIGNER = "assigner"
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):
@@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum):
class LogicOperator(StrEnum): class LogicOperator(StrEnum):
AND = "and" AND = "and"
OR = "or" OR = "or"
class AssignmentOperator(StrEnum):
ASSIGN = "assign"
CLEAR = "clear"
ADD = "add" # +=
SUBTRACT = "subtract" # -=
MULTIPLY = "multiply" # *=
DIVIDE = "divide" # /=
APPEND = "append"
REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first"
@classmethod
def get_operator(cls, obj) -> AssignmentOperatorType:
if isinstance(obj, str):
return StringOperator
elif isinstance(obj, bool):
return BooleanOperator
elif isinstance(obj, (int, float)):
return NumberOperator
elif isinstance(obj, list):
return ArrayOperator
elif isinstance(obj, dict):
return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})")

View File

@@ -1,7 +1,7 @@
import logging import logging
from typing import Any from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail from app.core.workflow.nodes.if_else.config import ConditionDetail

View File

@@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.db import get_db_context from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
@@ -136,7 +137,7 @@ class LLMNode(BaseNode):
base_url=api_base, base_url=api_base,
extra_params=extra_params extra_params=extra_params
), ),
type=model_type type=ModelType(model_type)
) )
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")

View File

@@ -7,6 +7,7 @@
import logging import logging
from typing import Any, Union from typing import Any, Union
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
@@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.assigner import AssignerNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,6 +28,8 @@ WorkflowNode = Union[
IfElseNode, IfElseNode,
AgentNode, AgentNode,
TransformNode, TransformNode,
AssignerNode,
# KnowledgeRetrievalNode,
] ]
@@ -42,7 +46,9 @@ class NodeFactory:
NodeType.LLM: LLMNode, NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode, NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode, NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode NodeType.IF_ELSE: IfElseNode,
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
} }
@classmethod @classmethod
@@ -82,10 +88,6 @@ class NodeFactory:
""" """
node_type = node_config.get("type") node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类 # 获取节点类
node_class = cls._node_types.get(node_type) node_class = cls._node_types.get(node_type)
if not node_class: if not node_class:

View File

@@ -0,0 +1,146 @@
from abc import ABC
from typing import Union, Type
from app.core.workflow.variable_pool import VariablePool
class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right):
self.pool = pool
self.left_selector = left_selector
self.right = right
self.type_limit: type[str, int, dict, list] = None
def check(self, no_right=False):
left = self.pool.get(self.left_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
if not no_right and not isinstance(self.right, self.type_limit):
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
class StringOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = str
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, '')
class NumberOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = (float, int)
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, 0)
def add(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin + self.right)
def subtract(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin - self.right)
def multiply(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin * self.right)
def divide(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin / self.right)
class BooleanOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = bool
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, False)
class ArrayOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = list
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, list())
def append(self) -> None:
self.check(no_right=True)
# TODOrequire type limit in list
origin = self.pool.get(self.left_selector)
origin.append(self.right)
self.pool.set(self.left_selector, origin)
def extend(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin.extend(self.right)
self.pool.set(self.left_selector, origin)
def remove_last(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin.pop()
self.pool.set(self.left_selector, origin)
def remove_first(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin.pop(0)
self.pool.set(self.left_selector, origin)
class ObjectOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = object
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, dict())
AssignmentOperatorInstance = Union[
StringOperator,
NumberOperator,
BooleanOperator,
ArrayOperator,
ObjectOperator
]
AssignmentOperatorType = Type[AssignmentOperatorInstance]

View File

@@ -10,7 +10,10 @@
""" """
import logging import logging
from typing import Any from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from app.core.workflow.nodes import WorkflowState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -82,7 +85,7 @@ class VariablePool:
>>> pool.set(["conv", "user_name"], "张三") >>> pool.set(["conv", "user_name"], "张三")
""" """
def __init__(self, state: dict[str, Any]): def __init__(self, state: "WorkflowState"):
"""初始化变量池 """初始化变量池
Args: Args:

View File

@@ -15,25 +15,6 @@ class ModelType(StrEnum):
EMBEDDING = "embedding" EMBEDDING = "embedding"
RERANK = "rerank" RERANK = "rerank"
@classmethod
def from_str(cls, value: str) -> "ModelType":
"""
Get a ModelType enum instance from a string value.
Args:
value (str): The string representation of the model type.
Returns:
ModelType: The corresponding ModelType enum object.
Raises:
ValueError: If the given value does not match any ModelType.
"""
try:
return cls(value)
except ValueError:
raise ValueError(f"Invalid ModelType: {value}")
class ModelProvider(StrEnum): class ModelProvider(StrEnum):
"""模型提供商枚举""" """模型提供商枚举"""

View File

@@ -1,6 +1,7 @@
import uuid
import datetime import datetime
from typing import Optional, Any, List, Dict, TYPE_CHECKING import uuid
from typing import Optional, Any, List, Dict
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
@@ -20,20 +21,19 @@ class KnowledgeBaseConfig(BaseModel):
class KnowledgeRetrievalConfig(BaseModel): class KnowledgeRetrievalConfig(BaseModel):
"""知识库检索配置(支持多个知识库,每个有独立配置)""" """知识库检索配置(支持多个知识库,每个有独立配置)"""
knowledge_bases: List[KnowledgeBaseConfig] = Field( knowledge_bases: List[KnowledgeBaseConfig] = Field(
default_factory=list, default_factory=list,
description="关联的知识库列表,每个知识库有独立配置" description="关联的知识库列表,每个知识库有独立配置"
) )
# 多知识库融合策略 # 多知识库融合策略
merge_strategy: str = Field( merge_strategy: str = Field(
default="weighted", default="weighted",
description="多知识库结果融合策略: weighted | rrf | concat" description="多知识库结果融合策略: weighted | rrf | concat"
) )
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID") reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数") reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
"""工具配置""" """工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具") enabled: bool = Field(default=False, description="是否启用该工具")
@@ -63,7 +63,7 @@ class VariableDefinition(BaseModel):
name: str = Field(..., description="变量名称(标识符)") name: str = Field(..., description="变量名称(标识符)")
display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)") display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)")
type: str = Field( type: str = Field(
default="string", default="string",
description="变量类型: string(单行文本) | text(多行文本) | number(数字)" description="变量类型: string(单行文本) | text(多行文本) | number(数字)"
) )
required: bool = Field(default=False, description="是否必填") required: bool = Field(default=False, description="是否必填")
@@ -75,32 +75,32 @@ class AgentConfigCreate(BaseModel):
"""Agent 行为配置""" """Agent 行为配置"""
# 提示词配置 # 提示词配置
system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则") system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则")
# 模型配置 # 模型配置
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID") default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID")
model_parameters: ModelParameters = Field( model_parameters: ModelParameters = Field(
default_factory=ModelParameters, default_factory=ModelParameters,
description="模型参数配置temperature、max_tokens 等)" description="模型参数配置temperature、max_tokens 等)"
) )
# 知识库关联 # 知识库关联
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
default=None, default=None,
description="知识库检索配置" description="知识库检索配置"
) )
# 记忆配置 # 记忆配置
memory: MemoryConfig = Field( memory: MemoryConfig = Field(
default_factory=lambda: MemoryConfig(enabled=True), default_factory=lambda: MemoryConfig(enabled=True),
description="对话历史记忆配置" description="对话历史记忆配置"
) )
# 变量配置 # 变量配置
variables: List[VariableDefinition] = Field( variables: List[VariableDefinition] = Field(
default_factory=list, default_factory=list,
description="Agent 可用的变量列表" description="Agent 可用的变量列表"
) )
# 工具配置 # 工具配置
tools: Dict[str, ToolConfig] = Field( tools: Dict[str, ToolConfig] = Field(
default_factory=dict, default_factory=dict,
@@ -120,7 +120,7 @@ class AppCreate(BaseModel):
# only for type=agent # only for type=agent
agent_config: Optional[AgentConfigCreate] = None agent_config: Optional[AgentConfigCreate] = None
# only for type=multi_agent # only for type=multi_agent
multi_agent_config: Optional[Dict[str, Any]] = None multi_agent_config: Optional[Dict[str, Any]] = None
@@ -139,23 +139,23 @@ class AgentConfigUpdate(BaseModel):
"""更新 Agent 行为配置""" """更新 Agent 行为配置"""
# 提示词配置 # 提示词配置
system_prompt: Optional[str] = Field(default=None, description="系统提示词") system_prompt: Optional[str] = Field(default=None, description="系统提示词")
# 模型配置 # 模型配置
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID") default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID")
model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置") model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置")
# 知识库关联 # 知识库关联
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
default=None, default=None,
description="知识库检索配置" description="知识库检索配置"
) )
# 记忆配置 # 记忆配置
memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置") memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置")
# 变量配置 # 变量配置
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
# 工具配置 # 工具配置
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
@@ -185,7 +185,7 @@ class App(BaseModel):
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json") @field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime): def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -197,26 +197,26 @@ class AgentConfig(BaseModel):
id: uuid.UUID id: uuid.UUID
app_id: uuid.UUID app_id: uuid.UUID
# 提示词 # 提示词
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
# 模型配置 # 模型配置
default_model_config_id: Optional[uuid.UUID] = None default_model_config_id: Optional[uuid.UUID] = None
model_parameters: ModelParameters = Field(default_factory=ModelParameters) model_parameters: ModelParameters = Field(default_factory=ModelParameters)
# 知识库检索 # 知识库检索
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None
# 记忆配置 # 记忆配置
memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True)) memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True))
# 变量配置 # 变量配置
variables: List[VariableDefinition] = [] variables: List[VariableDefinition] = []
# 工具配置 # 工具配置
tools: Dict[str, ToolConfig] = {} tools: Dict[str, ToolConfig] = {}
is_active: bool is_active: bool
created_at: datetime.datetime created_at: datetime.datetime
updated_at: datetime.datetime updated_at: datetime.datetime
@@ -228,7 +228,7 @@ class AgentConfig(BaseModel):
if v is None: if v is None:
return ModelParameters() return ModelParameters()
return v return v
@field_validator("memory", mode="before") @field_validator("memory", mode="before")
@classmethod @classmethod
def validate_memory(cls, v): def validate_memory(cls, v):
@@ -236,7 +236,7 @@ class AgentConfig(BaseModel):
if v is None: if v is None:
return MemoryConfig(enabled=True) return MemoryConfig(enabled=True)
return v return v
@field_validator("variables", mode="before") @field_validator("variables", mode="before")
@classmethod @classmethod
def validate_variables(cls, v): def validate_variables(cls, v):
@@ -244,7 +244,7 @@ class AgentConfig(BaseModel):
if v is None: if v is None:
return [] return []
return v return v
@field_validator("tools", mode="before") @field_validator("tools", mode="before")
@classmethod @classmethod
def validate_tools(cls, v): def validate_tools(cls, v):
@@ -256,7 +256,7 @@ class AgentConfig(BaseModel):
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json") @field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime): def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -294,15 +294,15 @@ class AppRelease(BaseModel):
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json") @field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime): def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@field_serializer("published_at", when_used="json") @field_serializer("published_at", when_used="json")
def _serialize_published_at(self, dt: datetime.datetime): def _serialize_published_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
# ---------- App Share Schemas ---------- # ---------- App Share Schemas ----------
@@ -314,7 +314,7 @@ class AppShareCreate(BaseModel):
class AppShare(BaseModel): class AppShare(BaseModel):
"""应用分享输出""" """应用分享输出"""
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
id: uuid.UUID id: uuid.UUID
source_app_id: uuid.UUID source_app_id: uuid.UUID
source_workspace_id: uuid.UUID source_workspace_id: uuid.UUID
@@ -322,11 +322,11 @@ class AppShare(BaseModel):
shared_by: uuid.UUID shared_by: uuid.UUID
created_at: datetime.datetime created_at: datetime.datetime
updated_at: datetime.datetime updated_at: datetime.datetime
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json") @field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime): def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -382,14 +382,14 @@ class DraftRunCompareRequest(BaseModel):
conversation_id: Optional[str] = Field(None, description="会话ID") conversation_id: Optional[str] = Field(None, description="会话ID")
user_id: Optional[str] = Field(None, description="用户ID") user_id: Optional[str] = Field(None, description="用户ID")
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
models: List[ModelCompareItem] = Field( models: List[ModelCompareItem] = Field(
..., ...,
min_length=1, min_length=1,
max_length=5, max_length=5,
description="要对比的模型列表1-5个" description="要对比的模型列表1-5个"
) )
parallel: bool = Field(True, description="是否并行执行") parallel: bool = Field(True, description="是否并行执行")
stream: bool = Field(False, description="是否流式返回") stream: bool = Field(False, description="是否流式返回")
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)") timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
@@ -400,14 +400,14 @@ class ModelRunResult(BaseModel):
model_config_id: uuid.UUID model_config_id: uuid.UUID
model_name: str model_name: str
label: Optional[str] = None label: Optional[str] = None
parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数") parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数")
message: Optional[str] = None message: Optional[str] = None
usage: Optional[Dict[str, Any]] = None usage: Optional[Dict[str, Any]] = None
elapsed_time: float elapsed_time: float
error: Optional[str] = None error: Optional[str] = None
tokens_per_second: Optional[float] = None tokens_per_second: Optional[float] = None
cost_estimate: Optional[float] = None cost_estimate: Optional[float] = None
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
@@ -416,10 +416,10 @@ class ModelRunResult(BaseModel):
class DraftRunCompareResponse(BaseModel): class DraftRunCompareResponse(BaseModel):
"""多模型对比响应""" """多模型对比响应"""
results: List[ModelRunResult] results: List[ModelRunResult]
total_elapsed_time: float total_elapsed_time: float
successful_count: int successful_count: int
failed_count: int failed_count: int
fastest_model: Optional[str] = None fastest_model: Optional[str] = None
cheapest_model: Optional[str] = None cheapest_model: Optional[str] = None

View File

@@ -169,7 +169,7 @@ class PromptOptimizerService:
provider=api_config.provider, provider=api_config.provider,
api_key=api_config.api_key, api_key=api_config.api_key,
base_url=api_config.api_base base_url=api_config.api_base
), type=ModelType.from_str(model_config.type)) ), type=ModelType(model_config.type))
# build message # build message
messages = [ messages = [

View File

@@ -39,14 +39,14 @@ class WorkflowService:
# ==================== 配置管理 ==================== # ==================== 配置管理 ====================
def create_workflow_config( def create_workflow_config(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
nodes: list[dict[str, Any]], nodes: list[dict[str, Any]],
edges: list[dict[str, Any]], edges: list[dict[str, Any]],
variables: list[dict[str, Any]] | None = None, variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None, execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None, triggers: list[dict[str, Any]] | None = None,
validate: bool = True validate: bool = True
) -> WorkflowConfig: ) -> WorkflowConfig:
"""创建工作流配置 """创建工作流配置
@@ -109,14 +109,14 @@ class WorkflowService:
return self.config_repo.get_by_app_id(app_id) return self.config_repo.get_by_app_id(app_id)
def update_workflow_config( def update_workflow_config(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
nodes: list[dict[str, Any]] | None = None, nodes: list[dict[str, Any]] | None = None,
edges: list[dict[str, Any]] | None = None, edges: list[dict[str, Any]] | None = None,
variables: list[dict[str, Any]] | None = None, variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None, execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None, triggers: list[dict[str, Any]] | None = None,
validate: bool = True validate: bool = True
) -> WorkflowConfig: ) -> WorkflowConfig:
"""更新工作流配置 """更新工作流配置
@@ -226,8 +226,8 @@ class WorkflowService:
return config return config
def validate_workflow_config_for_publish( def validate_workflow_config_for_publish(
self, self,
app_id: uuid.UUID app_id: uuid.UUID
) -> tuple[bool, list[str]]: ) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布 """验证工作流配置是否可以发布
@@ -260,13 +260,13 @@ class WorkflowService:
# ==================== 执行管理 ==================== # ==================== 执行管理 ====================
def create_execution( def create_execution(
self, self,
workflow_config_id: uuid.UUID, workflow_config_id: uuid.UUID,
app_id: uuid.UUID, app_id: uuid.UUID,
trigger_type: str, trigger_type: str,
triggered_by: uuid.UUID | None = None, triggered_by: uuid.UUID | None = None,
conversation_id: uuid.UUID | None = None, conversation_id: uuid.UUID | None = None,
input_data: dict[str, Any] | None = None input_data: dict[str, Any] | None = None
) -> WorkflowExecution: ) -> WorkflowExecution:
"""创建工作流执行记录 """创建工作流执行记录
@@ -314,10 +314,10 @@ class WorkflowService:
return self.execution_repo.get_by_execution_id(execution_id) return self.execution_repo.get_by_execution_id(execution_id)
def get_executions_by_app( def get_executions_by_app(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
limit: int = 50, limit: int = 50,
offset: int = 0 offset: int = 0
) -> list[WorkflowExecution]: ) -> list[WorkflowExecution]:
"""获取应用的执行记录列表 """获取应用的执行记录列表
@@ -332,12 +332,12 @@ class WorkflowService:
return self.execution_repo.get_by_app_id(app_id, limit, offset) return self.execution_repo.get_by_app_id(app_id, limit, offset)
def update_execution_status( def update_execution_status(
self, self,
execution_id: str, execution_id: str,
status: str, status: str,
output_data: dict[str, Any] | None = None, output_data: dict[str, Any] | None = None,
error_message: str | None = None, error_message: str | None = None,
error_node_id: str | None = None error_node_id: str | None = None
) -> WorkflowExecution: ) -> WorkflowExecution:
"""更新执行状态 """更新执行状态
@@ -407,10 +407,10 @@ class WorkflowService:
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
async def run( async def run(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
payload: DraftRunRequest, payload: DraftRunRequest,
config: WorkflowConfig config: WorkflowConfig
): ):
"""运行工作流 """运行工作流
@@ -527,10 +527,10 @@ class WorkflowService:
) )
async def run_stream( async def run_stream(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
payload: DraftRunRequest, payload: DraftRunRequest,
config: WorkflowConfig config: WorkflowConfig
): ):
"""运行工作流(流式) """运行工作流(流式)
@@ -600,11 +600,11 @@ class WorkflowService:
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件) # 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件)
async for event in self._run_workflow_stream( async for event in self._run_workflow_stream(
workflow_config=workflow_config_dict, workflow_config=workflow_config_dict,
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id="", workspace_id="",
user_id=payload.user_id user_id=payload.user_id
): ):
# 直接转发 executor 的事件(已经是正确的格式) # 直接转发 executor 的事件(已经是正确的格式)
yield event yield event
@@ -626,12 +626,12 @@ class WorkflowService:
} }
async def run_workflow( async def run_workflow(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
input_data: dict[str, Any], input_data: dict[str, Any],
triggered_by: uuid.UUID, triggered_by: uuid.UUID,
conversation_id: uuid.UUID | None = None, conversation_id: uuid.UUID | None = None,
stream: bool = False stream: bool = False
) -> AsyncGenerator | dict: ) -> AsyncGenerator | dict:
"""运行工作流 """运行工作流
@@ -778,12 +778,12 @@ class WorkflowService:
return clean_value(event) return clean_value(event)
async def _run_workflow_stream( async def _run_workflow_stream(
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str): user_id: str):
"""运行工作流(流式,内部方法) """运行工作流(流式,内部方法)
Args: Args:
@@ -800,11 +800,11 @@ class WorkflowService:
try: try:
async for event in execute_workflow_stream( async for event in execute_workflow_stream(
workflow_config=workflow_config, workflow_config=workflow_config,
input_data=input_data, input_data=input_data,
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id user_id=user_id
): ):
# 直接转发事件executor 已经返回正确格式) # 直接转发事件executor 已经返回正确格式)
yield event yield event
@@ -828,7 +828,7 @@ class WorkflowService:
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================
def get_workflow_service( def get_workflow_service(
db: Annotated[Session, Depends(get_db)] db: Annotated[Session, Depends(get_db)]
) -> WorkflowService: ) -> WorkflowService:
"""获取工作流服务(依赖注入)""" """获取工作流服务(依赖注入)"""
return WorkflowService(db) return WorkflowService(db)