[add] workflow support stream mode
This commit is contained in:
@@ -421,8 +421,8 @@ async def draft_run(
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
|
||||
|
||||
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -574,7 +574,7 @@ async def draft_run(
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
"开始多智能体流式试运行",
|
||||
"开始工作流流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -583,16 +583,13 @@ async def draft_run(
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
"""多智能体流式事件生成器"""
|
||||
multiservice = MultiAgentService(db)
|
||||
"""工作流事件生成器"""
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
payload=payload,
|
||||
config=config
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -617,7 +614,7 @@ async def draft_run(
|
||||
)
|
||||
|
||||
result = await workflow_service.run(app_id, payload,config)
|
||||
|
||||
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
|
||||
@@ -11,26 +11,24 @@ from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
TOOL_MANAGEMENT_AVAILABLE = True
|
||||
from app.db import get_db
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""工作流执行器
|
||||
|
||||
|
||||
负责将工作流配置转换为 LangGraph 并执行。
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
@@ -39,7 +37,7 @@ class WorkflowExecutor:
|
||||
user_id: str
|
||||
):
|
||||
"""初始化执行器
|
||||
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
execution_id: 执行 ID
|
||||
@@ -53,25 +51,25 @@ class WorkflowExecutor:
|
||||
self.nodes = workflow_config.get("nodes", [])
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""准备初始状态(注入系统变量和会话变量)
|
||||
|
||||
|
||||
变量命名空间:
|
||||
- sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等)
|
||||
- conv.xxx - 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx - 节点输出(执行时动态生成)
|
||||
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
|
||||
Returns:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_vars = input_data.get("conversation_vars") or {}
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
|
||||
# 构建分层的变量结构
|
||||
variables = {
|
||||
"sys": {
|
||||
@@ -84,7 +82,7 @@ class WorkflowExecutor:
|
||||
},
|
||||
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"messages": [HumanMessage(content=user_message)],
|
||||
"variables": variables,
|
||||
@@ -96,34 +94,34 @@ class WorkflowExecutor:
|
||||
"error": None,
|
||||
"error_node": None
|
||||
}
|
||||
|
||||
|
||||
|
||||
def build_graph(self) -> StateGraph:
|
||||
|
||||
|
||||
def build_graph(self) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
|
||||
Returns:
|
||||
编译后的状态图
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
|
||||
|
||||
# 1. 创建状态图
|
||||
workflow = StateGraph(WorkflowState)
|
||||
|
||||
|
||||
# 2. 添加所有节点(包括 start 和 end)
|
||||
start_node_id = None
|
||||
end_node_ids = []
|
||||
|
||||
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
if node_type == "start":
|
||||
start_node_id = node_id
|
||||
elif node_type == "end":
|
||||
end_node_ids.append(node_id)
|
||||
|
||||
|
||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
if node_instance:
|
||||
@@ -133,40 +131,40 @@ class WorkflowExecutor:
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
|
||||
|
||||
# 3. 添加边
|
||||
# 从 START 连接到 start 节点
|
||||
if start_node_id:
|
||||
workflow.add_edge(START, start_node_id)
|
||||
logger.debug(f"添加边: START -> {start_node_id}")
|
||||
|
||||
|
||||
for edge in self.edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
condition = edge.get("condition")
|
||||
|
||||
|
||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||
if source == start_node_id:
|
||||
# 但要连接 start 到下一个节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
|
||||
# 处理到 end 节点的边
|
||||
if target in end_node_ids:
|
||||
# 连接到 end 节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
|
||||
# 跳过错误边(在节点内部处理)
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
@@ -183,74 +181,74 @@ class WorkflowExecutor:
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
|
||||
|
||||
workflow.add_conditional_edges(source, router)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
else:
|
||||
# 普通边
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
|
||||
|
||||
# 从 end 节点连接到 END
|
||||
for end_node_id in end_node_ids:
|
||||
workflow.add_edge(end_node_id, END)
|
||||
logger.debug(f"添加边: {end_node_id} -> END")
|
||||
|
||||
|
||||
# 4. 编译图
|
||||
graph = workflow.compile()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(非流式)
|
||||
|
||||
|
||||
Args:
|
||||
input_data: 输入数据,包含 message 和 variables
|
||||
|
||||
|
||||
Returns:
|
||||
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
|
||||
"""
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state)
|
||||
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
# 提取节点输出(现在包含 start 和 end 节点)
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
|
||||
|
||||
# 提取最终输出(从最后一个非 start/end 节点)
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
|
||||
|
||||
# 聚合 token 使用情况
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
|
||||
|
||||
# 提取 conversation_id(从 start 节点输出)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
@@ -261,12 +259,12 @@ class WorkflowExecutor:
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error")
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
@@ -276,86 +274,94 @@ class WorkflowExecutor:
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
}
|
||||
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
|
||||
手动执行节点以支持细粒度的流式输出:
|
||||
- workflow_start: 工作流开始
|
||||
- node_start: 节点开始执行
|
||||
- node_chunk: LLM 节点的流式输出片段(逐 token)
|
||||
- node_complete: 节点执行完成
|
||||
- workflow_complete: 工作流完成
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||
|
||||
#
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 流式执行工作流
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
# 使用 astream 获取节点级别的更新
|
||||
async for event in graph.astream(initial_state, stream_mode="updates"):
|
||||
for node_name, state_update in event.items():
|
||||
yield {
|
||||
"type": "node_complete",
|
||||
"node": node_name,
|
||||
"data": state_update,
|
||||
"execution_id": self.execution_id
|
||||
}
|
||||
|
||||
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 发送完成事件
|
||||
yield {
|
||||
"type": "workflow_complete",
|
||||
"execution_id": self.execution_id
|
||||
}
|
||||
|
||||
async for chunk in graph.astream(
|
||||
initial_state,
|
||||
# subgraphs=True,
|
||||
stream_mode="updates",
|
||||
):
|
||||
# print(chunk)
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
# 计算耗时(即使失败也记录)
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
yield {
|
||||
"type": "workflow_error",
|
||||
"execution_id": self.execution_id,
|
||||
"error": str(e)
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"output": None,
|
||||
"node_outputs": {},
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
|
||||
优先级:
|
||||
1. 最后一个执行的非 start/end 节点的 output
|
||||
2. 如果没有节点输出,返回 None
|
||||
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
|
||||
Returns:
|
||||
最终输出字符串或 None
|
||||
"""
|
||||
if not node_outputs:
|
||||
return None
|
||||
|
||||
|
||||
# 获取最后一个节点的输出
|
||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||
|
||||
|
||||
if last_node_output and isinstance(last_node_output, dict):
|
||||
return last_node_output.get("output")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
|
||||
Returns:
|
||||
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
|
||||
如果没有 token 使用信息,返回 None
|
||||
@@ -364,7 +370,7 @@ class WorkflowExecutor:
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
@@ -373,16 +379,16 @@ class WorkflowExecutor:
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
@@ -392,14 +398,14 @@ async def execute_workflow(
|
||||
user_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(便捷函数)
|
||||
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
@@ -420,14 +426,14 @@ async def execute_workflow_stream(
|
||||
user_id: str
|
||||
):
|
||||
"""执行工作流(流式,便捷函数)
|
||||
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
@@ -445,25 +451,25 @@ async def execute_workflow_stream(
|
||||
|
||||
def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
"""获取工作流可用的工具列表
|
||||
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
可用工具列表
|
||||
"""
|
||||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
logger.warning("工具管理系统不可用")
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import Session
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
# 创建工具注册表
|
||||
registry = ToolRegistry(db)
|
||||
|
||||
|
||||
# 注册内置工具类
|
||||
from app.core.tools.builtin import (
|
||||
DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||||
@@ -473,12 +479,12 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
registry.register_tool_class(BaiduSearchTool)
|
||||
registry.register_tool_class(MinerUTool)
|
||||
registry.register_tool_class(TextInTool)
|
||||
|
||||
|
||||
# 获取活跃的工具
|
||||
import uuid
|
||||
tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
||||
active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||||
|
||||
|
||||
# 转换为Langchain工具
|
||||
langchain_tools = []
|
||||
for tool_info in active_tools:
|
||||
@@ -489,10 +495,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
langchain_tools.append(langchain_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||||
|
||||
|
||||
logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
||||
return langchain_tools
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流工具失败: {e}")
|
||||
return []
|
||||
@@ -500,10 +506,10 @@ def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
|
||||
class ToolWorkflowNode:
|
||||
"""工具工作流节点 - 在工作流中执行工具"""
|
||||
|
||||
|
||||
def __init__(self, node_config: dict, workflow_config: dict):
|
||||
"""初始化工具节点
|
||||
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
@@ -512,25 +518,25 @@ class ToolWorkflowNode:
|
||||
self.workflow_config = workflow_config
|
||||
self.tool_id = node_config.get("tool_id")
|
||||
self.tool_parameters = node_config.get("parameters", {})
|
||||
|
||||
|
||||
async def run(self, state: WorkflowState) -> WorkflowState:
|
||||
"""执行工具节点"""
|
||||
if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
logger.error("工具管理系统不可用")
|
||||
state["error"] = "工具管理系统不可用"
|
||||
return state
|
||||
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import Session
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
# 创建工具执行器
|
||||
registry = ToolRegistry(db)
|
||||
executor = ToolExecutor(db, registry)
|
||||
|
||||
|
||||
# 准备参数(支持变量替换)
|
||||
parameters = self._prepare_parameters(state)
|
||||
|
||||
|
||||
# 执行工具
|
||||
result = await executor.execute_tool(
|
||||
tool_id=self.tool_id,
|
||||
@@ -538,7 +544,7 @@ class ToolWorkflowNode:
|
||||
user_id=uuid.UUID(state["user_id"]),
|
||||
workspace_id=uuid.UUID(state["workspace_id"])
|
||||
)
|
||||
|
||||
|
||||
# 更新状态
|
||||
node_id = self.node_config.get("id")
|
||||
if result.success:
|
||||
@@ -549,7 +555,7 @@ class ToolWorkflowNode:
|
||||
"execution_time": result.execution_time,
|
||||
"token_usage": result.token_usage
|
||||
}
|
||||
|
||||
|
||||
# 更新运行时变量
|
||||
if isinstance(result.data, dict):
|
||||
for key, value in result.data.items():
|
||||
@@ -565,29 +571,29 @@ class ToolWorkflowNode:
|
||||
"error": result.error,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
|
||||
|
||||
return state
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具节点执行失败: {e}")
|
||||
state["error"] = str(e)
|
||||
state["error_node"] = self.node_config.get("id")
|
||||
return state
|
||||
|
||||
|
||||
def _prepare_parameters(self, state: WorkflowState) -> dict:
|
||||
"""准备工具参数(支持变量替换)"""
|
||||
parameters = {}
|
||||
|
||||
|
||||
for key, value in self.tool_parameters.items():
|
||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# 变量替换
|
||||
var_path = value[2:-1]
|
||||
|
||||
|
||||
# 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
||||
if "." in var_path:
|
||||
parts = var_path.split(".")
|
||||
current = state.get("variables", {})
|
||||
|
||||
|
||||
for part in parts:
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
@@ -596,7 +602,7 @@ class ToolWorkflowNode:
|
||||
runtime_key = ".".join(parts)
|
||||
current = state.get("runtime_vars", {}).get(runtime_key, value)
|
||||
break
|
||||
|
||||
|
||||
parameters[key] = current
|
||||
else:
|
||||
# 简单变量
|
||||
@@ -604,7 +610,7 @@ class ToolWorkflowNode:
|
||||
parameters[key] = variables.get(var_path, value)
|
||||
else:
|
||||
parameters[key] = value
|
||||
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +50,11 @@ class VariableDefinition(BaseModel):
|
||||
description="变量描述"
|
||||
)
|
||||
|
||||
max_length: int = Field(
|
||||
default=200,
|
||||
description="只对字符串类型生效"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
|
||||
@@ -5,7 +5,6 @@ End 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
|
||||
@@ -10,10 +10,8 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models import ModelConfig
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.db import get_db_context
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
工作流服务层
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
@@ -438,7 +438,7 @@ class WorkflowService:
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
|
||||
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
if payload.user_id:
|
||||
@@ -446,7 +446,7 @@ class WorkflowService:
|
||||
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = None
|
||||
if payload.conversation_id:
|
||||
@@ -454,7 +454,7 @@ class WorkflowService:
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
workflow_config_id=config.id,
|
||||
@@ -530,6 +530,109 @@ class WorkflowService:
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig
|
||||
):
|
||||
"""运行工作流(流式)
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||
config: 存储类型(可选)
|
||||
|
||||
Yields:
|
||||
SSE 格式的流式事件
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或执行失败时抛出
|
||||
"""
|
||||
# 1. 获取工作流配置
|
||||
if not config:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id}
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
if payload.user_id:
|
||||
try:
|
||||
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = None
|
||||
if payload.conversation_id:
|
||||
try:
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by_uuid,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
|
||||
# 5. 流式执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n"
|
||||
|
||||
# 调用流式执行
|
||||
async for event in self._run_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id="",
|
||||
user_id=payload.user_id
|
||||
):
|
||||
# 清理事件数据,移除不可序列化的对象
|
||||
cleaned_event = self._clean_event_for_json(event)
|
||||
# 转换为 SSE 格式
|
||||
yield f"data: {json.dumps(cleaned_event)}\n\n"
|
||||
|
||||
# 发送完成事件
|
||||
yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
# 发送错误事件
|
||||
yield f"data: {json.dumps({'type': 'error', 'execution_id': execution.execution_id, 'error': str(e)})}\n\n"
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -651,14 +754,44 @@ class WorkflowService:
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
"""清理事件数据,移除不可序列化的对象
|
||||
|
||||
Args:
|
||||
event: 原始事件数据
|
||||
|
||||
Returns:
|
||||
可序列化的事件数据
|
||||
"""
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
def clean_value(value):
|
||||
"""递归清理值"""
|
||||
if isinstance(value, BaseMessage):
|
||||
# 将 Message 对象转换为字典
|
||||
return {
|
||||
"type": value.__class__.__name__,
|
||||
"content": value.content,
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: clean_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [clean_value(item) for item in value]
|
||||
elif isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
else:
|
||||
# 其他不可序列化的对象转换为字符串
|
||||
return str(value)
|
||||
|
||||
return clean_value(event)
|
||||
|
||||
async def _run_workflow_stream(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
user_id: str):
|
||||
"""运行工作流(流式,内部方法)
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user