[modify] workflow executor support stream

This commit is contained in:
Mark
2025-12-20 13:59:20 +08:00
parent 0503b26232
commit 6c04c99073
7 changed files with 168 additions and 126 deletions

View File

@@ -9,18 +9,15 @@ LangChain Agent 封装
"""
import os
import time
import asyncio
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.memory.agent.mcp_server.services import session_service
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task

View File

@@ -92,7 +92,7 @@ class WorkflowExecutor:
def build_graph(self) -> CompiledStateGraph:
def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph
Returns:
@@ -122,12 +122,19 @@ class WorkflowExecutor:
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
async def node_func(state: WorkflowState, inst=node_instance):
async for item in inst.run_stream(state):
yield item
workflow.add_node(node_id, node_func)
else:
# 非流式模式:创建 async function
async def node_func(state: WorkflowState, inst=node_instance):
return await inst.run(state)
return node_func
workflow.add_node(node_id, node_func)
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边
@@ -276,12 +283,13 @@ class WorkflowExecutor:
):
"""执行工作流(流式)
手动执行节点以支持细粒度的流式输出:
- workflow_start: 工作流开始
- node_start: 节点开始执行
- node_chunk: LLM 节点的流式输出片段(逐 token
- node_complete: 节点执行完成
- workflow_complete: 工作流完成
使用 stream_mode="updates" 来获取每个节点的 state 更新。
节点的 generator 会 yield 多个值:
- 中间的 chunk 事件(带 type="chunk"
- 最后的 state 更新(纯字典,包含 node_outputs 等
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
Args:
input_data: 输入数据
@@ -289,27 +297,47 @@ class WorkflowExecutor:
Yields:
流式事件
"""
#
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
graph = self.build_graph(True)
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
async for chunk in graph.astream(
async for mode, event in graph.astream(
initial_state,
# subgraphs=True,
stream_mode="updates",
stream_mode=["updates","messages"],
):
# print(chunk)
yield chunk
# print("刚才跑的节点:", event[0])
# # 通过图结构就能算出“接下来是谁”
# print("接下来可能跑:", graph.get_next(event[0]))
# print("="*50)
# # print("mode",mode)
# print("event",event)
# print("="*50)
# event 是一个字典key 是节点 IDvalue 是 state 更新或 chunk
for node_id, update in event.items():
print("="*50)
print("node_id",node_id)
print("update",update)
print("="*50)
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content")
}
# 其他情况state 更新)会被 LangGraph 自动合并到 state不需要我们处理
print(event)
yield event
except Exception as e:
# 计算耗时(即使失败也记录)

View File

@@ -209,11 +209,15 @@ class BaseNode(ABC):
3. 将业务数据包装成标准输出格式
4. 错误处理
注意:在流式模式下,我们需要:
- yield 中间的 chunk 事件(用于实时显示)
- 最后 yield 一个包含 state 更新的字典LangGraph 会合并到 state
Args:
state: 工作流状态
Yields:
标准化的流式事件
标准化的流式事件和最终的 state 更新
"""
import time
@@ -263,27 +267,39 @@ class BaseNode(ABC):
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
extracted_output = self._extract_output(final_result)
# 包装最终结果
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# 构建完整的 state 更新(包含 node_outputs 和 runtime_vars
state_update = {
**final_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
# 最后 yield 纯粹的 state 更新LangGraph 会合并到 state 中)
yield state_update
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
error_output = self._wrap_error(str(e), elapsed_time, state)
yield error_output
def _wrap_output(
self,

View File

@@ -30,11 +30,11 @@ class EndNode(BaseNode):
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
# pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# print("="*20)
# print( pool.get("start.test"))
# print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)

View File

@@ -63,7 +63,7 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
@@ -125,16 +125,19 @@ class LLMNode(BaseNode):
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
print("="*50)
print("stream",stream)
print("="*50)
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
base_url=api_base,
extra_params={"streaming": stream}
),
type=model_type
)
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
@@ -146,13 +149,12 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
llm, prompt_or_messages = self._prepare_llm(state,True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
@@ -199,47 +201,47 @@ class LLMNode(BaseNode):
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
# async def execute_stream(self, state: WorkflowState):
# """流式执行 LLM 调用
Args:
state: 工作流状态
# Args:
# state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
# Yields:
# 文本片段chunk或完成标记
# """
# llm, prompt_or_messages = self._prepare_llm(state,True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
last_chunk = None
# # 累积完整响应
# full_response = ""
# last_chunk = None
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
# # 调用 LLM流式支持字符串或消息列表
# async for chunk in llm.astream(prompt_or_messages):
# # 提取内容
# if hasattr(chunk, 'content'):
# content = chunk.content
# else:
# content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
# full_response += content
# last_chunk = chunk
# logger.info(f"节点 {self.node_id} LLM : {content}")
# # 流式返回每个文本片段
# yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# # 构建完整的 AIMessage包含元数据
# if isinstance(last_chunk, AIMessage):
# final_message = AIMessage(
# content=full_response,
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
# )
# else:
# final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}
# # yield 完成标记
# yield {"__final__": True, "result": final_message}