From 6c04c99073db1ff7ff72e8f2fef7e05300609eaa Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 20 Dec 2025 13:59:20 +0800 Subject: [PATCH 1/5] [modify] workflow executor support stream --- api/app/core/agent/langchain_agent.py | 3 - api/app/core/workflow/executor.py | 64 +++++++++++----- api/app/core/workflow/nodes/base_node.py | 40 +++++++--- api/app/core/workflow/nodes/end/node.py | 8 +- api/app/core/workflow/nodes/llm/node.py | 82 ++++++++++---------- api/app/services/llm_router.py | 2 +- api/app/services/workflow_service.py | 95 ++++++++++++------------ 7 files changed, 168 insertions(+), 126 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index dc0d6922..3c33ad6e 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -9,18 +9,15 @@ LangChain Agent 封装 """ import os import time -import asyncio from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage from langchain_core.tools import BaseTool from langchain.agents import create_agent -from app.core.memory.agent.mcp_server.services import session_service from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType from app.core.logging_config import get_business_logger -from app.services.memory_agent_service import MemoryAgentService from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 80d5316a..75a9cb0b 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -92,7 +92,7 @@ class WorkflowExecutor: - def build_graph(self) -> CompiledStateGraph: + def build_graph(self,stream=False) -> CompiledStateGraph: """构建 LangGraph Returns: @@ -122,12 +122,19 @@ class WorkflowExecutor: if node_instance: # 包装节点的 run 方法 # 使用函数工厂避免闭包问题 - def make_node_func(inst): - async def node_func(state: WorkflowState): + if stream: + # 流式模式:创建 async generator 函数 + # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state + async def node_func(state: WorkflowState, inst=node_instance): + async for item in inst.run_stream(state): + yield item + workflow.add_node(node_id, node_func) + else: + # 非流式模式:创建 async function + async def node_func(state: WorkflowState, inst=node_instance): return await inst.run(state) - return node_func + workflow.add_node(node_id, node_func) - workflow.add_node(node_id, make_node_func(node_instance)) logger.debug(f"添加节点: {node_id} (type={node_type})") # 3. 添加边 @@ -276,12 +283,13 @@ class WorkflowExecutor: ): """执行工作流(流式) - 手动执行节点以支持细粒度的流式输出: - - workflow_start: 工作流开始 - - node_start: 节点开始执行 - - node_chunk: LLM 节点的流式输出片段(逐 token) - - node_complete: 节点执行完成 - - workflow_complete: 工作流完成 + 使用 stream_mode="updates" 来获取每个节点的 state 更新。 + 节点的 generator 会 yield 多个值: + - 中间的 chunk 事件(带 type="chunk") + - 最后的 state 更新(纯字典,包含 node_outputs 等) + + LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。 + 我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。 Args: input_data: 输入数据 @@ -289,27 +297,47 @@ class WorkflowExecutor: Yields: 流式事件 """ - # logger.info(f"开始执行工作流: execution_id={self.execution_id}") # 记录开始时间 start_time = datetime.datetime.now() # 1. 构建图 - graph = self.build_graph() + graph = self.build_graph(True) # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) # 3. 执行工作流 try: - async for chunk in graph.astream( + async for mode, event in graph.astream( initial_state, - # subgraphs=True, - stream_mode="updates", + stream_mode=["updates","messages"], ): - # print(chunk) - yield chunk + # print("刚才跑的节点:", event[0]) + # # 通过图结构就能算出“接下来是谁” + # print("接下来可能跑:", graph.get_next(event[0])) + # print("="*50) + # # print("mode",mode) + # print("event",event) + # print("="*50) + # event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk + for node_id, update in event.items(): + print("="*50) + print("node_id",node_id) + print("update",update) + + print("="*50) + if isinstance(update, dict) and update.get("type") == "chunk": + # 这是流式 chunk,转发给客户端 + yield { + "type": "node_chunk", + "node_id": update.get("node_id"), + "chunk": update.get("content") + } + # 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理 + print(event) + yield event except Exception as e: # 计算耗时(即使失败也记录) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index d17cc1fd..5674655a 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -209,11 +209,15 @@ class BaseNode(ABC): 3. 将业务数据包装成标准输出格式 4. 错误处理 + 注意:在流式模式下,我们需要: + - yield 中间的 chunk 事件(用于实时显示) + - 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state) + Args: state: 工作流状态 Yields: - 标准化的流式事件 + 标准化的流式事件和最终的 state 更新 """ import time @@ -263,27 +267,39 @@ class BaseNode(ABC): elapsed_time = time.time() - start_time + # 提取处理后的输出(调用子类的 _extract_output) + extracted_output = self._extract_output(final_result) + # 包装最终结果 final_output = self._wrap_output(final_result, elapsed_time, state) - yield { - "type": "complete", - **final_output + + # 将提取后的输出存储到运行时变量中(供后续节点快速访问) + if isinstance(extracted_output, dict): + runtime_var = extracted_output + else: + runtime_var = {"output": extracted_output} + + # 构建完整的 state 更新(包含 node_outputs 和 runtime_vars) + state_update = { + **final_output, + "runtime_vars": { + self.node_id: runtime_var + } } + + # 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中) + yield state_update except TimeoutError: elapsed_time = time.time() - start_time logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") - yield { - "type": "error", - **self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) - } + error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + yield error_output except Exception as e: elapsed_time = time.time() - start_time logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) - yield { - "type": "error", - **self._wrap_error(str(e), elapsed_time, state) - } + error_output = self._wrap_error(str(e), elapsed_time, state) + yield error_output def _wrap_output( self, diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index ad028f31..6ee56dde 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -30,11 +30,11 @@ class EndNode(BaseNode): # 获取配置的输出模板 output_template = self.config.get("output") - pool = self.get_variable_pool(state) + # pool = self.get_variable_pool(state) - print("="*20) - print( pool.get("start.test")) - print("="*20) + # print("="*20) + # print( pool.get("start.test")) + # print("="*20) # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state) diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index cf665ff1..295ae583 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -63,7 +63,7 @@ class LLMNode(BaseNode): - ai/assistant: AI 消息(AIMessage) """ - def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]: + def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: """准备 LLM 实例(公共逻辑) Args: @@ -125,16 +125,19 @@ class LLMNode(BaseNode): model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) + print("="*50) + print("stream",stream) + print("="*50) llm = RedBearLLM( RedBearModelConfig( model_name=model_name, provider=provider, api_key=api_key, - base_url=api_base + base_url=api_base, + extra_params={"streaming": stream} ), type=model_type ) - return llm, prompt_or_messages async def execute(self, state: WorkflowState) -> AIMessage: @@ -146,13 +149,12 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ - llm, prompt_or_messages = self._prepare_llm(state) + llm, prompt_or_messages = self._prepare_llm(state,True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") # 调用 LLM(支持字符串或消息列表) response = await llm.ainvoke(prompt_or_messages) - # 提取内容 if hasattr(response, 'content'): content = response.content @@ -199,47 +201,47 @@ class LLMNode(BaseNode): } return None - async def execute_stream(self, state: WorkflowState): - """流式执行 LLM 调用 + # async def execute_stream(self, state: WorkflowState): + # """流式执行 LLM 调用 - Args: - state: 工作流状态 + # Args: + # state: 工作流状态 - Yields: - 文本片段(chunk)或完成标记 - """ - llm, prompt_or_messages = self._prepare_llm(state) + # Yields: + # 文本片段(chunk)或完成标记 + # """ + # llm, prompt_or_messages = self._prepare_llm(state,True) - logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") + # logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") - # 累积完整响应 - full_response = "" - last_chunk = None + # # 累积完整响应 + # full_response = "" + # last_chunk = None - # 调用 LLM(流式,支持字符串或消息列表) - async for chunk in llm.astream(prompt_or_messages): - # 提取内容 - if hasattr(chunk, 'content'): - content = chunk.content - else: - content = str(chunk) + # # 调用 LLM(流式,支持字符串或消息列表) + # async for chunk in llm.astream(prompt_or_messages): + # # 提取内容 + # if hasattr(chunk, 'content'): + # content = chunk.content + # else: + # content = str(chunk) - full_response += content - last_chunk = chunk - - # 流式返回每个文本片段 - yield content + # full_response += content + # last_chunk = chunk + # logger.info(f"节点 {self.node_id} LLM : {content}") + # # 流式返回每个文本片段 + # yield content - logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}") + # logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}") - # 构建完整的 AIMessage(包含元数据) - if isinstance(last_chunk, AIMessage): - final_message = AIMessage( - content=full_response, - response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} - ) - else: - final_message = AIMessage(content=full_response) + # # 构建完整的 AIMessage(包含元数据) + # if isinstance(last_chunk, AIMessage): + # final_message = AIMessage( + # content=full_response, + # response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} + # ) + # else: + # final_message = AIMessage(content=full_response) - # yield 完成标记 - yield {"__final__": True, "result": final_message} + # # yield 完成标记 + # yield {"__final__": True, "result": final_message} diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index 089f2c07..9ef9dbb1 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -385,7 +385,7 @@ class LLMRouter: # 获取 API Key 配置 api_key_config = self.db.query(ModelApiKey).filter( ModelApiKey.model_config_id == self.routing_model_config.id, - ModelApiKey.is_active == True + ModelApiKey.is_active ).first() if not api_key_config: diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index f0b71824..b48edfdd 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -1,29 +1,27 @@ """ 工作流服务层 """ +import datetime import json import logging import uuid -import datetime from typing import Any, Annotated -from sqlalchemy.orm import Session from fastapi import Depends +from sqlalchemy.orm import Session +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.workflow.validator import validate_workflow_config +from app.db import get_db from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, - WorkflowNodeExecutionRepository, - get_workflow_config_repository, - get_workflow_execution_repository, - get_workflow_node_execution_repository + WorkflowNodeExecutionRepository ) -from app.core.workflow.validator import validate_workflow_config -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode -from app.db import get_db from app.schemas import DraftRunRequest +from app.utils.sse_utils import format_sse_message logger = logging.getLogger(__name__) @@ -81,7 +79,7 @@ class WorkflowService: if not is_valid: logger.warning(f"工作流配置验证失败: {errors}") raise BusinessException( - error_code=BizCode.INVALID_PARAMETER, + code=BizCode.INVALID_PARAMETER, message=f"工作流配置无效: {'; '.join(errors)}" ) @@ -140,7 +138,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -166,7 +164,7 @@ class WorkflowService: if not is_valid: logger.warning(f"工作流配置验证失败: {errors}") raise BusinessException( - error_code=BizCode.INVALID_PARAMETER, + code=BizCode.INVALID_PARAMETER, message=f"工作流配置无效: {'; '.join(errors)}" ) @@ -195,8 +193,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: return False - - self.config_repo.delete(config.id) + config.is_active = False logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}") return True @@ -245,7 +242,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -359,7 +356,7 @@ class WorkflowService: execution = self.get_execution(execution_id) if not execution: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"执行记录不存在: execution_id={execution_id}" ) @@ -474,11 +471,9 @@ class WorkflowService: } # 4. 获取工作空间 ID(从 app 获取) - from app.models import App - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow, execute_workflow_stream + from app.core.workflow.executor import execute_workflow try: # 更新状态为运行中 @@ -595,17 +590,18 @@ class WorkflowService: } # 4. 获取工作空间 ID(从 app 获取) - from app.models import App # 5. 流式执行工作流 - from app.core.workflow.executor import execute_workflow, execute_workflow_stream try: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") # 发送开始事件 - yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n" + yield format_sse_message("workflow_start", { + "execution_id": execution.execution_id, + "conversation_id_uuid": str(conversation_id_uuid), + }) # 调用流式执行 async for event in self._run_workflow_stream( @@ -621,7 +617,10 @@ class WorkflowService: yield f"data: {json.dumps(cleaned_event)}\n\n" # 发送完成事件 - yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n" + yield format_sse_message("workflow_end", { + "execution_id": execution.execution_id, + "conversation_id_uuid": str(conversation_id_uuid), + }) except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) @@ -660,7 +659,7 @@ class WorkflowService: config = self.get_workflow_config(app_id) if not config: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"工作流配置不存在: app_id={app_id}" ) @@ -687,12 +686,12 @@ class WorkflowService: app = self.db.query(App).filter(App.id == app_id).first() if not app: raise BusinessException( - error_code=BizCode.RESOURCE_NOT_FOUND, + code=BizCode.NOT_FOUND, message=f"应用不存在: app_id={app_id}" ) # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow, execute_workflow_stream + from app.core.workflow.executor import execute_workflow try: # 更新状态为运行中 @@ -750,7 +749,7 @@ class WorkflowService: error_message=str(e) ) raise BusinessException( - error_code=BizCode.INTERNAL_ERROR, + code=BizCode.INTERNAL_ERROR, message=f"工作流执行失败: {str(e)}" ) @@ -820,26 +819,26 @@ class WorkflowService: yield event # 收集输出数据 - if event.get("type") == "node_complete": - node_data = event.get("data", {}) - node_outputs = node_data.get("node_outputs", {}) - output_data.update(node_outputs) - - # 处理完成事件 - if event.get("type") == "workflow_complete": - self.update_execution_status( - execution_id, - "completed", - output_data=output_data - ) - - # 处理错误事件 - if event.get("type") == "workflow_error": - self.update_execution_status( - execution_id, - "failed", - error_message=event.get("error") - ) + # if event.get("type") == "node_complete": + # node_data = event.get("data", {}) + # node_outputs = node_data.get("node_outputs", {}) + # output_data.update(node_outputs) + # + # # 处理完成事件 + # if event.get("type") == "workflow_complete": + # self.update_execution_status( + # execution_id, + # "completed", + # output_data=output_data + # ) + # + # # 处理错误事件 + # if event.get("type") == "workflow_error": + # self.update_execution_status( + # execution_id, + # "failed", + # error_message=event.get("error") + # ) except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True) From d8fcea856460c11b39e3cbb874d586018e850aac Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 20 Dec 2025 16:03:41 +0800 Subject: [PATCH 2/5] [fix] model support stream --- api/app/core/models/llm.py | 252 ++++++++++++++++++----- api/app/core/workflow/executor.py | 118 +++++++---- api/app/core/workflow/nodes/base_node.py | 3 + api/app/core/workflow/nodes/end/node.py | 50 ++++- api/app/core/workflow/nodes/llm/node.py | 88 ++++---- 5 files changed, 377 insertions(+), 134 deletions(-) diff --git a/api/app/core/models/llm.py b/api/app/core/models/llm.py index 5808d31a..7cd12faa 100644 --- a/api/app/core/models/llm.py +++ b/api/app/core/models/llm.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Iterator, AsyncIterator, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun from langchain_core.language_models import BaseLLM -from langchain_core.outputs import LLMResult +from langchain_core.outputs import LLMResult, GenerationChunk from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class from app.models.models_model import ModelType @@ -10,21 +10,36 @@ from app.models.models_model import ModelType class RedBearLLM(BaseLLM): """ - RedBear LLM 模型包装器 - 完全动态代理实现 + RedBear LLM Model Wrapper - 这个包装器自动将所有方法调用委托给内部模型, - 同时提供优雅的回退机制和错误处理。 + This wrapper provides a unified interface to access different LLM providers, + while maintaining all LangChain functionality, including streaming output. + + Features: + - Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.) + - Full streaming output support + - Elegant error handling and fallback mechanism + - Automatic proxying of all underlying model methods and attributes """ - def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM): - self._model = self._create_model(config, type) + def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM): + """Initialize RedBear LLM wrapper + + Args: + config: Model configuration + type: Model type (LLM or CHAT) + """ + super().__init__() self._config = config + self._model = self._create_model(config, type) @property def _llm_type(self) -> str: - """返回LLM类型标识符""" - return self._model._llm_type + """Return LLM type identifier""" + return getattr(self._model, '_llm_type', 'redbear_llm') + # ==================== Core Methods (Required by BaseLLM) ==================== + def _generate( self, prompts: List[str], @@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ) -> LLMResult: - """同步生成文本""" + """Synchronous text generation (required by BaseLLM)""" return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs) async def _agenerate( @@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any ) -> LLMResult: - """异步生成文本""" + """Asynchronous text generation (required by BaseLLM)""" return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs) - # 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为 + # ==================== Advanced Methods (Support Message Lists) ==================== + def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any: - """直接调用底层模型以支持 ChatPrompt 和消息列表。""" + """Synchronous model invocation + + Supports various input formats including strings and message lists. + Directly delegates to the underlying model to avoid BaseLLM's string conversion. + + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + **kwargs: Additional arguments + + Returns: + Model response + """ try: return self._model.invoke(input, config=config, **kwargs) except AttributeError as e: - # 只在属性错误时回退(说明底层模型不支持该方法) if 'invoke' in str(e): + # Underlying model doesn't support invoke, fallback to parent implementation return super().invoke(input, config=config, **kwargs) - # 其他 AttributeError 直接抛出 raise except Exception: - # 其他所有异常(包括 ValidationException)直接抛出,不回退 + # Other exceptions are raised directly raise async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any: - """异步直接调用底层模型以支持 ChatPrompt 和消息列表。""" + """Asynchronous model invocation + + Supports various input formats including strings and message lists. + Directly delegates to the underlying model to avoid BaseLLM's string conversion. + + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + **kwargs: Additional arguments + + Returns: + Model response + """ try: return await self._model.ainvoke(input, config=config, **kwargs) except AttributeError as e: - # 只在属性错误时回退(说明底层模型不支持该方法) if 'ainvoke' in str(e): + # Underlying model doesn't support ainvoke, fallback to parent implementation return await super().ainvoke(input, config=config, **kwargs) - # 其他 AttributeError 直接抛出 raise except Exception: - # 其他所有异常(包括 ValidationException)直接抛出,不回退 + # Other exceptions are raised directly raise - def __getattr__(self, name): - """ - 动态代理:将所有未定义的属性和方法调用委托给内部模型 + # ==================== Streaming Methods (Critical) ==================== + + def stream( + self, + input: Any, + config: Optional[dict] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any + ) -> Iterator[GenerationChunk]: + """Synchronous streaming model invocation - 这是最优雅的包装器实现方式,完全避免了方法重复定义 - """ - # 处理特殊属性以避免递归 - if name in ('__isabstractmethod__', '__dict__', '__class__'): - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + stop: List of stop words + **kwargs: Additional arguments - # 检查内部模型是否有该属性(使用安全的方式避免递归) + Yields: + GenerationChunk: Generated text chunks + """ + try: + yield from self._model.stream(input, config=config, stop=stop, **kwargs) + except AttributeError as e: + if 'stream' in str(e): + # Underlying model doesn't support stream, fallback to parent implementation + yield from super().stream(input, config=config, stop=stop, **kwargs) + else: + raise + except Exception: + raise + + async def astream( + self, + input: Any, + config: Optional[dict] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any + ) -> AsyncIterator[GenerationChunk]: + """Asynchronous streaming model invocation + + This is the core method for streaming output. It directly proxies to the + underlying model's astream method, maintaining generator characteristics + to ensure each chunk is delivered in real-time. + + Args: + input: Input (string, message list, etc.) + config: Runtime configuration + stop: List of stop words + **kwargs: Additional arguments + + Yields: + GenerationChunk: Generated text chunks + """ + try: + async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs): + yield chunk + except AttributeError as e: + if 'astream' in str(e): + # Underlying model doesn't support astream, fallback to parent implementation + async for chunk in super().astream(input, config=config, stop=stop, **kwargs): + yield chunk + else: + raise + except Exception: + raise + + # ==================== Dynamic Proxy ==================== + + def __getattr__(self, name: str) -> Any: + """Dynamic proxy: delegate undefined attributes and method calls to internal model + + This method allows RedBearLLM to transparently access all attributes and methods + of the underlying model without explicitly defining each one. + + Args: + name: Attribute or method name + + Returns: + Attribute value or method + + Raises: + AttributeError: If attribute doesn't exist + """ + # Avoid recursion: raise error directly for special attributes + if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # Try to get attribute from internal model try: - # 使用 object.__getattribute__ 来安全地检查内部模型的属性 attr = object.__getattribute__(self._model, name) - # 如果是方法,返回一个包装器来处理调用 + # If it's callable (a method) if callable(attr): - # 流式方法直接返回,不包装(保持生成器特性) - if name in ('_stream', '_astream', 'stream', 'astream'): + # Streaming methods are returned directly to maintain generator characteristics + # Note: Although we've explicitly implemented stream/astream, + # this is kept to handle internal methods like _stream/_astream + if name in ('_stream', '_astream'): return attr - # 非流式方法使用包装器处理异常 + # Wrap other methods for easier debugging and error handling def method_wrapper(*args, **kwargs): - return attr(*args, **kwargs) + try: + return attr(*args, **kwargs) + except Exception: + # Can add logging or error handling here + raise - # 保持方法的元信息 + # Preserve method metadata method_wrapper.__name__ = name method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}") return method_wrapper - # 如果是普通属性,直接返回 + # If it's a regular attribute, return directly return attr except AttributeError: - # 内部模型没有该属性,尝试回退实现 + # Internal model doesn't have this attribute either pass - # 检查是否有回退方法(使用安全的方式避免递归) + # Check if there's a fallback method fallback_name = f'_fallback_{name}' try: - fallback_method = object.__getattribute__(self, fallback_name) - return fallback_method + return object.__getattribute__(self, fallback_name) except AttributeError: - # 没有回退方法,抛出适当的错误 pass - # 如果都没有,抛出适当的错误 - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + # Nothing found, raise error + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'. " + f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute." + ) + # ==================== Helper Methods ==================== + def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM: - """创建内部模型实例""" + """Create internal model instance + + Args: + config: Model configuration + type: Model type + + Returns: + Created model instance + """ llm_class = get_provider_llm_class(config, type) model_params = RedBearModelFactory.get_model_params(config) return llm_class(**model_params) - - - \ No newline at end of file + + def get_config(self) -> RedBearModelConfig: + """Get model configuration + + Returns: + Model configuration object + """ + return self._config + + def get_underlying_model(self) -> BaseLLM: + """Get underlying model instance + + Returns: + Underlying model instance + """ + return self._model + + def __repr__(self) -> str: + """Return string representation of the object""" + return ( + f"RedBearLLM(" + f"provider={self._config.provider}, " + f"model={self._config.model_name}, " + f"type={type(self._model).__name__}" + f")" + ) \ No newline at end of file diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 75a9cb0b..8d67dd1e 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -125,17 +125,22 @@ class WorkflowExecutor: if stream: # 流式模式:创建 async generator 函数 # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state - async def node_func(state: WorkflowState, inst=node_instance): - async for item in inst.run_stream(state): - yield item - workflow.add_node(node_id, node_func) + def make_stream_func(inst): + async def node_func(state: WorkflowState): + # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") + async for item in inst.run_stream(state): + yield item + return node_func + workflow.add_node(node_id, make_stream_func(node_instance)) else: # 非流式模式:创建 async function - async def node_func(state: WorkflowState, inst=node_instance): - return await inst.run(state) - workflow.add_node(node_id, node_func) + def make_func(inst): + async def node_func(state: WorkflowState): + return await inst.run(state) + return node_func + workflow.add_node(node_id, make_func(node_instance)) - logger.debug(f"添加节点: {node_id} (type={node_type})") + logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})") # 3. 添加边 # 从 START 连接到 start 节点 @@ -283,13 +288,9 @@ class WorkflowExecutor: ): """执行工作流(流式) - 使用 stream_mode="updates" 来获取每个节点的 state 更新。 - 节点的 generator 会 yield 多个值: - - 中间的 chunk 事件(带 type="chunk") - - 最后的 state 更新(纯字典,包含 node_outputs 等) - - LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。 - 我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。 + 使用多个 stream_mode 来获取: + 1. "updates" - 节点的 state 更新和流式 chunk + 2. "debug" - 节点执行的详细信息(开始/完成时间) Args: input_data: 输入数据 @@ -297,7 +298,7 @@ class WorkflowExecutor: Yields: 流式事件 """ - logger.info(f"开始执行工作流: execution_id={self.execution_id}") + logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") # 记录开始时间 start_time = datetime.datetime.now() @@ -310,34 +311,73 @@ class WorkflowExecutor: # 3. 执行工作流 try: - async for mode, event in graph.astream( + chunk_count = 0 + async for event in graph.astream( initial_state, - stream_mode=["updates","messages"], + stream_mode=["updates", "debug"], ): - # print("刚才跑的节点:", event[0]) - # # 通过图结构就能算出“接下来是谁” - # print("接下来可能跑:", graph.get_next(event[0])) - # print("="*50) - # # print("mode",mode) - # print("event",event) - # print("="*50) - # event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk - for node_id, update in event.items(): - print("="*50) - print("node_id",node_id) - print("update",update) + mode, data = event - print("="*50) - if isinstance(update, dict) and update.get("type") == "chunk": - # 这是流式 chunk,转发给客户端 + if mode == "debug": + # 处理调试信息(节点执行状态) + event_type = data.get("type") + payload = data.get("payload", {}) + node_name = payload.get("name") + + if event_type == "task": + # 节点开始执行 + inputv = payload.get("input", {}) + variables = inputv.get("variables", {}) + variables_sys = variables.get("sys", {}) + conversation_id = variables_sys.get("conversation_id") + execution_id = variables_sys.get("execution_id") + logger.info(f"[DEBUG] 节点开始执行: {node_name}") yield { - "type": "node_chunk", - "node_id": update.get("node_id"), - "chunk": update.get("content") + "type": "node_start", + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": execution_id, + "timestamp": data.get("timestamp") } - # 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理 - print(event) - yield event + elif event_type == "task_result": + # 节点执行完成 + result = payload.get("result", {}) + inputv = result.get("input", {}) + variables = inputv.get("variables", {}) + variables_sys = variables.get("sys", {}) + conversation_id = variables_sys.get("conversation_id") + execution_id = variables_sys.get("execution_id") + logger.info(f"[DEBUG] 节点执行完成: {node_name}") + yield { + "type": "node_end", + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": execution_id, + "timestamp": data.get("timestamp") + } + + elif mode == "updates": + # 处理 state 更新 + # data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk + print("="*50) + print(data) + print("-"*50) + for node_id, update in data.items(): + if isinstance(update, dict) and update.get("type") == "chunk": + # 这是流式 chunk,转发给客户端 + chunk_count += 1 + logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...") + yield { + "type": "node_chunk", + "node_id": update.get("node_id"), + "chunk": update.get("content"), + "full_content": update.get("full_content") + } + else: + logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}") + # 其他情况(state 更新)会被 LangGraph 自动合并到 state + + logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}") except Exception as e: # 计算耗时(即使失败也记录) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 5674655a..1d6f1c15 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -245,6 +245,9 @@ class BaseNode(ABC): final_result = item["result"] elif isinstance(item, str): # 字符串是 chunk + # print("="*50) + # print(item) + # print("-"*50) chunks.append(item) yield { "type": "chunk", diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 6ee56dde..cba0d649 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -30,11 +30,7 @@ class EndNode(BaseNode): # 获取配置的输出模板 output_template = self.config.get("output") - # pool = self.get_variable_pool(state) - - # print("="*20) - # print( pool.get("start.test")) - # print("="*20) + # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: output = self._render_template(output_template, state) @@ -46,7 +42,45 @@ class EndNode(BaseNode): total_nodes = len(node_outputs) logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") - print("="*20) - print(output) - print("="*20) + return output + + async def execute_stream(self, state: WorkflowState): + """流式执行 end 节点业务逻辑 + + 当 end 节点前面是 LLM 节点时,流式输出其内容。 + + Args: + state: 工作流状态 + + Yields: + 文本片段(chunk)或完成标记 + """ + logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") + + # 获取配置的输出模板 + output_template = self.config.get("output") + + # 如果配置了输出模板,使用模板渲染 + if output_template: + output = self._render_template(output_template, state) + + # 检查输出中是否包含节点引用(如 {{llm_node.output}}) + # 如果包含,则逐字符流式输出 + if output: + # 逐字符流式输出 + for char in output: + yield char + else: + output = "工作流已完成" + for char in output: + yield char + + # 统计信息(用于日志) + node_outputs = state.get("node_outputs", {}) + total_nodes = len(node_outputs) + + logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点") + + # yield 完成标记 + yield {"__final__": True, "result": output} diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 295ae583..bac707d7 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -125,19 +125,22 @@ class LLMNode(BaseNode): model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) - print("="*50) - print("stream",stream) - print("="*50) + # 注意:对于流式输出,需要在模型初始化时设置 streaming=True + extra_params = {"streaming": stream} if stream else {} + llm = RedBearLLM( RedBearModelConfig( model_name=model_name, provider=provider, api_key=api_key, base_url=api_base, - extra_params={"streaming": stream} + extra_params=extra_params ), type=model_type ) + + logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") + return llm, prompt_or_messages async def execute(self, state: WorkflowState) -> AIMessage: @@ -201,47 +204,54 @@ class LLMNode(BaseNode): } return None - # async def execute_stream(self, state: WorkflowState): - # """流式执行 LLM 调用 + async def execute_stream(self, state: WorkflowState): + """流式执行 LLM 调用 - # Args: - # state: 工作流状态 + Args: + state: 工作流状态 - # Yields: - # 文本片段(chunk)或完成标记 - # """ - # llm, prompt_or_messages = self._prepare_llm(state,True) + Yields: + 文本片段(chunk)或完成标记 + """ + llm, prompt_or_messages = self._prepare_llm(state, True) - # logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") + logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") - # # 累积完整响应 - # full_response = "" - # last_chunk = None + # 累积完整响应 + full_response = "" + last_chunk = None + chunk_count = 0 - # # 调用 LLM(流式,支持字符串或消息列表) - # async for chunk in llm.astream(prompt_or_messages): - # # 提取内容 - # if hasattr(chunk, 'content'): - # content = chunk.content - # else: - # content = str(chunk) + # 调用 LLM(流式,支持字符串或消息列表) + # 注意:astream 方法本身就是流式的,不需要额外配置 + async for chunk in llm.astream(prompt_or_messages): + # 提取内容 + if hasattr(chunk, 'content'): + content = chunk.content + else: + content = str(chunk) - # full_response += content - # last_chunk = chunk - # logger.info(f"节点 {self.node_id} LLM : {content}") - # # 流式返回每个文本片段 - # yield content + # 只有当内容不为空时才处理 + if content: + full_response += content + last_chunk = chunk + chunk_count += 1 + + # logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...") + # 流式返回每个文本片段 + yield content #AIMessage(content=content) - # logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}") + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") - # # 构建完整的 AIMessage(包含元数据) - # if isinstance(last_chunk, AIMessage): - # final_message = AIMessage( - # content=full_response, - # response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} - # ) - # else: - # final_message = AIMessage(content=full_response) + # 构建完整的 AIMessage(包含元数据) + if isinstance(last_chunk, AIMessage): + final_message = AIMessage( + content=full_response, + response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} + ) + else: + final_message = AIMessage(content=full_response) - # # yield 完成标记 - # yield {"__final__": True, "result": final_message} + # yield 完成标记 + yield {"__final__": True, "result": final_message} From 36b36b729b9c03c8f71b3d9ecc45a9e5c0ab1202 Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 20 Dec 2025 17:25:47 +0800 Subject: [PATCH 3/5] [add] workflow llm & end logic --- api/app/core/workflow/executor.py | 137 +++++++++++---- api/app/core/workflow/nodes/base_node.py | 157 +++++++++++------ api/app/core/workflow/nodes/end/node.py | 208 ++++++++++++++++++++--- api/app/core/workflow/nodes/llm/node.py | 31 +++- 4 files changed, 430 insertions(+), 103 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 8d67dd1e..992a8e1a 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -87,11 +87,75 @@ class WorkflowExecutor: "workspace_id": self.workspace_id, "user_id": self.user_id, "error": None, - "error_node": None + "error_node": None, + "streaming_buffer": {} # 流式缓冲区 } + def _analyze_end_node_prefixes(self) -> dict[str, str]: + """分析 End 节点的前缀配置 + + 检查每个 End 节点的模板,找到直接上游节点的引用, + 提取该引用之前的前缀部分。 + + Returns: + 字典:{上游节点ID: End节点前缀} + """ + import re + + prefixes = {} + + # 找到所有 End 节点 + end_nodes = [node for node in self.nodes if node.get("type") == "end"] + logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点") + + for end_node in end_nodes: + end_node_id = end_node.get("id") + output_template = end_node.get("config", {}).get("output") + + logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}") + + if not output_template: + continue + + # 找到所有直接连接到 End 节点的上游节点 + direct_upstream_nodes = [] + for edge in self.edges: + if edge.get("target") == end_node_id: + source_node_id = edge.get("source") + direct_upstream_nodes.append(source_node_id) + + logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}") + + # 查找模板中引用了哪些节点 + # 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格) + pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}' + matches = list(re.finditer(pattern, output_template)) + + logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用") + + # 找到第一个直接上游节点的引用 + for match in matches: + referenced_node_id = match.group(1) + logger.info(f"[前缀分析] 检查引用: {referenced_node_id}") + + if referenced_node_id in direct_upstream_nodes: + # 这是直接上游节点的引用,提取前缀 + prefix = output_template[:match.start()] + + logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'") + + if prefix: + prefixes[referenced_node_id] = prefix + logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'") + + # 只处理第一个直接上游节点的引用 + break + + logger.info(f"[前缀分析] 最终配置: {prefixes}") + return prefixes + def build_graph(self,stream=False) -> CompiledStateGraph: """构建 LangGraph @@ -99,6 +163,9 @@ class WorkflowExecutor: 编译后的状态图 """ logger.info(f"开始构建工作流图: execution_id={self.execution_id}") + + # 分析 End 节点的前缀配置 + end_prefixes = self._analyze_end_node_prefixes() if stream else {} # 1. 创建状态图 workflow = StateGraph(WorkflowState) @@ -120,6 +187,12 @@ class WorkflowExecutor: # 创建节点实例(现在 start 和 end 也会被创建) node_instance = NodeFactory.create_node(node, self.workflow_config) if node_instance: + # 如果是流式模式,且节点有 End 前缀配置,注入配置 + if stream and node_id in end_prefixes: + # 将 End 前缀配置注入到节点实例 + node_instance._end_node_prefix = end_prefixes[node_id] + logger.info(f"为节点 {node_id} 注入 End 前缀配置") + # 包装节点的 run 方法 # 使用函数工厂避免闭包问题 if stream: @@ -309,29 +382,48 @@ class WorkflowExecutor: # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) - # 3. 执行工作流 + # 3. Execute workflow try: chunk_count = 0 async for event in graph.astream( initial_state, - stream_mode=["updates", "debug"], + stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode ): - mode, data = event + # event should be a tuple: (mode, data) + # But let's handle both cases + if isinstance(event, tuple) and len(event) == 2: + mode, data = event + else: + # Unexpected format, log and skip + logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}") + continue - if mode == "debug": - # 处理调试信息(节点执行状态) + if mode == "custom": + # Handle custom streaming events (chunks from nodes via stream writer) + chunk_count += 1 + logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}") + yield { + "type": "node_chunk", + "node_id": data.get("node_id"), + "chunk": data.get("chunk"), + "full_content": data.get("full_content"), + "chunk_index": data.get("chunk_index") + } + + elif mode == "debug": + # Handle debug information (node execution status) event_type = data.get("type") payload = data.get("payload", {}) node_name = payload.get("name") if event_type == "task": - # 节点开始执行 + # Node starts execution inputv = payload.get("input", {}) variables = inputv.get("variables", {}) variables_sys = variables.get("sys", {}) conversation_id = variables_sys.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] 节点开始执行: {node_name}") + logger.info(f"[DEBUG] Node starts execution: {node_name}") yield { "type": "node_start", "node_id": node_name, @@ -340,16 +432,16 @@ class WorkflowExecutor: "timestamp": data.get("timestamp") } elif event_type == "task_result": - # 节点执行完成 + # Node execution completed result = payload.get("result", {}) inputv = result.get("input", {}) variables = inputv.get("variables", {}) variables_sys = variables.get("sys", {}) conversation_id = variables_sys.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] 节点执行完成: {node_name}") + logger.info(f"[DEBUG] Node execution completed: {node_name}") yield { - "type": "node_end", + "type": "node_complete", "node_id": node_name, "conversation_id": conversation_id, "execution_id": execution_id, @@ -357,27 +449,10 @@ class WorkflowExecutor: } elif mode == "updates": - # 处理 state 更新 - # data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk - print("="*50) - print(data) - print("-"*50) - for node_id, update in data.items(): - if isinstance(update, dict) and update.get("type") == "chunk": - # 这是流式 chunk,转发给客户端 - chunk_count += 1 - logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...") - yield { - "type": "node_chunk", - "node_id": update.get("node_id"), - "chunk": update.get("content"), - "full_content": update.get("full_content") - } - else: - logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}") - # 其他情况(state 更新)会被 LangGraph 自动合并到 state + # Handle state updates + logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}") - logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}") + logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}") except Exception as e: # 计算耗时(即使失败也记录) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 1d6f1c15..f2f18404 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod from typing import Any, TypedDict, Annotated from operator import add from langchain_core.messages import AnyMessage, HumanMessage, AIMessage +from langgraph.config import get_stream_writer from app.core.workflow.variable_pool import VariablePool @@ -43,6 +44,10 @@ class WorkflowState(TypedDict): # 错误信息(用于错误边) error: str | None error_node: str | None + + # 流式缓冲区(存储节点的实时流式输出) + # 格式:{node_id: {"chunks": [...], "full_content": "..."}} + streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}] class BaseNode(ABC): @@ -201,23 +206,25 @@ class BaseNode(ABC): return self._wrap_error(str(e), elapsed_time, state) async def run_stream(self, state: WorkflowState): - """执行节点(带错误处理和输出包装,流式) + """Execute node with error handling and output wrapping (streaming) - 这个方法由 Executor 调用,负责: - 1. 时间统计 - 2. 调用节点的 execute_stream() 方法 - 3. 将业务数据包装成标准输出格式 - 4. 错误处理 + This method is called by the Executor and is responsible for: + 1. Time tracking + 2. Calling the node's execute_stream() method + 3. Using LangGraph's stream writer to send chunks + 4. Updating streaming buffer in state for downstream nodes + 5. Wrapping business data into standard output format + 6. Error handling - 注意:在流式模式下,我们需要: - - yield 中间的 chunk 事件(用于实时显示) - - 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state) + Special handling for End nodes: + - End nodes don't send chunks via writer (prefix and LLM content already sent) + - End nodes only yield suffix for final result assembly Args: - state: 工作流状态 + state: Workflow state Yields: - 标准化的流式事件和最终的 state 更新 + State updates with streaming buffer and final result """ import time @@ -226,63 +233,102 @@ class BaseNode(ABC): try: timeout = self.get_timeout() - # 累积完整结果(用于最后的包装) + # Get LangGraph's stream writer for sending custom data + writer = get_stream_writer() + + # Check if this is an End node + # End nodes CAN send chunks (for suffix), but only after LLM content + is_end_node = self.node_type == "end" + + # Accumulate complete result (for final wrapping) chunks = [] final_result = None + chunk_count = 0 - # 使用异步生成器包装,支持超时 - async def stream_with_timeout(): - nonlocal final_result - loop_start = asyncio.get_event_loop().time() + # Stream chunks in real-time + loop_start = asyncio.get_event_loop().time() + + async for item in self.execute_stream(state): + # Check timeout + if asyncio.get_event_loop().time() - loop_start > timeout: + raise TimeoutError() - async for item in self.execute_stream(state): - # 检查超时 - if asyncio.get_event_loop().time() - loop_start > timeout: - raise TimeoutError() + # Check if it's a completion marker + if isinstance(item, dict) and item.get("__final__"): + final_result = item["result"] + elif isinstance(item, str): + # String is a chunk + chunk_count += 1 + chunks.append(item) + full_content = "".join(chunks) - # 检查是否是完成标记 - if isinstance(item, dict) and item.get("__final__"): - final_result = item["result"] - elif isinstance(item, str): - # 字符串是 chunk - # print("="*50) - # print(item) - # print("-"*50) - chunks.append(item) + # Send chunks for all nodes (including End nodes for suffix) + logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...") + + # 1. Send via stream writer (for real-time client updates) + writer({ + "node_id": self.node_id, + "chunk": item, + "full_content": full_content, + "chunk_index": chunk_count + }) + + # 2. Update streaming buffer in state (for downstream nodes) + # Only non-End nodes need streaming buffer + if not is_end_node: yield { - "type": "chunk", - "node_id": self.node_id, - "content": item, - "full_content": "".join(chunks) + "streaming_buffer": { + self.node_id: { + "full_content": full_content, + "chunk_count": chunk_count, + "is_complete": False + } + } } - else: - # 其他类型也当作 chunk 处理 - chunks.append(str(item)) + else: + # Other types are also treated as chunks + chunk_count += 1 + chunk_str = str(item) + chunks.append(chunk_str) + full_content = "".join(chunks) + + # Send chunks for all nodes + writer({ + "node_id": self.node_id, + "chunk": chunk_str, + "full_content": full_content, + "chunk_index": chunk_count + }) + + # Only non-End nodes need streaming buffer + if not is_end_node: yield { - "type": "chunk", - "node_id": self.node_id, - "content": str(item), - "full_content": "".join(chunks) + "streaming_buffer": { + self.node_id: { + "full_content": full_content, + "chunk_count": chunk_count, + "is_complete": False + } + } } - async for chunk_event in stream_with_timeout(): - yield chunk_event - elapsed_time = time.time() - start_time - # 提取处理后的输出(调用子类的 _extract_output) + logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}") + + # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) - # 包装最终结果 + # Wrap final result final_output = self._wrap_output(final_result, elapsed_time, state) - # 将提取后的输出存储到运行时变量中(供后续节点快速访问) + # Store extracted output in runtime variables (for quick access by subsequent nodes) if isinstance(extracted_output, dict): runtime_var = extracted_output else: runtime_var = {"output": extracted_output} - # 构建完整的 state 更新(包含 node_outputs 和 runtime_vars) + # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, "runtime_vars": { @@ -290,13 +336,24 @@ class BaseNode(ABC): } } - # 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中) + # Add streaming buffer for non-End nodes + if not is_end_node: + state_update["streaming_buffer"] = { + self.node_id: { + "full_content": "".join(chunks), + "chunk_count": chunk_count, + "is_complete": True # Mark as complete + } + } + + # Finally yield state update + # LangGraph will merge this into state yield state_update except TimeoutError: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") - error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)") + error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state) yield error_output except Exception as e: elapsed_time = time.time() - start_time diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index cba0d649..f47f3c1e 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -5,6 +5,8 @@ End 节点实现 """ import logging +import re +import asyncio from app.core.workflow.nodes.base_node import BaseNode, WorkflowState @@ -15,6 +17,7 @@ class EndNode(BaseNode): """End 节点 工作流的结束节点,根据配置的模板输出最终结果。 + 支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。 """ async def execute(self, state: WorkflowState) -> str: @@ -45,42 +48,209 @@ class EndNode(BaseNode): return output + def _extract_referenced_nodes(self, template: str) -> list[str]: + """从模板中提取引用的节点 ID + + 例如:'结果:{{llm_qa.output}}' -> ['llm_qa'] + + Args: + template: 模板字符串 + + Returns: + 引用的节点 ID 列表 + """ + # 匹配 {{node_id.xxx}} 格式 + pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' + matches = re.findall(pattern, template) + return list(set(matches)) # 去重 + + def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]: + """解析模板,分离静态文本和动态引用 + + 例如:'你好 {{llm.output}}, 这是后缀' + 返回:[ + {"type": "static", "content": "你好 "}, + {"type": "dynamic", "node_id": "llm", "field": "output"}, + {"type": "static", "content": ", 这是后缀"} + ] + + Args: + template: 模板字符串 + state: 工作流状态 + + Returns: + 模板部分列表 + """ + import re + + parts = [] + last_end = 0 + + # 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格) + pattern = r'\{\{\s*([^}]+?)\s*\}\}' + + for match in re.finditer(pattern, template): + start, end = match.span() + + # 添加前面的静态文本 + if start > last_end: + static_text = template[last_end:start] + if static_text: + parts.append({"type": "static", "content": static_text}) + + # 解析动态引用 + ref = match.group(1).strip() + + # 检查是否是节点引用(如 llm.output 或 llm_qa.output) + if '.' in ref: + node_id, field = ref.split('.', 1) + parts.append({ + "type": "dynamic", + "node_id": node_id, + "field": field, + "raw": ref + }) + else: + # 其他引用(如 {{var.xxx}}),当作静态处理 + # 直接渲染这部分 + rendered = self._render_template(f"{{{{{ref}}}}}", state) + parts.append({"type": "static", "content": rendered}) + + last_end = end + + # 添加最后的静态文本 + if last_end < len(template): + static_text = template[last_end:] + if static_text: + parts.append({"type": "static", "content": static_text}) + + return parts + async def execute_stream(self, state: WorkflowState): """流式执行 end 节点业务逻辑 - 当 end 节点前面是 LLM 节点时,流式输出其内容。 + 智能输出策略: + 1. 检测模板中是否引用了直接上游节点 + 2. 如果引用了,只输出该引用**之后**的部分(后缀) + 3. 前缀和引用内容已经在上游节点流式输出时发送了 + + 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' + - 直接上游节点是 llm_qa + - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 + - LLM 内容在 LLM 节点流式输出 + - End 节点只输出 ' lalalalala a'(后缀,一次性输出) Args: state: 工作流状态 Yields: - 文本片段(chunk)或完成标记 + 完成标记 """ logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") # 获取配置的输出模板 output_template = self.config.get("output") - # 如果配置了输出模板,使用模板渲染 - if output_template: - output = self._render_template(output_template, state) - - # 检查输出中是否包含节点引用(如 {{llm_node.output}}) - # 如果包含,则逐字符流式输出 - if output: - # 逐字符流式输出 - for char in output: - yield char - else: + if not output_template: output = "工作流已完成" - for char in output: - yield char + yield {"__final__": True, "result": output} + return - # 统计信息(用于日志) + # 找到直接上游节点 + direct_upstream_nodes = [] + for edge in self.workflow_config.get("edges", []): + if edge.get("target") == self.node_id: + source_node_id = edge.get("source") + direct_upstream_nodes.append(source_node_id) + + logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") + + # 解析模板部分 + parts = self._parse_template_parts(output_template, state) + logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") + + # 找到第一个引用直接上游节点的动态引用 + upstream_ref_index = None + for i, part in enumerate(parts): + if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes: + upstream_ref_index = i + logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") + break + + if upstream_ref_index is None: + # 没有引用直接上游节点,正常输出(渲染完整模板) + output = self._render_template(output_template, state) + logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容") + yield {"__final__": True, "result": output} + return + + # 有引用直接上游节点,只输出该引用之后的部分(后缀) + logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") + + # 收集后缀部分 + suffix_parts = [] + for i in range(upstream_ref_index + 1, len(parts)): + part = parts[i] + + if part["type"] == "static": + # 静态文本 + suffix_parts.append(part["content"]) + + elif part["type"] == "dynamic": + # 其他动态引用(如果有多个引用) + node_id = part["node_id"] + field = part["field"] + + # 从 streaming_buffer 或 node_outputs 读取 + streaming_buffer = state.get("streaming_buffer", {}) + if node_id in streaming_buffer: + buffer_data = streaming_buffer[node_id] + content = buffer_data.get("full_content", "") + else: + node_outputs = state.get("node_outputs", {}) + runtime_vars = state.get("runtime_vars", {}) + + content = "" + if node_id in node_outputs: + node_output = node_outputs[node_id] + if isinstance(node_output, dict): + content = str(node_output.get(field, "")) + elif node_id in runtime_vars: + runtime_var = runtime_vars[node_id] + if isinstance(runtime_var, dict): + content = str(runtime_var.get(field, "")) + + suffix_parts.append(content) + + # 拼接后缀 + suffix = "".join(suffix_parts) + + # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) + full_output = self._render_template(output_template, state) + + if suffix: + logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})") + # 一次性输出后缀(作为单个 chunk) + # 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理 + # 而是通过 writer 直接发送 + from langgraph.config import get_stream_writer + writer = get_stream_writer() + writer({ + "node_id": self.node_id, + "chunk": suffix, + "full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀) + "chunk_index": 1, + "is_suffix": True + }) + logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") + else: + logger.info(f"节点 {self.node_id} 没有后缀需要输出") + + # 统计信息 node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) - logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点") + logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点") - # yield 完成标记 - yield {"__final__": True, "result": output} + # yield 完成标记(包含完整输出) + yield {"__final__": True, "result": full_output} diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index bac707d7..56292b81 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -213,18 +213,44 @@ class LLMNode(BaseNode): Yields: 文本片段(chunk)或完成标记 """ + from langgraph.config import get_stream_writer + llm, prompt_or_messages = self._prepare_llm(state, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") + # 检查是否有注入的 End 节点前缀配置 + writer = get_stream_writer() + end_prefix = getattr(self, '_end_node_prefix', None) + + logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}") + if end_prefix: + logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'") + + if end_prefix: + # 渲染前缀(可能包含其他变量) + try: + rendered_prefix = self._render_template(end_prefix, state) + logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'") + + # 提前发送 End 节点的前缀 + writer({ + "node_id": "end", # 标记为 end 节点的输出 + "chunk": rendered_prefix, + "full_content": rendered_prefix, + "chunk_index": 0, + "is_prefix": True # 标记这是前缀 + }) + except Exception as e: + logger.warning(f"渲染/发送 End 节点前缀失败: {e}") + # 累积完整响应 full_response = "" last_chunk = None chunk_count = 0 # 调用 LLM(流式,支持字符串或消息列表) - # 注意:astream 方法本身就是流式的,不需要额外配置 async for chunk in llm.astream(prompt_or_messages): # 提取内容 if hasattr(chunk, 'content'): @@ -238,9 +264,8 @@ class LLMNode(BaseNode): last_chunk = chunk chunk_count += 1 - # logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...") # 流式返回每个文本片段 - yield content #AIMessage(content=content) + yield content logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") From 43a427bac775c7ac05f5f980d373e46ab9bf1b67 Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 20 Dec 2025 17:37:36 +0800 Subject: [PATCH 4/5] [modify] llm & end logic --- api/app/core/workflow/executor.py | 30 +++++++++++++++++------- api/app/core/workflow/nodes/base_node.py | 10 ++++++++ api/app/core/workflow/nodes/end/node.py | 1 + api/app/core/workflow/nodes/llm/node.py | 3 ++- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 992a8e1a..db4fa626 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -93,18 +93,19 @@ class WorkflowExecutor: - def _analyze_end_node_prefixes(self) -> dict[str, str]: + def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: """分析 End 节点的前缀配置 检查每个 End 节点的模板,找到直接上游节点的引用, 提取该引用之前的前缀部分。 Returns: - 字典:{上游节点ID: End节点前缀} + 元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合}) """ import re prefixes = {} + adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点 # 找到所有 End 节点 end_nodes = [node for node in self.nodes if node.get("type") == "end"] @@ -146,6 +147,9 @@ class WorkflowExecutor: logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'") + # 标记这个节点为"相邻且被引用" + adjacent_and_referenced.add(referenced_node_id) + if prefix: prefixes[referenced_node_id] = prefix logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'") @@ -154,7 +158,8 @@ class WorkflowExecutor: break logger.info(f"[前缀分析] 最终配置: {prefixes}") - return prefixes + logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") + return prefixes, adjacent_and_referenced def build_graph(self,stream=False) -> CompiledStateGraph: """构建 LangGraph @@ -164,8 +169,8 @@ class WorkflowExecutor: """ logger.info(f"开始构建工作流图: execution_id={self.execution_id}") - # 分析 End 节点的前缀配置 - end_prefixes = self._analyze_end_node_prefixes() if stream else {} + # 分析 End 节点的前缀配置和相邻且被引用的节点 + end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set()) # 1. 创建状态图 workflow = StateGraph(WorkflowState) @@ -193,6 +198,12 @@ class WorkflowExecutor: node_instance._end_node_prefix = end_prefixes[node_id] logger.info(f"为节点 {node_id} 注入 End 前缀配置") + # 如果是流式模式,标记节点是否与 End 相邻且被引用 + if stream: + node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced + if node_id in adjacent_and_referenced: + logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用") + # 包装节点的 run 方法 # 使用函数工厂避免闭包问题 if stream: @@ -401,13 +412,16 @@ class WorkflowExecutor: if mode == "custom": # Handle custom streaming events (chunks from nodes via stream writer) chunk_count += 1 - logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}") + event_type = data.get("type", "node_chunk") # 默认为 node_chunk + logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}") yield { - "type": "node_chunk", + "type": event_type, # "message" or "node_chunk" "node_id": data.get("node_id"), "chunk": data.get("chunk"), "full_content": data.get("full_content"), - "chunk_index": data.get("chunk_index") + "chunk_index": data.get("chunk_index"), + "is_prefix": data.get("is_prefix"), + "is_suffix": data.get("is_suffix") } elif mode == "debug": diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index f2f18404..25fdd29e 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -240,6 +240,14 @@ class BaseNode(ABC): # End nodes CAN send chunks (for suffix), but only after LLM content is_end_node = self.node_type == "end" + # Check if this node is adjacent to End node (for message type) + is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False) + + # Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others + chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk" + + logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})") + # Accumulate complete result (for final wrapping) chunks = [] final_result = None @@ -267,6 +275,7 @@ class BaseNode(ABC): # 1. Send via stream writer (for real-time client updates) writer({ + "type": chunk_type, # "message" or "node_chunk" "node_id": self.node_id, "chunk": item, "full_content": full_content, @@ -294,6 +303,7 @@ class BaseNode(ABC): # Send chunks for all nodes writer({ + "type": chunk_type, # "message" or "node_chunk" "node_id": self.node_id, "chunk": chunk_str, "full_content": full_content, diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index f47f3c1e..8540cf9d 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -236,6 +236,7 @@ class EndNode(BaseNode): from langgraph.config import get_stream_writer writer = get_stream_writer() writer({ + "type": "message", # End 节点的输出使用 message 类型 "node_id": self.node_id, "chunk": suffix, "full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀) diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 56292b81..8f809923 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -234,8 +234,9 @@ class LLMNode(BaseNode): rendered_prefix = self._render_template(end_prefix, state) logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'") - # 提前发送 End 节点的前缀 + # 提前发送 End 节点的前缀(使用 "message" 类型) writer({ + "type": "message", # End 相关的内容都是 message 类型 "node_id": "end", # 标记为 end 节点的输出 "chunk": rendered_prefix, "full_content": rendered_prefix, From fafbe72ce2ea57c4678204110a0382c33924a7cb Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 20 Dec 2025 17:45:58 +0800 Subject: [PATCH 5/5] [modify] sse format --- api/app/controllers/app_controller.py | 20 ++++- api/app/controllers/workflow_controller.py | 42 +++++++-- api/app/core/workflow/executor.py | 100 +++++++++++++++------ api/app/services/workflow_service.py | 64 ++++--------- 4 files changed, 139 insertions(+), 87 deletions(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index a92cfab2..29656608 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -583,15 +583,27 @@ async def draft_run( ) async def event_generator(): - """工作流事件生成器""" - - # 调用多智能体服务的流式方法 + """工作流事件生成器 + + 将事件转换为标准 SSE 格式: + event: + data: + """ + import json + + # 调用工作流服务的流式方法 async for event in workflow_service.run_stream( app_id=app_id, payload=payload, config=config ): - yield event + # 提取事件类型和数据 + event_type = event.get("event", "message") + event_data = event.get("data", {}) + + # 转换为标准 SSE 格式(字符串) + sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" + yield sse_message return StreamingResponse( event_generator(), diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py index 9ccfa858..91c21392 100644 --- a/api/app/controllers/workflow_controller.py +++ b/api/app/controllers/workflow_controller.py @@ -471,7 +471,20 @@ async def run_workflow( import json async def event_generator(): - """生成 SSE 事件""" + """生成 SSE 事件 + + SSE 格式: + event: + data: + + 支持的事件类型: + - workflow_start: 工作流开始 + - workflow_end: 工作流结束 + - node_start: 节点开始执行 + - node_end: 节点执行完成 + - node_chunk: 中间节点的流式输出 + - message: 最终消息的流式输出(End 节点及其相邻节点) + """ try: async for event in service.run_workflow( app_id=app_id, @@ -480,19 +493,30 @@ async def run_workflow( conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, stream=True ): - # 转换为 SSE 格式 - yield f"data: {json.dumps(event)}\n\n" + # 提取事件类型和数据 + event_type = event.get("event", "message") + event_data = event.get("data", {}) + + # 转换为标准 SSE 格式(字符串) + # event: + # data: + sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" + yield sse_message + except Exception as e: logger.error(f"流式执行异常: {e}", exc_info=True) - error_event = { - "type": "error", - "error": str(e) - } - yield f"data: {json.dumps(error_event)}\n\n" + # 发送错误事件 + sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" + yield sse_error return StreamingResponse( event_generator(), - media_type="text/event-stream" + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" # 禁用 nginx 缓冲 + } ) else: # 非流式执行 diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index db4fa626..029de97f 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -375,17 +375,32 @@ class WorkflowExecutor: 使用多个 stream_mode 来获取: 1. "updates" - 节点的 state 更新和流式 chunk 2. "debug" - 节点执行的详细信息(开始/完成时间) + 3. "custom" - 自定义流式数据(chunks) Args: input_data: 输入数据 Yields: - 流式事件 + 流式事件,格式: + { + "event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message", + "data": {...} + } """ logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") # 记录开始时间 start_time = datetime.datetime.now() + + # 发送 workflow_start 事件 + yield { + "event": "workflow_start", + "data": { + "execution_id": self.execution_id, + "workspace_id": self.workspace_id, + "timestamp": start_time.isoformat() + } + } # 1. 构建图 graph = self.build_graph(True) @@ -396,6 +411,8 @@ class WorkflowExecutor: # 3. Execute workflow try: chunk_count = 0 + final_state = None + async for event in graph.astream( initial_state, stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode @@ -412,16 +429,19 @@ class WorkflowExecutor: if mode == "custom": # Handle custom streaming events (chunks from nodes via stream writer) chunk_count += 1 - event_type = data.get("type", "node_chunk") # 默认为 node_chunk + event_type = data.get("type", "node_chunk") # "message" or "node_chunk" logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}") + yield { - "type": event_type, # "message" or "node_chunk" - "node_id": data.get("node_id"), - "chunk": data.get("chunk"), - "full_content": data.get("full_content"), - "chunk_index": data.get("chunk_index"), - "is_prefix": data.get("is_prefix"), - "is_suffix": data.get("is_suffix") + "event": event_type, # "message" or "node_chunk" + "data": { + "node_id": data.get("node_id"), + "chunk": data.get("chunk"), + "full_content": data.get("full_content"), + "chunk_index": data.get("chunk_index"), + "is_prefix": data.get("is_prefix"), + "is_suffix": data.get("is_suffix") + } } elif mode == "debug": @@ -438,12 +458,15 @@ class WorkflowExecutor: conversation_id = variables_sys.get("conversation_id") execution_id = variables_sys.get("execution_id") logger.info(f"[DEBUG] Node starts execution: {node_name}") + yield { - "type": "node_start", - "node_id": node_name, - "conversation_id": conversation_id, - "execution_id": execution_id, - "timestamp": data.get("timestamp") + "event": "node_start", + "data": { + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": execution_id, + "timestamp": data.get("timestamp") + } } elif event_type == "task_result": # Node execution completed @@ -454,19 +477,38 @@ class WorkflowExecutor: conversation_id = variables_sys.get("conversation_id") execution_id = variables_sys.get("execution_id") logger.info(f"[DEBUG] Node execution completed: {node_name}") + yield { - "type": "node_complete", - "node_id": node_name, - "conversation_id": conversation_id, - "execution_id": execution_id, - "timestamp": data.get("timestamp") + "event": "node_end", + "data": { + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": execution_id, + "timestamp": data.get("timestamp") + } } elif mode == "updates": - # Handle state updates + # Handle state updates - store final state logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}") + final_state = data - logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}") + # 计算耗时 + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s") + + # 发送 workflow_end 事件 + yield { + "event": "workflow_end", + "data": { + "execution_id": self.execution_id, + "status": "completed", + "elapsed_time": elapsed_time, + "timestamp": end_time.isoformat() + } + } except Exception as e: # 计算耗时(即使失败也记录) @@ -474,13 +516,17 @@ class WorkflowExecutor: elapsed_time = (end_time - start_time).total_seconds() logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) + + # 发送 workflow_end 事件(失败) yield { - "status": "failed", - "error": str(e), - "output": None, - "node_outputs": {}, - "elapsed_time": elapsed_time, - "token_usage": None + "event": "workflow_end", + "data": { + "execution_id": self.execution_id, + "status": "failed", + "error": str(e), + "elapsed_time": elapsed_time, + "timestamp": end_time.isoformat() + } } diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index b48edfdd..87f06c96 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -597,13 +597,7 @@ class WorkflowService: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") - # 发送开始事件 - yield format_sse_message("workflow_start", { - "execution_id": execution.execution_id, - "conversation_id_uuid": str(conversation_id_uuid), - }) - - # 调用流式执行 + # 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件) async for event in self._run_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, @@ -611,16 +605,8 @@ class WorkflowService: workspace_id="", user_id=payload.user_id ): - # 清理事件数据,移除不可序列化的对象 - cleaned_event = self._clean_event_for_json(event) - # 转换为 SSE 格式 - yield f"data: {json.dumps(cleaned_event)}\n\n" - - # 发送完成事件 - yield format_sse_message("workflow_end", { - "execution_id": execution.execution_id, - "conversation_id_uuid": str(conversation_id_uuid), - }) + # 直接转发 executor 的事件(已经是正确的格式) + yield event except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) @@ -630,7 +616,13 @@ class WorkflowService: error_message=str(e) ) # 发送错误事件 - yield f"data: {json.dumps({'type': 'error', 'execution_id': execution.execution_id, 'error': str(e)})}\n\n" + yield { + "event": "error", + "data": { + "execution_id": execution.execution_id, + "error": str(e) + } + } async def run_workflow( self, @@ -801,13 +793,11 @@ class WorkflowService: user_id: 用户 ID Yields: - 流式事件 + 流式事件(格式:{"event": "", "data": {...}}) """ from app.core.workflow.executor import execute_workflow_stream try: - output_data = {} - async for event in execute_workflow_stream( workflow_config=workflow_config, input_data=input_data, @@ -815,31 +805,9 @@ class WorkflowService: workspace_id=workspace_id, user_id=user_id ): - # 转发事件 + # 直接转发事件(executor 已经返回正确格式) yield event - # 收集输出数据 - # if event.get("type") == "node_complete": - # node_data = event.get("data", {}) - # node_outputs = node_data.get("node_outputs", {}) - # output_data.update(node_outputs) - # - # # 处理完成事件 - # if event.get("type") == "workflow_complete": - # self.update_execution_status( - # execution_id, - # "completed", - # output_data=output_data - # ) - # - # # 处理错误事件 - # if event.get("type") == "workflow_error": - # self.update_execution_status( - # execution_id, - # "failed", - # error_message=event.get("error") - # ) - except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True) self.update_execution_status( @@ -848,9 +816,11 @@ class WorkflowService: error_message=str(e) ) yield { - "type": "workflow_error", - "execution_id": execution_id, - "error": str(e) + "event": "error", + "data": { + "execution_id": execution_id, + "error": str(e) + } }