feat(workflow): support single-node execution and MCP Streamable HTTP protocol

- Add `run_single_node` method in workflow service for isolated node execution
- Refactor MCP client to support Streamable HTTP protocol (2025-03-26) with session ID handling, SSE/JSON response parsing, and proper initialized notification
- Update iteration node to conditionally initialize stream writer based on stream flag
- Improve cycle graph node invocation with checkpoint config passing
This commit is contained in:
Timebomb2018
2026-05-07 17:18:21 +08:00
parent 595c3517e3
commit 8d3da2fd0e
7 changed files with 302 additions and 28 deletions

View File

@@ -1,5 +1,6 @@
import uuid
import io
import json
from typing import Optional, Annotated
import yaml
@@ -1068,6 +1069,62 @@ async def draft_run_compare(
return success(data=app_schema.DraftRunCompareResponse(**result))
@router.post("/{app_id}/workflow/nodes/{node_id}/run", summary="单节点试运行")
@cur_workspace_access_guard()
async def run_single_workflow_node(
app_id: uuid.UUID,
node_id: str,
payload: app_schema.NodeRunRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None,
):
"""单独执行工作流中的某个节点
inputs 支持以下 key 格式:
- 节点变量: "node_id.var_name"
- 系统变量: "sys.message""sys.files"
"""
workspace_id = current_user.current_workspace_id
config = workflow_service.check_config(app_id)
raw_inputs = payload.inputs or {}
input_data = {
"message": raw_inputs.pop("sys.message", ""),
"files": raw_inputs.pop("sys.files", []),
"user_id": raw_inputs.pop("sys.user_id", str(current_user.id)),
"inputs": raw_inputs,
"conversation_id": "",
"conv_messages": [],
}
if payload.stream:
async def event_generator():
async for event in workflow_service.run_single_node_stream(
app_id=app_id,
node_id=node_id,
config=config,
workspace_id=workspace_id,
input_data=input_data,
):
yield f"event: {event['event']}\ndata: {json.dumps(event['data'], ensure_ascii=False)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
)
result = await workflow_service.run_single_node(
app_id=app_id,
node_id=node_id,
config=config,
workspace_id=workspace_id,
input_data=input_data,
)
return success(data=result)
@router.get("/{app_id}/workflow")
@cur_workspace_access_guard()
async def get_workflow_config(

View File

@@ -87,11 +87,11 @@ class SimpleMCPClient:
headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
if self.is_sse:
await self._initialize_sse_session()
elif "modelscope.net" in self.server_url:
await self._initialize_modelscope_session()
else:
await self._initialize_streamable_session()
async def _initialize_sse_session(self):
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
@@ -208,41 +208,41 @@ class SimpleMCPClient:
if not (200 <= response.status < 300):
logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话"""
async def _initialize_streamable_session(self):
"""初始化 Streamable HTTP MCP 会话MCP 2025-03-26 规范)"""
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {}},
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
}
}
try:
async with self._session.post(self.server_url, json=init_request) as response:
if not (200 <= response.status < 300):
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
init_response = await response.json()
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
# 提取 session idStreamable HTTP 规范要求后续请求携带)
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id})
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
async with self._session.post(self.server_url, json=initialized_notification):
pass
init_response = await self._parse_streamable_response(response)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
self._server_capabilities = init_response.get("result", {}).get("capabilities", {})
# 发送 initialized 通知
notification = {"jsonrpc": "2.0", "method": "notifications/initialized"}
async with self._session.post(self.server_url, json=notification):
pass
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
@@ -310,6 +310,21 @@ class SimpleMCPClient:
"method": "notifications/initialized"
}))
async def _parse_streamable_response(self, response) -> Dict[str, Any]:
"""解析 Streamable HTTP 响应(支持 JSON 和 SSE 两种格式)"""
content_type = response.headers.get("Content-Type", "")
if "text/event-stream" in content_type:
# 服务端返回 SSE 流,读取第一条 data 消息
async for line in response.content:
line = line.decode("utf-8").strip()
if line.startswith("data:"):
data = line[5:].strip()
if data and data != "[DONE]":
return json.loads(data)
raise MCPConnectionError("SSE 流中未收到有效响应")
else:
return await response.json()
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request = {
@@ -326,7 +341,7 @@ class SimpleMCPClient:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
response_data = await self._parse_streamable_response(response)
if "error" in response_data:
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
@@ -351,7 +366,7 @@ class SimpleMCPClient:
response_data = await self._send_sse_request(request)
else:
async with self._session.post(self.server_url, json=request) as response:
response_data = await response.json()
response_data = await self._parse_streamable_response(response)
if "error" in response_data:
error = response_data["error"]

View File

@@ -70,7 +70,7 @@ class IterationRuntime:
self.variable_pool = variable_pool
self.cycle_nodes = cycle_nodes
self.cycle_edges = cycle_edges
self.event_write = get_stream_writer()
self.event_write = get_stream_writer() if self.stream else (lambda x: None)
self.output_value = None
self.result: list = []
@@ -196,7 +196,7 @@ class IterationRuntime:
})
result = graph.get_state(config=checkpoint).values
else:
result = await graph.ainvoke(init_state)
result = await graph.ainvoke(init_state, config=checkpoint)
output = child_pool.get_value(self.output_value)
stopped = result["looping"] == 2

View File

@@ -57,7 +57,7 @@ class LoopRuntime:
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
self.event_write = get_stream_writer()
self.event_write = get_stream_writer() if self.stream else (lambda x: None)
self.checkpoint = RunnableConfig(
configurable={
@@ -223,7 +223,7 @@ class LoopRuntime:
})
return self.graph.get_state(config=self.checkpoint).values
else:
return await self.graph.ainvoke(loopstate)
return await self.graph.ainvoke(loopstate, config=self.checkpoint)
async def run(self):
"""

View File

@@ -385,6 +385,7 @@ class HttpRequestNode(BaseNode):
logger.info(f"Node {self.node_id}: HTTP request succeeded")
response = HttpResponse(resp)
# Build raw request summary for process_data
await resp.request.aread()
raw_request = (
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n"
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items())

View File

@@ -703,6 +703,24 @@ class ModelCompareItem(BaseModel):
)
class NodeRunRequest(BaseModel):
"""单节点试运行请求"""
# 扁平格式,支持:
# 节点变量: {"node_id.var_name": value}
# 系统变量: {"sys.message": "hello", "sys.files": [...]}
inputs: Dict[str, Any] = Field(
default_factory=dict,
description="节点输入变量,格式: {'node_id.var_name': value} 或 {'sys.message': 'hello'}",
examples=[{
"sys.message": "帮我写一首诗",
"sys.user_id": "user-123",
"sys.files": [],
"llm_node_abc.output": "上游输出内容",
}]
)
stream: bool = Field(default=False, description="是否流式返回")
class DraftRunCompareRequest(BaseModel):
"""多模型对比试运行请求"""
message: str = Field(..., description="用户消息")

View File

@@ -2,6 +2,7 @@
工作流服务层
"""
import datetime
import time
import logging
import uuid
from typing import Any, Annotated, Optional
@@ -17,7 +18,6 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db
from sqlalchemy import select
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from app.repositories import knowledge_repository
@@ -1070,6 +1070,189 @@ class WorkflowService:
}
}
async def _build_node_context(
self,
app_id: uuid.UUID,
node_id: str,
config: WorkflowConfig,
workspace_id: uuid.UUID,
input_data: dict[str, Any],
):
"""构建单节点执行所需的上下文node_config, node, state, variable_pool"""
from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.nodes.node_factory import NodeFactory
from app.core.workflow.variable.base_variable import VariableType
if not config:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(code=BizCode.CONFIG_MISSING, message="工作流配置不存在")
node_config = next((n for n in config.nodes if n.get("id") == node_id), None)
if not node_config:
raise BusinessException(code=BizCode.NOT_FOUND, message=f"节点不存在: node_id={node_id}")
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables or [],
"execution_config": config.execution_config or {},
"features": config.features or {},
}
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
execution_id = f"node_{uuid.uuid4().hex[:16]}"
execution_context = ExecutionContext.create(
execution_id=execution_id,
workspace_id=str(workspace_id),
user_id=input_data.get("user_id", ""),
conversation_id=input_data.get("conversation_id", ""),
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
)
# sys.files 转换为 FileObject 格式
raw_files = input_data.get("files") or []
if raw_files:
from app.schemas.app_schema import FileInput
file_inputs = [
FileInput(**f) if isinstance(f, dict) else f
for f in raw_files
]
input_data["files"] = await self._handle_file_input(file_inputs)
variable_pool = VariablePool()
await VariablePoolInitializer(workflow_config_dict).initialize(variable_pool, input_data, execution_context)
# 注入节点输入变量,支持扁平格式 {"node_id.var": value}
for key, value in (input_data.get("inputs") or {}).items():
if "." in key:
ref_node_id, var_name = key.split(".", 1)
var_type = VariableType.type_map(value)
await variable_pool.new(ref_node_id, var_name, value, var_type, mut=False)
state = WorkflowState(
messages=input_data.get("conv_messages", []),
node_outputs={},
execution_id=execution_id,
workspace_id=str(workspace_id),
user_id=input_data.get("user_id", ""),
error=None,
error_node=None,
cycle_nodes=[],
looping=0,
activate={node_id: True},
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
)
node = NodeFactory.create_node(node_config, workflow_config_dict, [])
return node_config, node, state, variable_pool
async def run_single_node(
self,
app_id: uuid.UUID,
node_id: str,
config: WorkflowConfig,
workspace_id: uuid.UUID,
input_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""单节点执行(非流式)"""
input_data = input_data or {}
node_config, node, state, variable_pool = await self._build_node_context(
app_id, node_id, config, workspace_id, input_data
)
start_time = time.time()
try:
result = await node.execute(state, variable_pool)
elapsed = (time.time() - start_time) * 1000
return {
"status": "completed",
"node_id": node_id,
"node_type": node_config.get("type"),
"inputs": node._extract_input(state, variable_pool),
"outputs": node._extract_output(result),
"token_usage": node._extract_token_usage(result),
"elapsed_time": elapsed,
"error": None,
}
except Exception as e:
elapsed = (time.time() - start_time) * 1000
logger.error(f"单节点执行失败: node_id={node_id}, error={e}", exc_info=True)
return {
"status": "failed",
"node_id": node_id,
"node_type": node_config.get("type"),
"inputs": node._extract_input(state, variable_pool),
"outputs": None,
"token_usage": None,
"elapsed_time": elapsed,
"error": str(e),
}
async def run_single_node_stream(
self,
app_id: uuid.UUID,
node_id: str,
config: WorkflowConfig,
workspace_id: uuid.UUID,
input_data: dict[str, Any] | None = None,
):
"""单节点执行(流式)
Yields:
node_start -> node_chunkLLM 等流式节点)-> node_end / node_error
"""
input_data = input_data or {}
node_config, node, state, variable_pool = await self._build_node_context(
app_id, node_id, config, workspace_id, input_data
)
node_type = node_config.get("type")
start_time = time.time()
yield {"event": "node_start", "data": {"node_id": node_id, "node_type": node_type}}
final_result = None
try:
async for item in node.execute_stream(state, variable_pool):
if item.get("__final__"):
final_result = item["result"]
else:
chunk = item.get("chunk", "")
if chunk:
yield {"event": "node_chunk", "data": {"node_id": node_id, "chunk": chunk}}
elapsed = (time.time() - start_time) * 1000
yield {
"event": "node_end",
"data": {
"node_id": node_id,
"node_type": node_type,
"status": "succeeded",
"inputs": node._extract_input(state, variable_pool),
"outputs": node._extract_output(final_result),
"token_usage": node._extract_token_usage(final_result),
"elapsed_time": elapsed,
"error": None,
}
}
except Exception as e:
elapsed = (time.time() - start_time) * 1000
logger.error(f"单节点流式执行失败: node_id={node_id}, error={e}", exc_info=True)
yield {
"event": "node_error",
"data": {
"node_id": node_id,
"node_type": node_type,
"inputs": node._extract_input(state, variable_pool),
"elapsed_time": elapsed,
"error": str(e),
}
}
@staticmethod
def get_start_node_variables(config: dict) -> list:
nodes = config.get("nodes", [])