diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 8d67dd1e..992a8e1a 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -87,11 +87,75 @@ class WorkflowExecutor: "workspace_id": self.workspace_id, "user_id": self.user_id, "error": None, - "error_node": None + "error_node": None, + "streaming_buffer": {} # 流式缓冲区 } + def _analyze_end_node_prefixes(self) -> dict[str, str]: + """分析 End 节点的前缀配置 + + 检查每个 End 节点的模板,找到直接上游节点的引用, + 提取该引用之前的前缀部分。 + + Returns: + 字典:{上游节点ID: End节点前缀} + """ + import re + + prefixes = {} + + # 找到所有 End 节点 + end_nodes = [node for node in self.nodes if node.get("type") == "end"] + logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点") + + for end_node in end_nodes: + end_node_id = end_node.get("id") + output_template = end_node.get("config", {}).get("output") + + logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}") + + if not output_template: + continue + + # 找到所有直接连接到 End 节点的上游节点 + direct_upstream_nodes = [] + for edge in self.edges: + if edge.get("target") == end_node_id: + source_node_id = edge.get("source") + direct_upstream_nodes.append(source_node_id) + + logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}") + + # 查找模板中引用了哪些节点 + # 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格) + pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}' + matches = list(re.finditer(pattern, output_template)) + + logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用") + + # 找到第一个直接上游节点的引用 + for match in matches: + referenced_node_id = match.group(1) + logger.info(f"[前缀分析] 检查引用: {referenced_node_id}") + + if referenced_node_id in direct_upstream_nodes: + # 这是直接上游节点的引用,提取前缀 + prefix = output_template[:match.start()] + + logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'") + + if prefix: + prefixes[referenced_node_id] = prefix + logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'") + + # 只处理第一个直接上游节点的引用 + break + + logger.info(f"[前缀分析] 最终配置: {prefixes}") + return prefixes + def build_graph(self,stream=False) -> CompiledStateGraph: """构建 LangGraph @@ -99,6 +163,9 @@ class WorkflowExecutor: 编译后的状态图 """ logger.info(f"开始构建工作流图: execution_id={self.execution_id}") + + # 分析 End 节点的前缀配置 + end_prefixes = self._analyze_end_node_prefixes() if stream else {} # 1. 创建状态图 workflow = StateGraph(WorkflowState) @@ -120,6 +187,12 @@ class WorkflowExecutor: # 创建节点实例(现在 start 和 end 也会被创建) node_instance = NodeFactory.create_node(node, self.workflow_config) if node_instance: + # 如果是流式模式,且节点有 End 前缀配置,注入配置 + if stream and node_id in end_prefixes: + # 将 End 前缀配置注入到节点实例 + node_instance._end_node_prefix = end_prefixes[node_id] + logger.info(f"为节点 {node_id} 注入 End 前缀配置") + # 包装节点的 run 方法 # 使用函数工厂避免闭包问题 if stream: @@ -309,29 +382,48 @@ class WorkflowExecutor: # 2. 初始化状态(自动注入系统变量) initial_state = self._prepare_initial_state(input_data) - # 3. 执行工作流 + # 3. Execute workflow try: chunk_count = 0 async for event in graph.astream( initial_state, - stream_mode=["updates", "debug"], + stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode ): - mode, data = event + # event should be a tuple: (mode, data) + # But let's handle both cases + if isinstance(event, tuple) and len(event) == 2: + mode, data = event + else: + # Unexpected format, log and skip + logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}") + continue - if mode == "debug": - # 处理调试信息(节点执行状态) + if mode == "custom": + # Handle custom streaming events (chunks from nodes via stream writer) + chunk_count += 1 + logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}") + yield { + "type": "node_chunk", + "node_id": data.get("node_id"), + "chunk": data.get("chunk"), + "full_content": data.get("full_content"), + "chunk_index": data.get("chunk_index") + } + + elif mode == "debug": + # Handle debug information (node execution status) event_type = data.get("type") payload = data.get("payload", {}) node_name = payload.get("name") if event_type == "task": - # 节点开始执行 + # Node starts execution inputv = payload.get("input", {}) variables = inputv.get("variables", {}) variables_sys = variables.get("sys", {}) conversation_id = variables_sys.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] 节点开始执行: {node_name}") + logger.info(f"[DEBUG] Node starts execution: {node_name}") yield { "type": "node_start", "node_id": node_name, @@ -340,16 +432,16 @@ class WorkflowExecutor: "timestamp": data.get("timestamp") } elif event_type == "task_result": - # 节点执行完成 + # Node execution completed result = payload.get("result", {}) inputv = result.get("input", {}) variables = inputv.get("variables", {}) variables_sys = variables.get("sys", {}) conversation_id = variables_sys.get("conversation_id") execution_id = variables_sys.get("execution_id") - logger.info(f"[DEBUG] 节点执行完成: {node_name}") + logger.info(f"[DEBUG] Node execution completed: {node_name}") yield { - "type": "node_end", + "type": "node_complete", "node_id": node_name, "conversation_id": conversation_id, "execution_id": execution_id, @@ -357,27 +449,10 @@ class WorkflowExecutor: } elif mode == "updates": - # 处理 state 更新 - # data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk - print("="*50) - print(data) - print("-"*50) - for node_id, update in data.items(): - if isinstance(update, dict) and update.get("type") == "chunk": - # 这是流式 chunk,转发给客户端 - chunk_count += 1 - logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...") - yield { - "type": "node_chunk", - "node_id": update.get("node_id"), - "chunk": update.get("content"), - "full_content": update.get("full_content") - } - else: - logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}") - # 其他情况(state 更新)会被 LangGraph 自动合并到 state + # Handle state updates + logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}") - logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}") + logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}") except Exception as e: # 计算耗时(即使失败也记录) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 1d6f1c15..f2f18404 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod from typing import Any, TypedDict, Annotated from operator import add from langchain_core.messages import AnyMessage, HumanMessage, AIMessage +from langgraph.config import get_stream_writer from app.core.workflow.variable_pool import VariablePool @@ -43,6 +44,10 @@ class WorkflowState(TypedDict): # 错误信息(用于错误边) error: str | None error_node: str | None + + # 流式缓冲区(存储节点的实时流式输出) + # 格式:{node_id: {"chunks": [...], "full_content": "..."}} + streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}] class BaseNode(ABC): @@ -201,23 +206,25 @@ class BaseNode(ABC): return self._wrap_error(str(e), elapsed_time, state) async def run_stream(self, state: WorkflowState): - """执行节点(带错误处理和输出包装,流式) + """Execute node with error handling and output wrapping (streaming) - 这个方法由 Executor 调用,负责: - 1. 时间统计 - 2. 调用节点的 execute_stream() 方法 - 3. 将业务数据包装成标准输出格式 - 4. 错误处理 + This method is called by the Executor and is responsible for: + 1. Time tracking + 2. Calling the node's execute_stream() method + 3. Using LangGraph's stream writer to send chunks + 4. Updating streaming buffer in state for downstream nodes + 5. Wrapping business data into standard output format + 6. Error handling - 注意:在流式模式下,我们需要: - - yield 中间的 chunk 事件(用于实时显示) - - 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state) + Special handling for End nodes: + - End nodes don't send chunks via writer (prefix and LLM content already sent) + - End nodes only yield suffix for final result assembly Args: - state: 工作流状态 + state: Workflow state Yields: - 标准化的流式事件和最终的 state 更新 + State updates with streaming buffer and final result """ import time @@ -226,63 +233,102 @@ class BaseNode(ABC): try: timeout = self.get_timeout() - # 累积完整结果(用于最后的包装) + # Get LangGraph's stream writer for sending custom data + writer = get_stream_writer() + + # Check if this is an End node + # End nodes CAN send chunks (for suffix), but only after LLM content + is_end_node = self.node_type == "end" + + # Accumulate complete result (for final wrapping) chunks = [] final_result = None + chunk_count = 0 - # 使用异步生成器包装,支持超时 - async def stream_with_timeout(): - nonlocal final_result - loop_start = asyncio.get_event_loop().time() + # Stream chunks in real-time + loop_start = asyncio.get_event_loop().time() + + async for item in self.execute_stream(state): + # Check timeout + if asyncio.get_event_loop().time() - loop_start > timeout: + raise TimeoutError() - async for item in self.execute_stream(state): - # 检查超时 - if asyncio.get_event_loop().time() - loop_start > timeout: - raise TimeoutError() + # Check if it's a completion marker + if isinstance(item, dict) and item.get("__final__"): + final_result = item["result"] + elif isinstance(item, str): + # String is a chunk + chunk_count += 1 + chunks.append(item) + full_content = "".join(chunks) - # 检查是否是完成标记 - if isinstance(item, dict) and item.get("__final__"): - final_result = item["result"] - elif isinstance(item, str): - # 字符串是 chunk - # print("="*50) - # print(item) - # print("-"*50) - chunks.append(item) + # Send chunks for all nodes (including End nodes for suffix) + logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...") + + # 1. Send via stream writer (for real-time client updates) + writer({ + "node_id": self.node_id, + "chunk": item, + "full_content": full_content, + "chunk_index": chunk_count + }) + + # 2. Update streaming buffer in state (for downstream nodes) + # Only non-End nodes need streaming buffer + if not is_end_node: yield { - "type": "chunk", - "node_id": self.node_id, - "content": item, - "full_content": "".join(chunks) + "streaming_buffer": { + self.node_id: { + "full_content": full_content, + "chunk_count": chunk_count, + "is_complete": False + } + } } - else: - # 其他类型也当作 chunk 处理 - chunks.append(str(item)) + else: + # Other types are also treated as chunks + chunk_count += 1 + chunk_str = str(item) + chunks.append(chunk_str) + full_content = "".join(chunks) + + # Send chunks for all nodes + writer({ + "node_id": self.node_id, + "chunk": chunk_str, + "full_content": full_content, + "chunk_index": chunk_count + }) + + # Only non-End nodes need streaming buffer + if not is_end_node: yield { - "type": "chunk", - "node_id": self.node_id, - "content": str(item), - "full_content": "".join(chunks) + "streaming_buffer": { + self.node_id: { + "full_content": full_content, + "chunk_count": chunk_count, + "is_complete": False + } + } } - async for chunk_event in stream_with_timeout(): - yield chunk_event - elapsed_time = time.time() - start_time - # 提取处理后的输出(调用子类的 _extract_output) + logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}") + + # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) - # 包装最终结果 + # Wrap final result final_output = self._wrap_output(final_result, elapsed_time, state) - # 将提取后的输出存储到运行时变量中(供后续节点快速访问) + # Store extracted output in runtime variables (for quick access by subsequent nodes) if isinstance(extracted_output, dict): runtime_var = extracted_output else: runtime_var = {"output": extracted_output} - # 构建完整的 state 更新(包含 node_outputs 和 runtime_vars) + # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, "runtime_vars": { @@ -290,13 +336,24 @@ class BaseNode(ABC): } } - # 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中) + # Add streaming buffer for non-End nodes + if not is_end_node: + state_update["streaming_buffer"] = { + self.node_id: { + "full_content": "".join(chunks), + "chunk_count": chunk_count, + "is_complete": True # Mark as complete + } + } + + # Finally yield state update + # LangGraph will merge this into state yield state_update except TimeoutError: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") - error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)") + error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state) yield error_output except Exception as e: elapsed_time = time.time() - start_time diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index cba0d649..f47f3c1e 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -5,6 +5,8 @@ End 节点实现 """ import logging +import re +import asyncio from app.core.workflow.nodes.base_node import BaseNode, WorkflowState @@ -15,6 +17,7 @@ class EndNode(BaseNode): """End 节点 工作流的结束节点,根据配置的模板输出最终结果。 + 支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。 """ async def execute(self, state: WorkflowState) -> str: @@ -45,42 +48,209 @@ class EndNode(BaseNode): return output + def _extract_referenced_nodes(self, template: str) -> list[str]: + """从模板中提取引用的节点 ID + + 例如:'结果:{{llm_qa.output}}' -> ['llm_qa'] + + Args: + template: 模板字符串 + + Returns: + 引用的节点 ID 列表 + """ + # 匹配 {{node_id.xxx}} 格式 + pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' + matches = re.findall(pattern, template) + return list(set(matches)) # 去重 + + def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]: + """解析模板,分离静态文本和动态引用 + + 例如:'你好 {{llm.output}}, 这是后缀' + 返回:[ + {"type": "static", "content": "你好 "}, + {"type": "dynamic", "node_id": "llm", "field": "output"}, + {"type": "static", "content": ", 这是后缀"} + ] + + Args: + template: 模板字符串 + state: 工作流状态 + + Returns: + 模板部分列表 + """ + import re + + parts = [] + last_end = 0 + + # 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格) + pattern = r'\{\{\s*([^}]+?)\s*\}\}' + + for match in re.finditer(pattern, template): + start, end = match.span() + + # 添加前面的静态文本 + if start > last_end: + static_text = template[last_end:start] + if static_text: + parts.append({"type": "static", "content": static_text}) + + # 解析动态引用 + ref = match.group(1).strip() + + # 检查是否是节点引用(如 llm.output 或 llm_qa.output) + if '.' in ref: + node_id, field = ref.split('.', 1) + parts.append({ + "type": "dynamic", + "node_id": node_id, + "field": field, + "raw": ref + }) + else: + # 其他引用(如 {{var.xxx}}),当作静态处理 + # 直接渲染这部分 + rendered = self._render_template(f"{{{{{ref}}}}}", state) + parts.append({"type": "static", "content": rendered}) + + last_end = end + + # 添加最后的静态文本 + if last_end < len(template): + static_text = template[last_end:] + if static_text: + parts.append({"type": "static", "content": static_text}) + + return parts + async def execute_stream(self, state: WorkflowState): """流式执行 end 节点业务逻辑 - 当 end 节点前面是 LLM 节点时,流式输出其内容。 + 智能输出策略: + 1. 检测模板中是否引用了直接上游节点 + 2. 如果引用了,只输出该引用**之后**的部分(后缀) + 3. 前缀和引用内容已经在上游节点流式输出时发送了 + + 示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' + - 直接上游节点是 llm_qa + - 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 + - LLM 内容在 LLM 节点流式输出 + - End 节点只输出 ' lalalalala a'(后缀,一次性输出) Args: state: 工作流状态 Yields: - 文本片段(chunk)或完成标记 + 完成标记 """ logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") # 获取配置的输出模板 output_template = self.config.get("output") - # 如果配置了输出模板,使用模板渲染 - if output_template: - output = self._render_template(output_template, state) - - # 检查输出中是否包含节点引用(如 {{llm_node.output}}) - # 如果包含,则逐字符流式输出 - if output: - # 逐字符流式输出 - for char in output: - yield char - else: + if not output_template: output = "工作流已完成" - for char in output: - yield char + yield {"__final__": True, "result": output} + return - # 统计信息(用于日志) + # 找到直接上游节点 + direct_upstream_nodes = [] + for edge in self.workflow_config.get("edges", []): + if edge.get("target") == self.node_id: + source_node_id = edge.get("source") + direct_upstream_nodes.append(source_node_id) + + logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") + + # 解析模板部分 + parts = self._parse_template_parts(output_template, state) + logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") + + # 找到第一个引用直接上游节点的动态引用 + upstream_ref_index = None + for i, part in enumerate(parts): + if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes: + upstream_ref_index = i + logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") + break + + if upstream_ref_index is None: + # 没有引用直接上游节点,正常输出(渲染完整模板) + output = self._render_template(output_template, state) + logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容") + yield {"__final__": True, "result": output} + return + + # 有引用直接上游节点,只输出该引用之后的部分(后缀) + logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") + + # 收集后缀部分 + suffix_parts = [] + for i in range(upstream_ref_index + 1, len(parts)): + part = parts[i] + + if part["type"] == "static": + # 静态文本 + suffix_parts.append(part["content"]) + + elif part["type"] == "dynamic": + # 其他动态引用(如果有多个引用) + node_id = part["node_id"] + field = part["field"] + + # 从 streaming_buffer 或 node_outputs 读取 + streaming_buffer = state.get("streaming_buffer", {}) + if node_id in streaming_buffer: + buffer_data = streaming_buffer[node_id] + content = buffer_data.get("full_content", "") + else: + node_outputs = state.get("node_outputs", {}) + runtime_vars = state.get("runtime_vars", {}) + + content = "" + if node_id in node_outputs: + node_output = node_outputs[node_id] + if isinstance(node_output, dict): + content = str(node_output.get(field, "")) + elif node_id in runtime_vars: + runtime_var = runtime_vars[node_id] + if isinstance(runtime_var, dict): + content = str(runtime_var.get(field, "")) + + suffix_parts.append(content) + + # 拼接后缀 + suffix = "".join(suffix_parts) + + # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) + full_output = self._render_template(output_template, state) + + if suffix: + logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})") + # 一次性输出后缀(作为单个 chunk) + # 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理 + # 而是通过 writer 直接发送 + from langgraph.config import get_stream_writer + writer = get_stream_writer() + writer({ + "node_id": self.node_id, + "chunk": suffix, + "full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀) + "chunk_index": 1, + "is_suffix": True + }) + logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}") + else: + logger.info(f"节点 {self.node_id} 没有后缀需要输出") + + # 统计信息 node_outputs = state.get("node_outputs", {}) total_nodes = len(node_outputs) - logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点") + logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点") - # yield 完成标记 - yield {"__final__": True, "result": output} + # yield 完成标记(包含完整输出) + yield {"__final__": True, "result": full_output} diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index bac707d7..56292b81 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -213,18 +213,44 @@ class LLMNode(BaseNode): Yields: 文本片段(chunk)或完成标记 """ + from langgraph.config import get_stream_writer + llm, prompt_or_messages = self._prepare_llm(state, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") + # 检查是否有注入的 End 节点前缀配置 + writer = get_stream_writer() + end_prefix = getattr(self, '_end_node_prefix', None) + + logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}") + if end_prefix: + logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'") + + if end_prefix: + # 渲染前缀(可能包含其他变量) + try: + rendered_prefix = self._render_template(end_prefix, state) + logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'") + + # 提前发送 End 节点的前缀 + writer({ + "node_id": "end", # 标记为 end 节点的输出 + "chunk": rendered_prefix, + "full_content": rendered_prefix, + "chunk_index": 0, + "is_prefix": True # 标记这是前缀 + }) + except Exception as e: + logger.warning(f"渲染/发送 End 节点前缀失败: {e}") + # 累积完整响应 full_response = "" last_chunk = None chunk_count = 0 # 调用 LLM(流式,支持字符串或消息列表) - # 注意:astream 方法本身就是流式的,不需要额外配置 async for chunk in llm.astream(prompt_or_messages): # 提取内容 if hasattr(chunk, 'content'): @@ -238,9 +264,8 @@ class LLMNode(BaseNode): last_chunk = chunk chunk_count += 1 - # logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...") # 流式返回每个文本片段 - yield content #AIMessage(content=content) + yield content logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")