[modify] workflow executor support stream
This commit is contained in:
@@ -92,7 +92,7 @@ class WorkflowExecutor:
|
||||
|
||||
|
||||
|
||||
def build_graph(self) -> CompiledStateGraph:
|
||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
Returns:
|
||||
@@ -122,12 +122,19 @@ class WorkflowExecutor:
|
||||
if node_instance:
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
def make_node_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
if stream:
|
||||
# 流式模式:创建 async generator 函数
|
||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||
async def node_func(state: WorkflowState, inst=node_instance):
|
||||
async for item in inst.run_stream(state):
|
||||
yield item
|
||||
workflow.add_node(node_id, node_func)
|
||||
else:
|
||||
# 非流式模式:创建 async function
|
||||
async def node_func(state: WorkflowState, inst=node_instance):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
workflow.add_node(node_id, node_func)
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
|
||||
# 3. 添加边
|
||||
@@ -276,12 +283,13 @@ class WorkflowExecutor:
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
手动执行节点以支持细粒度的流式输出:
|
||||
- workflow_start: 工作流开始
|
||||
- node_start: 节点开始执行
|
||||
- node_chunk: LLM 节点的流式输出片段(逐 token)
|
||||
- node_complete: 节点执行完成
|
||||
- workflow_complete: 工作流完成
|
||||
使用 stream_mode="updates" 来获取每个节点的 state 更新。
|
||||
节点的 generator 会 yield 多个值:
|
||||
- 中间的 chunk 事件(带 type="chunk")
|
||||
- 最后的 state 更新(纯字典,包含 node_outputs 等)
|
||||
|
||||
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
|
||||
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
@@ -289,27 +297,47 @@ class WorkflowExecutor:
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
#
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
graph = self.build_graph(True)
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
async for chunk in graph.astream(
|
||||
async for mode, event in graph.astream(
|
||||
initial_state,
|
||||
# subgraphs=True,
|
||||
stream_mode="updates",
|
||||
stream_mode=["updates","messages"],
|
||||
):
|
||||
# print(chunk)
|
||||
yield chunk
|
||||
# print("刚才跑的节点:", event[0])
|
||||
# # 通过图结构就能算出“接下来是谁”
|
||||
# print("接下来可能跑:", graph.get_next(event[0]))
|
||||
# print("="*50)
|
||||
# # print("mode",mode)
|
||||
# print("event",event)
|
||||
# print("="*50)
|
||||
# event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
||||
for node_id, update in event.items():
|
||||
print("="*50)
|
||||
print("node_id",node_id)
|
||||
print("update",update)
|
||||
|
||||
print("="*50)
|
||||
if isinstance(update, dict) and update.get("type") == "chunk":
|
||||
# 这是流式 chunk,转发给客户端
|
||||
yield {
|
||||
"type": "node_chunk",
|
||||
"node_id": update.get("node_id"),
|
||||
"chunk": update.get("content")
|
||||
}
|
||||
# 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理
|
||||
print(event)
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
|
||||
@@ -209,11 +209,15 @@ class BaseNode(ABC):
|
||||
3. 将业务数据包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
注意:在流式模式下,我们需要:
|
||||
- yield 中间的 chunk 事件(用于实时显示)
|
||||
- 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
标准化的流式事件
|
||||
标准化的流式事件和最终的 state 更新
|
||||
"""
|
||||
import time
|
||||
|
||||
@@ -263,27 +267,39 @@ class BaseNode(ABC):
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取处理后的输出(调用子类的 _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
|
||||
# 包装最终结果
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
yield {
|
||||
"type": "complete",
|
||||
**final_output
|
||||
|
||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
|
||||
# 构建完整的 state 更新(包含 node_outputs 和 runtime_vars)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
}
|
||||
|
||||
# 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中)
|
||||
yield state_update
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
}
|
||||
error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(str(e), elapsed_time, state)
|
||||
}
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state)
|
||||
yield error_output
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
|
||||
@@ -30,11 +30,11 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
pool = self.get_variable_pool(state)
|
||||
# pool = self.get_variable_pool(state)
|
||||
|
||||
print("="*20)
|
||||
print( pool.get("start.test"))
|
||||
print("="*20)
|
||||
# print("="*20)
|
||||
# print( pool.get("start.test"))
|
||||
# print("="*20)
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
|
||||
@@ -63,7 +63,7 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
|
||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -125,16 +125,19 @@ class LLMNode(BaseNode):
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
print("="*50)
|
||||
print("stream",stream)
|
||||
print("="*50)
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
base_url=api_base,
|
||||
extra_params={"streaming": stream}
|
||||
),
|
||||
type=model_type
|
||||
)
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
@@ -146,13 +149,12 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
@@ -199,47 +201,47 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 LLM 调用
|
||||
# async def execute_stream(self, state: WorkflowState):
|
||||
# """流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
# Args:
|
||||
# state: 工作流状态
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
# Yields:
|
||||
# 文本片段(chunk)或完成标记
|
||||
# """
|
||||
# llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
# # 累积完整响应
|
||||
# full_response = ""
|
||||
# last_chunk = None
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
# # 调用 LLM(流式,支持字符串或消息列表)
|
||||
# async for chunk in llm.astream(prompt_or_messages):
|
||||
# # 提取内容
|
||||
# if hasattr(chunk, 'content'):
|
||||
# content = chunk.content
|
||||
# else:
|
||||
# content = str(chunk)
|
||||
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
# full_response += content
|
||||
# last_chunk = chunk
|
||||
# logger.info(f"节点 {self.node_id} LLM : {content}")
|
||||
# # 流式返回每个文本片段
|
||||
# yield content
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
# # 构建完整的 AIMessage(包含元数据)
|
||||
# if isinstance(last_chunk, AIMessage):
|
||||
# final_message = AIMessage(
|
||||
# content=full_response,
|
||||
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
# )
|
||||
# else:
|
||||
# final_message = AIMessage(content=full_response)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
# # yield 完成标记
|
||||
# yield {"__final__": True, "result": final_message}
|
||||
|
||||
Reference in New Issue
Block a user