[add] workflow llm & end logic
This commit is contained in:
@@ -87,11 +87,75 @@ class WorkflowExecutor:
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None
|
||||
"error_node": None,
|
||||
"streaming_buffer": {} # 流式缓冲区
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> dict[str, str]:
|
||||
"""分析 End 节点的前缀配置
|
||||
|
||||
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||
提取该引用之前的前缀部分。
|
||||
|
||||
Returns:
|
||||
字典:{上游节点ID: End节点前缀}
|
||||
"""
|
||||
import re
|
||||
|
||||
prefixes = {}
|
||||
|
||||
# 找到所有 End 节点
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
|
||||
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
continue
|
||||
|
||||
# 找到所有直接连接到 End 节点的上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
# 查找模板中引用了哪些节点
|
||||
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
|
||||
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
|
||||
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||
return prefixes
|
||||
|
||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
@@ -99,6 +163,9 @@ class WorkflowExecutor:
|
||||
编译后的状态图
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
|
||||
# 分析 End 节点的前缀配置
|
||||
end_prefixes = self._analyze_end_node_prefixes() if stream else {}
|
||||
|
||||
# 1. 创建状态图
|
||||
workflow = StateGraph(WorkflowState)
|
||||
@@ -120,6 +187,12 @@ class WorkflowExecutor:
|
||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
if node_instance:
|
||||
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
||||
if stream and node_id in end_prefixes:
|
||||
# 将 End 前缀配置注入到节点实例
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
if stream:
|
||||
@@ -309,29 +382,48 @@ class WorkflowExecutor:
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 执行工作流
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug"],
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
):
|
||||
mode, data = event
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
# Unexpected format, log and skip
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
|
||||
continue
|
||||
|
||||
if mode == "debug":
|
||||
# 处理调试信息(节点执行状态)
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}")
|
||||
yield {
|
||||
"type": "node_chunk",
|
||||
"node_id": data.get("node_id"),
|
||||
"chunk": data.get("chunk"),
|
||||
"full_content": data.get("full_content"),
|
||||
"chunk_index": data.get("chunk_index")
|
||||
}
|
||||
|
||||
elif mode == "debug":
|
||||
# Handle debug information (node execution status)
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if event_type == "task":
|
||||
# 节点开始执行
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] 节点开始执行: {node_name}")
|
||||
logger.info(f"[DEBUG] Node starts execution: {node_name}")
|
||||
yield {
|
||||
"type": "node_start",
|
||||
"node_id": node_name,
|
||||
@@ -340,16 +432,16 @@ class WorkflowExecutor:
|
||||
"timestamp": data.get("timestamp")
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# 节点执行完成
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
inputv = result.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] 节点执行完成: {node_name}")
|
||||
logger.info(f"[DEBUG] Node execution completed: {node_name}")
|
||||
yield {
|
||||
"type": "node_end",
|
||||
"type": "node_complete",
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
@@ -357,27 +449,10 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
elif mode == "updates":
|
||||
# 处理 state 更新
|
||||
# data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
||||
print("="*50)
|
||||
print(data)
|
||||
print("-"*50)
|
||||
for node_id, update in data.items():
|
||||
if isinstance(update, dict) and update.get("type") == "chunk":
|
||||
# 这是流式 chunk,转发给客户端
|
||||
chunk_count += 1
|
||||
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
|
||||
yield {
|
||||
"type": "node_chunk",
|
||||
"node_id": update.get("node_id"),
|
||||
"chunk": update.get("content"),
|
||||
"full_content": update.get("full_content")
|
||||
}
|
||||
else:
|
||||
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
|
||||
# 其他情况(state 更新)会被 LangGraph 自动合并到 state
|
||||
# Handle state updates
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||
|
||||
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}")
|
||||
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}")
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
|
||||
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict, Annotated
|
||||
from operator import add
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
|
||||
# 错误信息(用于错误边)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# 流式缓冲区(存储节点的实时流式输出)
|
||||
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
@@ -201,23 +206,25 @@ class BaseNode(ABC):
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
|
||||
async def run_stream(self, state: WorkflowState):
|
||||
"""执行节点(带错误处理和输出包装,流式)
|
||||
"""Execute node with error handling and output wrapping (streaming)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute_stream() 方法
|
||||
3. 将业务数据包装成标准输出格式
|
||||
4. 错误处理
|
||||
This method is called by the Executor and is responsible for:
|
||||
1. Time tracking
|
||||
2. Calling the node's execute_stream() method
|
||||
3. Using LangGraph's stream writer to send chunks
|
||||
4. Updating streaming buffer in state for downstream nodes
|
||||
5. Wrapping business data into standard output format
|
||||
6. Error handling
|
||||
|
||||
注意:在流式模式下,我们需要:
|
||||
- yield 中间的 chunk 事件(用于实时显示)
|
||||
- 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state)
|
||||
Special handling for End nodes:
|
||||
- End nodes don't send chunks via writer (prefix and LLM content already sent)
|
||||
- End nodes only yield suffix for final result assembly
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
state: Workflow state
|
||||
|
||||
Yields:
|
||||
标准化的流式事件和最终的 state 更新
|
||||
State updates with streaming buffer and final result
|
||||
"""
|
||||
import time
|
||||
|
||||
@@ -226,63 +233,102 @@ class BaseNode(ABC):
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 累积完整结果(用于最后的包装)
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
# Check if this is an End node
|
||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||
is_end_node = self.node_type == "end"
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
final_result = None
|
||||
chunk_count = 0
|
||||
|
||||
# 使用异步生成器包装,支持超时
|
||||
async def stream_with_timeout():
|
||||
nonlocal final_result
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
# Stream chunks in real-time
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# Check timeout
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# 检查超时
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
# Check if it's a completion marker
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# String is a chunk
|
||||
chunk_count += 1
|
||||
chunks.append(item)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# 检查是否是完成标记
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# 字符串是 chunk
|
||||
# print("="*50)
|
||||
# print(item)
|
||||
# print("-"*50)
|
||||
chunks.append(item)
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
"node_id": self.node_id,
|
||||
"chunk": item,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# 2. Update streaming buffer in state (for downstream nodes)
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": item,
|
||||
"full_content": "".join(chunks)
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 其他类型也当作 chunk 处理
|
||||
chunks.append(str(item))
|
||||
else:
|
||||
# Other types are also treated as chunks
|
||||
chunk_count += 1
|
||||
chunk_str = str(item)
|
||||
chunks.append(chunk_str)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# Send chunks for all nodes
|
||||
writer({
|
||||
"node_id": self.node_id,
|
||||
"chunk": chunk_str,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": str(item),
|
||||
"full_content": "".join(chunks)
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async for chunk_event in stream_with_timeout():
|
||||
yield chunk_event
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取处理后的输出(调用子类的 _extract_output)
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
|
||||
# 包装最终结果
|
||||
# Wrap final result
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
|
||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
|
||||
# 构建完整的 state 更新(包含 node_outputs 和 runtime_vars)
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"runtime_vars": {
|
||||
@@ -290,13 +336,24 @@ class BaseNode(ABC):
|
||||
}
|
||||
}
|
||||
|
||||
# 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中)
|
||||
# Add streaming buffer for non-End nodes
|
||||
if not is_end_node:
|
||||
state_update["streaming_buffer"] = {
|
||||
self.node_id: {
|
||||
"full_content": "".join(chunks),
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": True # Mark as complete
|
||||
}
|
||||
}
|
||||
|
||||
# Finally yield state update
|
||||
# LangGraph will merge this into state
|
||||
yield state_update
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
||||
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
@@ -5,6 +5,8 @@ End 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
@@ -15,6 +17,7 @@ class EndNode(BaseNode):
|
||||
"""End 节点
|
||||
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
@@ -45,42 +48,209 @@ class EndNode(BaseNode):
|
||||
|
||||
return output
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
# 匹配 {{node_id.xxx}} 格式
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
parts.append({
|
||||
"type": "dynamic",
|
||||
"node_id": node_id,
|
||||
"field": field,
|
||||
"raw": ref
|
||||
})
|
||||
else:
|
||||
# 其他引用(如 {{var.xxx}}),当作静态处理
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
last_end = end
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
return parts
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
|
||||
当 end 节点前面是 LLM 节点时,流式输出其内容。
|
||||
智能输出策略:
|
||||
1. 检测模板中是否引用了直接上游节点
|
||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||
|
||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- 直接上游节点是 llm_qa
|
||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||
- LLM 内容在 LLM 节点流式输出
|
||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
完成标记
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
|
||||
# 检查输出中是否包含节点引用(如 {{llm_node.output}})
|
||||
# 如果包含,则逐字符流式输出
|
||||
if output:
|
||||
# 逐字符流式输出
|
||||
for char in output:
|
||||
yield char
|
||||
else:
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
for char in output:
|
||||
yield char
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 统计信息(用于日志)
|
||||
# 找到直接上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
# 解析模板部分
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
|
||||
# 找到第一个引用直接上游节点的动态引用
|
||||
upstream_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
|
||||
upstream_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
if upstream_ref_index is None:
|
||||
# 没有引用直接上游节点,正常输出(渲染完整模板)
|
||||
output = self._render_template(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容")
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||
|
||||
# 收集后缀部分
|
||||
suffix_parts = []
|
||||
for i in range(upstream_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
|
||||
if part["type"] == "static":
|
||||
# 静态文本
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# 其他动态引用(如果有多个引用)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
# 从 streaming_buffer 或 node_outputs 读取
|
||||
streaming_buffer = state.get("streaming_buffer", {})
|
||||
if node_id in streaming_buffer:
|
||||
buffer_data = streaming_buffer[node_id]
|
||||
content = buffer_data.get("full_content", "")
|
||||
else:
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
runtime_vars = state.get("runtime_vars", {})
|
||||
|
||||
content = ""
|
||||
if node_id in node_outputs:
|
||||
node_output = node_outputs[node_id]
|
||||
if isinstance(node_output, dict):
|
||||
content = str(node_output.get(field, ""))
|
||||
elif node_id in runtime_vars:
|
||||
runtime_var = runtime_vars[node_id]
|
||||
if isinstance(runtime_var, dict):
|
||||
content = str(runtime_var.get(field, ""))
|
||||
|
||||
suffix_parts.append(content)
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state)
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||
# 而是通过 writer 直接发送
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"node_id": self.node_id,
|
||||
"chunk": suffix,
|
||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||
"chunk_index": 1,
|
||||
"is_suffix": True
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": output}
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -213,18 +213,44 @@ class LLMNode(BaseNode):
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 检查是否有注入的 End 节点前缀配置
|
||||
writer = get_stream_writer()
|
||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||
|
||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||
if end_prefix:
|
||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||
|
||||
if end_prefix:
|
||||
# 渲染前缀(可能包含其他变量)
|
||||
try:
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
# 提前发送 End 节点的前缀
|
||||
writer({
|
||||
"node_id": "end", # 标记为 end 节点的输出
|
||||
"chunk": rendered_prefix,
|
||||
"full_content": rendered_prefix,
|
||||
"chunk_index": 0,
|
||||
"is_prefix": True # 标记这是前缀
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
# 注意:astream 方法本身就是流式的,不需要额外配置
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
@@ -238,9 +264,8 @@ class LLMNode(BaseNode):
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
|
||||
# 流式返回每个文本片段
|
||||
yield content #AIMessage(content=content)
|
||||
yield content
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user