[add] workflow llm & end logic

This commit is contained in:
Mark
2025-12-20 17:25:47 +08:00
parent d8fcea8564
commit 36b36b729b
4 changed files with 430 additions and 103 deletions

View File

@@ -87,11 +87,75 @@ class WorkflowExecutor:
"workspace_id": self.workspace_id, "workspace_id": self.workspace_id,
"user_id": self.user_id, "user_id": self.user_id,
"error": None, "error": None,
"error_node": None "error_node": None,
"streaming_buffer": {} # 流式缓冲区
} }
def _analyze_end_node_prefixes(self) -> dict[str, str]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
字典:{上游节点ID: End节点前缀}
"""
import re
prefixes = {}
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
return prefixes
def build_graph(self,stream=False) -> CompiledStateGraph: def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph """构建 LangGraph
@@ -99,6 +163,9 @@ class WorkflowExecutor:
编译后的状态图 编译后的状态图
""" """
logger.info(f"开始构建工作流图: execution_id={self.execution_id}") logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置
end_prefixes = self._analyze_end_node_prefixes() if stream else {}
# 1. 创建状态图 # 1. 创建状态图
workflow = StateGraph(WorkflowState) workflow = StateGraph(WorkflowState)
@@ -120,6 +187,12 @@ class WorkflowExecutor:
# 创建节点实例(现在 start 和 end 也会被创建) # 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance: if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 包装节点的 run 方法 # 包装节点的 run 方法
# 使用函数工厂避免闭包问题 # 使用函数工厂避免闭包问题
if stream: if stream:
@@ -309,29 +382,48 @@ class WorkflowExecutor:
# 2. 初始化状态(自动注入系统变量) # 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data) initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流 # 3. Execute workflow
try: try:
chunk_count = 0 chunk_count = 0
async for event in graph.astream( async for event in graph.astream(
initial_state, initial_state,
stream_mode=["updates", "debug"], stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
): ):
mode, data = event # event should be a tuple: (mode, data)
# But let's handle both cases
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
# Unexpected format, log and skip
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
continue
if mode == "debug": if mode == "custom":
# 处理调试信息(节点执行状态) # Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}")
yield {
"type": "node_chunk",
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index")
}
elif mode == "debug":
# Handle debug information (node execution status)
event_type = data.get("type") event_type = data.get("type")
payload = data.get("payload", {}) payload = data.get("payload", {})
node_name = payload.get("name") node_name = payload.get("name")
if event_type == "task": if event_type == "task":
# 节点开始执行 # Node starts execution
inputv = payload.get("input", {}) inputv = payload.get("input", {})
variables = inputv.get("variables", {}) variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {}) variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id") conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id") execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点开始执行: {node_name}") logger.info(f"[DEBUG] Node starts execution: {node_name}")
yield { yield {
"type": "node_start", "type": "node_start",
"node_id": node_name, "node_id": node_name,
@@ -340,16 +432,16 @@ class WorkflowExecutor:
"timestamp": data.get("timestamp") "timestamp": data.get("timestamp")
} }
elif event_type == "task_result": elif event_type == "task_result":
# 节点执行完成 # Node execution completed
result = payload.get("result", {}) result = payload.get("result", {})
inputv = result.get("input", {}) inputv = result.get("input", {})
variables = inputv.get("variables", {}) variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {}) variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id") conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id") execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点执行完成: {node_name}") logger.info(f"[DEBUG] Node execution completed: {node_name}")
yield { yield {
"type": "node_end", "type": "node_complete",
"node_id": node_name, "node_id": node_name,
"conversation_id": conversation_id, "conversation_id": conversation_id,
"execution_id": execution_id, "execution_id": execution_id,
@@ -357,27 +449,10 @@ class WorkflowExecutor:
} }
elif mode == "updates": elif mode == "updates":
# 处理 state 更新 # Handle state updates
# data 是一个字典key 是节点 IDvalue 是 state 更新或 chunk logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
print("="*50)
print(data)
print("-"*50)
for node_id, update in data.items():
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
chunk_count += 1
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content"),
"full_content": update.get("full_content")
}
else:
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
# 其他情况state 更新)会被 LangGraph 自动合并到 state
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}") logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}")
except Exception as e: except Exception as e:
# 计算耗时(即使失败也记录) # 计算耗时(即使失败也记录)

View File

@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated from typing import Any, TypedDict, Annotated
from operator import add from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.config import get_stream_writer
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
# 错误信息(用于错误边) # 错误信息(用于错误边)
error: str | None error: str | None
error_node: str | None error_node: str | None
# 流式缓冲区(存储节点的实时流式输出)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
class BaseNode(ABC): class BaseNode(ABC):
@@ -201,23 +206,25 @@ class BaseNode(ABC):
return self._wrap_error(str(e), elapsed_time, state) return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState): async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式) """Execute node with error handling and output wrapping (streaming)
这个方法由 Executor 调用,负责: This method is called by the Executor and is responsible for:
1. 时间统计 1. Time tracking
2. 调用节点的 execute_stream() 方法 2. Calling the node's execute_stream() method
3. 将业务数据包装成标准输出格式 3. Using LangGraph's stream writer to send chunks
4. 错误处理 4. Updating streaming buffer in state for downstream nodes
5. Wrapping business data into standard output format
6. Error handling
注意:在流式模式下,我们需要: Special handling for End nodes:
- yield 中间的 chunk 事件(用于实时显示) - End nodes don't send chunks via writer (prefix and LLM content already sent)
- 最后 yield 一个包含 state 更新的字典LangGraph 会合并到 state - End nodes only yield suffix for final result assembly
Args: Args:
state: 工作流状态 state: Workflow state
Yields: Yields:
标准化的流式事件和最终的 state 更新 State updates with streaming buffer and final result
""" """
import time import time
@@ -226,63 +233,102 @@ class BaseNode(ABC):
try: try:
timeout = self.get_timeout() timeout = self.get_timeout()
# 累积完整结果(用于最后的包装) # Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()
# Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Accumulate complete result (for final wrapping)
chunks = [] chunks = []
final_result = None final_result = None
chunk_count = 0
# 使用异步生成器包装,支持超时 # Stream chunks in real-time
async def stream_with_timeout(): loop_start = asyncio.get_event_loop().time()
nonlocal final_result
loop_start = asyncio.get_event_loop().time() async for item in self.execute_stream(state):
# Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
async for item in self.execute_stream(state): # Check if it's a completion marker
# 检查超时 if isinstance(item, dict) and item.get("__final__"):
if asyncio.get_event_loop().time() - loop_start > timeout: final_result = item["result"]
raise TimeoutError() elif isinstance(item, str):
# String is a chunk
chunk_count += 1
chunks.append(item)
full_content = "".join(chunks)
# 检查是否是完成标记 # Send chunks for all nodes (including End nodes for suffix)
if isinstance(item, dict) and item.get("__final__"): logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
final_result = item["result"]
elif isinstance(item, str): # 1. Send via stream writer (for real-time client updates)
# 字符串是 chunk writer({
# print("="*50) "node_id": self.node_id,
# print(item) "chunk": item,
# print("-"*50) "full_content": full_content,
chunks.append(item) "chunk_index": chunk_count
})
# 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer
if not is_end_node:
yield { yield {
"type": "chunk", "streaming_buffer": {
"node_id": self.node_id, self.node_id: {
"content": item, "full_content": full_content,
"full_content": "".join(chunks) "chunk_count": chunk_count,
"is_complete": False
}
}
} }
else: else:
# 其他类型也当作 chunk 处理 # Other types are also treated as chunks
chunks.append(str(item)) chunk_count += 1
chunk_str = str(item)
chunks.append(chunk_str)
full_content = "".join(chunks)
# Send chunks for all nodes
writer({
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,
"chunk_index": chunk_count
})
# Only non-End nodes need streaming buffer
if not is_end_node:
yield { yield {
"type": "chunk", "streaming_buffer": {
"node_id": self.node_id, self.node_id: {
"content": str(item), "full_content": full_content,
"full_content": "".join(chunks) "chunk_count": chunk_count,
"is_complete": False
}
}
} }
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result) extracted_output = self._extract_output(final_result)
# 包装最终结果 # Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state) final_output = self._wrap_output(final_result, elapsed_time, state)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问) # Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict): if isinstance(extracted_output, dict):
runtime_var = extracted_output runtime_var = extracted_output
else: else:
runtime_var = {"output": extracted_output} runtime_var = {"output": extracted_output}
# 构建完整的 state 更新(包含 node_outputs runtime_vars # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = { state_update = {
**final_output, **final_output,
"runtime_vars": { "runtime_vars": {
@@ -290,13 +336,24 @@ class BaseNode(ABC):
} }
} }
# 最后 yield 纯粹的 state 更新LangGraph 会合并到 state 中) # Add streaming buffer for non-End nodes
if not is_end_node:
state_update["streaming_buffer"] = {
self.node_id: {
"full_content": "".join(chunks),
"chunk_count": chunk_count,
"is_complete": True # Mark as complete
}
}
# Finally yield state update
# LangGraph will merge this into state
yield state_update yield state_update
except TimeoutError: except TimeoutError:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时{timeout}秒)") logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
error_output = self._wrap_error(f"节点执行超时{timeout}秒)", elapsed_time, state) error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
yield error_output yield error_output
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time

View File

@@ -5,6 +5,8 @@ End 节点实现
""" """
import logging import logging
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -15,6 +17,7 @@ class EndNode(BaseNode):
"""End 节点 """End 节点
工作流的结束节点,根据配置的模板输出最终结果。 工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
""" """
async def execute(self, state: WorkflowState) -> str: async def execute(self, state: WorkflowState) -> str:
@@ -45,42 +48,209 @@ class EndNode(BaseNode):
return output return output
def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args:
template: 模板字符串
Returns:
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀'
返回:[
{"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"}
]
Args:
template: 模板字符串
state: 工作流状态
Returns:
模板部分列表
"""
import re
parts = []
last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template):
start, end = match.span()
# 添加前面的静态文本
if start > last_end:
static_text = template[last_end:start]
if static_text:
parts.append({"type": "static", "content": static_text})
# 解析动态引用
ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref:
node_id, field = ref.split('.', 1)
parts.append({
"type": "dynamic",
"node_id": node_id,
"field": field,
"raw": ref
})
else:
# 其他引用(如 {{var.xxx}}),当作静态处理
# 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered})
last_end = end
# 添加最后的静态文本
if last_end < len(template):
static_text = template[last_end:]
if static_text:
parts.append({"type": "static", "content": static_text})
return parts
async def execute_stream(self, state: WorkflowState): async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑 """流式执行 end 节点业务逻辑
当 end 节点前面是 LLM 节点时,流式输出其内容。 智能输出策略:
1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Args: Args:
state: 工作流状态 state: 工作流状态
Yields: Yields:
文本片段chunk完成标记 完成标记
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)") logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
# 如果配置了输出模板,使用模板渲染 if not output_template:
if output_template:
output = self._render_template(output_template, state)
# 检查输出中是否包含节点引用(如 {{llm_node.output}}
# 如果包含,则逐字符流式输出
if output:
# 逐字符流式输出
for char in output:
yield char
else:
output = "工作流已完成" output = "工作流已完成"
for char in output: yield {"__final__": True, "result": output}
yield char return
# 统计信息(用于日志) # 找到直接上游节点
direct_upstream_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
# 解析模板部分
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
# 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_ref_index is None:
# 没有引用直接上游节点,正常输出(渲染完整模板)
output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容")
yield {"__final__": True, "result": output}
return
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# 收集后缀部分
suffix_parts = []
for i in range(upstream_ref_index + 1, len(parts)):
part = parts[i]
if part["type"] == "static":
# 静态文本
suffix_parts.append(part["content"])
elif part["type"] == "dynamic":
# 其他动态引用(如果有多个引用)
node_id = part["node_id"]
field = part["field"]
# 从 streaming_buffer 或 node_outputs 读取
streaming_buffer = state.get("streaming_buffer", {})
if node_id in streaming_buffer:
buffer_data = streaming_buffer[node_id]
content = buffer_data.get("full_content", "")
else:
node_outputs = state.get("node_outputs", {})
runtime_vars = state.get("runtime_vars", {})
content = ""
if node_id in node_outputs:
node_output = node_outputs[node_id]
if isinstance(node_output, dict):
content = str(node_output.get(field, ""))
elif node_id in runtime_vars:
runtime_var = runtime_vars[node_id]
if isinstance(runtime_var, dict):
content = str(runtime_var.get(field, ""))
suffix_parts.append(content)
# 拼接后缀
suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state)
if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
# 而是通过 writer 直接发送
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
"chunk_index": 1,
"is_suffix": True
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
# 统计信息
node_outputs = state.get("node_outputs", {}) node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs) total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点") logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
# yield 完成标记 # yield 完成标记(包含完整输出)
yield {"__final__": True, "result": output} yield {"__final__": True, "result": full_output}

View File

@@ -213,18 +213,44 @@ class LLMNode(BaseNode):
Yields: Yields:
文本片段chunk或完成标记 文本片段chunk或完成标记
""" """
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True) llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀
writer({
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,
"chunk_index": 0,
"is_prefix": True # 标记这是前缀
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
last_chunk = None last_chunk = None
chunk_count = 0 chunk_count = 0
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
# 注意astream 方法本身就是流式的,不需要额外配置
async for chunk in llm.astream(prompt_or_messages): async for chunk in llm.astream(prompt_or_messages):
# 提取内容 # 提取内容
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
@@ -238,9 +264,8 @@ class LLMNode(BaseNode):
last_chunk = chunk last_chunk = chunk
chunk_count += 1 chunk_count += 1
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
# 流式返回每个文本片段 # 流式返回每个文本片段
yield content #AIMessage(content=content) yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")