Merge branch 'develop-mark' into develop

# Conflicts:
#	api/app/core/workflow/executor.py
#	api/app/services/workflow_service.py
This commit is contained in:
Mark
2025-12-20 17:51:49 +08:00
10 changed files with 915 additions and 223 deletions

View File

@@ -583,15 +583,27 @@ async def draft_run(
)
async def event_generator():
"""工作流事件生成器"""
# 调用多智能体服务的流式方法
"""工作流事件生成器
将事件转换为标准 SSE 格式:
event: <event_type>
data: <json_data>
"""
import json
# 调用工作流服务的流式方法
async for event in workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config
):
yield event
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
return StreamingResponse(
event_generator(),

View File

@@ -471,7 +471,20 @@ async def run_workflow(
import json
async def event_generator():
"""生成 SSE 事件"""
"""生成 SSE 事件
SSE 格式:
event: <event_type>
data: <json_data>
支持的事件类型:
- workflow_start: 工作流开始
- workflow_end: 工作流结束
- node_start: 节点开始执行
- node_end: 节点执行完成
- node_chunk: 中间节点的流式输出
- message: 最终消息的流式输出End 节点及其相邻节点)
"""
try:
async for event in await service.run_workflow(
app_id=app_id,
@@ -480,19 +493,30 @@ async def run_workflow(
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 转换为 SSE 格式
yield f"data: {json.dumps(event)}\n\n"
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
error_event = {
"type": "error",
"error": str(e)
}
yield f"data: {json.dumps(error_event)}\n\n"
# 发送错误事件
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
yield sse_error
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
}
)
else:
# 非流式执行

View File

@@ -9,18 +9,15 @@ LangChain Agent 封装
"""
import os
import time
import asyncio
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.memory.agent.mcp_server.services import session_service
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Iterator, AsyncIterator, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import LLMResult
from langchain_core.outputs import LLMResult, GenerationChunk
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
from app.models.models_model import ModelType
@@ -10,21 +10,36 @@ from app.models.models_model import ModelType
class RedBearLLM(BaseLLM):
"""
RedBear LLM 模型包装器 - 完全动态代理实现
RedBear LLM Model Wrapper
这个包装器自动将所有方法调用委托给内部模型,
同时提供优雅的回退机制和错误处理。
This wrapper provides a unified interface to access different LLM providers,
while maintaining all LangChain functionality, including streaming output.
Features:
- Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.)
- Full streaming output support
- Elegant error handling and fallback mechanism
- Automatic proxying of all underlying model methods and attributes
"""
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
self._model = self._create_model(config, type)
def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM):
"""Initialize RedBear LLM wrapper
Args:
config: Model configuration
type: Model type (LLM or CHAT)
"""
super().__init__()
self._config = config
self._model = self._create_model(config, type)
@property
def _llm_type(self) -> str:
"""返回LLM类型标识符"""
return self._model._llm_type
"""Return LLM type identifier"""
return getattr(self._model, '_llm_type', 'redbear_llm')
# ==================== Core Methods (Required by BaseLLM) ====================
def _generate(
self,
prompts: List[str],
@@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
"""同步生成文本"""
"""Synchronous text generation (required by BaseLLM)"""
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
async def _agenerate(
@@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
"""异步生成文本"""
"""Asynchronous text generation (required by BaseLLM)"""
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
# 关键:覆盖 invoke/ainvoke直接委托到底层模型避免 BaseLLM 的字符串化行为
# ==================== Advanced Methods (Support Message Lists) ====================
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
"""Synchronous model invocation
Supports various input formats including strings and message lists.
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
**kwargs: Additional arguments
Returns:
Model response
"""
try:
return self._model.invoke(input, config=config, **kwargs)
except AttributeError as e:
# 只在属性错误时回退(说明底层模型不支持该方法)
if 'invoke' in str(e):
# Underlying model doesn't support invoke, fallback to parent implementation
return super().invoke(input, config=config, **kwargs)
# 其他 AttributeError 直接抛出
raise
except Exception:
# 其他所有异常(包括 ValidationException直接抛出不回退
# Other exceptions are raised directly
raise
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
"""Asynchronous model invocation
Supports various input formats including strings and message lists.
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
**kwargs: Additional arguments
Returns:
Model response
"""
try:
return await self._model.ainvoke(input, config=config, **kwargs)
except AttributeError as e:
# 只在属性错误时回退(说明底层模型不支持该方法)
if 'ainvoke' in str(e):
# Underlying model doesn't support ainvoke, fallback to parent implementation
return await super().ainvoke(input, config=config, **kwargs)
# 其他 AttributeError 直接抛出
raise
except Exception:
# 其他所有异常(包括 ValidationException直接抛出不回退
# Other exceptions are raised directly
raise
def __getattr__(self, name):
"""
动态代理:将所有未定义的属性和方法调用委托给内部模型
# ==================== Streaming Methods (Critical) ====================
def stream(
self,
input: Any,
config: Optional[dict] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> Iterator[GenerationChunk]:
"""Synchronous streaming model invocation
这是最优雅的包装器实现方式,完全避免了方法重复定义
"""
# 处理特殊属性以避免递归
if name in ('__isabstractmethod__', '__dict__', '__class__'):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
stop: List of stop words
**kwargs: Additional arguments
# 检查内部模型是否有该属性(使用安全的方式避免递归)
Yields:
GenerationChunk: Generated text chunks
"""
try:
yield from self._model.stream(input, config=config, stop=stop, **kwargs)
except AttributeError as e:
if 'stream' in str(e):
# Underlying model doesn't support stream, fallback to parent implementation
yield from super().stream(input, config=config, stop=stop, **kwargs)
else:
raise
except Exception:
raise
async def astream(
self,
input: Any,
config: Optional[dict] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> AsyncIterator[GenerationChunk]:
"""Asynchronous streaming model invocation
This is the core method for streaming output. It directly proxies to the
underlying model's astream method, maintaining generator characteristics
to ensure each chunk is delivered in real-time.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
stop: List of stop words
**kwargs: Additional arguments
Yields:
GenerationChunk: Generated text chunks
"""
try:
async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs):
yield chunk
except AttributeError as e:
if 'astream' in str(e):
# Underlying model doesn't support astream, fallback to parent implementation
async for chunk in super().astream(input, config=config, stop=stop, **kwargs):
yield chunk
else:
raise
except Exception:
raise
# ==================== Dynamic Proxy ====================
def __getattr__(self, name: str) -> Any:
"""Dynamic proxy: delegate undefined attributes and method calls to internal model
This method allows RedBearLLM to transparently access all attributes and methods
of the underlying model without explicitly defining each one.
Args:
name: Attribute or method name
Returns:
Attribute value or method
Raises:
AttributeError: If attribute doesn't exist
"""
# Avoid recursion: raise error directly for special attributes
if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# Try to get attribute from internal model
try:
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
attr = object.__getattribute__(self._model, name)
# 如果是方法,返回一个包装器来处理调用
# If it's callable (a method)
if callable(attr):
# 流式方法直接返回,不包装(保持生成器特性)
if name in ('_stream', '_astream', 'stream', 'astream'):
# Streaming methods are returned directly to maintain generator characteristics
# Note: Although we've explicitly implemented stream/astream,
# this is kept to handle internal methods like _stream/_astream
if name in ('_stream', '_astream'):
return attr
# 非流式方法使用包装器处理异常
# Wrap other methods for easier debugging and error handling
def method_wrapper(*args, **kwargs):
return attr(*args, **kwargs)
try:
return attr(*args, **kwargs)
except Exception:
# Can add logging or error handling here
raise
# 保持方法的元信息
# Preserve method metadata
method_wrapper.__name__ = name
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
return method_wrapper
# 如果是普通属性,直接返回
# If it's a regular attribute, return directly
return attr
except AttributeError:
# 内部模型没有该属性,尝试回退实现
# Internal model doesn't have this attribute either
pass
# 检查是否有回退方法(使用安全的方式避免递归)
# Check if there's a fallback method
fallback_name = f'_fallback_{name}'
try:
fallback_method = object.__getattribute__(self, fallback_name)
return fallback_method
return object.__getattribute__(self, fallback_name)
except AttributeError:
# 没有回退方法,抛出适当的错误
pass
# 如果都没有,抛出适当的错误
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# Nothing found, raise error
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'. "
f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute."
)
# ==================== Helper Methods ====================
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
"""创建内部模型实例"""
"""Create internal model instance
Args:
config: Model configuration
type: Model type
Returns:
Created model instance
"""
llm_class = get_provider_llm_class(config, type)
model_params = RedBearModelFactory.get_model_params(config)
return llm_class(**model_params)
def get_config(self) -> RedBearModelConfig:
"""Get model configuration
Returns:
Model configuration object
"""
return self._config
def get_underlying_model(self) -> BaseLLM:
"""Get underlying model instance
Returns:
Underlying model instance
"""
return self._model
def __repr__(self) -> str:
"""Return string representation of the object"""
return (
f"RedBearLLM("
f"provider={self._config.provider}, "
f"model={self._config.model_name}, "
f"type={type(self._model).__name__}"
f")"
)

View File

@@ -94,16 +94,90 @@ class WorkflowExecutor:
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
"error_node": None,
"streaming_buffer": {} # 流式缓冲区
}
def build_graph(self) -> CompiledStateGraph:
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置和相邻且被引用的节点
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
@@ -143,16 +217,39 @@ class WorkflowExecutor:
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
@@ -300,40 +397,143 @@ class WorkflowExecutor:
):
"""执行工作流(流式)
手动执行节点以支持细粒度的流式输出
- workflow_start: 工作流开始
- node_start: 节点开始执行
- node_chunk: LLM 节点的流式输出片段(逐 token
- node_complete: 节点执行完成
- workflow_complete: 工作流完成
使用多个 stream_mode 来获取
1. "updates" - 节点的 state 更新和流式 chunk
2. "debug" - 节点执行的详细信息(开始/完成时间)
3. "custom" - 自定义流式数据chunks
Args:
input_data: 输入数据
Yields:
流式事件
流式事件,格式:
{
"event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message",
"data": {...}
}
"""
#
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 发送 workflow_start 事件
yield {
"event": "workflow_start",
"data": {
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"timestamp": start_time.isoformat()
}
}
# 1. 构建图
graph = self.build_graph()
graph = self.build_graph(True)
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
# 3. Execute workflow
try:
async for chunk in graph.astream(
chunk_count = 0
final_state = None
async for event in graph.astream(
initial_state,
# subgraphs=True,
stream_mode="updates",
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
):
# print(chunk)
yield chunk
# event should be a tuple: (mode, data)
# But let's handle both cases
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
# Unexpected format, log and skip
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
continue
if mode == "custom":
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index"),
"is_prefix": data.get("is_prefix"),
"is_suffix": data.get("is_suffix")
}
}
elif mode == "debug":
# Handle debug information (node execution status)
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if event_type == "task":
# Node starts execution
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node starts execution: {node_name}")
yield {
"event": "node_start",
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
}
elif event_type == "task_result":
# Node execution completed
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node execution completed: {node_name}")
yield {
"event": "node_end",
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
}
elif mode == "updates":
# Handle state updates - store final state
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
final_state = data
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s")
# 发送 workflow_end 事件
yield {
"event": "workflow_end",
"data": {
"execution_id": self.execution_id,
"status": "completed",
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
}
except Exception as e:
# 计算耗时(即使失败也记录)
@@ -341,13 +541,17 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
# 发送 workflow_end 事件(失败)
yield {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
"event": "workflow_end",
"data": {
"execution_id": self.execution_id,
"status": "failed",
"error": str(e),
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:

View File

@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.config import get_stream_writer
from app.core.workflow.variable_pool import VariablePool
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
# 错误信息(用于错误边)
error: str | None
error_node: str | None
# 流式缓冲区(存储节点的实时流式输出)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
class BaseNode(ABC):
@@ -201,19 +206,25 @@ class BaseNode(ABC):
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
"""Execute node with error handling and output wrapping (streaming)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
This method is called by the Executor and is responsible for:
1. Time tracking
2. Calling the node's execute_stream() method
3. Using LangGraph's stream writer to send chunks
4. Updating streaming buffer in state for downstream nodes
5. Wrapping business data into standard output format
6. Error handling
Special handling for End nodes:
- End nodes don't send chunks via writer (prefix and LLM content already sent)
- End nodes only yield suffix for final result assembly
Args:
state: 工作流状态
state: Workflow state
Yields:
标准化的流式事件
State updates with streaming buffer and final result
"""
import time
@@ -222,68 +233,143 @@ class BaseNode(ABC):
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()
# Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Check if this node is adjacent to End node (for message type)
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping)
chunks = []
final_result = None
chunk_count = 0
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
# Stream chunks in real-time
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# Check if it's a completion marker
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# String is a chunk
chunk_count += 1
chunks.append(item)
full_content = "".join(chunks)
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
chunks.append(item)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": item,
"full_content": full_content,
"chunk_index": chunk_count
})
# 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
else:
# Other types are also treated as chunks
chunk_count += 1
chunk_str = str(item)
chunks.append(chunk_str)
full_content = "".join(chunks)
# Send chunks for all nodes
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,
"chunk_index": chunk_count
})
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 包装最终结果
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
# Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
# Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
# Add streaming buffer for non-End nodes
if not is_end_node:
state_update["streaming_buffer"] = {
self.node_id: {
"full_content": "".join(chunks),
"chunk_count": chunk_count,
"is_complete": True # Mark as complete
}
}
# Finally yield state update
# LangGraph will merge this into state
yield state_update
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时{timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
error_output = self._wrap_error(str(e), elapsed_time, state)
yield error_output
def _wrap_output(
self,

View File

@@ -5,6 +5,8 @@ End 节点实现
"""
import logging
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -15,6 +17,7 @@ class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
async def execute(self, state: WorkflowState) -> str:
@@ -30,11 +33,7 @@ class EndNode(BaseNode):
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
@@ -46,7 +45,213 @@ class EndNode(BaseNode):
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output
def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args:
template: 模板字符串
Returns:
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀'
返回:[
{"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"}
]
Args:
template: 模板字符串
state: 工作流状态
Returns:
模板部分列表
"""
import re
parts = []
last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template):
start, end = match.span()
# 添加前面的静态文本
if start > last_end:
static_text = template[last_end:start]
if static_text:
parts.append({"type": "static", "content": static_text})
# 解析动态引用
ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref:
node_id, field = ref.split('.', 1)
parts.append({
"type": "dynamic",
"node_id": node_id,
"field": field,
"raw": ref
})
else:
# 其他引用(如 {{var.xxx}}),当作静态处理
# 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered})
last_end = end
# 添加最后的静态文本
if last_end < len(template):
static_text = template[last_end:]
if static_text:
parts.append({"type": "static", "content": static_text})
return parts
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
智能输出策略:
1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Args:
state: 工作流状态
Yields:
完成标记
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
if not output_template:
output = "工作流已完成"
yield {"__final__": True, "result": output}
return
# 找到直接上游节点
direct_upstream_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
# 解析模板部分
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
# 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_ref_index is None:
# 没有引用直接上游节点,正常输出(渲染完整模板)
output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容")
yield {"__final__": True, "result": output}
return
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# 收集后缀部分
suffix_parts = []
for i in range(upstream_ref_index + 1, len(parts)):
part = parts[i]
if part["type"] == "static":
# 静态文本
suffix_parts.append(part["content"])
elif part["type"] == "dynamic":
# 其他动态引用(如果有多个引用)
node_id = part["node_id"]
field = part["field"]
# 从 streaming_buffer 或 node_outputs 读取
streaming_buffer = state.get("streaming_buffer", {})
if node_id in streaming_buffer:
buffer_data = streaming_buffer[node_id]
content = buffer_data.get("full_content", "")
else:
node_outputs = state.get("node_outputs", {})
runtime_vars = state.get("runtime_vars", {})
content = ""
if node_id in node_outputs:
node_output = node_outputs[node_id]
if isinstance(node_output, dict):
content = str(node_output.get(field, ""))
elif node_id in runtime_vars:
runtime_var = runtime_vars[node_id]
if isinstance(runtime_var, dict):
content = str(runtime_var.get(field, ""))
suffix_parts.append(content)
# 拼接后缀
suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state)
if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
# 而是通过 writer 直接发送
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
"chunk_index": 1,
"is_suffix": True
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
# yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output}

View File

@@ -63,7 +63,7 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
@@ -125,16 +125,22 @@ class LLMNode(BaseNode):
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
base_url=api_base,
extra_params=extra_params
),
type=model_type
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
@@ -146,13 +152,12 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
llm, prompt_or_messages = self._prepare_llm(state,True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
@@ -208,13 +213,43 @@ class LLMNode(BaseNode):
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,
"chunk_index": 0,
"is_prefix": True # 标记这是前缀
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
@@ -224,13 +259,16 @@ class LLMNode(BaseNode):
else:
content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):

View File

@@ -385,7 +385,7 @@ class LLMRouter:
# 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == self.routing_model_config.id,
ModelApiKey.is_active == True
ModelApiKey.is_active
).first()
if not api_key_config:

View File

@@ -1,29 +1,28 @@
"""
工作流服务层
"""
import datetime
import json
import logging
import uuid
import datetime
from typing import Any, Annotated, AsyncGenerator
from sqlalchemy.orm import Session
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository,
get_workflow_config_repository,
get_workflow_execution_repository,
get_workflow_node_execution_repository
WorkflowNodeExecutionRepository
)
from app.core.workflow.validator import validate_workflow_config
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.db import get_db
from app.schemas import DraftRunRequest
from app.utils.sse_utils import format_sse_message
logger = logging.getLogger(__name__)
@@ -195,8 +194,7 @@ class WorkflowService:
config = self.get_workflow_config(app_id)
if not config:
return False
self.config_repo.delete(config.id)
config.is_active = False
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
return True
@@ -474,11 +472,9 @@ class WorkflowService:
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
@@ -595,19 +591,14 @@ class WorkflowService:
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 流式执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
# 发送开始事件
yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n"
# 调用流式执行
# 调用流式执行executor 会发送 workflow_start 和 workflow_end 事件
async for event in self._run_workflow_stream(
workflow_config=workflow_config_dict,
input_data=input_data,
@@ -615,13 +606,8 @@ class WorkflowService:
workspace_id="",
user_id=payload.user_id
):
# 清理事件数据,移除不可序列化的对象
cleaned_event = self._clean_event_for_json(event)
# 转换为 SSE 格式
yield f"data: {json.dumps(cleaned_event)}\n\n"
# 发送完成事件
yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n"
# 直接转发 executor 的事件(已经是正确的格式)
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
@@ -631,7 +617,13 @@ class WorkflowService:
error_message=str(e)
)
# 发送错误事件
yield f"data: {json.dumps({'type': 'error', 'execution_id': execution.execution_id, 'error': str(e)})}\n\n"
yield {
"event": "error",
"data": {
"execution_id": execution.execution_id,
"error": str(e)
}
}
async def run_workflow(
self,
@@ -692,7 +684,7 @@ class WorkflowService:
)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
@@ -802,13 +794,11 @@ class WorkflowService:
user_id: 用户 ID
Yields:
流式事件
流式事件(格式:{"event": "<type>", "data": {...}}
"""
from app.core.workflow.executor import execute_workflow_stream
try:
output_data = {}
async for event in execute_workflow_stream(
workflow_config=workflow_config,
input_data=input_data,
@@ -816,31 +806,9 @@ class WorkflowService:
workspace_id=workspace_id,
user_id=user_id
):
# 转发事件
# 直接转发事件executor 已经返回正确格式)
yield event
# 收集输出数据
if event.get("type") == "node_complete":
node_data = event.get("data", {})
node_outputs = node_data.get("node_outputs", {})
output_data.update(node_outputs)
# 处理完成事件
if event.get("type") == "workflow_complete":
self.update_execution_status(
execution_id,
"completed",
output_data=output_data
)
# 处理错误事件
if event.get("type") == "workflow_error":
self.update_execution_status(
execution_id,
"failed",
error_message=event.get("error")
)
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
self.update_execution_status(
@@ -849,9 +817,11 @@ class WorkflowService:
error_message=str(e)
)
yield {
"type": "workflow_error",
"execution_id": execution_id,
"error": str(e)
"event": "error",
"data": {
"execution_id": execution_id,
"error": str(e)
}
}