[fix] end stream output

This commit is contained in:
Mark
2025-12-24 12:37:50 +08:00
parent 9124a54b0f
commit 63d5047d21

View File

@@ -9,28 +9,29 @@ import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
async def execute(self, state: WorkflowState) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
Returns:
最终输出字符串
"""
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板
output_template = self.config.get("output")
@@ -39,11 +40,11 @@ class EndNode(BaseNode):
output = self._render_template(output_template, state)
else:
output = "工作流已完成"
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output
@@ -127,24 +128,26 @@ class EndNode(BaseNode):
return parts
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
"""Execute End node business logic (streaming)
智能输出策略:
1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了
Smart output strategy:
1. Check if template references a direct upstream LLM node
2. If yes, only output the part AFTER that reference (suffix)
3. Prefix and LLM content have already been sent during LLM node streaming
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Note: Only LLM nodes get this special treatment. Other node types output normally.
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- Direct upstream LLM node is llm_qa
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
- LLM content was streamed during LLM node execution
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
Args:
state: 工作流状态
state: Workflow state
Yields:
完成标记
Completion marker
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
@@ -156,39 +159,45 @@ class EndNode(BaseNode):
yield {"__final__": True, "result": output}
return
# 找到直接上游节点
direct_upstream_nodes = []
# Find direct upstream LLM nodes
direct_upstream_llm_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
# Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
break
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
# 解析模板部分
# Parse template parts
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}")
# 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None
# Find the first reference to a direct upstream LLM node
upstream_llm_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
upstream_llm_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_ref_index is None:
# 没有引用直接上游节点,输出完整模板内容
if upstream_llm_ref_index is None:
# No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# 通过 writer 发送完整内容(作为一个 message chunk
# Send complete content via writer (as a single message chunk)
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": output,
"full_content": output,
@@ -197,17 +206,17 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
# yield 完成标记
# yield completion marker
yield {"__final__": True, "result": output}
return
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# 收集后缀部分
# Collect suffix parts
suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_ref_index + 1, len(parts)):
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_llm_ref_index + 1, len(parts)):
part = parts[i]
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
if part["type"] == "static":
@@ -219,7 +228,7 @@ class EndNode(BaseNode):
# Other dynamic references (if there are multiple references)
node_id = part["node_id"]
field = part["field"]
# Use VariablePool to get variable value
pool = self.get_variable_pool(state)
try:
@@ -232,7 +241,7 @@ class EndNode(BaseNode):
# Convert to string if not None
suffix_parts.append(str(content) if content is not None else "")
# 拼接后缀
suffix = "".join(suffix_parts)
@@ -261,8 +270,8 @@ class EndNode(BaseNode):
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}")
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)