[modify] workflow executor support stream

This commit is contained in:
Mark
2025-12-20 13:59:20 +08:00
parent 0503b26232
commit 6c04c99073
7 changed files with 168 additions and 126 deletions

View File

@@ -9,18 +9,15 @@ LangChain Agent 封装
"""
import os
import time
import asyncio
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.memory.agent.mcp_server.services import session_service
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task

View File

@@ -92,7 +92,7 @@ class WorkflowExecutor:
def build_graph(self) -> CompiledStateGraph:
def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph
Returns:
@@ -122,12 +122,19 @@ class WorkflowExecutor:
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
async def node_func(state: WorkflowState, inst=node_instance):
async for item in inst.run_stream(state):
yield item
workflow.add_node(node_id, node_func)
else:
# 非流式模式:创建 async function
async def node_func(state: WorkflowState, inst=node_instance):
return await inst.run(state)
return node_func
workflow.add_node(node_id, node_func)
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边
@@ -276,12 +283,13 @@ class WorkflowExecutor:
):
"""执行工作流(流式)
手动执行节点以支持细粒度的流式输出:
- workflow_start: 工作流开始
- node_start: 节点开始执行
- node_chunk: LLM 节点的流式输出片段(逐 token
- node_complete: 节点执行完成
- workflow_complete: 工作流完成
使用 stream_mode="updates" 来获取每个节点的 state 更新。
节点的 generator 会 yield 多个值:
- 中间的 chunk 事件(带 type="chunk"
- 最后的 state 更新(纯字典,包含 node_outputs 等
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
Args:
input_data: 输入数据
@@ -289,27 +297,47 @@ class WorkflowExecutor:
Yields:
流式事件
"""
#
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
graph = self.build_graph(True)
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
async for chunk in graph.astream(
async for mode, event in graph.astream(
initial_state,
# subgraphs=True,
stream_mode="updates",
stream_mode=["updates","messages"],
):
# print(chunk)
yield chunk
# print("刚才跑的节点:", event[0])
# # 通过图结构就能算出“接下来是谁”
# print("接下来可能跑:", graph.get_next(event[0]))
# print("="*50)
# # print("mode",mode)
# print("event",event)
# print("="*50)
# event 是一个字典key 是节点 IDvalue 是 state 更新或 chunk
for node_id, update in event.items():
print("="*50)
print("node_id",node_id)
print("update",update)
print("="*50)
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content")
}
# 其他情况state 更新)会被 LangGraph 自动合并到 state不需要我们处理
print(event)
yield event
except Exception as e:
# 计算耗时(即使失败也记录)

View File

@@ -209,11 +209,15 @@ class BaseNode(ABC):
3. 将业务数据包装成标准输出格式
4. 错误处理
注意:在流式模式下,我们需要:
- yield 中间的 chunk 事件(用于实时显示)
- 最后 yield 一个包含 state 更新的字典LangGraph 会合并到 state
Args:
state: 工作流状态
Yields:
标准化的流式事件
标准化的流式事件和最终的 state 更新
"""
import time
@@ -263,27 +267,39 @@ class BaseNode(ABC):
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
extracted_output = self._extract_output(final_result)
# 包装最终结果
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# 构建完整的 state 更新(包含 node_outputs 和 runtime_vars
state_update = {
**final_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
# 最后 yield 纯粹的 state 更新LangGraph 会合并到 state 中)
yield state_update
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
error_output = self._wrap_error(str(e), elapsed_time, state)
yield error_output
def _wrap_output(
self,

View File

@@ -30,11 +30,11 @@ class EndNode(BaseNode):
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
# pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# print("="*20)
# print( pool.get("start.test"))
# print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)

View File

@@ -63,7 +63,7 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
@@ -125,16 +125,19 @@ class LLMNode(BaseNode):
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
print("="*50)
print("stream",stream)
print("="*50)
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
base_url=api_base,
extra_params={"streaming": stream}
),
type=model_type
)
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
@@ -146,13 +149,12 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
llm, prompt_or_messages = self._prepare_llm(state,True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
@@ -199,47 +201,47 @@ class LLMNode(BaseNode):
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
# async def execute_stream(self, state: WorkflowState):
# """流式执行 LLM 调用
Args:
state: 工作流状态
# Args:
# state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
# Yields:
# 文本片段chunk或完成标记
# """
# llm, prompt_or_messages = self._prepare_llm(state,True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
last_chunk = None
# # 累积完整响应
# full_response = ""
# last_chunk = None
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
# # 调用 LLM流式支持字符串或消息列表
# async for chunk in llm.astream(prompt_or_messages):
# # 提取内容
# if hasattr(chunk, 'content'):
# content = chunk.content
# else:
# content = str(chunk)
full_response += content
last_chunk = chunk
# full_response += content
# last_chunk = chunk
# logger.info(f"节点 {self.node_id} LLM : {content}")
# # 流式返回每个文本片段
# yield content
# 流式返回每个文本片段
yield content
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
# # 构建完整的 AIMessage包含元数据
# if isinstance(last_chunk, AIMessage):
# final_message = AIMessage(
# content=full_response,
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
# )
# else:
# final_message = AIMessage(content=full_response)
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}
# # yield 完成标记
# yield {"__final__": True, "result": final_message}

View File

@@ -385,7 +385,7 @@ class LLMRouter:
# 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == self.routing_model_config.id,
ModelApiKey.is_active == True
ModelApiKey.is_active
).first()
if not api_key_config:

View File

@@ -1,29 +1,27 @@
"""
工作流服务层
"""
import datetime
import json
import logging
import uuid
import datetime
from typing import Any, Annotated
from sqlalchemy.orm import Session
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository,
get_workflow_config_repository,
get_workflow_execution_repository,
get_workflow_node_execution_repository
WorkflowNodeExecutionRepository
)
from app.core.workflow.validator import validate_workflow_config
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.db import get_db
from app.schemas import DraftRunRequest
from app.utils.sse_utils import format_sse_message
logger = logging.getLogger(__name__)
@@ -81,7 +79,7 @@ class WorkflowService:
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
@@ -140,7 +138,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
@@ -166,7 +164,7 @@ class WorkflowService:
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
@@ -195,8 +193,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
return False
self.config_repo.delete(config.id)
config.is_active = False
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
return True
@@ -245,7 +242,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
@@ -359,7 +356,7 @@ class WorkflowService:
execution = self.get_execution(execution_id)
if not execution:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"执行记录不存在: execution_id={execution_id}"
)
@@ -474,11 +471,9 @@ class WorkflowService:
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
@@ -595,17 +590,18 @@ class WorkflowService:
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 流式执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
# 发送开始事件
yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n"
yield format_sse_message("workflow_start", {
"execution_id": execution.execution_id,
"conversation_id_uuid": str(conversation_id_uuid),
})
# 调用流式执行
async for event in self._run_workflow_stream(
@@ -621,7 +617,10 @@ class WorkflowService:
yield f"data: {json.dumps(cleaned_event)}\n\n"
# 发送完成事件
yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n"
yield format_sse_message("workflow_end", {
"execution_id": execution.execution_id,
"conversation_id_uuid": str(conversation_id_uuid),
})
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
@@ -660,7 +659,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
@@ -687,12 +686,12 @@ class WorkflowService:
app = self.db.query(App).filter(App.id == app_id).first()
if not app:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
code=BizCode.NOT_FOUND,
message=f"应用不存在: app_id={app_id}"
)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
@@ -750,7 +749,7 @@ class WorkflowService:
error_message=str(e)
)
raise BusinessException(
error_code=BizCode.INTERNAL_ERROR,
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
@@ -820,26 +819,26 @@ class WorkflowService:
yield event
# 收集输出数据
if event.get("type") == "node_complete":
node_data = event.get("data", {})
node_outputs = node_data.get("node_outputs", {})
output_data.update(node_outputs)
# 处理完成事件
if event.get("type") == "workflow_complete":
self.update_execution_status(
execution_id,
"completed",
output_data=output_data
)
# 处理错误事件
if event.get("type") == "workflow_error":
self.update_execution_status(
execution_id,
"failed",
error_message=event.get("error")
)
# if event.get("type") == "node_complete":
# node_data = event.get("data", {})
# node_outputs = node_data.get("node_outputs", {})
# output_data.update(node_outputs)
#
# # 处理完成事件
# if event.get("type") == "workflow_complete":
# self.update_execution_status(
# execution_id,
# "completed",
# output_data=output_data
# )
#
# # 处理错误事件
# if event.get("type") == "workflow_error":
# self.update_execution_status(
# execution_id,
# "failed",
# error_message=event.get("error")
# )
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)