[add] workflow llm & end logic

This commit is contained in:
Mark
2025-12-20 17:25:47 +08:00
parent d8fcea8564
commit 36b36b729b
4 changed files with 430 additions and 103 deletions

View File

@@ -87,11 +87,75 @@ class WorkflowExecutor:
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
"error_node": None,
"streaming_buffer": {} # 流式缓冲区
}
def _analyze_end_node_prefixes(self) -> dict[str, str]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
字典:{上游节点ID: End节点前缀}
"""
import re
prefixes = {}
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
return prefixes
def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph
@@ -99,6 +163,9 @@ class WorkflowExecutor:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置
end_prefixes = self._analyze_end_node_prefixes() if stream else {}
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
@@ -120,6 +187,12 @@ class WorkflowExecutor:
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if stream:
@@ -309,29 +382,48 @@ class WorkflowExecutor:
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
# 3. Execute workflow
try:
chunk_count = 0
async for event in graph.astream(
initial_state,
stream_mode=["updates", "debug"],
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
):
mode, data = event
# event should be a tuple: (mode, data)
# But let's handle both cases
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
# Unexpected format, log and skip
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
continue
if mode == "debug":
# 处理调试信息(节点执行状态)
if mode == "custom":
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
logger.info(f"[CUSTOM] ✅ 收到 chunk #{chunk_count} from {data.get('node_id')}")
yield {
"type": "node_chunk",
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index")
}
elif mode == "debug":
# Handle debug information (node execution status)
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if event_type == "task":
# 节点开始执行
# Node starts execution
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点开始执行: {node_name}")
logger.info(f"[DEBUG] Node starts execution: {node_name}")
yield {
"type": "node_start",
"node_id": node_name,
@@ -340,16 +432,16 @@ class WorkflowExecutor:
"timestamp": data.get("timestamp")
}
elif event_type == "task_result":
# 节点执行完成
# Node execution completed
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点执行完成: {node_name}")
logger.info(f"[DEBUG] Node execution completed: {node_name}")
yield {
"type": "node_end",
"type": "node_complete",
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
@@ -357,27 +449,10 @@ class WorkflowExecutor:
}
elif mode == "updates":
# 处理 state 更新
# data 是一个字典key 是节点 IDvalue 是 state 更新或 chunk
print("="*50)
print(data)
print("-"*50)
for node_id, update in data.items():
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
chunk_count += 1
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content"),
"full_content": update.get("full_content")
}
else:
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
# 其他情况state 更新)会被 LangGraph 自动合并到 state
# Handle state updates
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}")
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}")
except Exception as e:
# 计算耗时(即使失败也记录)

View File

@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.config import get_stream_writer
from app.core.workflow.variable_pool import VariablePool
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
# 错误信息(用于错误边)
error: str | None
error_node: str | None
# 流式缓冲区(存储节点的实时流式输出)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
class BaseNode(ABC):
@@ -201,23 +206,25 @@ class BaseNode(ABC):
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
"""Execute node with error handling and output wrapping (streaming)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
This method is called by the Executor and is responsible for:
1. Time tracking
2. Calling the node's execute_stream() method
3. Using LangGraph's stream writer to send chunks
4. Updating streaming buffer in state for downstream nodes
5. Wrapping business data into standard output format
6. Error handling
注意:在流式模式下,我们需要:
- yield 中间的 chunk 事件(用于实时显示)
- 最后 yield 一个包含 state 更新的字典LangGraph 会合并到 state
Special handling for End nodes:
- End nodes don't send chunks via writer (prefix and LLM content already sent)
- End nodes only yield suffix for final result assembly
Args:
state: 工作流状态
state: Workflow state
Yields:
标准化的流式事件和最终的 state 更新
State updates with streaming buffer and final result
"""
import time
@@ -226,63 +233,102 @@ class BaseNode(ABC):
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()
# Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Accumulate complete result (for final wrapping)
chunks = []
final_result = None
chunk_count = 0
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
# Stream chunks in real-time
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# Check if it's a completion marker
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# String is a chunk
chunk_count += 1
chunks.append(item)
full_content = "".join(chunks)
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
# print("="*50)
# print(item)
# print("-"*50)
chunks.append(item)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
"node_id": self.node_id,
"chunk": item,
"full_content": full_content,
"chunk_index": chunk_count
})
# 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
else:
# Other types are also treated as chunks
chunk_count += 1
chunk_str = str(item)
chunks.append(chunk_str)
full_content = "".join(chunks)
# Send chunks for all nodes
writer({
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,
"chunk_index": chunk_count
})
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
# 包装最终结果
# Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# 构建完整的 state 更新(包含 node_outputs runtime_vars
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"runtime_vars": {
@@ -290,13 +336,24 @@ class BaseNode(ABC):
}
}
# 最后 yield 纯粹的 state 更新LangGraph 会合并到 state 中)
# Add streaming buffer for non-End nodes
if not is_end_node:
state_update["streaming_buffer"] = {
self.node_id: {
"full_content": "".join(chunks),
"chunk_count": chunk_count,
"is_complete": True # Mark as complete
}
}
# Finally yield state update
# LangGraph will merge this into state
yield state_update
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时{timeout}秒)")
error_output = self._wrap_error(f"节点执行超时{timeout}秒)", elapsed_time, state)
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time

View File

@@ -5,6 +5,8 @@ End 节点实现
"""
import logging
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -15,6 +17,7 @@ class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
async def execute(self, state: WorkflowState) -> str:
@@ -45,42 +48,209 @@ class EndNode(BaseNode):
return output
def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args:
template: 模板字符串
Returns:
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀'
返回:[
{"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"}
]
Args:
template: 模板字符串
state: 工作流状态
Returns:
模板部分列表
"""
import re
parts = []
last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template):
start, end = match.span()
# 添加前面的静态文本
if start > last_end:
static_text = template[last_end:start]
if static_text:
parts.append({"type": "static", "content": static_text})
# 解析动态引用
ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref:
node_id, field = ref.split('.', 1)
parts.append({
"type": "dynamic",
"node_id": node_id,
"field": field,
"raw": ref
})
else:
# 其他引用(如 {{var.xxx}}),当作静态处理
# 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered})
last_end = end
# 添加最后的静态文本
if last_end < len(template):
static_text = template[last_end:]
if static_text:
parts.append({"type": "static", "content": static_text})
return parts
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
当 end 节点前面是 LLM 节点时,流式输出其内容。
智能输出策略:
1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Args:
state: 工作流状态
Yields:
文本片段chunk完成标记
完成标记
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
# 如果配置了输出模板,使用模板渲染
if output_template:
output = self._render_template(output_template, state)
# 检查输出中是否包含节点引用(如 {{llm_node.output}}
# 如果包含,则逐字符流式输出
if output:
# 逐字符流式输出
for char in output:
yield char
else:
if not output_template:
output = "工作流已完成"
for char in output:
yield char
yield {"__final__": True, "result": output}
return
# 统计信息(用于日志)
# 找到直接上游节点
direct_upstream_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
# 解析模板部分
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
# 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_ref_index is None:
# 没有引用直接上游节点,正常输出(渲染完整模板)
output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容")
yield {"__final__": True, "result": output}
return
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# 收集后缀部分
suffix_parts = []
for i in range(upstream_ref_index + 1, len(parts)):
part = parts[i]
if part["type"] == "static":
# 静态文本
suffix_parts.append(part["content"])
elif part["type"] == "dynamic":
# 其他动态引用(如果有多个引用)
node_id = part["node_id"]
field = part["field"]
# 从 streaming_buffer 或 node_outputs 读取
streaming_buffer = state.get("streaming_buffer", {})
if node_id in streaming_buffer:
buffer_data = streaming_buffer[node_id]
content = buffer_data.get("full_content", "")
else:
node_outputs = state.get("node_outputs", {})
runtime_vars = state.get("runtime_vars", {})
content = ""
if node_id in node_outputs:
node_output = node_outputs[node_id]
if isinstance(node_output, dict):
content = str(node_output.get(field, ""))
elif node_id in runtime_vars:
runtime_var = runtime_vars[node_id]
if isinstance(runtime_var, dict):
content = str(runtime_var.get(field, ""))
suffix_parts.append(content)
# 拼接后缀
suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state)
if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
# 而是通过 writer 直接发送
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
"chunk_index": 1,
"is_suffix": True
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
# yield 完成标记
yield {"__final__": True, "result": output}
# yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output}

View File

@@ -213,18 +213,44 @@ class LLMNode(BaseNode):
Yields:
文本片段chunk或完成标记
"""
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀
writer({
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,
"chunk_index": 0,
"is_prefix": True # 标记这是前缀
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
# 注意astream 方法本身就是流式的,不需要额外配置
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
@@ -238,9 +264,8 @@ class LLMNode(BaseNode):
last_chunk = chunk
chunk_count += 1
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
# 流式返回每个文本片段
yield content #AIMessage(content=content)
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")