[add] workflow support stream mode

This commit is contained in:
Mark
2025-12-18 19:46:36 +08:00
committed by 谢俊男
parent 9e48f2143e
commit 3aff6baccb
6 changed files with 282 additions and 144 deletions

View File

@@ -421,8 +421,8 @@ async def draft_run(
# 流式返回
if payload.stream:
async def event_generator():
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
@@ -574,7 +574,7 @@ async def draft_run(
# 3. 流式返回
if payload.stream:
logger.debug(
"开始多智能体流式试运行",
"开始工作流流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
@@ -583,16 +583,13 @@ async def draft_run(
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
"""工作流事件生成器"""
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
async for event in workflow_service.run_stream(
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
payload=payload,
config=config
):
yield event
@@ -617,7 +614,7 @@ async def draft_run(
)
result = await workflow_service.run(app_id, payload,config)
logger.debug(
"工作流试运行返回结果",
extra={

View File

@@ -11,26 +11,24 @@ from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
from app.core.tools.registry import ToolRegistry
from app.core.tools.executor import ToolExecutor
from app.core.tools.langchain_adapter import LangchainAdapter
TOOL_MANAGEMENT_AVAILABLE = True
from app.db import get_db
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""工作流执行器
负责将工作流配置转换为 LangGraph 并执行。
"""
def __init__(
self,
workflow_config: dict[str, Any],
@@ -39,7 +37,7 @@ class WorkflowExecutor:
user_id: str
):
"""初始化执行器
Args:
workflow_config: 工作流配置
execution_id: 执行 ID
@@ -53,25 +51,25 @@ class WorkflowExecutor:
self.nodes = workflow_config.get("nodes", [])
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
变量命名空间:
- sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等)
- conv.xxx - 会话变量(跨多轮对话保持)
- node_id.xxx - 节点输出(执行时动态生成)
Args:
input_data: 输入数据
Returns:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_vars = input_data.get("conversation_vars") or {}
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构
variables = {
"sys": {
@@ -84,7 +82,7 @@ class WorkflowExecutor:
},
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
}
return {
"messages": [HumanMessage(content=user_message)],
"variables": variables,
@@ -96,34 +94,34 @@ class WorkflowExecutor:
"error": None,
"error_node": None
}
def build_graph(self) -> StateGraph:
def build_graph(self) -> CompiledStateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
start_node_id = node_id
elif node_type == "end":
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance:
@@ -133,40 +131,40 @@ class WorkflowExecutor:
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
@@ -183,74 +181,74 @@ class WorkflowExecutor:
):
return tgt
return END # 条件不满足,结束
workflow.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
async def execute(
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
Args:
input_data: 输入数据,包含 message 和 variables
Returns:
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
result = await graph.ainvoke(initial_state)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return {
"status": "completed",
"output": final_output,
@@ -261,12 +259,12 @@ class WorkflowExecutor:
"token_usage": token_usage,
"error": result.get("error")
}
except Exception as e:
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
return {
"status": "failed",
@@ -276,86 +274,94 @@ class WorkflowExecutor:
"elapsed_time": elapsed_time,
"token_usage": None
}
async def execute_stream(
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
手动执行节点以支持细粒度的流式输出:
- workflow_start: 工作流开始
- node_start: 节点开始执行
- node_chunk: LLM 节点的流式输出片段(逐 token
- node_complete: 节点执行完成
- workflow_complete: 工作流完成
Args:
input_data: 输入数据
Yields:
流式事件
"""
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
#
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 流式执行工作流
# 3. 执行工作流
try:
# 使用 astream 获取节点级别的更新
async for event in graph.astream(initial_state, stream_mode="updates"):
for node_name, state_update in event.items():
yield {
"type": "node_complete",
"node": node_name,
"data": state_update,
"execution_id": self.execution_id
}
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
# 发送完成事件
yield {
"type": "workflow_complete",
"execution_id": self.execution_id
}
async for chunk in graph.astream(
initial_state,
# subgraphs=True,
stream_mode="updates",
):
# print(chunk)
yield chunk
except Exception as e:
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
yield {
"type": "workflow_error",
"execution_id": self.execution_id,
"error": str(e)
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
Args:
node_outputs: 所有节点的输出
Returns:
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
如果没有 token 使用信息,返回 None
@@ -364,7 +370,7 @@ class WorkflowExecutor:
total_completion_tokens = 0
total_tokens = 0
has_token_info = False
for node_output in node_outputs.values():
if isinstance(node_output, dict):
token_usage = node_output.get("token_usage")
@@ -373,16 +379,16 @@ class WorkflowExecutor:
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
total_completion_tokens += token_usage.get("completion_tokens", 0)
total_tokens += token_usage.get("total_tokens", 0)
if not has_token_info:
return None
return {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_tokens
}
async def execute_workflow(
workflow_config: dict[str, Any],
@@ -392,14 +398,14 @@ async def execute_workflow(
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Returns:
执行结果
"""
@@ -420,14 +426,14 @@ async def execute_workflow_stream(
user_id: str
):
"""执行工作流(流式,便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
@@ -445,25 +451,25 @@ async def execute_workflow_stream(
def get_workflow_tools(workspace_id: str, user_id: str) -> list:
"""获取工作流可用的工具列表
Args:
workspace_id: 工作空间ID
user_id: 用户ID
Returns:
可用工具列表
"""
if not TOOL_MANAGEMENT_AVAILABLE:
logger.warning("工具管理系统不可用")
return []
try:
from sqlalchemy.orm import Session
db = next(get_db())
# 创建工具注册表
registry = ToolRegistry(db)
# 注册内置工具类
from app.core.tools.builtin import (
DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
@@ -473,12 +479,12 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
registry.register_tool_class(BaiduSearchTool)
registry.register_tool_class(MinerUTool)
registry.register_tool_class(TextInTool)
# 获取活跃的工具
import uuid
tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
active_tools = [tool for tool in tools if tool.status.value == "active"]
# 转换为Langchain工具
langchain_tools = []
for tool_info in active_tools:
@@ -489,10 +495,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
langchain_tools.append(langchain_tool)
except Exception as e:
logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
return langchain_tools
except Exception as e:
logger.error(f"获取工作流工具失败: {e}")
return []
@@ -500,10 +506,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
class ToolWorkflowNode:
"""工具工作流节点 - 在工作流中执行工具"""
def __init__(self, node_config: dict, workflow_config: dict):
"""初始化工具节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
@@ -512,25 +518,25 @@ class ToolWorkflowNode:
self.workflow_config = workflow_config
self.tool_id = node_config.get("tool_id")
self.tool_parameters = node_config.get("parameters", {})
async def run(self, state: WorkflowState) -> WorkflowState:
"""执行工具节点"""
if not TOOL_MANAGEMENT_AVAILABLE:
logger.error("工具管理系统不可用")
state["error"] = "工具管理系统不可用"
return state
try:
from sqlalchemy.orm import Session
db = next(get_db())
# 创建工具执行器
registry = ToolRegistry(db)
executor = ToolExecutor(db, registry)
# 准备参数(支持变量替换)
parameters = self._prepare_parameters(state)
# 执行工具
result = await executor.execute_tool(
tool_id=self.tool_id,
@@ -538,7 +544,7 @@ class ToolWorkflowNode:
user_id=uuid.UUID(state["user_id"]),
workspace_id=uuid.UUID(state["workspace_id"])
)
# 更新状态
node_id = self.node_config.get("id")
if result.success:
@@ -549,7 +555,7 @@ class ToolWorkflowNode:
"execution_time": result.execution_time,
"token_usage": result.token_usage
}
# 更新运行时变量
if isinstance(result.data, dict):
for key, value in result.data.items():
@@ -565,29 +571,29 @@ class ToolWorkflowNode:
"error": result.error,
"execution_time": result.execution_time
}
return state
except Exception as e:
logger.error(f"工具节点执行失败: {e}")
state["error"] = str(e)
state["error_node"] = self.node_config.get("id")
return state
def _prepare_parameters(self, state: WorkflowState) -> dict:
"""准备工具参数(支持变量替换)"""
parameters = {}
for key, value in self.tool_parameters.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 变量替换
var_path = value[2:-1]
# 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
if "." in var_path:
parts = var_path.split(".")
current = state.get("variables", {})
for part in parts:
if isinstance(current, dict) and part in current:
current = current[part]
@@ -596,7 +602,7 @@ class ToolWorkflowNode:
runtime_key = ".".join(parts)
current = state.get("runtime_vars", {}).get(runtime_key, value)
break
parameters[key] = current
else:
# 简单变量
@@ -604,7 +610,7 @@ class ToolWorkflowNode:
parameters[key] = variables.get(var_path, value)
else:
parameters[key] = value
return parameters

View File

@@ -50,6 +50,11 @@ class VariableDefinition(BaseModel):
description="变量描述"
)
max_length: int = Field(
default=200,
description="只对字符串类型生效"
)
class Config:
json_schema_extra = {
"examples": [

View File

@@ -5,7 +5,6 @@ End 节点实现
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState

View File

@@ -10,10 +10,8 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import ModelConfig
from app.db import get_db, get_db_context
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.db import get_db_context
from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode

View File

@@ -1,7 +1,7 @@
"""
工作流服务层
"""
import json
import logging
import uuid
import datetime
@@ -438,7 +438,7 @@ class WorkflowService:
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
@@ -446,7 +446,7 @@ class WorkflowService:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
@@ -454,7 +454,7 @@ class WorkflowService:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
@@ -530,6 +530,109 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}"
)
async def run_stream(
self,
app_id: uuid.UUID,
payload: DraftRunRequest,
config: WorkflowConfig
):
"""运行工作流(流式)
Args:
app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选)
Yields:
SSE 格式的流式事件
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
if not config:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
conversation_id=conversation_id_uuid,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 流式执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
# 发送开始事件
yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n"
# 调用流式执行
async for event in self._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id="",
user_id=payload.user_id
):
# 清理事件数据,移除不可序列化的对象
cleaned_event = self._clean_event_for_json(event)
# 转换为 SSE 格式
yield f"data: {json.dumps(cleaned_event)}\n\n"
# 发送完成事件
yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n"
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
# 发送错误事件
yield f"data: {json.dumps({'type': 'error', 'execution_id': execution.execution_id, 'error': str(e)})}\n\n"
async def run_workflow(
self,
app_id: uuid.UUID,
@@ -651,14 +754,44 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}"
)
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
"""清理事件数据,移除不可序列化的对象
Args:
event: 原始事件数据
Returns:
可序列化的事件数据
"""
from langchain_core.messages import BaseMessage
def clean_value(value):
"""递归清理值"""
if isinstance(value, BaseMessage):
# 将 Message 对象转换为字典
return {
"type": value.__class__.__name__,
"content": value.content,
}
elif isinstance(value, dict):
return {k: clean_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [clean_value(item) for item in value]
elif isinstance(value, (str, int, float, bool, type(None))):
return value
else:
# 其他不可序列化的对象转换为字符串
return str(value)
return clean_value(event)
async def _run_workflow_stream(
self,
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
user_id: str):
"""运行工作流(流式,内部方法)
Args: