feat(workflow): support single-node execution and MCP Streamable HTTP protocol
- Add `run_single_node` method in workflow service for isolated node execution - Refactor MCP client to support Streamable HTTP protocol (2025-03-26) with session ID handling, SSE/JSON response parsing, and proper initialized notification - Update iteration node to conditionally initialize stream writer based on stream flag - Improve cycle graph node invocation with checkpoint config passing
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
import io
|
||||
import json
|
||||
from typing import Optional, Annotated
|
||||
|
||||
import yaml
|
||||
@@ -1068,6 +1069,62 @@ async def draft_run_compare(
|
||||
return success(data=app_schema.DraftRunCompareResponse(**result))
|
||||
|
||||
|
||||
@router.post("/{app_id}/workflow/nodes/{node_id}/run", summary="单节点试运行")
|
||||
@cur_workspace_access_guard()
|
||||
async def run_single_workflow_node(
|
||||
app_id: uuid.UUID,
|
||||
node_id: str,
|
||||
payload: app_schema.NodeRunRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None,
|
||||
):
|
||||
"""单独执行工作流中的某个节点
|
||||
|
||||
inputs 支持以下 key 格式:
|
||||
- 节点变量: "node_id.var_name"
|
||||
- 系统变量: "sys.message"、"sys.files"
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config = workflow_service.check_config(app_id)
|
||||
|
||||
raw_inputs = payload.inputs or {}
|
||||
input_data = {
|
||||
"message": raw_inputs.pop("sys.message", ""),
|
||||
"files": raw_inputs.pop("sys.files", []),
|
||||
"user_id": raw_inputs.pop("sys.user_id", str(current_user.id)),
|
||||
"inputs": raw_inputs,
|
||||
"conversation_id": "",
|
||||
"conv_messages": [],
|
||||
}
|
||||
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in workflow_service.run_single_node_stream(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
input_data=input_data,
|
||||
):
|
||||
yield f"event: {event['event']}\ndata: {json.dumps(event['data'], ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
|
||||
)
|
||||
|
||||
result = await workflow_service.run_single_node(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
input_data=input_data,
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def get_workflow_config(
|
||||
|
||||
@@ -87,11 +87,11 @@ class SimpleMCPClient:
|
||||
headers = self._build_headers()
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
|
||||
|
||||
|
||||
if self.is_sse:
|
||||
await self._initialize_sse_session()
|
||||
elif "modelscope.net" in self.server_url:
|
||||
await self._initialize_modelscope_session()
|
||||
else:
|
||||
await self._initialize_streamable_session()
|
||||
|
||||
async def _initialize_sse_session(self):
|
||||
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
|
||||
@@ -208,41 +208,41 @@ class SimpleMCPClient:
|
||||
if not (200 <= response.status < 300):
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
"""初始化 ModelScope MCP 会话"""
|
||||
async def _initialize_streamable_session(self):
|
||||
"""初始化 Streamable HTTP MCP 会话(MCP 2025-03-26 规范)"""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if not (200 <= response.status < 300):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
|
||||
init_response = await response.json()
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
|
||||
# 提取 session id(Streamable HTTP 规范要求后续请求携带)
|
||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||
if session_id:
|
||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
initialized_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}
|
||||
|
||||
async with self._session.post(self.server_url, json=initialized_notification):
|
||||
pass
|
||||
|
||||
|
||||
init_response = await self._parse_streamable_response(response)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
self._server_capabilities = init_response.get("result", {}).get("capabilities", {})
|
||||
|
||||
# 发送 initialized 通知
|
||||
notification = {"jsonrpc": "2.0", "method": "notifications/initialized"}
|
||||
async with self._session.post(self.server_url, json=notification):
|
||||
pass
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
@@ -310,6 +310,21 @@ class SimpleMCPClient:
|
||||
"method": "notifications/initialized"
|
||||
}))
|
||||
|
||||
async def _parse_streamable_response(self, response) -> Dict[str, Any]:
|
||||
"""解析 Streamable HTTP 响应(支持 JSON 和 SSE 两种格式)"""
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
# 服务端返回 SSE 流,读取第一条 data 消息
|
||||
async for line in response.content:
|
||||
line = line.decode("utf-8").strip()
|
||||
if line.startswith("data:"):
|
||||
data = line[5:].strip()
|
||||
if data and data != "[DONE]":
|
||||
return json.loads(data)
|
||||
raise MCPConnectionError("SSE 流中未收到有效响应")
|
||||
else:
|
||||
return await response.json()
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request = {
|
||||
@@ -326,7 +341,7 @@ class SimpleMCPClient:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
response_data = await self._parse_streamable_response(response)
|
||||
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
|
||||
@@ -351,7 +366,7 @@ class SimpleMCPClient:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
response_data = await self._parse_streamable_response(response)
|
||||
|
||||
if "error" in response_data:
|
||||
error = response_data["error"]
|
||||
|
||||
@@ -70,7 +70,7 @@ class IterationRuntime:
|
||||
self.variable_pool = variable_pool
|
||||
self.cycle_nodes = cycle_nodes
|
||||
self.cycle_edges = cycle_edges
|
||||
self.event_write = get_stream_writer()
|
||||
self.event_write = get_stream_writer() if self.stream else (lambda x: None)
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
@@ -196,7 +196,7 @@ class IterationRuntime:
|
||||
})
|
||||
result = graph.get_state(config=checkpoint).values
|
||||
else:
|
||||
result = await graph.ainvoke(init_state)
|
||||
result = await graph.ainvoke(init_state, config=checkpoint)
|
||||
|
||||
output = child_pool.get_value(self.output_value)
|
||||
stopped = result["looping"] == 2
|
||||
|
||||
@@ -57,7 +57,7 @@ class LoopRuntime:
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.event_write = get_stream_writer()
|
||||
self.event_write = get_stream_writer() if self.stream else (lambda x: None)
|
||||
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
@@ -223,7 +223,7 @@ class LoopRuntime:
|
||||
})
|
||||
return self.graph.get_state(config=self.checkpoint).values
|
||||
else:
|
||||
return await self.graph.ainvoke(loopstate)
|
||||
return await self.graph.ainvoke(loopstate, config=self.checkpoint)
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
|
||||
@@ -385,6 +385,7 @@ class HttpRequestNode(BaseNode):
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
response = HttpResponse(resp)
|
||||
# Build raw request summary for process_data
|
||||
await resp.request.aread()
|
||||
raw_request = (
|
||||
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n"
|
||||
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items())
|
||||
|
||||
@@ -703,6 +703,24 @@ class ModelCompareItem(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class NodeRunRequest(BaseModel):
|
||||
"""单节点试运行请求"""
|
||||
# 扁平格式,支持:
|
||||
# 节点变量: {"node_id.var_name": value}
|
||||
# 系统变量: {"sys.message": "hello", "sys.files": [...]}
|
||||
inputs: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="节点输入变量,格式: {'node_id.var_name': value} 或 {'sys.message': 'hello'}",
|
||||
examples=[{
|
||||
"sys.message": "帮我写一首诗",
|
||||
"sys.user_id": "user-123",
|
||||
"sys.files": [],
|
||||
"llm_node_abc.output": "上游输出内容",
|
||||
}]
|
||||
)
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
|
||||
|
||||
class DraftRunCompareRequest(BaseModel):
|
||||
"""多模型对比试运行请求"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
工作流服务层
|
||||
"""
|
||||
import datetime
|
||||
import time
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Annotated, Optional
|
||||
@@ -17,7 +18,6 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db
|
||||
from sqlalchemy import select
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||
from app.repositories import knowledge_repository
|
||||
@@ -1070,6 +1070,189 @@ class WorkflowService:
|
||||
}
|
||||
}
|
||||
|
||||
async def _build_node_context(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
node_id: str,
|
||||
config: WorkflowConfig,
|
||||
workspace_id: uuid.UUID,
|
||||
input_data: dict[str, Any],
|
||||
):
|
||||
"""构建单节点执行所需的上下文(node_config, node, state, variable_pool)"""
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
if not config:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(code=BizCode.CONFIG_MISSING, message="工作流配置不存在")
|
||||
|
||||
node_config = next((n for n in config.nodes if n.get("id") == node_id), None)
|
||||
if not node_config:
|
||||
raise BusinessException(code=BizCode.NOT_FOUND, message=f"节点不存在: node_id={node_id}")
|
||||
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables or [],
|
||||
"execution_config": config.execution_config or {},
|
||||
"features": config.features or {},
|
||||
}
|
||||
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
execution_id = f"node_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=input_data.get("user_id", ""),
|
||||
conversation_id=input_data.get("conversation_id", ""),
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
)
|
||||
|
||||
# sys.files 转换为 FileObject 格式
|
||||
raw_files = input_data.get("files") or []
|
||||
if raw_files:
|
||||
from app.schemas.app_schema import FileInput
|
||||
file_inputs = [
|
||||
FileInput(**f) if isinstance(f, dict) else f
|
||||
for f in raw_files
|
||||
]
|
||||
input_data["files"] = await self._handle_file_input(file_inputs)
|
||||
|
||||
variable_pool = VariablePool()
|
||||
await VariablePoolInitializer(workflow_config_dict).initialize(variable_pool, input_data, execution_context)
|
||||
|
||||
# 注入节点输入变量,支持扁平格式 {"node_id.var": value}
|
||||
for key, value in (input_data.get("inputs") or {}).items():
|
||||
if "." in key:
|
||||
ref_node_id, var_name = key.split(".", 1)
|
||||
var_type = VariableType.type_map(value)
|
||||
await variable_pool.new(ref_node_id, var_name, value, var_type, mut=False)
|
||||
|
||||
state = WorkflowState(
|
||||
messages=input_data.get("conv_messages", []),
|
||||
node_outputs={},
|
||||
execution_id=execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=input_data.get("user_id", ""),
|
||||
error=None,
|
||||
error_node=None,
|
||||
cycle_nodes=[],
|
||||
looping=0,
|
||||
activate={node_id: True},
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
)
|
||||
|
||||
node = NodeFactory.create_node(node_config, workflow_config_dict, [])
|
||||
return node_config, node, state, variable_pool
|
||||
|
||||
async def run_single_node(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
node_id: str,
|
||||
config: WorkflowConfig,
|
||||
workspace_id: uuid.UUID,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""单节点执行(非流式)"""
|
||||
input_data = input_data or {}
|
||||
node_config, node, state, variable_pool = await self._build_node_context(
|
||||
app_id, node_id, config, workspace_id, input_data
|
||||
)
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = await node.execute(state, variable_pool)
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
return {
|
||||
"status": "completed",
|
||||
"node_id": node_id,
|
||||
"node_type": node_config.get("type"),
|
||||
"inputs": node._extract_input(state, variable_pool),
|
||||
"outputs": node._extract_output(result),
|
||||
"token_usage": node._extract_token_usage(result),
|
||||
"elapsed_time": elapsed,
|
||||
"error": None,
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
logger.error(f"单节点执行失败: node_id={node_id}, error={e}", exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"node_id": node_id,
|
||||
"node_type": node_config.get("type"),
|
||||
"inputs": node._extract_input(state, variable_pool),
|
||||
"outputs": None,
|
||||
"token_usage": None,
|
||||
"elapsed_time": elapsed,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def run_single_node_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
node_id: str,
|
||||
config: WorkflowConfig,
|
||||
workspace_id: uuid.UUID,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
):
|
||||
"""单节点执行(流式)
|
||||
|
||||
Yields:
|
||||
node_start -> node_chunk(LLM 等流式节点)-> node_end / node_error
|
||||
"""
|
||||
input_data = input_data or {}
|
||||
node_config, node, state, variable_pool = await self._build_node_context(
|
||||
app_id, node_id, config, workspace_id, input_data
|
||||
)
|
||||
node_type = node_config.get("type")
|
||||
start_time = time.time()
|
||||
|
||||
yield {"event": "node_start", "data": {"node_id": node_id, "node_type": node_type}}
|
||||
|
||||
final_result = None
|
||||
try:
|
||||
async for item in node.execute_stream(state, variable_pool):
|
||||
if item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
else:
|
||||
chunk = item.get("chunk", "")
|
||||
if chunk:
|
||||
yield {"event": "node_chunk", "data": {"node_id": node_id, "chunk": chunk}}
|
||||
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"node_type": node_type,
|
||||
"status": "succeeded",
|
||||
"inputs": node._extract_input(state, variable_pool),
|
||||
"outputs": node._extract_output(final_result),
|
||||
"token_usage": node._extract_token_usage(final_result),
|
||||
"elapsed_time": elapsed,
|
||||
"error": None,
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
logger.error(f"单节点流式执行失败: node_id={node_id}, error={e}", exc_info=True)
|
||||
yield {
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"node_type": node_type,
|
||||
"inputs": node._extract_input(state, variable_pool),
|
||||
"elapsed_time": elapsed,
|
||||
"error": str(e),
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_start_node_variables(config: dict) -> list:
|
||||
nodes = config.get("nodes", [])
|
||||
|
||||
Reference in New Issue
Block a user