diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 41422bd4..ccb193af 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -1,5 +1,6 @@ import uuid import io +import json from typing import Optional, Annotated import yaml @@ -1068,6 +1069,62 @@ async def draft_run_compare( return success(data=app_schema.DraftRunCompareResponse(**result)) +@router.post("/{app_id}/workflow/nodes/{node_id}/run", summary="单节点试运行") +@cur_workspace_access_guard() +async def run_single_workflow_node( + app_id: uuid.UUID, + node_id: str, + payload: app_schema.NodeRunRequest, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None, +): + """单独执行工作流中的某个节点 + + inputs 支持以下 key 格式: + - 节点变量: "node_id.var_name" + - 系统变量: "sys.message"、"sys.files" + """ + workspace_id = current_user.current_workspace_id + config = workflow_service.check_config(app_id) + + raw_inputs = payload.inputs or {} + input_data = { + "message": raw_inputs.pop("sys.message", ""), + "files": raw_inputs.pop("sys.files", []), + "user_id": raw_inputs.pop("sys.user_id", str(current_user.id)), + "inputs": raw_inputs, + "conversation_id": "", + "conv_messages": [], + } + + if payload.stream: + async def event_generator(): + async for event in workflow_service.run_single_node_stream( + app_id=app_id, + node_id=node_id, + config=config, + workspace_id=workspace_id, + input_data=input_data, + ): + yield f"event: {event['event']}\ndata: {json.dumps(event['data'], ensure_ascii=False)}\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"} + ) + + result = await workflow_service.run_single_node( + app_id=app_id, + node_id=node_id, + config=config, + workspace_id=workspace_id, + input_data=input_data, + ) + return success(data=result) + + @router.get("/{app_id}/workflow") @cur_workspace_access_guard() async def get_workflow_config( diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 3539d33a..ac9d47c4 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -87,11 +87,11 @@ class SimpleMCPClient: headers = self._build_headers() timeout = aiohttp.ClientTimeout(total=self.timeout) self._session = aiohttp.ClientSession(headers=headers, timeout=timeout) - + if self.is_sse: await self._initialize_sse_session() - elif "modelscope.net" in self.server_url: - await self._initialize_modelscope_session() + else: + await self._initialize_streamable_session() async def _initialize_sse_session(self): """初始化 SSE MCP 会话 - 参考 Dify 实现""" @@ -208,41 +208,41 @@ class SimpleMCPClient: if not (200 <= response.status < 300): logger.warning(f"通知发送失败: {response.status}") - async def _initialize_modelscope_session(self): - """初始化 ModelScope MCP 会话""" + async def _initialize_streamable_session(self): + """初始化 Streamable HTTP MCP 会话(MCP 2025-03-26 规范)""" init_request = { "jsonrpc": "2.0", "id": self._get_request_id(), "method": "initialize", "params": { - "protocolVersion": "2024-11-05", + "protocolVersion": "2025-03-26", "capabilities": {"tools": {}}, "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} } } - + try: async with self._session.post(self.server_url, json=init_request) as response: if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") - - init_response = await response.json() - if "error" in init_response: - raise MCPConnectionError(f"初始化失败: {init_response['error']}") - + + # 提取 session id(Streamable HTTP 规范要求后续请求携带) session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") if session_id: self._session.headers.update({"Mcp-Session-Id": session_id}) - - initialized_notification = { - "jsonrpc": "2.0", - "method": "notifications/initialized" - } - - async with self._session.post(self.server_url, json=initialized_notification): - pass - + + init_response = await self._parse_streamable_response(response) + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") + + self._server_capabilities = init_response.get("result", {}).get("capabilities", {}) + + # 发送 initialized 通知 + notification = {"jsonrpc": "2.0", "method": "notifications/initialized"} + async with self._session.post(self.server_url, json=notification): + pass + except aiohttp.ClientError as e: raise MCPConnectionError(f"初始化连接失败: {e}") @@ -310,6 +310,21 @@ class SimpleMCPClient: "method": "notifications/initialized" })) + async def _parse_streamable_response(self, response) -> Dict[str, Any]: + """解析 Streamable HTTP 响应(支持 JSON 和 SSE 两种格式)""" + content_type = response.headers.get("Content-Type", "") + if "text/event-stream" in content_type: + # 服务端返回 SSE 流,读取第一条 data 消息 + async for line in response.content: + line = line.decode("utf-8").strip() + if line.startswith("data:"): + data = line[5:].strip() + if data and data != "[DONE]": + return json.loads(data) + raise MCPConnectionError("SSE 流中未收到有效响应") + else: + return await response.json() + async def list_tools(self) -> List[Dict[str, Any]]: """获取工具列表""" request = { @@ -326,7 +341,7 @@ class SimpleMCPClient: response_data = await self._send_sse_request(request) else: async with self._session.post(self.server_url, json=request) as response: - response_data = await response.json() + response_data = await self._parse_streamable_response(response) if "error" in response_data: raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}") @@ -351,7 +366,7 @@ class SimpleMCPClient: response_data = await self._send_sse_request(request) else: async with self._session.post(self.server_url, json=request) as response: - response_data = await response.json() + response_data = await self._parse_streamable_response(response) if "error" in response_data: error = response_data["error"] diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index 3ce22ab2..b40bdef2 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -70,7 +70,7 @@ class IterationRuntime: self.variable_pool = variable_pool self.cycle_nodes = cycle_nodes self.cycle_edges = cycle_edges - self.event_write = get_stream_writer() + self.event_write = get_stream_writer() if self.stream else (lambda x: None) self.output_value = None self.result: list = [] @@ -196,7 +196,7 @@ class IterationRuntime: }) result = graph.get_state(config=checkpoint).values else: - result = await graph.ainvoke(init_state) + result = await graph.ainvoke(init_state, config=checkpoint) output = child_pool.get_value(self.output_value) stopped = result["looping"] == 2 diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 93f1a1e4..87e3bd8d 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -57,7 +57,7 @@ class LoopRuntime: self.looping = True self.variable_pool = variable_pool self.child_variable_pool = child_variable_pool - self.event_write = get_stream_writer() + self.event_write = get_stream_writer() if self.stream else (lambda x: None) self.checkpoint = RunnableConfig( configurable={ @@ -223,7 +223,7 @@ class LoopRuntime: }) return self.graph.get_state(config=self.checkpoint).values else: - return await self.graph.ainvoke(loopstate) + return await self.graph.ainvoke(loopstate, config=self.checkpoint) async def run(self): """ diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 6b117368..f954bd79 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -385,6 +385,7 @@ class HttpRequestNode(BaseNode): logger.info(f"Node {self.node_id}: HTTP request succeeded") response = HttpResponse(resp) # Build raw request summary for process_data + await resp.request.aread() raw_request = ( f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n" + "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items()) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 7facf381..40d927d7 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -703,6 +703,24 @@ class ModelCompareItem(BaseModel): ) +class NodeRunRequest(BaseModel): + """单节点试运行请求""" + # 扁平格式,支持: + # 节点变量: {"node_id.var_name": value} + # 系统变量: {"sys.message": "hello", "sys.files": [...]} + inputs: Dict[str, Any] = Field( + default_factory=dict, + description="节点输入变量,格式: {'node_id.var_name': value} 或 {'sys.message': 'hello'}", + examples=[{ + "sys.message": "帮我写一首诗", + "sys.user_id": "user-123", + "sys.files": [], + "llm_node_abc.output": "上游输出内容", + }] + ) + stream: bool = Field(default=False, description="是否流式返回") + + class DraftRunCompareRequest(BaseModel): """多模型对比试运行请求""" message: str = Field(..., description="用户消息") diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 27327e99..23cd8833 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -2,6 +2,7 @@ 工作流服务层 """ import datetime +import time import logging import uuid from typing import Any, Annotated, Optional @@ -17,7 +18,6 @@ from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config from app.db import get_db -from sqlalchemy import select from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from app.repositories import knowledge_repository @@ -1070,6 +1070,189 @@ class WorkflowService: } } + async def _build_node_context( + self, + app_id: uuid.UUID, + node_id: str, + config: WorkflowConfig, + workspace_id: uuid.UUID, + input_data: dict[str, Any], + ): + """构建单节点执行所需的上下文(node_config, node, state, variable_pool)""" + from app.core.workflow.engine.runtime_schema import ExecutionContext + from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer + from app.core.workflow.engine.state_manager import WorkflowState + from app.core.workflow.nodes.node_factory import NodeFactory + from app.core.workflow.variable.base_variable import VariableType + + if not config: + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException(code=BizCode.CONFIG_MISSING, message="工作流配置不存在") + + node_config = next((n for n in config.nodes if n.get("id") == node_id), None) + if not node_config: + raise BusinessException(code=BizCode.NOT_FOUND, message=f"节点不存在: node_id={node_id}") + + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables or [], + "execution_config": config.execution_config or {}, + "features": config.features or {}, + } + + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) + execution_id = f"node_{uuid.uuid4().hex[:16]}" + + execution_context = ExecutionContext.create( + execution_id=execution_id, + workspace_id=str(workspace_id), + user_id=input_data.get("user_id", ""), + conversation_id=input_data.get("conversation_id", ""), + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + ) + + # sys.files 转换为 FileObject 格式 + raw_files = input_data.get("files") or [] + if raw_files: + from app.schemas.app_schema import FileInput + file_inputs = [ + FileInput(**f) if isinstance(f, dict) else f + for f in raw_files + ] + input_data["files"] = await self._handle_file_input(file_inputs) + + variable_pool = VariablePool() + await VariablePoolInitializer(workflow_config_dict).initialize(variable_pool, input_data, execution_context) + + # 注入节点输入变量,支持扁平格式 {"node_id.var": value} + for key, value in (input_data.get("inputs") or {}).items(): + if "." in key: + ref_node_id, var_name = key.split(".", 1) + var_type = VariableType.type_map(value) + await variable_pool.new(ref_node_id, var_name, value, var_type, mut=False) + + state = WorkflowState( + messages=input_data.get("conv_messages", []), + node_outputs={}, + execution_id=execution_id, + workspace_id=str(workspace_id), + user_id=input_data.get("user_id", ""), + error=None, + error_node=None, + cycle_nodes=[], + looping=0, + activate={node_id: True}, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + ) + + node = NodeFactory.create_node(node_config, workflow_config_dict, []) + return node_config, node, state, variable_pool + + async def run_single_node( + self, + app_id: uuid.UUID, + node_id: str, + config: WorkflowConfig, + workspace_id: uuid.UUID, + input_data: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """单节点执行(非流式)""" + input_data = input_data or {} + node_config, node, state, variable_pool = await self._build_node_context( + app_id, node_id, config, workspace_id, input_data + ) + start_time = time.time() + try: + result = await node.execute(state, variable_pool) + elapsed = (time.time() - start_time) * 1000 + return { + "status": "completed", + "node_id": node_id, + "node_type": node_config.get("type"), + "inputs": node._extract_input(state, variable_pool), + "outputs": node._extract_output(result), + "token_usage": node._extract_token_usage(result), + "elapsed_time": elapsed, + "error": None, + } + except Exception as e: + elapsed = (time.time() - start_time) * 1000 + logger.error(f"单节点执行失败: node_id={node_id}, error={e}", exc_info=True) + return { + "status": "failed", + "node_id": node_id, + "node_type": node_config.get("type"), + "inputs": node._extract_input(state, variable_pool), + "outputs": None, + "token_usage": None, + "elapsed_time": elapsed, + "error": str(e), + } + + async def run_single_node_stream( + self, + app_id: uuid.UUID, + node_id: str, + config: WorkflowConfig, + workspace_id: uuid.UUID, + input_data: dict[str, Any] | None = None, + ): + """单节点执行(流式) + + Yields: + node_start -> node_chunk(LLM 等流式节点)-> node_end / node_error + """ + input_data = input_data or {} + node_config, node, state, variable_pool = await self._build_node_context( + app_id, node_id, config, workspace_id, input_data + ) + node_type = node_config.get("type") + start_time = time.time() + + yield {"event": "node_start", "data": {"node_id": node_id, "node_type": node_type}} + + final_result = None + try: + async for item in node.execute_stream(state, variable_pool): + if item.get("__final__"): + final_result = item["result"] + else: + chunk = item.get("chunk", "") + if chunk: + yield {"event": "node_chunk", "data": {"node_id": node_id, "chunk": chunk}} + + elapsed = (time.time() - start_time) * 1000 + yield { + "event": "node_end", + "data": { + "node_id": node_id, + "node_type": node_type, + "status": "succeeded", + "inputs": node._extract_input(state, variable_pool), + "outputs": node._extract_output(final_result), + "token_usage": node._extract_token_usage(final_result), + "elapsed_time": elapsed, + "error": None, + } + } + except Exception as e: + elapsed = (time.time() - start_time) * 1000 + logger.error(f"单节点流式执行失败: node_id={node_id}, error={e}", exc_info=True) + yield { + "event": "node_error", + "data": { + "node_id": node_id, + "node_type": node_type, + "inputs": node._extract_input(state, variable_pool), + "elapsed_time": elapsed, + "error": str(e), + } + } + @staticmethod def get_start_node_variables(config: dict) -> list: nodes = config.get("nodes", [])