[modify] workflow executor support stream
This commit is contained in:
@@ -9,18 +9,15 @@ LangChain Agent 封装
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.core.memory.agent.mcp_server.services import session_service
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -92,7 +92,7 @@ class WorkflowExecutor:
|
||||
|
||||
|
||||
|
||||
def build_graph(self) -> CompiledStateGraph:
|
||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
Returns:
|
||||
@@ -122,12 +122,19 @@ class WorkflowExecutor:
|
||||
if node_instance:
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
def make_node_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
if stream:
|
||||
# 流式模式:创建 async generator 函数
|
||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||
async def node_func(state: WorkflowState, inst=node_instance):
|
||||
async for item in inst.run_stream(state):
|
||||
yield item
|
||||
workflow.add_node(node_id, node_func)
|
||||
else:
|
||||
# 非流式模式:创建 async function
|
||||
async def node_func(state: WorkflowState, inst=node_instance):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
workflow.add_node(node_id, node_func)
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
|
||||
# 3. 添加边
|
||||
@@ -276,12 +283,13 @@ class WorkflowExecutor:
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
手动执行节点以支持细粒度的流式输出:
|
||||
- workflow_start: 工作流开始
|
||||
- node_start: 节点开始执行
|
||||
- node_chunk: LLM 节点的流式输出片段(逐 token)
|
||||
- node_complete: 节点执行完成
|
||||
- workflow_complete: 工作流完成
|
||||
使用 stream_mode="updates" 来获取每个节点的 state 更新。
|
||||
节点的 generator 会 yield 多个值:
|
||||
- 中间的 chunk 事件(带 type="chunk")
|
||||
- 最后的 state 更新(纯字典,包含 node_outputs 等)
|
||||
|
||||
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
|
||||
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
@@ -289,27 +297,47 @@ class WorkflowExecutor:
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
#
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
graph = self.build_graph(True)
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
async for chunk in graph.astream(
|
||||
async for mode, event in graph.astream(
|
||||
initial_state,
|
||||
# subgraphs=True,
|
||||
stream_mode="updates",
|
||||
stream_mode=["updates","messages"],
|
||||
):
|
||||
# print(chunk)
|
||||
yield chunk
|
||||
# print("刚才跑的节点:", event[0])
|
||||
# # 通过图结构就能算出“接下来是谁”
|
||||
# print("接下来可能跑:", graph.get_next(event[0]))
|
||||
# print("="*50)
|
||||
# # print("mode",mode)
|
||||
# print("event",event)
|
||||
# print("="*50)
|
||||
# event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
||||
for node_id, update in event.items():
|
||||
print("="*50)
|
||||
print("node_id",node_id)
|
||||
print("update",update)
|
||||
|
||||
print("="*50)
|
||||
if isinstance(update, dict) and update.get("type") == "chunk":
|
||||
# 这是流式 chunk,转发给客户端
|
||||
yield {
|
||||
"type": "node_chunk",
|
||||
"node_id": update.get("node_id"),
|
||||
"chunk": update.get("content")
|
||||
}
|
||||
# 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理
|
||||
print(event)
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
|
||||
@@ -209,11 +209,15 @@ class BaseNode(ABC):
|
||||
3. 将业务数据包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
注意:在流式模式下,我们需要:
|
||||
- yield 中间的 chunk 事件(用于实时显示)
|
||||
- 最后 yield 一个包含 state 更新的字典(LangGraph 会合并到 state)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
标准化的流式事件
|
||||
标准化的流式事件和最终的 state 更新
|
||||
"""
|
||||
import time
|
||||
|
||||
@@ -263,27 +267,39 @@ class BaseNode(ABC):
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取处理后的输出(调用子类的 _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
|
||||
# 包装最终结果
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
yield {
|
||||
"type": "complete",
|
||||
**final_output
|
||||
|
||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
|
||||
# 构建完整的 state 更新(包含 node_outputs 和 runtime_vars)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
}
|
||||
|
||||
# 最后 yield 纯粹的 state 更新(LangGraph 会合并到 state 中)
|
||||
yield state_update
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
}
|
||||
error_output = self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(str(e), elapsed_time, state)
|
||||
}
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state)
|
||||
yield error_output
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
|
||||
@@ -30,11 +30,11 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
pool = self.get_variable_pool(state)
|
||||
# pool = self.get_variable_pool(state)
|
||||
|
||||
print("="*20)
|
||||
print( pool.get("start.test"))
|
||||
print("="*20)
|
||||
# print("="*20)
|
||||
# print( pool.get("start.test"))
|
||||
# print("="*20)
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
|
||||
@@ -63,7 +63,7 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
|
||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -125,16 +125,19 @@ class LLMNode(BaseNode):
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
print("="*50)
|
||||
print("stream",stream)
|
||||
print("="*50)
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
base_url=api_base,
|
||||
extra_params={"streaming": stream}
|
||||
),
|
||||
type=model_type
|
||||
)
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
@@ -146,13 +149,12 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
@@ -199,47 +201,47 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 LLM 调用
|
||||
# async def execute_stream(self, state: WorkflowState):
|
||||
# """流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
# Args:
|
||||
# state: 工作流状态
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
# Yields:
|
||||
# 文本片段(chunk)或完成标记
|
||||
# """
|
||||
# llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
# # 累积完整响应
|
||||
# full_response = ""
|
||||
# last_chunk = None
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
# # 调用 LLM(流式,支持字符串或消息列表)
|
||||
# async for chunk in llm.astream(prompt_or_messages):
|
||||
# # 提取内容
|
||||
# if hasattr(chunk, 'content'):
|
||||
# content = chunk.content
|
||||
# else:
|
||||
# content = str(chunk)
|
||||
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
# full_response += content
|
||||
# last_chunk = chunk
|
||||
# logger.info(f"节点 {self.node_id} LLM : {content}")
|
||||
# # 流式返回每个文本片段
|
||||
# yield content
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
# # 构建完整的 AIMessage(包含元数据)
|
||||
# if isinstance(last_chunk, AIMessage):
|
||||
# final_message = AIMessage(
|
||||
# content=full_response,
|
||||
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
# )
|
||||
# else:
|
||||
# final_message = AIMessage(content=full_response)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
# # yield 完成标记
|
||||
# yield {"__final__": True, "result": final_message}
|
||||
|
||||
@@ -385,7 +385,7 @@ class LLMRouter:
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == self.routing_model_config.id,
|
||||
ModelApiKey.is_active == True
|
||||
ModelApiKey.is_active
|
||||
).first()
|
||||
|
||||
if not api_key_config:
|
||||
|
||||
@@ -1,29 +1,27 @@
|
||||
"""
|
||||
工作流服务层
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Any, Annotated
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository,
|
||||
get_workflow_config_repository,
|
||||
get_workflow_execution_repository,
|
||||
get_workflow_node_execution_repository
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.db import get_db
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,7 +79,7 @@ class WorkflowService:
|
||||
if not is_valid:
|
||||
logger.warning(f"工作流配置验证失败: {errors}")
|
||||
raise BusinessException(
|
||||
error_code=BizCode.INVALID_PARAMETER,
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
@@ -140,7 +138,7 @@ class WorkflowService:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
@@ -166,7 +164,7 @@ class WorkflowService:
|
||||
if not is_valid:
|
||||
logger.warning(f"工作流配置验证失败: {errors}")
|
||||
raise BusinessException(
|
||||
error_code=BizCode.INVALID_PARAMETER,
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
@@ -195,8 +193,7 @@ class WorkflowService:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
return False
|
||||
|
||||
self.config_repo.delete(config.id)
|
||||
config.is_active = False
|
||||
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||
return True
|
||||
|
||||
@@ -245,7 +242,7 @@ class WorkflowService:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
@@ -359,7 +356,7 @@ class WorkflowService:
|
||||
execution = self.get_execution(execution_id)
|
||||
if not execution:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"执行记录不存在: execution_id={execution_id}"
|
||||
)
|
||||
|
||||
@@ -474,11 +471,9 @@ class WorkflowService:
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
@@ -595,17 +590,18 @@ class WorkflowService:
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
|
||||
# 5. 流式执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n"
|
||||
yield format_sse_message("workflow_start", {
|
||||
"execution_id": execution.execution_id,
|
||||
"conversation_id_uuid": str(conversation_id_uuid),
|
||||
})
|
||||
|
||||
# 调用流式执行
|
||||
async for event in self._run_workflow_stream(
|
||||
@@ -621,7 +617,10 @@ class WorkflowService:
|
||||
yield f"data: {json.dumps(cleaned_event)}\n\n"
|
||||
|
||||
# 发送完成事件
|
||||
yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n"
|
||||
yield format_sse_message("workflow_end", {
|
||||
"execution_id": execution.execution_id,
|
||||
"conversation_id_uuid": str(conversation_id_uuid),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
@@ -660,7 +659,7 @@ class WorkflowService:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
@@ -687,12 +686,12 @@ class WorkflowService:
|
||||
app = self.db.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"应用不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
@@ -750,7 +749,7 @@ class WorkflowService:
|
||||
error_message=str(e)
|
||||
)
|
||||
raise BusinessException(
|
||||
error_code=BizCode.INTERNAL_ERROR,
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -820,26 +819,26 @@ class WorkflowService:
|
||||
yield event
|
||||
|
||||
# 收集输出数据
|
||||
if event.get("type") == "node_complete":
|
||||
node_data = event.get("data", {})
|
||||
node_outputs = node_data.get("node_outputs", {})
|
||||
output_data.update(node_outputs)
|
||||
|
||||
# 处理完成事件
|
||||
if event.get("type") == "workflow_complete":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=output_data
|
||||
)
|
||||
|
||||
# 处理错误事件
|
||||
if event.get("type") == "workflow_error":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
error_message=event.get("error")
|
||||
)
|
||||
# if event.get("type") == "node_complete":
|
||||
# node_data = event.get("data", {})
|
||||
# node_outputs = node_data.get("node_outputs", {})
|
||||
# output_data.update(node_outputs)
|
||||
#
|
||||
# # 处理完成事件
|
||||
# if event.get("type") == "workflow_complete":
|
||||
# self.update_execution_status(
|
||||
# execution_id,
|
||||
# "completed",
|
||||
# output_data=output_data
|
||||
# )
|
||||
#
|
||||
# # 处理错误事件
|
||||
# if event.get("type") == "workflow_error":
|
||||
# self.update_execution_status(
|
||||
# execution_id,
|
||||
# "failed",
|
||||
# error_message=event.get("error")
|
||||
# )
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
|
||||
|
||||
Reference in New Issue
Block a user