feat(workflow): support single-node execution and MCP Streamable HTTP protocol

- Add `run_single_node` method in workflow service for isolated node execution
- Refactor MCP client to support Streamable HTTP protocol (2025-03-26) with session ID handling, SSE/JSON response parsing, and proper initialized notification
- Update iteration node to conditionally initialize stream writer based on stream flag
- Improve cycle graph node invocation with checkpoint config passing
This commit is contained in:
Timebomb2018
2026-05-07 17:18:21 +08:00
parent 595c3517e3
commit 8d3da2fd0e
7 changed files with 302 additions and 28 deletions

View File

@@ -2,6 +2,7 @@
工作流服务层
"""
import datetime
import time
import logging
import uuid
from typing import Any, Annotated, Optional
@@ -17,7 +18,6 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db
from sqlalchemy import select
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from app.repositories import knowledge_repository
@@ -1070,6 +1070,189 @@ class WorkflowService:
}
}
async def _build_node_context(
self,
app_id: uuid.UUID,
node_id: str,
config: WorkflowConfig,
workspace_id: uuid.UUID,
input_data: dict[str, Any],
):
"""构建单节点执行所需的上下文node_config, node, state, variable_pool"""
from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.nodes.node_factory import NodeFactory
from app.core.workflow.variable.base_variable import VariableType
if not config:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(code=BizCode.CONFIG_MISSING, message="工作流配置不存在")
node_config = next((n for n in config.nodes if n.get("id") == node_id), None)
if not node_config:
raise BusinessException(code=BizCode.NOT_FOUND, message=f"节点不存在: node_id={node_id}")
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables or [],
"execution_config": config.execution_config or {},
"features": config.features or {},
}
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
execution_id = f"node_{uuid.uuid4().hex[:16]}"
execution_context = ExecutionContext.create(
execution_id=execution_id,
workspace_id=str(workspace_id),
user_id=input_data.get("user_id", ""),
conversation_id=input_data.get("conversation_id", ""),
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
)
# sys.files 转换为 FileObject 格式
raw_files = input_data.get("files") or []
if raw_files:
from app.schemas.app_schema import FileInput
file_inputs = [
FileInput(**f) if isinstance(f, dict) else f
for f in raw_files
]
input_data["files"] = await self._handle_file_input(file_inputs)
variable_pool = VariablePool()
await VariablePoolInitializer(workflow_config_dict).initialize(variable_pool, input_data, execution_context)
# 注入节点输入变量,支持扁平格式 {"node_id.var": value}
for key, value in (input_data.get("inputs") or {}).items():
if "." in key:
ref_node_id, var_name = key.split(".", 1)
var_type = VariableType.type_map(value)
await variable_pool.new(ref_node_id, var_name, value, var_type, mut=False)
state = WorkflowState(
messages=input_data.get("conv_messages", []),
node_outputs={},
execution_id=execution_id,
workspace_id=str(workspace_id),
user_id=input_data.get("user_id", ""),
error=None,
error_node=None,
cycle_nodes=[],
looping=0,
activate={node_id: True},
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
)
node = NodeFactory.create_node(node_config, workflow_config_dict, [])
return node_config, node, state, variable_pool
async def run_single_node(
self,
app_id: uuid.UUID,
node_id: str,
config: WorkflowConfig,
workspace_id: uuid.UUID,
input_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""单节点执行(非流式)"""
input_data = input_data or {}
node_config, node, state, variable_pool = await self._build_node_context(
app_id, node_id, config, workspace_id, input_data
)
start_time = time.time()
try:
result = await node.execute(state, variable_pool)
elapsed = (time.time() - start_time) * 1000
return {
"status": "completed",
"node_id": node_id,
"node_type": node_config.get("type"),
"inputs": node._extract_input(state, variable_pool),
"outputs": node._extract_output(result),
"token_usage": node._extract_token_usage(result),
"elapsed_time": elapsed,
"error": None,
}
except Exception as e:
elapsed = (time.time() - start_time) * 1000
logger.error(f"单节点执行失败: node_id={node_id}, error={e}", exc_info=True)
return {
"status": "failed",
"node_id": node_id,
"node_type": node_config.get("type"),
"inputs": node._extract_input(state, variable_pool),
"outputs": None,
"token_usage": None,
"elapsed_time": elapsed,
"error": str(e),
}
async def run_single_node_stream(
self,
app_id: uuid.UUID,
node_id: str,
config: WorkflowConfig,
workspace_id: uuid.UUID,
input_data: dict[str, Any] | None = None,
):
"""单节点执行(流式)
Yields:
node_start -> node_chunkLLM 等流式节点)-> node_end / node_error
"""
input_data = input_data or {}
node_config, node, state, variable_pool = await self._build_node_context(
app_id, node_id, config, workspace_id, input_data
)
node_type = node_config.get("type")
start_time = time.time()
yield {"event": "node_start", "data": {"node_id": node_id, "node_type": node_type}}
final_result = None
try:
async for item in node.execute_stream(state, variable_pool):
if item.get("__final__"):
final_result = item["result"]
else:
chunk = item.get("chunk", "")
if chunk:
yield {"event": "node_chunk", "data": {"node_id": node_id, "chunk": chunk}}
elapsed = (time.time() - start_time) * 1000
yield {
"event": "node_end",
"data": {
"node_id": node_id,
"node_type": node_type,
"status": "succeeded",
"inputs": node._extract_input(state, variable_pool),
"outputs": node._extract_output(final_result),
"token_usage": node._extract_token_usage(final_result),
"elapsed_time": elapsed,
"error": None,
}
}
except Exception as e:
elapsed = (time.time() - start_time) * 1000
logger.error(f"单节点流式执行失败: node_id={node_id}, error={e}", exc_info=True)
yield {
"event": "node_error",
"data": {
"node_id": node_id,
"node_type": node_type,
"inputs": node._extract_input(state, variable_pool),
"elapsed_time": elapsed,
"error": str(e),
}
}
@staticmethod
def get_start_node_variables(config: dict) -> list:
nodes = config.get("nodes", [])