[add] workflow support stream mode

This commit is contained in:
Mark
2025-12-18 19:46:36 +08:00
committed by 谢俊男
parent 9e48f2143e
commit 3aff6baccb
6 changed files with 282 additions and 144 deletions

View File

@@ -421,8 +421,8 @@ async def draft_run(
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
async def event_generator(): async def event_generator():
async for event in draft_service.run_stream( async for event in draft_service.run_stream(
agent_config=agent_cfg, agent_config=agent_cfg,
model_config=model_config, model_config=model_config,
@@ -574,7 +574,7 @@ async def draft_run(
# 3. 流式返回 # 3. 流式返回
if payload.stream: if payload.stream:
logger.debug( logger.debug(
"开始多智能体流式试运行", "开始工作流流式试运行",
extra={ extra={
"app_id": str(app_id), "app_id": str(app_id),
"message_length": len(payload.message), "message_length": len(payload.message),
@@ -583,16 +583,13 @@ async def draft_run(
) )
async def event_generator(): async def event_generator():
"""多智能体流式事件生成器""" """工作流事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法 # 调用多智能体服务的流式方法
async for event in multiservice.run_stream( async for event in workflow_service.run_stream(
app_id=app_id, app_id=app_id,
request=multi_agent_request, payload=payload,
storage_type=storage_type, config=config
user_rag_memory_id=user_rag_memory_id
): ):
yield event yield event
@@ -617,7 +614,7 @@ async def draft_run(
) )
result = await workflow_service.run(app_id, payload,config) result = await workflow_service.run(app_id, payload,config)
logger.debug( logger.debug(
"工作流试运行返回结果", "工作流试运行返回结果",
extra={ extra={

View File

@@ -11,26 +11,24 @@ from typing import Any
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.expression_evaluator import evaluate_condition
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
from app.core.tools.registry import ToolRegistry from app.core.tools.registry import ToolRegistry
from app.core.tools.executor import ToolExecutor from app.core.tools.executor import ToolExecutor
from app.core.tools.langchain_adapter import LangchainAdapter from app.core.tools.langchain_adapter import LangchainAdapter
TOOL_MANAGEMENT_AVAILABLE = True TOOL_MANAGEMENT_AVAILABLE = True
from app.db import get_db
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowExecutor: class WorkflowExecutor:
"""工作流执行器 """工作流执行器
负责将工作流配置转换为 LangGraph 并执行。 负责将工作流配置转换为 LangGraph 并执行。
""" """
def __init__( def __init__(
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
@@ -39,7 +37,7 @@ class WorkflowExecutor:
user_id: str user_id: str
): ):
"""初始化执行器 """初始化执行器
Args: Args:
workflow_config: 工作流配置 workflow_config: 工作流配置
execution_id: 执行 ID execution_id: 执行 ID
@@ -53,25 +51,25 @@ class WorkflowExecutor:
self.nodes = workflow_config.get("nodes", []) self.nodes = workflow_config.get("nodes", [])
self.edges = workflow_config.get("edges", []) self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {}) self.execution_config = workflow_config.get("execution_config", {})
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量) """准备初始状态(注入系统变量和会话变量)
变量命名空间: 变量命名空间:
- sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等) - sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等)
- conv.xxx - 会话变量(跨多轮对话保持) - conv.xxx - 会话变量(跨多轮对话保持)
- node_id.xxx - 节点输出(执行时动态生成) - node_id.xxx - 节点输出(执行时动态生成)
Args: Args:
input_data: 输入数据 input_data: 输入数据
Returns: Returns:
初始化的工作流状态 初始化的工作流状态
""" """
user_message = input_data.get("message") or "" user_message = input_data.get("message") or ""
conversation_vars = input_data.get("conversation_vars") or {} conversation_vars = input_data.get("conversation_vars") or {}
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构 # 构建分层的变量结构
variables = { variables = {
"sys": { "sys": {
@@ -84,7 +82,7 @@ class WorkflowExecutor:
}, },
"conv": conversation_vars # 会话级变量(跨多轮对话保持) "conv": conversation_vars # 会话级变量(跨多轮对话保持)
} }
return { return {
"messages": [HumanMessage(content=user_message)], "messages": [HumanMessage(content=user_message)],
"variables": variables, "variables": variables,
@@ -96,34 +94,34 @@ class WorkflowExecutor:
"error": None, "error": None,
"error_node": None "error_node": None
} }
def build_graph(self) -> StateGraph:
def build_graph(self) -> CompiledStateGraph:
"""构建 LangGraph """构建 LangGraph
Returns: Returns:
编译后的状态图 编译后的状态图
""" """
logger.info(f"开始构建工作流图: execution_id={self.execution_id}") logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 1. 创建状态图 # 1. 创建状态图
workflow = StateGraph(WorkflowState) workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end # 2. 添加所有节点(包括 start 和 end
start_node_id = None start_node_id = None
end_node_ids = [] end_node_ids = []
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
node_id = node.get("id") node_id = node.get("id")
# 记录 start 和 end 节点 ID # 记录 start 和 end 节点 ID
if node_type == "start": if node_type == "start":
start_node_id = node_id start_node_id = node_id
elif node_type == "end": elif node_type == "end":
end_node_ids.append(node_id) end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建) # 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance: if node_instance:
@@ -133,40 +131,40 @@ class WorkflowExecutor:
async def node_func(state: WorkflowState): async def node_func(state: WorkflowState):
return await inst.run(state) return await inst.run(state)
return node_func return node_func
workflow.add_node(node_id, make_node_func(node_instance)) workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})") logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边 # 3. 添加边
# 从 START 连接到 start 节点 # 从 START 连接到 start 节点
if start_node_id: if start_node_id:
workflow.add_edge(START, start_node_id) workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}") logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges: for edge in self.edges:
source = edge.get("source") source = edge.get("source")
target = edge.get("target") target = edge.get("target")
edge_type = edge.get("type") edge_type = edge.get("type")
condition = edge.get("condition") condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start # 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id: if source == start_node_id:
# 但要连接 start 到下一个节点 # 但要连接 start 到下一个节点
workflow.add_edge(source, target) workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}") logger.debug(f"添加边: {source} -> {target}")
continue continue
# 处理到 end 节点的边 # 处理到 end 节点的边
if target in end_node_ids: if target in end_node_ids:
# 连接到 end 节点 # 连接到 end 节点
workflow.add_edge(source, target) workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}") logger.debug(f"添加边: {source} -> {target}")
continue continue
# 跳过错误边(在节点内部处理) # 跳过错误边(在节点内部处理)
if edge_type == "error": if edge_type == "error":
continue continue
if condition: if condition:
# 条件边 # 条件边
def router(state: WorkflowState, cond=condition, tgt=target): def router(state: WorkflowState, cond=condition, tgt=target):
@@ -183,74 +181,74 @@ class WorkflowExecutor:
): ):
return tgt return tgt
return END # 条件不满足,结束 return END # 条件不满足,结束
workflow.add_conditional_edges(source, router) workflow.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else: else:
# 普通边 # 普通边
workflow.add_edge(source, target) workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}") logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END # 从 end 节点连接到 END
for end_node_id in end_node_ids: for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END) workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END") logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图 # 4. 编译图
graph = workflow.compile() graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}") logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph return graph
async def execute( async def execute(
self, self,
input_data: dict[str, Any] input_data: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""执行工作流(非流式) """执行工作流(非流式)
Args: Args:
input_data: 输入数据,包含 message 和 variables input_data: 输入数据,包含 message 和 variables
Returns: Returns:
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage 执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
""" """
logger.info(f"开始执行工作流: execution_id={self.execution_id}") logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间 # 记录开始时间
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
# 1. 构建图 # 1. 构建图
graph = self.build_graph() graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量) # 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data) initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流 # 3. 执行工作流
try: try:
result = await graph.ainvoke(initial_state) result = await graph.ainvoke(initial_state)
# 计算耗时 # 计算耗时
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds() elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点) # 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {}) node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点) # 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs) final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况 # 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs) token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出) # 提取 conversation_id从 start 节点输出)
conversation_id = None conversation_id = None
for node_id, node_output in node_outputs.items(): for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start": if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id") conversation_id = node_output.get("output", {}).get("conversation_id")
break break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return { return {
"status": "completed", "status": "completed",
"output": final_output, "output": final_output,
@@ -261,12 +259,12 @@ class WorkflowExecutor:
"token_usage": token_usage, "token_usage": token_usage,
"error": result.get("error") "error": result.get("error")
} }
except Exception as e: except Exception as e:
# 计算耗时(即使失败也记录) # 计算耗时(即使失败也记录)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds() elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
return { return {
"status": "failed", "status": "failed",
@@ -276,86 +274,94 @@ class WorkflowExecutor:
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": None "token_usage": None
} }
async def execute_stream( async def execute_stream(
self, self,
input_data: dict[str, Any] input_data: dict[str, Any]
): ):
"""执行工作流(流式) """执行工作流(流式)
手动执行节点以支持细粒度的流式输出:
- workflow_start: 工作流开始
- node_start: 节点开始执行
- node_chunk: LLM 节点的流式输出片段(逐 token
- node_complete: 节点执行完成
- workflow_complete: 工作流完成
Args: Args:
input_data: 输入数据 input_data: 输入数据
Yields: Yields:
流式事件 流式事件
""" """
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") #
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图 # 1. 构建图
graph = self.build_graph() graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量) # 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data) initial_state = self._prepare_initial_state(input_data)
# 3. 流式执行工作流 # 3. 执行工作流
try: try:
# 使用 astream 获取节点级别的更新 async for chunk in graph.astream(
async for event in graph.astream(initial_state, stream_mode="updates"): initial_state,
for node_name, state_update in event.items(): # subgraphs=True,
yield { stream_mode="updates",
"type": "node_complete", ):
"node": node_name, # print(chunk)
"data": state_update, yield chunk
"execution_id": self.execution_id
}
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
# 发送完成事件
yield {
"type": "workflow_complete",
"execution_id": self.execution_id
}
except Exception as e: except Exception as e:
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True) # 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
yield { yield {
"type": "workflow_error", "status": "failed",
"execution_id": self.execution_id, "error": str(e),
"error": str(e) "output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
} }
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None: def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出 """从节点输出中提取最终输出
优先级: 优先级:
1. 最后一个执行的非 start/end 节点的 output 1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None 2. 如果没有节点输出,返回 None
Args: Args:
node_outputs: 所有节点的输出 node_outputs: 所有节点的输出
Returns: Returns:
最终输出字符串或 None 最终输出字符串或 None
""" """
if not node_outputs: if not node_outputs:
return None return None
# 获取最后一个节点的输出 # 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict): if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output") return last_node_output.get("output")
return None return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None: def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况 """聚合所有节点的 token 使用情况
Args: Args:
node_outputs: 所有节点的输出 node_outputs: 所有节点的输出
Returns: Returns:
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z} 聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
如果没有 token 使用信息,返回 None 如果没有 token 使用信息,返回 None
@@ -364,7 +370,7 @@ class WorkflowExecutor:
total_completion_tokens = 0 total_completion_tokens = 0
total_tokens = 0 total_tokens = 0
has_token_info = False has_token_info = False
for node_output in node_outputs.values(): for node_output in node_outputs.values():
if isinstance(node_output, dict): if isinstance(node_output, dict):
token_usage = node_output.get("token_usage") token_usage = node_output.get("token_usage")
@@ -373,16 +379,16 @@ class WorkflowExecutor:
total_prompt_tokens += token_usage.get("prompt_tokens", 0) total_prompt_tokens += token_usage.get("prompt_tokens", 0)
total_completion_tokens += token_usage.get("completion_tokens", 0) total_completion_tokens += token_usage.get("completion_tokens", 0)
total_tokens += token_usage.get("total_tokens", 0) total_tokens += token_usage.get("total_tokens", 0)
if not has_token_info: if not has_token_info:
return None return None
return { return {
"prompt_tokens": total_prompt_tokens, "prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens, "completion_tokens": total_completion_tokens,
"total_tokens": total_tokens "total_tokens": total_tokens
} }
async def execute_workflow( async def execute_workflow(
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
@@ -392,14 +398,14 @@ async def execute_workflow(
user_id: str user_id: str
) -> dict[str, Any]: ) -> dict[str, Any]:
"""执行工作流(便捷函数) """执行工作流(便捷函数)
Args: Args:
workflow_config: 工作流配置 workflow_config: 工作流配置
input_data: 输入数据 input_data: 输入数据
execution_id: 执行 ID execution_id: 执行 ID
workspace_id: 工作空间 ID workspace_id: 工作空间 ID
user_id: 用户 ID user_id: 用户 ID
Returns: Returns:
执行结果 执行结果
""" """
@@ -420,14 +426,14 @@ async def execute_workflow_stream(
user_id: str user_id: str
): ):
"""执行工作流(流式,便捷函数) """执行工作流(流式,便捷函数)
Args: Args:
workflow_config: 工作流配置 workflow_config: 工作流配置
input_data: 输入数据 input_data: 输入数据
execution_id: 执行 ID execution_id: 执行 ID
workspace_id: 工作空间 ID workspace_id: 工作空间 ID
user_id: 用户 ID user_id: 用户 ID
Yields: Yields:
流式事件 流式事件
""" """
@@ -445,25 +451,25 @@ async def execute_workflow_stream(
def get_workflow_tools(workspace_id: str, user_id: str) -> list: def get_workflow_tools(workspace_id: str, user_id: str) -> list:
"""获取工作流可用的工具列表 """获取工作流可用的工具列表
Args: Args:
workspace_id: 工作空间ID workspace_id: 工作空间ID
user_id: 用户ID user_id: 用户ID
Returns: Returns:
可用工具列表 可用工具列表
""" """
if not TOOL_MANAGEMENT_AVAILABLE: if not TOOL_MANAGEMENT_AVAILABLE:
logger.warning("工具管理系统不可用") logger.warning("工具管理系统不可用")
return [] return []
try: try:
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
db = next(get_db()) db = next(get_db())
# 创建工具注册表 # 创建工具注册表
registry = ToolRegistry(db) registry = ToolRegistry(db)
# 注册内置工具类 # 注册内置工具类
from app.core.tools.builtin import ( from app.core.tools.builtin import (
DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
@@ -473,12 +479,12 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
registry.register_tool_class(BaiduSearchTool) registry.register_tool_class(BaiduSearchTool)
registry.register_tool_class(MinerUTool) registry.register_tool_class(MinerUTool)
registry.register_tool_class(TextInTool) registry.register_tool_class(TextInTool)
# 获取活跃的工具 # 获取活跃的工具
import uuid import uuid
tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id)) tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
active_tools = [tool for tool in tools if tool.status.value == "active"] active_tools = [tool for tool in tools if tool.status.value == "active"]
# 转换为Langchain工具 # 转换为Langchain工具
langchain_tools = [] langchain_tools = []
for tool_info in active_tools: for tool_info in active_tools:
@@ -489,10 +495,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
langchain_tools.append(langchain_tool) langchain_tools.append(langchain_tool)
except Exception as e: except Exception as e:
logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具") logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
return langchain_tools return langchain_tools
except Exception as e: except Exception as e:
logger.error(f"获取工作流工具失败: {e}") logger.error(f"获取工作流工具失败: {e}")
return [] return []
@@ -500,10 +506,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
class ToolWorkflowNode: class ToolWorkflowNode:
"""工具工作流节点 - 在工作流中执行工具""" """工具工作流节点 - 在工作流中执行工具"""
def __init__(self, node_config: dict, workflow_config: dict): def __init__(self, node_config: dict, workflow_config: dict):
"""初始化工具节点 """初始化工具节点
Args: Args:
node_config: 节点配置 node_config: 节点配置
workflow_config: 工作流配置 workflow_config: 工作流配置
@@ -512,25 +518,25 @@ class ToolWorkflowNode:
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.tool_id = node_config.get("tool_id") self.tool_id = node_config.get("tool_id")
self.tool_parameters = node_config.get("parameters", {}) self.tool_parameters = node_config.get("parameters", {})
async def run(self, state: WorkflowState) -> WorkflowState: async def run(self, state: WorkflowState) -> WorkflowState:
"""执行工具节点""" """执行工具节点"""
if not TOOL_MANAGEMENT_AVAILABLE: if not TOOL_MANAGEMENT_AVAILABLE:
logger.error("工具管理系统不可用") logger.error("工具管理系统不可用")
state["error"] = "工具管理系统不可用" state["error"] = "工具管理系统不可用"
return state return state
try: try:
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
db = next(get_db()) db = next(get_db())
# 创建工具执行器 # 创建工具执行器
registry = ToolRegistry(db) registry = ToolRegistry(db)
executor = ToolExecutor(db, registry) executor = ToolExecutor(db, registry)
# 准备参数(支持变量替换) # 准备参数(支持变量替换)
parameters = self._prepare_parameters(state) parameters = self._prepare_parameters(state)
# 执行工具 # 执行工具
result = await executor.execute_tool( result = await executor.execute_tool(
tool_id=self.tool_id, tool_id=self.tool_id,
@@ -538,7 +544,7 @@ class ToolWorkflowNode:
user_id=uuid.UUID(state["user_id"]), user_id=uuid.UUID(state["user_id"]),
workspace_id=uuid.UUID(state["workspace_id"]) workspace_id=uuid.UUID(state["workspace_id"])
) )
# 更新状态 # 更新状态
node_id = self.node_config.get("id") node_id = self.node_config.get("id")
if result.success: if result.success:
@@ -549,7 +555,7 @@ class ToolWorkflowNode:
"execution_time": result.execution_time, "execution_time": result.execution_time,
"token_usage": result.token_usage "token_usage": result.token_usage
} }
# 更新运行时变量 # 更新运行时变量
if isinstance(result.data, dict): if isinstance(result.data, dict):
for key, value in result.data.items(): for key, value in result.data.items():
@@ -565,29 +571,29 @@ class ToolWorkflowNode:
"error": result.error, "error": result.error,
"execution_time": result.execution_time "execution_time": result.execution_time
} }
return state return state
except Exception as e: except Exception as e:
logger.error(f"工具节点执行失败: {e}") logger.error(f"工具节点执行失败: {e}")
state["error"] = str(e) state["error"] = str(e)
state["error_node"] = self.node_config.get("id") state["error_node"] = self.node_config.get("id")
return state return state
def _prepare_parameters(self, state: WorkflowState) -> dict: def _prepare_parameters(self, state: WorkflowState) -> dict:
"""准备工具参数(支持变量替换)""" """准备工具参数(支持变量替换)"""
parameters = {} parameters = {}
for key, value in self.tool_parameters.items(): for key, value in self.tool_parameters.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"): if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 变量替换 # 变量替换
var_path = value[2:-1] var_path = value[2:-1]
# 支持多层级变量访问,如 ${sys.message} 或 ${node1.result} # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
if "." in var_path: if "." in var_path:
parts = var_path.split(".") parts = var_path.split(".")
current = state.get("variables", {}) current = state.get("variables", {})
for part in parts: for part in parts:
if isinstance(current, dict) and part in current: if isinstance(current, dict) and part in current:
current = current[part] current = current[part]
@@ -596,7 +602,7 @@ class ToolWorkflowNode:
runtime_key = ".".join(parts) runtime_key = ".".join(parts)
current = state.get("runtime_vars", {}).get(runtime_key, value) current = state.get("runtime_vars", {}).get(runtime_key, value)
break break
parameters[key] = current parameters[key] = current
else: else:
# 简单变量 # 简单变量
@@ -604,7 +610,7 @@ class ToolWorkflowNode:
parameters[key] = variables.get(var_path, value) parameters[key] = variables.get(var_path, value)
else: else:
parameters[key] = value parameters[key] = value
return parameters return parameters

View File

@@ -50,6 +50,11 @@ class VariableDefinition(BaseModel):
description="变量描述" description="变量描述"
) )
max_length: int = Field(
default=200,
description="只对字符串类型生效"
)
class Config: class Config:
json_schema_extra = { json_schema_extra = {
"examples": [ "examples": [

View File

@@ -5,7 +5,6 @@ End 节点实现
""" """
import logging import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState

View File

@@ -10,10 +10,8 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import ModelConfig from app.db import get_db_context
from app.db import get_db, get_db_context from app.services.model_service import ModelConfigService
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode from app.core.error_codes import BizCode

View File

@@ -1,7 +1,7 @@
""" """
工作流服务层 工作流服务层
""" """
import json
import logging import logging
import uuid import uuid
import datetime import datetime
@@ -438,7 +438,7 @@ class WorkflowService:
message=f"工作流配置不存在: app_id={app_id}" message=f"工作流配置不存在: app_id={app_id}"
) )
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID # 转换 user_id 为 UUID
triggered_by_uuid = None triggered_by_uuid = None
if payload.user_id: if payload.user_id:
@@ -446,7 +446,7 @@ class WorkflowService:
triggered_by_uuid = uuid.UUID(payload.user_id) triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError): except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}") logger.warning(f"无效的 user_id 格式: {payload.user_id}")
# 转换 conversation_id 为 UUID # 转换 conversation_id 为 UUID
conversation_id_uuid = None conversation_id_uuid = None
if payload.conversation_id: if payload.conversation_id:
@@ -454,7 +454,7 @@ class WorkflowService:
conversation_id_uuid = uuid.UUID(payload.conversation_id) conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError): except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}") logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
# 2. 创建执行记录 # 2. 创建执行记录
execution = self.create_execution( execution = self.create_execution(
workflow_config_id=config.id, workflow_config_id=config.id,
@@ -530,6 +530,109 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}" message=f"工作流执行失败: {str(e)}"
) )
async def run_stream(
self,
app_id: uuid.UUID,
payload: DraftRunRequest,
config: WorkflowConfig
):
"""运行工作流(流式)
Args:
app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选)
Yields:
SSE 格式的流式事件
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
if not config:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
conversation_id=conversation_id_uuid,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 流式执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
# 发送开始事件
yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n"
# 调用流式执行
async for event in self._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id="",
user_id=payload.user_id
):
# 清理事件数据,移除不可序列化的对象
cleaned_event = self._clean_event_for_json(event)
# 转换为 SSE 格式
yield f"data: {json.dumps(cleaned_event)}\n\n"
# 发送完成事件
yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n"
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
# 发送错误事件
yield f"data: {json.dumps({'type': 'error', 'execution_id': execution.execution_id, 'error': str(e)})}\n\n"
async def run_workflow( async def run_workflow(
self, self,
app_id: uuid.UUID, app_id: uuid.UUID,
@@ -651,14 +754,44 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}" message=f"工作流执行失败: {str(e)}"
) )
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
"""清理事件数据,移除不可序列化的对象
Args:
event: 原始事件数据
Returns:
可序列化的事件数据
"""
from langchain_core.messages import BaseMessage
def clean_value(value):
"""递归清理值"""
if isinstance(value, BaseMessage):
# 将 Message 对象转换为字典
return {
"type": value.__class__.__name__,
"content": value.content,
}
elif isinstance(value, dict):
return {k: clean_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [clean_value(item) for item in value]
elif isinstance(value, (str, int, float, bool, type(None))):
return value
else:
# 其他不可序列化的对象转换为字符串
return str(value)
return clean_value(event)
async def _run_workflow_stream( async def _run_workflow_stream(
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str user_id: str):
):
"""运行工作流(流式,内部方法) """运行工作流(流式,内部方法)
Args: Args: