Merge branch 'develop-mark' into develop

# Conflicts:
#	api/app/core/workflow/executor.py
#	api/app/services/workflow_service.py
This commit is contained in:
Mark
2025-12-20 17:51:49 +08:00
10 changed files with 915 additions and 223 deletions

View File

@@ -94,16 +94,90 @@ class WorkflowExecutor:
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
"error_node": None,
"streaming_buffer": {} # 流式缓冲区
}
def build_graph(self) -> CompiledStateGraph:
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置和相邻且被引用的节点
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
@@ -143,16 +217,39 @@ class WorkflowExecutor:
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
@@ -300,40 +397,143 @@ class WorkflowExecutor:
):
"""执行工作流(流式)
手动执行节点以支持细粒度的流式输出
- workflow_start: 工作流开始
- node_start: 节点开始执行
- node_chunk: LLM 节点的流式输出片段(逐 token
- node_complete: 节点执行完成
- workflow_complete: 工作流完成
使用多个 stream_mode 来获取
1. "updates" - 节点的 state 更新和流式 chunk
2. "debug" - 节点执行的详细信息(开始/完成时间)
3. "custom" - 自定义流式数据chunks
Args:
input_data: 输入数据
Yields:
流式事件
流式事件,格式:
{
"event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message",
"data": {...}
}
"""
#
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 发送 workflow_start 事件
yield {
"event": "workflow_start",
"data": {
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"timestamp": start_time.isoformat()
}
}
# 1. 构建图
graph = self.build_graph()
graph = self.build_graph(True)
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
# 3. Execute workflow
try:
async for chunk in graph.astream(
chunk_count = 0
final_state = None
async for event in graph.astream(
initial_state,
# subgraphs=True,
stream_mode="updates",
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
):
# print(chunk)
yield chunk
# event should be a tuple: (mode, data)
# But let's handle both cases
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
# Unexpected format, log and skip
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
continue
if mode == "custom":
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index"),
"is_prefix": data.get("is_prefix"),
"is_suffix": data.get("is_suffix")
}
}
elif mode == "debug":
# Handle debug information (node execution status)
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if event_type == "task":
# Node starts execution
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node starts execution: {node_name}")
yield {
"event": "node_start",
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
}
elif event_type == "task_result":
# Node execution completed
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node execution completed: {node_name}")
yield {
"event": "node_end",
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
}
elif mode == "updates":
# Handle state updates - store final state
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
final_state = data
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s")
# 发送 workflow_end 事件
yield {
"event": "workflow_end",
"data": {
"execution_id": self.execution_id,
"status": "completed",
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
}
except Exception as e:
# 计算耗时(即使失败也记录)
@@ -341,13 +541,17 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
# 发送 workflow_end 事件(失败)
yield {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
"event": "workflow_end",
"data": {
"execution_id": self.execution_id,
"status": "failed",
"error": str(e),
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:

View File

@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.config import get_stream_writer
from app.core.workflow.variable_pool import VariablePool
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
# 错误信息(用于错误边)
error: str | None
error_node: str | None
# 流式缓冲区(存储节点的实时流式输出)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
class BaseNode(ABC):
@@ -201,19 +206,25 @@ class BaseNode(ABC):
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
"""Execute node with error handling and output wrapping (streaming)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
This method is called by the Executor and is responsible for:
1. Time tracking
2. Calling the node's execute_stream() method
3. Using LangGraph's stream writer to send chunks
4. Updating streaming buffer in state for downstream nodes
5. Wrapping business data into standard output format
6. Error handling
Special handling for End nodes:
- End nodes don't send chunks via writer (prefix and LLM content already sent)
- End nodes only yield suffix for final result assembly
Args:
state: 工作流状态
state: Workflow state
Yields:
标准化的流式事件
State updates with streaming buffer and final result
"""
import time
@@ -222,68 +233,143 @@ class BaseNode(ABC):
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()
# Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Check if this node is adjacent to End node (for message type)
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping)
chunks = []
final_result = None
chunk_count = 0
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
# Stream chunks in real-time
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# Check if it's a completion marker
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# String is a chunk
chunk_count += 1
chunks.append(item)
full_content = "".join(chunks)
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
chunks.append(item)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": item,
"full_content": full_content,
"chunk_index": chunk_count
})
# 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
else:
# Other types are also treated as chunks
chunk_count += 1
chunk_str = str(item)
chunks.append(chunk_str)
full_content = "".join(chunks)
# Send chunks for all nodes
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,
"chunk_index": chunk_count
})
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 包装最终结果
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
# Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
# Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
# Add streaming buffer for non-End nodes
if not is_end_node:
state_update["streaming_buffer"] = {
self.node_id: {
"full_content": "".join(chunks),
"chunk_count": chunk_count,
"is_complete": True # Mark as complete
}
}
# Finally yield state update
# LangGraph will merge this into state
yield state_update
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时{timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
error_output = self._wrap_error(str(e), elapsed_time, state)
yield error_output
def _wrap_output(
self,

View File

@@ -5,6 +5,8 @@ End 节点实现
"""
import logging
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -15,6 +17,7 @@ class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
async def execute(self, state: WorkflowState) -> str:
@@ -30,11 +33,7 @@ class EndNode(BaseNode):
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
@@ -46,7 +45,213 @@ class EndNode(BaseNode):
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output
def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args:
template: 模板字符串
Returns:
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀'
返回:[
{"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"}
]
Args:
template: 模板字符串
state: 工作流状态
Returns:
模板部分列表
"""
import re
parts = []
last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template):
start, end = match.span()
# 添加前面的静态文本
if start > last_end:
static_text = template[last_end:start]
if static_text:
parts.append({"type": "static", "content": static_text})
# 解析动态引用
ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref:
node_id, field = ref.split('.', 1)
parts.append({
"type": "dynamic",
"node_id": node_id,
"field": field,
"raw": ref
})
else:
# 其他引用(如 {{var.xxx}}),当作静态处理
# 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered})
last_end = end
# 添加最后的静态文本
if last_end < len(template):
static_text = template[last_end:]
if static_text:
parts.append({"type": "static", "content": static_text})
return parts
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
智能输出策略:
1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Args:
state: 工作流状态
Yields:
完成标记
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
if not output_template:
output = "工作流已完成"
yield {"__final__": True, "result": output}
return
# 找到直接上游节点
direct_upstream_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
# 解析模板部分
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
# 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_ref_index is None:
# 没有引用直接上游节点,正常输出(渲染完整模板)
output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容")
yield {"__final__": True, "result": output}
return
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# 收集后缀部分
suffix_parts = []
for i in range(upstream_ref_index + 1, len(parts)):
part = parts[i]
if part["type"] == "static":
# 静态文本
suffix_parts.append(part["content"])
elif part["type"] == "dynamic":
# 其他动态引用(如果有多个引用)
node_id = part["node_id"]
field = part["field"]
# 从 streaming_buffer 或 node_outputs 读取
streaming_buffer = state.get("streaming_buffer", {})
if node_id in streaming_buffer:
buffer_data = streaming_buffer[node_id]
content = buffer_data.get("full_content", "")
else:
node_outputs = state.get("node_outputs", {})
runtime_vars = state.get("runtime_vars", {})
content = ""
if node_id in node_outputs:
node_output = node_outputs[node_id]
if isinstance(node_output, dict):
content = str(node_output.get(field, ""))
elif node_id in runtime_vars:
runtime_var = runtime_vars[node_id]
if isinstance(runtime_var, dict):
content = str(runtime_var.get(field, ""))
suffix_parts.append(content)
# 拼接后缀
suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state)
if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
# 而是通过 writer 直接发送
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
"chunk_index": 1,
"is_suffix": True
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
# yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output}

View File

@@ -63,7 +63,7 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
@@ -125,16 +125,22 @@ class LLMNode(BaseNode):
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
base_url=api_base,
extra_params=extra_params
),
type=model_type
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
@@ -146,13 +152,12 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
llm, prompt_or_messages = self._prepare_llm(state,True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
@@ -208,13 +213,43 @@ class LLMNode(BaseNode):
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,
"chunk_index": 0,
"is_prefix": True # 标记这是前缀
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
@@ -224,13 +259,16 @@ class LLMNode(BaseNode):
else:
content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):