Merge branch 'develop-mark' into develop
# Conflicts: # api/app/core/workflow/executor.py # api/app/services/workflow_service.py
This commit is contained in:
@@ -583,15 +583,27 @@ async def draft_run(
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
"""工作流事件生成器"""
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
"""工作流事件生成器
|
||||
|
||||
将事件转换为标准 SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
"""
|
||||
import json
|
||||
|
||||
# 调用工作流服务的流式方法
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config
|
||||
):
|
||||
yield event
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
@@ -471,7 +471,20 @@ async def run_workflow(
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件"""
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
@@ -480,19 +493,30 @@ async def run_workflow(
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 转换为 SSE 格式
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream"
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
|
||||
@@ -9,18 +9,15 @@ LangChain Agent 封装
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.core.memory.agent.mcp_server.services import session_service
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Iterator, AsyncIterator, List, Optional
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.outputs import LLMResult, GenerationChunk
|
||||
|
||||
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
|
||||
from app.models.models_model import ModelType
|
||||
@@ -10,21 +10,36 @@ from app.models.models_model import ModelType
|
||||
|
||||
class RedBearLLM(BaseLLM):
|
||||
"""
|
||||
RedBear LLM 模型包装器 - 完全动态代理实现
|
||||
RedBear LLM Model Wrapper
|
||||
|
||||
这个包装器自动将所有方法调用委托给内部模型,
|
||||
同时提供优雅的回退机制和错误处理。
|
||||
This wrapper provides a unified interface to access different LLM providers,
|
||||
while maintaining all LangChain functionality, including streaming output.
|
||||
|
||||
Features:
|
||||
- Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.)
|
||||
- Full streaming output support
|
||||
- Elegant error handling and fallback mechanism
|
||||
- Automatic proxying of all underlying model methods and attributes
|
||||
"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
|
||||
self._model = self._create_model(config, type)
|
||||
def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM):
|
||||
"""Initialize RedBear LLM wrapper
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
type: Model type (LLM or CHAT)
|
||||
"""
|
||||
super().__init__()
|
||||
self._config = config
|
||||
self._model = self._create_model(config, type)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""返回LLM类型标识符"""
|
||||
return self._model._llm_type
|
||||
"""Return LLM type identifier"""
|
||||
return getattr(self._model, '_llm_type', 'redbear_llm')
|
||||
|
||||
# ==================== Core Methods (Required by BaseLLM) ====================
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""同步生成文本"""
|
||||
"""Synchronous text generation (required by BaseLLM)"""
|
||||
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
async def _agenerate(
|
||||
@@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""异步生成文本"""
|
||||
"""Asynchronous text generation (required by BaseLLM)"""
|
||||
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
# 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为
|
||||
# ==================== Advanced Methods (Support Message Lists) ====================
|
||||
|
||||
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
"""Synchronous model invocation
|
||||
|
||||
Supports various input formats including strings and message lists.
|
||||
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
try:
|
||||
return self._model.invoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'invoke' in str(e):
|
||||
# Underlying model doesn't support invoke, fallback to parent implementation
|
||||
return super().invoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
# Other exceptions are raised directly
|
||||
raise
|
||||
|
||||
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
"""Asynchronous model invocation
|
||||
|
||||
Supports various input formats including strings and message lists.
|
||||
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
try:
|
||||
return await self._model.ainvoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'ainvoke' in str(e):
|
||||
# Underlying model doesn't support ainvoke, fallback to parent implementation
|
||||
return await super().ainvoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
# Other exceptions are raised directly
|
||||
raise
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
动态代理:将所有未定义的属性和方法调用委托给内部模型
|
||||
# ==================== Streaming Methods (Critical) ====================
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[dict] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Synchronous streaming model invocation
|
||||
|
||||
这是最优雅的包装器实现方式,完全避免了方法重复定义
|
||||
"""
|
||||
# 处理特殊属性以避免递归
|
||||
if name in ('__isabstractmethod__', '__dict__', '__class__'):
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
stop: List of stop words
|
||||
**kwargs: Additional arguments
|
||||
|
||||
# 检查内部模型是否有该属性(使用安全的方式避免递归)
|
||||
Yields:
|
||||
GenerationChunk: Generated text chunks
|
||||
"""
|
||||
try:
|
||||
yield from self._model.stream(input, config=config, stop=stop, **kwargs)
|
||||
except AttributeError as e:
|
||||
if 'stream' in str(e):
|
||||
# Underlying model doesn't support stream, fallback to parent implementation
|
||||
yield from super().stream(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[dict] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
"""Asynchronous streaming model invocation
|
||||
|
||||
This is the core method for streaming output. It directly proxies to the
|
||||
underlying model's astream method, maintaining generator characteristics
|
||||
to ensure each chunk is delivered in real-time.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
stop: List of stop words
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Yields:
|
||||
GenerationChunk: Generated text chunks
|
||||
"""
|
||||
try:
|
||||
async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs):
|
||||
yield chunk
|
||||
except AttributeError as e:
|
||||
if 'astream' in str(e):
|
||||
# Underlying model doesn't support astream, fallback to parent implementation
|
||||
async for chunk in super().astream(input, config=config, stop=stop, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# ==================== Dynamic Proxy ====================
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamic proxy: delegate undefined attributes and method calls to internal model
|
||||
|
||||
This method allows RedBearLLM to transparently access all attributes and methods
|
||||
of the underlying model without explicitly defining each one.
|
||||
|
||||
Args:
|
||||
name: Attribute or method name
|
||||
|
||||
Returns:
|
||||
Attribute value or method
|
||||
|
||||
Raises:
|
||||
AttributeError: If attribute doesn't exist
|
||||
"""
|
||||
# Avoid recursion: raise error directly for special attributes
|
||||
if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'):
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
# Try to get attribute from internal model
|
||||
try:
|
||||
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
|
||||
attr = object.__getattribute__(self._model, name)
|
||||
|
||||
# 如果是方法,返回一个包装器来处理调用
|
||||
# If it's callable (a method)
|
||||
if callable(attr):
|
||||
# 流式方法直接返回,不包装(保持生成器特性)
|
||||
if name in ('_stream', '_astream', 'stream', 'astream'):
|
||||
# Streaming methods are returned directly to maintain generator characteristics
|
||||
# Note: Although we've explicitly implemented stream/astream,
|
||||
# this is kept to handle internal methods like _stream/_astream
|
||||
if name in ('_stream', '_astream'):
|
||||
return attr
|
||||
|
||||
# 非流式方法使用包装器处理异常
|
||||
# Wrap other methods for easier debugging and error handling
|
||||
def method_wrapper(*args, **kwargs):
|
||||
return attr(*args, **kwargs)
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception:
|
||||
# Can add logging or error handling here
|
||||
raise
|
||||
|
||||
# 保持方法的元信息
|
||||
# Preserve method metadata
|
||||
method_wrapper.__name__ = name
|
||||
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
|
||||
return method_wrapper
|
||||
|
||||
# 如果是普通属性,直接返回
|
||||
# If it's a regular attribute, return directly
|
||||
return attr
|
||||
|
||||
except AttributeError:
|
||||
# 内部模型没有该属性,尝试回退实现
|
||||
# Internal model doesn't have this attribute either
|
||||
pass
|
||||
|
||||
# 检查是否有回退方法(使用安全的方式避免递归)
|
||||
# Check if there's a fallback method
|
||||
fallback_name = f'_fallback_{name}'
|
||||
try:
|
||||
fallback_method = object.__getattribute__(self, fallback_name)
|
||||
return fallback_method
|
||||
return object.__getattribute__(self, fallback_name)
|
||||
except AttributeError:
|
||||
# 没有回退方法,抛出适当的错误
|
||||
pass
|
||||
|
||||
# 如果都没有,抛出适当的错误
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
# Nothing found, raise error
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'. "
|
||||
f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute."
|
||||
)
|
||||
|
||||
# ==================== Helper Methods ====================
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
|
||||
"""创建内部模型实例"""
|
||||
"""Create internal model instance
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
type: Model type
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
"""
|
||||
llm_class = get_provider_llm_class(config, type)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return llm_class(**model_params)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_config(self) -> RedBearModelConfig:
|
||||
"""Get model configuration
|
||||
|
||||
Returns:
|
||||
Model configuration object
|
||||
"""
|
||||
return self._config
|
||||
|
||||
def get_underlying_model(self) -> BaseLLM:
|
||||
"""Get underlying model instance
|
||||
|
||||
Returns:
|
||||
Underlying model instance
|
||||
"""
|
||||
return self._model
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of the object"""
|
||||
return (
|
||||
f"RedBearLLM("
|
||||
f"provider={self._config.provider}, "
|
||||
f"model={self._config.model_name}, "
|
||||
f"type={type(self._model).__name__}"
|
||||
f")"
|
||||
)
|
||||
@@ -94,16 +94,90 @@ class WorkflowExecutor:
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None
|
||||
"error_node": None,
|
||||
"streaming_buffer": {} # 流式缓冲区
|
||||
}
|
||||
|
||||
def build_graph(self) -> CompiledStateGraph:
|
||||
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""分析 End 节点的前缀配置
|
||||
|
||||
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||
提取该引用之前的前缀部分。
|
||||
|
||||
Returns:
|
||||
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||
"""
|
||||
import re
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||
|
||||
# 找到所有 End 节点
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
|
||||
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
continue
|
||||
|
||||
# 找到所有直接连接到 End 节点的上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
# 查找模板中引用了哪些节点
|
||||
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
|
||||
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
|
||||
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
|
||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
Returns:
|
||||
编译后的状态图
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
|
||||
# 分析 End 节点的前缀配置和相邻且被引用的节点
|
||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
|
||||
|
||||
# 1. 创建状态图
|
||||
workflow = StateGraph(WorkflowState)
|
||||
@@ -143,16 +217,39 @@ class WorkflowExecutor:
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
||||
if stream and node_id in end_prefixes:
|
||||
# 将 End 前缀配置注入到节点实例
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||
|
||||
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
||||
if stream:
|
||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||
if node_id in adjacent_and_referenced:
|
||||
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
||||
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
def make_node_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
if stream:
|
||||
# 流式模式:创建 async generator 函数
|
||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||
def make_stream_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
||||
async for item in inst.run_stream(state):
|
||||
yield item
|
||||
return node_func
|
||||
workflow.add_node(node_id, make_stream_func(node_instance))
|
||||
else:
|
||||
# 非流式模式:创建 async function
|
||||
def make_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
workflow.add_node(node_id, make_func(node_instance))
|
||||
|
||||
return node_func
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
|
||||
|
||||
# 3. 添加边
|
||||
# 从 START 连接到 start 节点
|
||||
@@ -300,40 +397,143 @@ class WorkflowExecutor:
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
手动执行节点以支持细粒度的流式输出:
|
||||
- workflow_start: 工作流开始
|
||||
- node_start: 节点开始执行
|
||||
- node_chunk: LLM 节点的流式输出片段(逐 token)
|
||||
- node_complete: 节点执行完成
|
||||
- workflow_complete: 工作流完成
|
||||
使用多个 stream_mode 来获取:
|
||||
1. "updates" - 节点的 state 更新和流式 chunk
|
||||
2. "debug" - 节点执行的详细信息(开始/完成时间)
|
||||
3. "custom" - 自定义流式数据(chunks)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
流式事件,格式:
|
||||
{
|
||||
"event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message",
|
||||
"data": {...}
|
||||
}
|
||||
"""
|
||||
#
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# 发送 workflow_start 事件
|
||||
yield {
|
||||
"event": "workflow_start",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"timestamp": start_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
graph = self.build_graph(True)
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 执行工作流
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
async for chunk in graph.astream(
|
||||
chunk_count = 0
|
||||
final_state = None
|
||||
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
# subgraphs=True,
|
||||
stream_mode="updates",
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
):
|
||||
# print(chunk)
|
||||
yield chunk
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
# Unexpected format, log and skip
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
|
||||
continue
|
||||
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
|
||||
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
"data": {
|
||||
"node_id": data.get("node_id"),
|
||||
"chunk": data.get("chunk"),
|
||||
"full_content": data.get("full_content"),
|
||||
"chunk_index": data.get("chunk_index"),
|
||||
"is_prefix": data.get("is_prefix"),
|
||||
"is_suffix": data.get("is_suffix")
|
||||
}
|
||||
}
|
||||
|
||||
elif mode == "debug":
|
||||
# Handle debug information (node execution status)
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node starts execution: {node_name}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
inputv = result.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node execution completed: {node_name}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||
final_state = data
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s")
|
||||
|
||||
# 发送 workflow_end 事件
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"status": "completed",
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
@@ -341,13 +541,17 @@ class WorkflowExecutor:
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
|
||||
# 发送 workflow_end 事件(失败)
|
||||
yield {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"output": None,
|
||||
"node_outputs": {},
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||
|
||||
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict, Annotated
|
||||
from operator import add
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
|
||||
# 错误信息(用于错误边)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# 流式缓冲区(存储节点的实时流式输出)
|
||||
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
@@ -201,19 +206,25 @@ class BaseNode(ABC):
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
|
||||
async def run_stream(self, state: WorkflowState):
|
||||
"""执行节点(带错误处理和输出包装,流式)
|
||||
"""Execute node with error handling and output wrapping (streaming)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute_stream() 方法
|
||||
3. 将业务数据包装成标准输出格式
|
||||
4. 错误处理
|
||||
This method is called by the Executor and is responsible for:
|
||||
1. Time tracking
|
||||
2. Calling the node's execute_stream() method
|
||||
3. Using LangGraph's stream writer to send chunks
|
||||
4. Updating streaming buffer in state for downstream nodes
|
||||
5. Wrapping business data into standard output format
|
||||
6. Error handling
|
||||
|
||||
Special handling for End nodes:
|
||||
- End nodes don't send chunks via writer (prefix and LLM content already sent)
|
||||
- End nodes only yield suffix for final result assembly
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
state: Workflow state
|
||||
|
||||
Yields:
|
||||
标准化的流式事件
|
||||
State updates with streaming buffer and final result
|
||||
"""
|
||||
import time
|
||||
|
||||
@@ -222,68 +233,143 @@ class BaseNode(ABC):
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 累积完整结果(用于最后的包装)
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
# Check if this is an End node
|
||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||
is_end_node = self.node_type == "end"
|
||||
|
||||
# Check if this node is adjacent to End node (for message type)
|
||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||
|
||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||
|
||||
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
final_result = None
|
||||
chunk_count = 0
|
||||
|
||||
# 使用异步生成器包装,支持超时
|
||||
async def stream_with_timeout():
|
||||
nonlocal final_result
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
# Stream chunks in real-time
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# Check timeout
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# 检查超时
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
# Check if it's a completion marker
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# String is a chunk
|
||||
chunk_count += 1
|
||||
chunks.append(item)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# 检查是否是完成标记
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# 字符串是 chunk
|
||||
chunks.append(item)
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": item,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# 2. Update streaming buffer in state (for downstream nodes)
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": item,
|
||||
"full_content": "".join(chunks)
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 其他类型也当作 chunk 处理
|
||||
chunks.append(str(item))
|
||||
else:
|
||||
# Other types are also treated as chunks
|
||||
chunk_count += 1
|
||||
chunk_str = str(item)
|
||||
chunks.append(chunk_str)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# Send chunks for all nodes
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": chunk_str,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": str(item),
|
||||
"full_content": "".join(chunks)
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async for chunk_event in stream_with_timeout():
|
||||
yield chunk_event
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 包装最终结果
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
|
||||
# Wrap final result
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
yield {
|
||||
"type": "complete",
|
||||
**final_output
|
||||
|
||||
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
}
|
||||
|
||||
# Add streaming buffer for non-End nodes
|
||||
if not is_end_node:
|
||||
state_update["streaming_buffer"] = {
|
||||
self.node_id: {
|
||||
"full_content": "".join(chunks),
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": True # Mark as complete
|
||||
}
|
||||
}
|
||||
|
||||
# Finally yield state update
|
||||
# LangGraph will merge this into state
|
||||
yield state_update
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
}
|
||||
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
||||
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(str(e), elapsed_time, state)
|
||||
}
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state)
|
||||
yield error_output
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
|
||||
@@ -5,6 +5,8 @@ End 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
@@ -15,6 +17,7 @@ class EndNode(BaseNode):
|
||||
"""End 节点
|
||||
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
@@ -30,11 +33,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
print("="*20)
|
||||
print( pool.get("start.test"))
|
||||
print("="*20)
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
@@ -46,7 +45,213 @@ class EndNode(BaseNode):
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
print("="*20)
|
||||
print(output)
|
||||
print("="*20)
|
||||
|
||||
return output
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
# 匹配 {{node_id.xxx}} 格式
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
parts.append({
|
||||
"type": "dynamic",
|
||||
"node_id": node_id,
|
||||
"field": field,
|
||||
"raw": ref
|
||||
})
|
||||
else:
|
||||
# 其他引用(如 {{var.xxx}}),当作静态处理
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
last_end = end
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
return parts
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
|
||||
智能输出策略:
|
||||
1. 检测模板中是否引用了直接上游节点
|
||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||
|
||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- 直接上游节点是 llm_qa
|
||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||
- LLM 内容在 LLM 节点流式输出
|
||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
完成标记
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 找到直接上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
# 解析模板部分
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
|
||||
# 找到第一个引用直接上游节点的动态引用
|
||||
upstream_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
|
||||
upstream_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
if upstream_ref_index is None:
|
||||
# 没有引用直接上游节点,正常输出(渲染完整模板)
|
||||
output = self._render_template(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容")
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||
|
||||
# 收集后缀部分
|
||||
suffix_parts = []
|
||||
for i in range(upstream_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
|
||||
if part["type"] == "static":
|
||||
# 静态文本
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# 其他动态引用(如果有多个引用)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
# 从 streaming_buffer 或 node_outputs 读取
|
||||
streaming_buffer = state.get("streaming_buffer", {})
|
||||
if node_id in streaming_buffer:
|
||||
buffer_data = streaming_buffer[node_id]
|
||||
content = buffer_data.get("full_content", "")
|
||||
else:
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
runtime_vars = state.get("runtime_vars", {})
|
||||
|
||||
content = ""
|
||||
if node_id in node_outputs:
|
||||
node_output = node_outputs[node_id]
|
||||
if isinstance(node_output, dict):
|
||||
content = str(node_output.get(field, ""))
|
||||
elif node_id in runtime_vars:
|
||||
runtime_var = runtime_vars[node_id]
|
||||
if isinstance(runtime_var, dict):
|
||||
content = str(runtime_var.get(field, ""))
|
||||
|
||||
suffix_parts.append(content)
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state)
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||
# 而是通过 writer 直接发送
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End 节点的输出使用 message 类型
|
||||
"node_id": self.node_id,
|
||||
"chunk": suffix,
|
||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||
"chunk_index": 1,
|
||||
"is_suffix": True
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -63,7 +63,7 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
|
||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -125,16 +125,22 @@ class LLMNode(BaseNode):
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
extra_params = {"streaming": stream} if stream else {}
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
),
|
||||
type=model_type
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
@@ -146,13 +152,12 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
@@ -208,13 +213,43 @@ class LLMNode(BaseNode):
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 检查是否有注入的 End 节点前缀配置
|
||||
writer = get_stream_writer()
|
||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||
|
||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||
if end_prefix:
|
||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||
|
||||
if end_prefix:
|
||||
# 渲染前缀(可能包含其他变量)
|
||||
try:
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||
writer({
|
||||
"type": "message", # End 相关的内容都是 message 类型
|
||||
"node_id": "end", # 标记为 end 节点的输出
|
||||
"chunk": rendered_prefix,
|
||||
"full_content": rendered_prefix,
|
||||
"chunk_index": 0,
|
||||
"is_prefix": True # 标记这是前缀
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
@@ -224,13 +259,16 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
content = str(chunk)
|
||||
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
|
||||
@@ -385,7 +385,7 @@ class LLMRouter:
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == self.routing_model_config.id,
|
||||
ModelApiKey.is_active == True
|
||||
ModelApiKey.is_active
|
||||
).first()
|
||||
|
||||
if not api_key_config:
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
"""
|
||||
工作流服务层
|
||||
"""
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Any, Annotated, AsyncGenerator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository,
|
||||
get_workflow_config_repository,
|
||||
get_workflow_execution_repository,
|
||||
get_workflow_node_execution_repository
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.db import get_db
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -195,8 +194,7 @@ class WorkflowService:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
return False
|
||||
|
||||
self.config_repo.delete(config.id)
|
||||
config.is_active = False
|
||||
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||
return True
|
||||
|
||||
@@ -474,11 +472,9 @@ class WorkflowService:
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
@@ -595,19 +591,14 @@ class WorkflowService:
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
|
||||
# 5. 流式执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"data: {json.dumps({'type': 'workflow_start', 'execution_id': execution.execution_id})}\n\n"
|
||||
|
||||
# 调用流式执行
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in self._run_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
@@ -615,13 +606,8 @@ class WorkflowService:
|
||||
workspace_id="",
|
||||
user_id=payload.user_id
|
||||
):
|
||||
# 清理事件数据,移除不可序列化的对象
|
||||
cleaned_event = self._clean_event_for_json(event)
|
||||
# 转换为 SSE 格式
|
||||
yield f"data: {json.dumps(cleaned_event)}\n\n"
|
||||
|
||||
# 发送完成事件
|
||||
yield f"data: {json.dumps({'type': 'workflow_end', 'execution_id': execution.execution_id})}\n\n"
|
||||
# 直接转发 executor 的事件(已经是正确的格式)
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
@@ -631,7 +617,13 @@ class WorkflowService:
|
||||
error_message=str(e)
|
||||
)
|
||||
# 发送错误事件
|
||||
yield f"data: {json.dumps({'type': 'error', 'execution_id': execution.execution_id, 'error': str(e)})}\n\n"
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"execution_id": execution.execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
@@ -692,7 +684,7 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
@@ -802,13 +794,11 @@ class WorkflowService:
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
流式事件(格式:{"event": "<type>", "data": {...}})
|
||||
"""
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
try:
|
||||
output_data = {}
|
||||
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config,
|
||||
input_data=input_data,
|
||||
@@ -816,31 +806,9 @@ class WorkflowService:
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
):
|
||||
# 转发事件
|
||||
# 直接转发事件(executor 已经返回正确格式)
|
||||
yield event
|
||||
|
||||
# 收集输出数据
|
||||
if event.get("type") == "node_complete":
|
||||
node_data = event.get("data", {})
|
||||
node_outputs = node_data.get("node_outputs", {})
|
||||
output_data.update(node_outputs)
|
||||
|
||||
# 处理完成事件
|
||||
if event.get("type") == "workflow_complete":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=output_data
|
||||
)
|
||||
|
||||
# 处理错误事件
|
||||
if event.get("type") == "workflow_error":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
error_message=event.get("error")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
@@ -849,9 +817,11 @@ class WorkflowService:
|
||||
error_message=str(e)
|
||||
)
|
||||
yield {
|
||||
"type": "workflow_error",
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
"event": "error",
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user