diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 2300f148..f55ea5b5 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -60,14 +60,14 @@ def list_apps( """ workspace_id = current_user.current_workspace_id service = app_service.AppService(db) - + # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: app_ids = [id.strip() for id in ids.split(',') if id.strip()] items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] return success(data=items) - + # 正常分页查询 items_orm, total = app_service.list_apps( db, diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index e3d634d8..67689935 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -3,13 +3,11 @@ 基于 LangGraph 的工作流执行引擎。 """ - -# import uuid import datetime import logging +import uuid from typing import Any -from langchain_core.messages import HumanMessage from langgraph.graph.state import CompiledStateGraph from app.core.workflow.graph_builder import GraphBuilder @@ -55,6 +53,12 @@ class WorkflowExecutor: self.edges = workflow_config.get("edges", []) self.execution_config = workflow_config.get("execution_config", {}) + self.checkpoint_config = { + "configurable": { + "thread_id": uuid.uuid4(), + } + } + def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: """准备初始状态(注入系统变量和会话变量) @@ -95,7 +99,7 @@ class WorkflowExecutor: case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING: conversation_vars[var_name] = [] input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 - + conversation_vars = conversation_vars | input_data.get("conv", {}) # 构建分层的变量结构 variables = { "sys": { @@ -110,7 +114,7 @@ class WorkflowExecutor: } return { - "messages": [HumanMessage(content=user_message)], + "messages": [('user', user_message)], "variables": variables, "node_outputs": {}, "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) @@ -196,6 +200,28 @@ class WorkflowExecutor: logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") return prefixes, adjacent_and_referenced + def _build_final_output(self, result, elapsed_time): + node_outputs = result.get("node_outputs", {}) + final_output = self._extract_final_output(node_outputs) + token_usage = self._aggregate_token_usage(node_outputs) + conversation_id = None + for node_id, node_output in node_outputs.items(): + if node_output.get("node_type") == "start": + conversation_id = node_output.get("output", {}).get("conversation_id") + break + + return { + "status": "completed", + "output": final_output, + "node_outputs": node_outputs, + "messages": result.get("messages", []), + "conversation_id": conversation_id, + "elapsed_time": elapsed_time, + "token_usage": token_usage, + "error": result.get("error"), + "variables": result.get("variables", {}), + } + def build_graph(self, stream=False) -> CompiledStateGraph: """构建 LangGraph @@ -236,40 +262,16 @@ class WorkflowExecutor: # 3. 执行工作流 try: - result = await graph.ainvoke(initial_state) + + result = await graph.ainvoke(initial_state, config=self.checkpoint_config) # 计算耗时 end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - # 提取节点输出(现在包含 start 和 end 节点) - node_outputs = result.get("node_outputs", {}) - - # 提取最终输出(从最后一个非 start/end 节点) - final_output = self._extract_final_output(node_outputs) - - # 聚合 token 使用情况 - token_usage = self._aggregate_token_usage(node_outputs) - - # 提取 conversation_id(从 start 节点输出) - conversation_id = None - for node_id, node_output in node_outputs.items(): - if node_output.get("node_type") == "start": - conversation_id = node_output.get("output", {}).get("conversation_id") - break - logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") - return { - "status": "completed", - "output": final_output, - "node_outputs": node_outputs, - "messages": result.get("messages", []), - "conversation_id": conversation_id, - "elapsed_time": elapsed_time, - "token_usage": token_usage, - "error": result.get("error") - } + return self._build_final_output(result, elapsed_time) except Exception as e: # 计算耗时(即使失败也记录) @@ -331,11 +333,11 @@ class WorkflowExecutor: # 3. Execute workflow try: chunk_count = 0 - final_state = None async for event in graph.astream( initial_state, stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode + config=self.checkpoint_config ): # event should be a tuple: (mode, data) # But let's handle both cases @@ -411,12 +413,11 @@ class WorkflowExecutor: elif mode == "updates": # Handle state updates - store final state logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}") - final_state = data # 计算耗时 end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - + result = graph.get_state(self.checkpoint_config).values logger.info( f"Workflow execution completed (streaming), " f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s" @@ -425,12 +426,7 @@ class WorkflowExecutor: # 发送 workflow_end 事件 yield { "event": "workflow_end", - "data": { - "execution_id": self.execution_id, - "status": "completed", - "elapsed_time": elapsed_time, - "timestamp": end_time.isoformat() - } + "data": self._build_final_output(result, elapsed_time) } except Exception as e: diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index 69ed3b6a..b75b867e 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -4,6 +4,7 @@ from typing import Any from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.graph import START, END +from langgraph.checkpoint.memory import InMemorySaver from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState, NodeFactory @@ -249,4 +250,5 @@ class GraphBuilder: self.graph = StateGraph(WorkflowState) self.add_nodes() self.add_edges() # 添加边必须在添加节点之后 - return self.graph.compile() + checkpointer = InMemorySaver() + return self.graph.compile(checkpointer=checkpointer) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 727f7391..e3bf36c9 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -25,7 +25,7 @@ class WorkflowState(TypedDict): The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ # List of messages (append mode) - messages: Annotated[list[AnyMessage], add] + messages: Annotated[list[tuple[str, str]], add] # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list @@ -203,6 +203,7 @@ class BaseNode(ABC): # 返回包装后的输出和运行时变量 return { **wrapped_output, + "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var }, @@ -355,6 +356,7 @@ class BaseNode(ABC): # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, + "variables": state["variables"], "runtime_vars": { self.node_id: runtime_var },