feat(workflow): officially support workflow session variables
This commit is contained in:
@@ -60,14 +60,14 @@ def list_apps(
|
|||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
service = app_service.AppService(db)
|
service = app_service.AppService(db)
|
||||||
|
|
||||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||||
return success(data=items)
|
return success(data=items)
|
||||||
|
|
||||||
# 正常分页查询
|
# 正常分页查询
|
||||||
items_orm, total = app_service.list_apps(
|
items_orm, total = app_service.list_apps(
|
||||||
db,
|
db,
|
||||||
|
|||||||
@@ -3,13 +3,11 @@
|
|||||||
|
|
||||||
基于 LangGraph 的工作流执行引擎。
|
基于 LangGraph 的工作流执行引擎。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# import uuid
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.graph_builder import GraphBuilder
|
from app.core.workflow.graph_builder import GraphBuilder
|
||||||
@@ -55,6 +53,12 @@ class WorkflowExecutor:
|
|||||||
self.edges = workflow_config.get("edges", [])
|
self.edges = workflow_config.get("edges", [])
|
||||||
self.execution_config = workflow_config.get("execution_config", {})
|
self.execution_config = workflow_config.get("execution_config", {})
|
||||||
|
|
||||||
|
self.checkpoint_config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": uuid.uuid4(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||||
"""准备初始状态(注入系统变量和会话变量)
|
"""准备初始状态(注入系统变量和会话变量)
|
||||||
|
|
||||||
@@ -95,7 +99,7 @@ class WorkflowExecutor:
|
|||||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||||
conversation_vars[var_name] = []
|
conversation_vars[var_name] = []
|
||||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||||
|
conversation_vars = conversation_vars | input_data.get("conv", {})
|
||||||
# 构建分层的变量结构
|
# 构建分层的变量结构
|
||||||
variables = {
|
variables = {
|
||||||
"sys": {
|
"sys": {
|
||||||
@@ -110,7 +114,7 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"messages": [HumanMessage(content=user_message)],
|
"messages": [('user', user_message)],
|
||||||
"variables": variables,
|
"variables": variables,
|
||||||
"node_outputs": {},
|
"node_outputs": {},
|
||||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||||
@@ -196,6 +200,28 @@ class WorkflowExecutor:
|
|||||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||||
return prefixes, adjacent_and_referenced
|
return prefixes, adjacent_and_referenced
|
||||||
|
|
||||||
|
def _build_final_output(self, result, elapsed_time):
|
||||||
|
node_outputs = result.get("node_outputs", {})
|
||||||
|
final_output = self._extract_final_output(node_outputs)
|
||||||
|
token_usage = self._aggregate_token_usage(node_outputs)
|
||||||
|
conversation_id = None
|
||||||
|
for node_id, node_output in node_outputs.items():
|
||||||
|
if node_output.get("node_type") == "start":
|
||||||
|
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||||
|
break
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"output": final_output,
|
||||||
|
"node_outputs": node_outputs,
|
||||||
|
"messages": result.get("messages", []),
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"token_usage": token_usage,
|
||||||
|
"error": result.get("error"),
|
||||||
|
"variables": result.get("variables", {}),
|
||||||
|
}
|
||||||
|
|
||||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||||
"""构建 LangGraph
|
"""构建 LangGraph
|
||||||
|
|
||||||
@@ -236,40 +262,16 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
# 3. 执行工作流
|
# 3. 执行工作流
|
||||||
try:
|
try:
|
||||||
result = await graph.ainvoke(initial_state)
|
|
||||||
|
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||||
|
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
# 提取节点输出(现在包含 start 和 end 节点)
|
|
||||||
node_outputs = result.get("node_outputs", {})
|
|
||||||
|
|
||||||
# 提取最终输出(从最后一个非 start/end 节点)
|
|
||||||
final_output = self._extract_final_output(node_outputs)
|
|
||||||
|
|
||||||
# 聚合 token 使用情况
|
|
||||||
token_usage = self._aggregate_token_usage(node_outputs)
|
|
||||||
|
|
||||||
# 提取 conversation_id(从 start 节点输出)
|
|
||||||
conversation_id = None
|
|
||||||
for node_id, node_output in node_outputs.items():
|
|
||||||
if node_output.get("node_type") == "start":
|
|
||||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
return {
|
return self._build_final_output(result, elapsed_time)
|
||||||
"status": "completed",
|
|
||||||
"output": final_output,
|
|
||||||
"node_outputs": node_outputs,
|
|
||||||
"messages": result.get("messages", []),
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"elapsed_time": elapsed_time,
|
|
||||||
"token_usage": token_usage,
|
|
||||||
"error": result.get("error")
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 计算耗时(即使失败也记录)
|
# 计算耗时(即使失败也记录)
|
||||||
@@ -331,11 +333,11 @@ class WorkflowExecutor:
|
|||||||
# 3. Execute workflow
|
# 3. Execute workflow
|
||||||
try:
|
try:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
final_state = None
|
|
||||||
|
|
||||||
async for event in graph.astream(
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||||
|
config=self.checkpoint_config
|
||||||
):
|
):
|
||||||
# event should be a tuple: (mode, data)
|
# event should be a tuple: (mode, data)
|
||||||
# But let's handle both cases
|
# But let's handle both cases
|
||||||
@@ -411,12 +413,11 @@ class WorkflowExecutor:
|
|||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# Handle state updates - store final state
|
# Handle state updates - store final state
|
||||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||||
final_state = data
|
|
||||||
|
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Workflow execution completed (streaming), "
|
f"Workflow execution completed (streaming), "
|
||||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
|
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s"
|
||||||
@@ -425,12 +426,7 @@ class WorkflowExecutor:
|
|||||||
# 发送 workflow_end 事件
|
# 发送 workflow_end 事件
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": {
|
"data": self._build_final_output(result, elapsed_time)
|
||||||
"execution_id": self.execution_id,
|
|
||||||
"status": "completed",
|
|
||||||
"elapsed_time": elapsed_time,
|
|
||||||
"timestamp": end_time.isoformat()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
@@ -249,4 +250,5 @@ class GraphBuilder:
|
|||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
self.add_edges() # 添加边必须在添加节点之后
|
self.add_edges() # 添加边必须在添加节点之后
|
||||||
return self.graph.compile()
|
checkpointer = InMemorySaver()
|
||||||
|
return self.graph.compile(checkpointer=checkpointer)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class WorkflowState(TypedDict):
|
|||||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||||
"""
|
"""
|
||||||
# List of messages (append mode)
|
# List of messages (append mode)
|
||||||
messages: Annotated[list[AnyMessage], add]
|
messages: Annotated[list[tuple[str, str]], add]
|
||||||
|
|
||||||
# Set of loop node IDs, used for assigning values in loop nodes
|
# Set of loop node IDs, used for assigning values in loop nodes
|
||||||
cycle_nodes: list
|
cycle_nodes: list
|
||||||
@@ -203,6 +203,7 @@ class BaseNode(ABC):
|
|||||||
# 返回包装后的输出和运行时变量
|
# 返回包装后的输出和运行时变量
|
||||||
return {
|
return {
|
||||||
**wrapped_output,
|
**wrapped_output,
|
||||||
|
"variables": state["variables"],
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
self.node_id: runtime_var
|
self.node_id: runtime_var
|
||||||
},
|
},
|
||||||
@@ -355,6 +356,7 @@ class BaseNode(ABC):
|
|||||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||||
state_update = {
|
state_update = {
|
||||||
**final_output,
|
**final_output,
|
||||||
|
"variables": state["variables"],
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
self.node_id: runtime_var
|
self.node_id: runtime_var
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user