[fix] end stream output

This commit is contained in:
Mark
2025-12-24 12:37:50 +08:00
parent 9124a54b0f
commit 63d5047d21

View File

@@ -9,28 +9,29 @@ import re
import asyncio import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EndNode(BaseNode): class EndNode(BaseNode):
"""End 节点 """End 节点
工作流的结束节点,根据配置的模板输出最终结果。 工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。 支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
""" """
async def execute(self, state: WorkflowState) -> str: async def execute(self, state: WorkflowState) -> str:
"""执行 end 节点业务逻辑 """执行 end 节点业务逻辑
Args: Args:
state: 工作流状态 state: 工作流状态
Returns: Returns:
最终输出字符串 最终输出字符串
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行") logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
@@ -39,11 +40,11 @@ class EndNode(BaseNode):
output = self._render_template(output_template, state) output = self._render_template(output_template, state)
else: else:
output = "工作流已完成" output = "工作流已完成"
# 统计信息(用于日志) # 统计信息(用于日志)
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs) total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output return output
@@ -127,24 +128,26 @@ class EndNode(BaseNode):
return parts return parts
async def execute_stream(self, state: WorkflowState): async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑 """Execute End node business logic (streaming)
智能输出策略: Smart output strategy:
1. 检测模板中是否引用了直接上游节点 1. Check if template references a direct upstream LLM node
2. 如果引用了,只输出该引用**之后**的部分(后缀) 2. If yes, only output the part AFTER that reference (suffix)
3. 前缀和引用内容已经在上游节点流式输出时发送了 3. Prefix and LLM content have already been sent during LLM node streaming
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a' Note: Only LLM nodes get this special treatment. Other node types output normally.
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送 Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- LLM 内容在 LLM 节点流式输出 - Direct upstream LLM node is llm_qa
- End 节点只输出 ' lalalalala a'(后缀,一次性输出) - Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
- LLM content was streamed during LLM node execution
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
Args: Args:
state: 工作流状态 state: Workflow state
Yields: Yields:
完成标记 Completion marker
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
@@ -156,39 +159,45 @@ class EndNode(BaseNode):
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 找到直接上游节点 # Find direct upstream LLM nodes
direct_upstream_nodes = [] direct_upstream_llm_nodes = []
for edge in self.workflow_config.get("edges", []): for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id: if edge.get("target") == self.node_id:
source_node_id = edge.get("source") source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id) # Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
print("="*50)
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
break
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}") logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
# 解析模板部分 # Parse template parts
parts = self._parse_template_parts(output_template, state) parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分") logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts): for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}") logger.info(f"[模板解析] part[{i}]: {part}")
# 找到第一个引用直接上游节点的动态引用 # Find the first reference to a direct upstream LLM node
upstream_ref_index = None upstream_llm_ref_index = None
for i, part in enumerate(parts): for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes: if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
upstream_ref_index = i upstream_llm_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}") logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
break break
if upstream_ref_index is None: if upstream_llm_ref_index is None:
# 没有引用直接上游节点,输出完整模板内容 # No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state) output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'") logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# 通过 writer 发送完整内容(作为一个 message chunk # Send complete content via writer (as a single message chunk)
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
writer = get_stream_writer() writer = get_stream_writer()
writer({ writer({
"type": "message", # End 节点的输出使用 message 类型 "type": "message", # End node output uses message type
"node_id": self.node_id, "node_id": self.node_id,
"chunk": output, "chunk": output,
"full_content": output, "full_content": output,
@@ -197,17 +206,17 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容") logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
# yield 完成标记 # yield completion marker
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
# 有引用直接上游节点,只输出该引用之后的部分(后缀) # Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)") logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# 收集后缀部分 # Collect suffix parts
suffix_parts = [] suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1}{len(parts) - 1}") logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_ref_index + 1, len(parts)): for i in range(upstream_llm_ref_index + 1, len(parts)):
part = parts[i] part = parts[i]
logger.info(f"[后缀调试] 处理 part[{i}]: {part}") logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
if part["type"] == "static": if part["type"] == "static":
@@ -219,7 +228,7 @@ class EndNode(BaseNode):
# Other dynamic references (if there are multiple references) # Other dynamic references (if there are multiple references)
node_id = part["node_id"] node_id = part["node_id"]
field = part["field"] field = part["field"]
# Use VariablePool to get variable value # Use VariablePool to get variable value
pool = self.get_variable_pool(state) pool = self.get_variable_pool(state)
try: try:
@@ -232,7 +241,7 @@ class EndNode(BaseNode):
# Convert to string if not None # Convert to string if not None
suffix_parts.append(str(content) if content is not None else "") suffix_parts.append(str(content) if content is not None else "")
# 拼接后缀 # 拼接后缀
suffix = "".join(suffix_parts) suffix = "".join(suffix_parts)
@@ -261,8 +270,8 @@ class EndNode(BaseNode):
}) })
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}") logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else: else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}") logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空不发送upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息 # 统计信息
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs) total_nodes = len(node_outputs)