[add] workflow llm & end logic
This commit is contained in:
@@ -87,11 +87,75 @@ class WorkflowExecutor:
|
|||||||
"workspace_id": self.workspace_id,
|
"workspace_id": self.workspace_id,
|
||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
"error": None,
|
"error": None,
|
||||||
"error_node": None
|
"error_node": None,
|
||||||
|
"streaming_buffer": {} # 流式缓冲区
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _analyze_end_node_prefixes(self) -> dict[str, str]:
|
||||||
|
"""分析 End 节点的前缀配置
|
||||||
|
|
||||||
|
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||||
|
提取该引用之前的前缀部分。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典:{上游节点ID: End节点前缀}
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
prefixes = {}
|
||||||
|
|
||||||
|
# 找到所有 End 节点
|
||||||
|
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||||
|
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||||
|
|
||||||
|
for end_node in end_nodes:
|
||||||
|
end_node_id = end_node.get("id")
|
||||||
|
output_template = end_node.get("config", {}).get("output")
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||||
|
|
||||||
|
if not output_template:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 找到所有直接连接到 End 节点的上游节点
|
||||||
|
direct_upstream_nodes = []
|
||||||
|
for edge in self.edges:
|
||||||
|
if edge.get("target") == end_node_id:
|
||||||
|
source_node_id = edge.get("source")
|
||||||
|
direct_upstream_nodes.append(source_node_id)
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||||
|
|
||||||
|
# 查找模板中引用了哪些节点
|
||||||
|
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||||
|
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
matches = list(re.finditer(pattern, output_template))
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||||
|
|
||||||
|
# 找到第一个直接上游节点的引用
|
||||||
|
for match in matches:
|
||||||
|
referenced_node_id = match.group(1)
|
||||||
|
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||||
|
|
||||||
|
if referenced_node_id in direct_upstream_nodes:
|
||||||
|
# 这是直接上游节点的引用,提取前缀
|
||||||
|
prefix = output_template[:match.start()]
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
prefixes[referenced_node_id] = prefix
|
||||||
|
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||||
|
|
||||||
|
# 只处理第一个直接上游节点的引用
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||||
|
return prefixes
|
||||||
|
|
||||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||||
"""构建 LangGraph
|
"""构建 LangGraph
|
||||||
|
|
||||||
@@ -99,6 +163,9 @@ class WorkflowExecutor:
|
|||||||
编译后的状态图
|
编译后的状态图
|
||||||
"""
|
"""
|
||||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||||
|
|
||||||
|
# 分析 End 节点的前缀配置
|
||||||
|
end_prefixes = self._analyze_end_node_prefixes() if stream else {}
|
||||||
|
|
||||||
# 1. 创建状态图
|
# 1. 创建状态图
|
||||||
workflow = StateGraph(WorkflowState)
|
workflow = StateGraph(WorkflowState)
|
||||||
@@ -120,6 +187,12 @@ class WorkflowExecutor:
|
|||||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||||
if node_instance:
|
if node_instance:
|
||||||
|
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
||||||
|
if stream and node_id in end_prefixes:
|
||||||
|
# 将 End 前缀配置注入到节点实例
|
||||||
|
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||||
|
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||||
|
|
||||||
# 包装节点的 run 方法
|
# 包装节点的 run 方法
|
||||||
# 使用函数工厂避免闭包问题
|
# 使用函数工厂避免闭包问题
|
||||||
if stream:
|
if stream:
|
||||||
@@ -309,29 +382,48 @@ class WorkflowExecutor:
|
|||||||
# 2. 初始化状态(自动注入系统变量)
|
# 2. 初始化状态(自动注入系统变量)
|
||||||
initial_state = self._prepare_initial_state(input_data)
|
initial_state = self._prepare_initial_state(input_data)
|
||||||
|
|
||||||
# 3. 执行工作流
|
# 3. Execute workflow
|
||||||
try:
|
try:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
async for event in graph.astream(
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode=["updates", "debug"],
|
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||||
):
|
):
|
||||||
mode, data = event
|
# event should be a tuple: (mode, data)
|
||||||
|
# But let's handle both cases
|
||||||
|
if isinstance(event, tuple) and len(event) == 2:
|
||||||
|
mode, data = event
|
||||||
|
else:
|
||||||
|
# Unexpected format, log and skip
|
||||||
|
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
|
||||||
|
continue
|
||||||
|
|
||||||
if mode == "debug":
|
if mode == "custom":
|
||||||
# 处理调试信息(节点执行状态)
|
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}")
|
||||||
|
yield {
|
||||||
|
"type": "node_chunk",
|
||||||
|
"node_id": data.get("node_id"),
|
||||||
|
"chunk": data.get("chunk"),
|
||||||
|
"full_content": data.get("full_content"),
|
||||||
|
"chunk_index": data.get("chunk_index")
|
||||||
|
}
|
||||||
|
|
||||||
|
elif mode == "debug":
|
||||||
|
# Handle debug information (node execution status)
|
||||||
event_type = data.get("type")
|
event_type = data.get("type")
|
||||||
payload = data.get("payload", {})
|
payload = data.get("payload", {})
|
||||||
node_name = payload.get("name")
|
node_name = payload.get("name")
|
||||||
|
|
||||||
if event_type == "task":
|
if event_type == "task":
|
||||||
# 节点开始执行
|
# Node starts execution
|
||||||
inputv = payload.get("input", {})
|
inputv = payload.get("input", {})
|
||||||
variables = inputv.get("variables", {})
|
variables = inputv.get("variables", {})
|
||||||
variables_sys = variables.get("sys", {})
|
variables_sys = variables.get("sys", {})
|
||||||
conversation_id = variables_sys.get("conversation_id")
|
conversation_id = variables_sys.get("conversation_id")
|
||||||
execution_id = variables_sys.get("execution_id")
|
execution_id = variables_sys.get("execution_id")
|
||||||
logger.info(f"[DEBUG] 节点开始执行: {node_name}")
|
logger.info(f"[DEBUG] Node starts execution: {node_name}")
|
||||||
yield {
|
yield {
|
||||||
"type": "node_start",
|
"type": "node_start",
|
||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
@@ -340,16 +432,16 @@ class WorkflowExecutor:
|
|||||||
"timestamp": data.get("timestamp")
|
"timestamp": data.get("timestamp")
|
||||||
}
|
}
|
||||||
elif event_type == "task_result":
|
elif event_type == "task_result":
|
||||||
# 节点执行完成
|
# Node execution completed
|
||||||
result = payload.get("result", {})
|
result = payload.get("result", {})
|
||||||
inputv = result.get("input", {})
|
inputv = result.get("input", {})
|
||||||
variables = inputv.get("variables", {})
|
variables = inputv.get("variables", {})
|
||||||
variables_sys = variables.get("sys", {})
|
variables_sys = variables.get("sys", {})
|
||||||
conversation_id = variables_sys.get("conversation_id")
|
conversation_id = variables_sys.get("conversation_id")
|
||||||
execution_id = variables_sys.get("execution_id")
|
execution_id = variables_sys.get("execution_id")
|
||||||
logger.info(f"[DEBUG] 节点执行完成: {node_name}")
|
logger.info(f"[DEBUG] Node execution completed: {node_name}")
|
||||||
yield {
|
yield {
|
||||||
"type": "node_end",
|
"type": "node_complete",
|
||||||
"node_id": node_name,
|
"node_id": node_name,
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"execution_id": execution_id,
|
"execution_id": execution_id,
|
||||||
@@ -357,27 +449,10 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# 处理 state 更新
|
# Handle state updates
|
||||||
# data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||||
print("="*50)
|
|
||||||
print(data)
|
|
||||||
print("-"*50)
|
|
||||||
for node_id, update in data.items():
|
|
||||||
if isinstance(update, dict) and update.get("type") == "chunk":
|
|
||||||
# 这是流式 chunk,转发给客户端
|
|
||||||
chunk_count += 1
|
|
||||||
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
|
|
||||||
yield {
|
|
||||||
"type": "node_chunk",
|
|
||||||
"node_id": update.get("node_id"),
|
|
||||||
"chunk": update.get("content"),
|
|
||||||
"full_content": update.get("full_content")
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
|
|
||||||
# 其他情况(state 更新)会被 LangGraph 自动合并到 state
|
|
||||||
|
|
||||||
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}")
|
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 计算耗时(即使失败也记录)
|
# 计算耗时(即使失败也记录)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any, TypedDict, Annotated
|
from typing import Any, TypedDict, Annotated
|
||||||
from operator import add
|
from operator import add
|
||||||
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||||
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
|
|||||||
# 错误信息(用于错误边)
|
# 错误信息(用于错误边)
|
||||||
error: str | None
|
error: str | None
|
||||||
error_node: str | None
|
error_node: str | None
|
||||||
|
|
||||||
|
# 流式缓冲区(存储节点的实时流式输出)
|
||||||
|
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
|
||||||
|
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||||
|
|
||||||
|
|
||||||
class BaseNode(ABC):
|
class BaseNode(ABC):
|
||||||
@@ -201,23 +206,25 @@ class BaseNode(ABC):
|
|||||||
return self._wrap_error(str(e), elapsed_time, state)
|
return self._wrap_error(str(e), elapsed_time, state)
|
||||||
|
|
||||||
async def run_stream(self, state: WorkflowState):
|
async def run_stream(self, state: WorkflowState):
|
||||||
"""执行节点(带错误处理和输出包装,流式)
|
"""Execute node with error handling and output wrapping (streaming)
|
||||||
|
|
||||||
这个方法由 Executor 调用,负责:
|
This method is called by the Executor and is responsible for:
|
||||||
1. 时间统计
|
1. Time tracking
|
||||||
2. 调用节点的 execute_stream() 方法
|
2. Calling the node's execute_stream() method
|
||||||
3. 将业务数据包装成标准输出格式
|
3. Using LangGraph's stream writer to send chunks
|
||||||
4. 错误处理
|
4. Updating streaming buffer in state for downstream nodes
|
||||||
|
5. Wrapping business data into standard output format
|
||||||
|
6. Error handling
|
||||||
|
|
||||||
注意:在流式模式下,我们需要:
|
Special handling for End nodes:
|
||||||
- yield 中间的 chunk 事件(用于实时显示)
|
- End nodes don't send chunks via writer (prefix and LLM content already sent)
|
||||||
- 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state)
|
- End nodes only yield suffix for final result assembly
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 工作流状态
|
state: Workflow state
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
标准化的流式事件和最终的 state 更新
|
State updates with streaming buffer and final result
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -226,63 +233,102 @@ class BaseNode(ABC):
|
|||||||
try:
|
try:
|
||||||
timeout = self.get_timeout()
|
timeout = self.get_timeout()
|
||||||
|
|
||||||
# 累积完整结果(用于最后的包装)
|
# Get LangGraph's stream writer for sending custom data
|
||||||
|
writer = get_stream_writer()
|
||||||
|
|
||||||
|
# Check if this is an End node
|
||||||
|
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||||
|
is_end_node = self.node_type == "end"
|
||||||
|
|
||||||
|
# Accumulate complete result (for final wrapping)
|
||||||
chunks = []
|
chunks = []
|
||||||
final_result = None
|
final_result = None
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
# 使用异步生成器包装,支持超时
|
# Stream chunks in real-time
|
||||||
async def stream_with_timeout():
|
loop_start = asyncio.get_event_loop().time()
|
||||||
nonlocal final_result
|
|
||||||
loop_start = asyncio.get_event_loop().time()
|
async for item in self.execute_stream(state):
|
||||||
|
# Check timeout
|
||||||
|
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||||
|
raise TimeoutError()
|
||||||
|
|
||||||
async for item in self.execute_stream(state):
|
# Check if it's a completion marker
|
||||||
# 检查超时
|
if isinstance(item, dict) and item.get("__final__"):
|
||||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
final_result = item["result"]
|
||||||
raise TimeoutError()
|
elif isinstance(item, str):
|
||||||
|
# String is a chunk
|
||||||
|
chunk_count += 1
|
||||||
|
chunks.append(item)
|
||||||
|
full_content = "".join(chunks)
|
||||||
|
|
||||||
# 检查是否是完成标记
|
# Send chunks for all nodes (including End nodes for suffix)
|
||||||
if isinstance(item, dict) and item.get("__final__"):
|
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||||
final_result = item["result"]
|
|
||||||
elif isinstance(item, str):
|
# 1. Send via stream writer (for real-time client updates)
|
||||||
# 字符串是 chunk
|
writer({
|
||||||
# print("="*50)
|
"node_id": self.node_id,
|
||||||
# print(item)
|
"chunk": item,
|
||||||
# print("-"*50)
|
"full_content": full_content,
|
||||||
chunks.append(item)
|
"chunk_index": chunk_count
|
||||||
|
})
|
||||||
|
|
||||||
|
# 2. Update streaming buffer in state (for downstream nodes)
|
||||||
|
# Only non-End nodes need streaming buffer
|
||||||
|
if not is_end_node:
|
||||||
yield {
|
yield {
|
||||||
"type": "chunk",
|
"streaming_buffer": {
|
||||||
"node_id": self.node_id,
|
self.node_id: {
|
||||||
"content": item,
|
"full_content": full_content,
|
||||||
"full_content": "".join(chunks)
|
"chunk_count": chunk_count,
|
||||||
|
"is_complete": False
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# 其他类型也当作 chunk 处理
|
# Other types are also treated as chunks
|
||||||
chunks.append(str(item))
|
chunk_count += 1
|
||||||
|
chunk_str = str(item)
|
||||||
|
chunks.append(chunk_str)
|
||||||
|
full_content = "".join(chunks)
|
||||||
|
|
||||||
|
# Send chunks for all nodes
|
||||||
|
writer({
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"chunk": chunk_str,
|
||||||
|
"full_content": full_content,
|
||||||
|
"chunk_index": chunk_count
|
||||||
|
})
|
||||||
|
|
||||||
|
# Only non-End nodes need streaming buffer
|
||||||
|
if not is_end_node:
|
||||||
yield {
|
yield {
|
||||||
"type": "chunk",
|
"streaming_buffer": {
|
||||||
"node_id": self.node_id,
|
self.node_id: {
|
||||||
"content": str(item),
|
"full_content": full_content,
|
||||||
"full_content": "".join(chunks)
|
"chunk_count": chunk_count,
|
||||||
|
"is_complete": False
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async for chunk_event in stream_with_timeout():
|
|
||||||
yield chunk_event
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# 提取处理后的输出(调用子类的 _extract_output)
|
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||||
|
|
||||||
|
# Extract processed output (call subclass's _extract_output)
|
||||||
extracted_output = self._extract_output(final_result)
|
extracted_output = self._extract_output(final_result)
|
||||||
|
|
||||||
# 包装最终结果
|
# Wrap final result
|
||||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||||
|
|
||||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
||||||
if isinstance(extracted_output, dict):
|
if isinstance(extracted_output, dict):
|
||||||
runtime_var = extracted_output
|
runtime_var = extracted_output
|
||||||
else:
|
else:
|
||||||
runtime_var = {"output": extracted_output}
|
runtime_var = {"output": extracted_output}
|
||||||
|
|
||||||
# 构建完整的 state 更新(包含 node_outputs 和 runtime_vars)
|
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||||
state_update = {
|
state_update = {
|
||||||
**final_output,
|
**final_output,
|
||||||
"runtime_vars": {
|
"runtime_vars": {
|
||||||
@@ -290,13 +336,24 @@ class BaseNode(ABC):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中)
|
# Add streaming buffer for non-End nodes
|
||||||
|
if not is_end_node:
|
||||||
|
state_update["streaming_buffer"] = {
|
||||||
|
self.node_id: {
|
||||||
|
"full_content": "".join(chunks),
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
"is_complete": True # Mark as complete
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Finally yield state update
|
||||||
|
# LangGraph will merge this into state
|
||||||
yield state_update
|
yield state_update
|
||||||
|
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
||||||
error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
|
||||||
yield error_output
|
yield error_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ End 节点实现
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
|
||||||
@@ -15,6 +17,7 @@ class EndNode(BaseNode):
|
|||||||
"""End 节点
|
"""End 节点
|
||||||
|
|
||||||
工作流的结束节点,根据配置的模板输出最终结果。
|
工作流的结束节点,根据配置的模板输出最终结果。
|
||||||
|
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> str:
|
async def execute(self, state: WorkflowState) -> str:
|
||||||
@@ -45,42 +48,209 @@ class EndNode(BaseNode):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||||
|
"""从模板中提取引用的节点 ID
|
||||||
|
|
||||||
|
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
引用的节点 ID 列表
|
||||||
|
"""
|
||||||
|
# 匹配 {{node_id.xxx}} 格式
|
||||||
|
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||||
|
matches = re.findall(pattern, template)
|
||||||
|
return list(set(matches)) # 去重
|
||||||
|
|
||||||
|
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||||
|
"""解析模板,分离静态文本和动态引用
|
||||||
|
|
||||||
|
例如:'你好 {{llm.output}}, 这是后缀'
|
||||||
|
返回:[
|
||||||
|
{"type": "static", "content": "你好 "},
|
||||||
|
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||||
|
{"type": "static", "content": ", 这是后缀"}
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板部分列表
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||||
|
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||||
|
|
||||||
|
for match in re.finditer(pattern, template):
|
||||||
|
start, end = match.span()
|
||||||
|
|
||||||
|
# 添加前面的静态文本
|
||||||
|
if start > last_end:
|
||||||
|
static_text = template[last_end:start]
|
||||||
|
if static_text:
|
||||||
|
parts.append({"type": "static", "content": static_text})
|
||||||
|
|
||||||
|
# 解析动态引用
|
||||||
|
ref = match.group(1).strip()
|
||||||
|
|
||||||
|
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||||
|
if '.' in ref:
|
||||||
|
node_id, field = ref.split('.', 1)
|
||||||
|
parts.append({
|
||||||
|
"type": "dynamic",
|
||||||
|
"node_id": node_id,
|
||||||
|
"field": field,
|
||||||
|
"raw": ref
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# 其他引用(如 {{var.xxx}}),当作静态处理
|
||||||
|
# 直接渲染这部分
|
||||||
|
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||||
|
parts.append({"type": "static", "content": rendered})
|
||||||
|
|
||||||
|
last_end = end
|
||||||
|
|
||||||
|
# 添加最后的静态文本
|
||||||
|
if last_end < len(template):
|
||||||
|
static_text = template[last_end:]
|
||||||
|
if static_text:
|
||||||
|
parts.append({"type": "static", "content": static_text})
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState):
|
async def execute_stream(self, state: WorkflowState):
|
||||||
"""流式执行 end 节点业务逻辑
|
"""流式执行 end 节点业务逻辑
|
||||||
|
|
||||||
当 end 节点前面是 LLM 节点时,流式输出其内容。
|
智能输出策略:
|
||||||
|
1. 检测模板中是否引用了直接上游节点
|
||||||
|
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||||
|
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||||
|
|
||||||
|
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||||
|
- 直接上游节点是 llm_qa
|
||||||
|
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||||
|
- LLM 内容在 LLM 节点流式输出
|
||||||
|
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: 工作流状态
|
state: 工作流状态
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
文本片段(chunk)或完成标记
|
完成标记
|
||||||
"""
|
"""
|
||||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||||
|
|
||||||
# 获取配置的输出模板
|
# 获取配置的输出模板
|
||||||
output_template = self.config.get("output")
|
output_template = self.config.get("output")
|
||||||
|
|
||||||
# 如果配置了输出模板,使用模板渲染
|
if not output_template:
|
||||||
if output_template:
|
|
||||||
output = self._render_template(output_template, state)
|
|
||||||
|
|
||||||
# 检查输出中是否包含节点引用(如 {{llm_node.output}})
|
|
||||||
# 如果包含,则逐字符流式输出
|
|
||||||
if output:
|
|
||||||
# 逐字符流式输出
|
|
||||||
for char in output:
|
|
||||||
yield char
|
|
||||||
else:
|
|
||||||
output = "工作流已完成"
|
output = "工作流已完成"
|
||||||
for char in output:
|
yield {"__final__": True, "result": output}
|
||||||
yield char
|
return
|
||||||
|
|
||||||
# 统计信息(用于日志)
|
# 找到直接上游节点
|
||||||
|
direct_upstream_nodes = []
|
||||||
|
for edge in self.workflow_config.get("edges", []):
|
||||||
|
if edge.get("target") == self.node_id:
|
||||||
|
source_node_id = edge.get("source")
|
||||||
|
direct_upstream_nodes.append(source_node_id)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||||
|
|
||||||
|
# 解析模板部分
|
||||||
|
parts = self._parse_template_parts(output_template, state)
|
||||||
|
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||||
|
|
||||||
|
# 找到第一个引用直接上游节点的动态引用
|
||||||
|
upstream_ref_index = None
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
|
||||||
|
upstream_ref_index = i
|
||||||
|
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if upstream_ref_index is None:
|
||||||
|
# 没有引用直接上游节点,正常输出(渲染完整模板)
|
||||||
|
output = self._render_template(output_template, state)
|
||||||
|
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容")
|
||||||
|
yield {"__final__": True, "result": output}
|
||||||
|
return
|
||||||
|
|
||||||
|
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||||
|
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||||
|
|
||||||
|
# 收集后缀部分
|
||||||
|
suffix_parts = []
|
||||||
|
for i in range(upstream_ref_index + 1, len(parts)):
|
||||||
|
part = parts[i]
|
||||||
|
|
||||||
|
if part["type"] == "static":
|
||||||
|
# 静态文本
|
||||||
|
suffix_parts.append(part["content"])
|
||||||
|
|
||||||
|
elif part["type"] == "dynamic":
|
||||||
|
# 其他动态引用(如果有多个引用)
|
||||||
|
node_id = part["node_id"]
|
||||||
|
field = part["field"]
|
||||||
|
|
||||||
|
# 从 streaming_buffer 或 node_outputs 读取
|
||||||
|
streaming_buffer = state.get("streaming_buffer", {})
|
||||||
|
if node_id in streaming_buffer:
|
||||||
|
buffer_data = streaming_buffer[node_id]
|
||||||
|
content = buffer_data.get("full_content", "")
|
||||||
|
else:
|
||||||
|
node_outputs = state.get("node_outputs", {})
|
||||||
|
runtime_vars = state.get("runtime_vars", {})
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
if node_id in node_outputs:
|
||||||
|
node_output = node_outputs[node_id]
|
||||||
|
if isinstance(node_output, dict):
|
||||||
|
content = str(node_output.get(field, ""))
|
||||||
|
elif node_id in runtime_vars:
|
||||||
|
runtime_var = runtime_vars[node_id]
|
||||||
|
if isinstance(runtime_var, dict):
|
||||||
|
content = str(runtime_var.get(field, ""))
|
||||||
|
|
||||||
|
suffix_parts.append(content)
|
||||||
|
|
||||||
|
# 拼接后缀
|
||||||
|
suffix = "".join(suffix_parts)
|
||||||
|
|
||||||
|
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||||
|
full_output = self._render_template(output_template, state)
|
||||||
|
|
||||||
|
if suffix:
|
||||||
|
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
|
||||||
|
# 一次性输出后缀(作为单个 chunk)
|
||||||
|
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||||
|
# 而是通过 writer 直接发送
|
||||||
|
from langgraph.config import get_stream_writer
|
||||||
|
writer = get_stream_writer()
|
||||||
|
writer({
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"chunk": suffix,
|
||||||
|
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||||
|
"chunk_index": 1,
|
||||||
|
"is_suffix": True
|
||||||
|
})
|
||||||
|
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||||
|
else:
|
||||||
|
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
node_outputs = state.get("node_outputs", {})
|
node_outputs = state.get("node_outputs", {})
|
||||||
total_nodes = len(node_outputs)
|
total_nodes = len(node_outputs)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
|
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||||
|
|
||||||
# yield 完成标记
|
# yield 完成标记(包含完整输出)
|
||||||
yield {"__final__": True, "result": output}
|
yield {"__final__": True, "result": full_output}
|
||||||
|
|||||||
@@ -213,18 +213,44 @@ class LLMNode(BaseNode):
|
|||||||
Yields:
|
Yields:
|
||||||
文本片段(chunk)或完成标记
|
文本片段(chunk)或完成标记
|
||||||
"""
|
"""
|
||||||
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||||
|
|
||||||
|
# 检查是否有注入的 End 节点前缀配置
|
||||||
|
writer = get_stream_writer()
|
||||||
|
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||||
|
|
||||||
|
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||||
|
if end_prefix:
|
||||||
|
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||||
|
|
||||||
|
if end_prefix:
|
||||||
|
# 渲染前缀(可能包含其他变量)
|
||||||
|
try:
|
||||||
|
rendered_prefix = self._render_template(end_prefix, state)
|
||||||
|
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||||
|
|
||||||
|
# 提前发送 End 节点的前缀
|
||||||
|
writer({
|
||||||
|
"node_id": "end", # 标记为 end 节点的输出
|
||||||
|
"chunk": rendered_prefix,
|
||||||
|
"full_content": rendered_prefix,
|
||||||
|
"chunk_index": 0,
|
||||||
|
"is_prefix": True # 标记这是前缀
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
last_chunk = None
|
last_chunk = None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
# 注意:astream 方法本身就是流式的,不需要额外配置
|
|
||||||
async for chunk in llm.astream(prompt_or_messages):
|
async for chunk in llm.astream(prompt_or_messages):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
@@ -238,9 +264,8 @@ class LLMNode(BaseNode):
|
|||||||
last_chunk = chunk
|
last_chunk = chunk
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
|
|
||||||
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
|
|
||||||
# 流式返回每个文本片段
|
# 流式返回每个文本片段
|
||||||
yield content #AIMessage(content=content)
|
yield content
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user