[fix] end stream output
This commit is contained in:
@@ -9,28 +9,29 @@ import re
|
||||
import asyncio
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndNode(BaseNode):
|
||||
"""End 节点
|
||||
|
||||
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||
"""
|
||||
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
"""执行 end 节点业务逻辑
|
||||
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
|
||||
Returns:
|
||||
最终输出字符串
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
||||
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
@@ -39,11 +40,11 @@ class EndNode(BaseNode):
|
||||
output = self._render_template(output_template, state)
|
||||
else:
|
||||
output = "工作流已完成"
|
||||
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
return output
|
||||
@@ -127,24 +128,26 @@ class EndNode(BaseNode):
|
||||
return parts
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
"""Execute End node business logic (streaming)
|
||||
|
||||
智能输出策略:
|
||||
1. 检测模板中是否引用了直接上游节点
|
||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||
Smart output strategy:
|
||||
1. Check if template references a direct upstream LLM node
|
||||
2. If yes, only output the part AFTER that reference (suffix)
|
||||
3. Prefix and LLM content have already been sent during LLM node streaming
|
||||
|
||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- 直接上游节点是 llm_qa
|
||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||
- LLM 内容在 LLM 节点流式输出
|
||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||
Note: Only LLM nodes get this special treatment. Other node types output normally.
|
||||
|
||||
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- Direct upstream LLM node is llm_qa
|
||||
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
|
||||
- LLM content was streamed during LLM node execution
|
||||
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
state: Workflow state
|
||||
|
||||
Yields:
|
||||
完成标记
|
||||
Completion marker
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
@@ -156,39 +159,45 @@ class EndNode(BaseNode):
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 找到直接上游节点
|
||||
direct_upstream_nodes = []
|
||||
# Find direct upstream LLM nodes
|
||||
direct_upstream_llm_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
# Check if the source node is an LLM node
|
||||
for node in self.workflow_config.get("nodes", []):
|
||||
print("="*50)
|
||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||
direct_upstream_llm_nodes.append(source_node_id)
|
||||
break
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
|
||||
|
||||
# 解析模板部分
|
||||
# Parse template parts
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
for i, part in enumerate(parts):
|
||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
||||
|
||||
# 找到第一个引用直接上游节点的动态引用
|
||||
upstream_ref_index = None
|
||||
# Find the first reference to a direct upstream LLM node
|
||||
upstream_llm_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
|
||||
upstream_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
|
||||
upstream_llm_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
if upstream_ref_index is None:
|
||||
# 没有引用直接上游节点,输出完整模板内容
|
||||
if upstream_llm_ref_index is None:
|
||||
# No reference to direct upstream LLM node, output complete template content
|
||||
output = self._render_template(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
# 通过 writer 发送完整内容(作为一个 message chunk)
|
||||
# Send complete content via writer (as a single message chunk)
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End 节点的输出使用 message 类型
|
||||
"type": "message", # End node output uses message type
|
||||
"node_id": self.node_id,
|
||||
"chunk": output,
|
||||
"full_content": output,
|
||||
@@ -197,17 +206,17 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
# yield 完成标记
|
||||
# yield completion marker
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
|
||||
# 收集后缀部分
|
||||
# Collect suffix parts
|
||||
suffix_parts = []
|
||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}")
|
||||
for i in range(upstream_ref_index + 1, len(parts)):
|
||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1} 到 {len(parts) - 1}")
|
||||
for i in range(upstream_llm_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
||||
if part["type"] == "static":
|
||||
@@ -219,7 +228,7 @@ class EndNode(BaseNode):
|
||||
# Other dynamic references (if there are multiple references)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
|
||||
# Use VariablePool to get variable value
|
||||
pool = self.get_variable_pool(state)
|
||||
try:
|
||||
@@ -232,7 +241,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# Convert to string if not None
|
||||
suffix_parts.append(str(content) if content is not None else "")
|
||||
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
@@ -261,8 +270,8 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
Reference in New Issue
Block a user