[fix] model support stream
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Iterator, AsyncIterator, List, Optional
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||||
from langchain_core.language_models import BaseLLM
|
from langchain_core.language_models import BaseLLM
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult, GenerationChunk
|
||||||
|
|
||||||
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
|
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
@@ -10,21 +10,36 @@ from app.models.models_model import ModelType
|
|||||||
|
|
||||||
class RedBearLLM(BaseLLM):
|
class RedBearLLM(BaseLLM):
|
||||||
"""
|
"""
|
||||||
RedBear LLM 模型包装器 - 完全动态代理实现
|
RedBear LLM Model Wrapper
|
||||||
|
|
||||||
这个包装器自动将所有方法调用委托给内部模型,
|
This wrapper provides a unified interface to access different LLM providers,
|
||||||
同时提供优雅的回退机制和错误处理。
|
while maintaining all LangChain functionality, including streaming output.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.)
|
||||||
|
- Full streaming output support
|
||||||
|
- Elegant error handling and fallback mechanism
|
||||||
|
- Automatic proxying of all underlying model methods and attributes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
|
def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM):
|
||||||
self._model = self._create_model(config, type)
|
"""Initialize RedBear LLM wrapper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Model configuration
|
||||||
|
type: Model type (LLM or CHAT)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
self._config = config
|
self._config = config
|
||||||
|
self._model = self._create_model(config, type)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""返回LLM类型标识符"""
|
"""Return LLM type identifier"""
|
||||||
return self._model._llm_type
|
return getattr(self._model, '_llm_type', 'redbear_llm')
|
||||||
|
|
||||||
|
# ==================== Core Methods (Required by BaseLLM) ====================
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""同步生成文本"""
|
"""Synchronous text generation (required by BaseLLM)"""
|
||||||
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
@@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""异步生成文本"""
|
"""Asynchronous text generation (required by BaseLLM)"""
|
||||||
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||||
|
|
||||||
# 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为
|
# ==================== Advanced Methods (Support Message Lists) ====================
|
||||||
|
|
||||||
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||||
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
"""Synchronous model invocation
|
||||||
|
|
||||||
|
Supports various input formats including strings and message lists.
|
||||||
|
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input (string, message list, etc.)
|
||||||
|
config: Runtime configuration
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model response
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return self._model.invoke(input, config=config, **kwargs)
|
return self._model.invoke(input, config=config, **kwargs)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
|
||||||
if 'invoke' in str(e):
|
if 'invoke' in str(e):
|
||||||
|
# Underlying model doesn't support invoke, fallback to parent implementation
|
||||||
return super().invoke(input, config=config, **kwargs)
|
return super().invoke(input, config=config, **kwargs)
|
||||||
# 其他 AttributeError 直接抛出
|
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
# Other exceptions are raised directly
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||||
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
"""Asynchronous model invocation
|
||||||
|
|
||||||
|
Supports various input formats including strings and message lists.
|
||||||
|
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input (string, message list, etc.)
|
||||||
|
config: Runtime configuration
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model response
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return await self._model.ainvoke(input, config=config, **kwargs)
|
return await self._model.ainvoke(input, config=config, **kwargs)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
|
||||||
if 'ainvoke' in str(e):
|
if 'ainvoke' in str(e):
|
||||||
|
# Underlying model doesn't support ainvoke, fallback to parent implementation
|
||||||
return await super().ainvoke(input, config=config, **kwargs)
|
return await super().ainvoke(input, config=config, **kwargs)
|
||||||
# 其他 AttributeError 直接抛出
|
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
# Other exceptions are raised directly
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def __getattr__(self, name):
|
# ==================== Streaming Methods (Critical) ====================
|
||||||
"""
|
|
||||||
动态代理:将所有未定义的属性和方法调用委托给内部模型
|
def stream(
|
||||||
|
self,
|
||||||
|
input: Any,
|
||||||
|
config: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
"""Synchronous streaming model invocation
|
||||||
|
|
||||||
这是最优雅的包装器实现方式,完全避免了方法重复定义
|
Args:
|
||||||
"""
|
input: Input (string, message list, etc.)
|
||||||
# 处理特殊属性以避免递归
|
config: Runtime configuration
|
||||||
if name in ('__isabstractmethod__', '__dict__', '__class__'):
|
stop: List of stop words
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
# 检查内部模型是否有该属性(使用安全的方式避免递归)
|
Yields:
|
||||||
|
GenerationChunk: Generated text chunks
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
yield from self._model.stream(input, config=config, stop=stop, **kwargs)
|
||||||
|
except AttributeError as e:
|
||||||
|
if 'stream' in str(e):
|
||||||
|
# Underlying model doesn't support stream, fallback to parent implementation
|
||||||
|
yield from super().stream(input, config=config, stop=stop, **kwargs)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def astream(
|
||||||
|
self,
|
||||||
|
input: Any,
|
||||||
|
config: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
"""Asynchronous streaming model invocation
|
||||||
|
|
||||||
|
This is the core method for streaming output. It directly proxies to the
|
||||||
|
underlying model's astream method, maintaining generator characteristics
|
||||||
|
to ensure each chunk is delivered in real-time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input (string, message list, etc.)
|
||||||
|
config: Runtime configuration
|
||||||
|
stop: List of stop words
|
||||||
|
**kwargs: Additional arguments
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
GenerationChunk: Generated text chunks
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
except AttributeError as e:
|
||||||
|
if 'astream' in str(e):
|
||||||
|
# Underlying model doesn't support astream, fallback to parent implementation
|
||||||
|
async for chunk in super().astream(input, config=config, stop=stop, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ==================== Dynamic Proxy ====================
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Dynamic proxy: delegate undefined attributes and method calls to internal model
|
||||||
|
|
||||||
|
This method allows RedBearLLM to transparently access all attributes and methods
|
||||||
|
of the underlying model without explicitly defining each one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Attribute or method name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attribute value or method
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: If attribute doesn't exist
|
||||||
|
"""
|
||||||
|
# Avoid recursion: raise error directly for special attributes
|
||||||
|
if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'):
|
||||||
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||||
|
|
||||||
|
# Try to get attribute from internal model
|
||||||
try:
|
try:
|
||||||
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
|
|
||||||
attr = object.__getattribute__(self._model, name)
|
attr = object.__getattribute__(self._model, name)
|
||||||
|
|
||||||
# 如果是方法,返回一个包装器来处理调用
|
# If it's callable (a method)
|
||||||
if callable(attr):
|
if callable(attr):
|
||||||
# 流式方法直接返回,不包装(保持生成器特性)
|
# Streaming methods are returned directly to maintain generator characteristics
|
||||||
if name in ('_stream', '_astream', 'stream', 'astream'):
|
# Note: Although we've explicitly implemented stream/astream,
|
||||||
|
# this is kept to handle internal methods like _stream/_astream
|
||||||
|
if name in ('_stream', '_astream'):
|
||||||
return attr
|
return attr
|
||||||
|
|
||||||
# 非流式方法使用包装器处理异常
|
# Wrap other methods for easier debugging and error handling
|
||||||
def method_wrapper(*args, **kwargs):
|
def method_wrapper(*args, **kwargs):
|
||||||
return attr(*args, **kwargs)
|
try:
|
||||||
|
return attr(*args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
# Can add logging or error handling here
|
||||||
|
raise
|
||||||
|
|
||||||
# 保持方法的元信息
|
# Preserve method metadata
|
||||||
method_wrapper.__name__ = name
|
method_wrapper.__name__ = name
|
||||||
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
|
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
|
||||||
return method_wrapper
|
return method_wrapper
|
||||||
|
|
||||||
# 如果是普通属性,直接返回
|
# If it's a regular attribute, return directly
|
||||||
return attr
|
return attr
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# 内部模型没有该属性,尝试回退实现
|
# Internal model doesn't have this attribute either
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 检查是否有回退方法(使用安全的方式避免递归)
|
# Check if there's a fallback method
|
||||||
fallback_name = f'_fallback_{name}'
|
fallback_name = f'_fallback_{name}'
|
||||||
try:
|
try:
|
||||||
fallback_method = object.__getattribute__(self, fallback_name)
|
return object.__getattribute__(self, fallback_name)
|
||||||
return fallback_method
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# 没有回退方法,抛出适当的错误
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 如果都没有,抛出适当的错误
|
# Nothing found, raise error
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
raise AttributeError(
|
||||||
|
f"'{type(self).__name__}' object has no attribute '{name}'. "
|
||||||
|
f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== Helper Methods ====================
|
||||||
|
|
||||||
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
|
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
|
||||||
"""创建内部模型实例"""
|
"""Create internal model instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Model configuration
|
||||||
|
type: Model type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created model instance
|
||||||
|
"""
|
||||||
llm_class = get_provider_llm_class(config, type)
|
llm_class = get_provider_llm_class(config, type)
|
||||||
model_params = RedBearModelFactory.get_model_params(config)
|
model_params = RedBearModelFactory.get_model_params(config)
|
||||||
return llm_class(**model_params)
|
return llm_class(**model_params)
|
||||||
|
|
||||||
|
def get_config(self) -> RedBearModelConfig:
|
||||||
|
"""Get model configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model configuration object
|
||||||
|
"""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def get_underlying_model(self) -> BaseLLM:
|
||||||
|
"""Get underlying model instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Underlying model instance
|
||||||
|
"""
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Return string representation of the object"""
|
||||||
|
return (
|
||||||
|
f"RedBearLLM("
|
||||||
|
f"provider={self._config.provider}, "
|
||||||
|
f"model={self._config.model_name}, "
|
||||||
|
f"type={type(self._model).__name__}"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
@@ -125,17 +125,22 @@ class WorkflowExecutor:
|
|||||||
if stream:
|
if stream:
|
||||||
# 流式模式:创建 async generator 函数
|
# 流式模式:创建 async generator 函数
|
||||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||||
async def node_func(state: WorkflowState, inst=node_instance):
|
def make_stream_func(inst):
|
||||||
async for item in inst.run_stream(state):
|
async def node_func(state: WorkflowState):
|
||||||
yield item
|
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
||||||
workflow.add_node(node_id, node_func)
|
async for item in inst.run_stream(state):
|
||||||
|
yield item
|
||||||
|
return node_func
|
||||||
|
workflow.add_node(node_id, make_stream_func(node_instance))
|
||||||
else:
|
else:
|
||||||
# 非流式模式:创建 async function
|
# 非流式模式:创建 async function
|
||||||
async def node_func(state: WorkflowState, inst=node_instance):
|
def make_func(inst):
|
||||||
return await inst.run(state)
|
async def node_func(state: WorkflowState):
|
||||||
workflow.add_node(node_id, node_func)
|
return await inst.run(state)
|
||||||
|
return node_func
|
||||||
|
workflow.add_node(node_id, make_func(node_instance))
|
||||||
|
|
||||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
|
||||||
|
|
||||||
# 3. 添加边
|
# 3. 添加边
|
||||||
# 从 START 连接到 start 节点
|
# 从 START 连接到 start 节点
|
||||||
@@ -283,13 +288,9 @@ class WorkflowExecutor:
|
|||||||
):
|
):
|
||||||
"""执行工作流(流式)
|
"""执行工作流(流式)
|
||||||
|
|
||||||
使用 stream_mode="updates" 来获取每个节点的 state 更新。
|
使用多个 stream_mode 来获取:
|
||||||
节点的 generator 会 yield 多个值:
|
1. "updates" - 节点的 state 更新和流式 chunk
|
||||||
- 中间的 chunk 事件(带 type="chunk")
|
2. "debug" - 节点执行的详细信息(开始/完成时间)
|
||||||
- 最后的 state 更新(纯字典,包含 node_outputs 等)
|
|
||||||
|
|
||||||
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
|
|
||||||
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data: 输入数据
|
input_data: 输入数据
|
||||||
@@ -297,7 +298,7 @@ class WorkflowExecutor:
|
|||||||
Yields:
|
Yields:
|
||||||
流式事件
|
流式事件
|
||||||
"""
|
"""
|
||||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||||
|
|
||||||
# 记录开始时间
|
# 记录开始时间
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
@@ -310,34 +311,73 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
# 3. 执行工作流
|
# 3. 执行工作流
|
||||||
try:
|
try:
|
||||||
async for mode, event in graph.astream(
|
chunk_count = 0
|
||||||
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode=["updates","messages"],
|
stream_mode=["updates", "debug"],
|
||||||
):
|
):
|
||||||
# print("刚才跑的节点:", event[0])
|
mode, data = event
|
||||||
# # 通过图结构就能算出“接下来是谁”
|
|
||||||
# print("接下来可能跑:", graph.get_next(event[0]))
|
|
||||||
# print("="*50)
|
|
||||||
# # print("mode",mode)
|
|
||||||
# print("event",event)
|
|
||||||
# print("="*50)
|
|
||||||
# event 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
|
||||||
for node_id, update in event.items():
|
|
||||||
print("="*50)
|
|
||||||
print("node_id",node_id)
|
|
||||||
print("update",update)
|
|
||||||
|
|
||||||
print("="*50)
|
if mode == "debug":
|
||||||
if isinstance(update, dict) and update.get("type") == "chunk":
|
# 处理调试信息(节点执行状态)
|
||||||
# 这是流式 chunk,转发给客户端
|
event_type = data.get("type")
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
node_name = payload.get("name")
|
||||||
|
|
||||||
|
if event_type == "task":
|
||||||
|
# 节点开始执行
|
||||||
|
inputv = payload.get("input", {})
|
||||||
|
variables = inputv.get("variables", {})
|
||||||
|
variables_sys = variables.get("sys", {})
|
||||||
|
conversation_id = variables_sys.get("conversation_id")
|
||||||
|
execution_id = variables_sys.get("execution_id")
|
||||||
|
logger.info(f"[DEBUG] 节点开始执行: {node_name}")
|
||||||
yield {
|
yield {
|
||||||
"type": "node_chunk",
|
"type": "node_start",
|
||||||
"node_id": update.get("node_id"),
|
"node_id": node_name,
|
||||||
"chunk": update.get("content")
|
"conversation_id": conversation_id,
|
||||||
|
"execution_id": execution_id,
|
||||||
|
"timestamp": data.get("timestamp")
|
||||||
}
|
}
|
||||||
# 其他情况(state 更新)会被 LangGraph 自动合并到 state,不需要我们处理
|
elif event_type == "task_result":
|
||||||
print(event)
|
# 节点执行完成
|
||||||
yield event
|
result = payload.get("result", {})
|
||||||
|
inputv = result.get("input", {})
|
||||||
|
variables = inputv.get("variables", {})
|
||||||
|
variables_sys = variables.get("sys", {})
|
||||||
|
conversation_id = variables_sys.get("conversation_id")
|
||||||
|
execution_id = variables_sys.get("execution_id")
|
||||||
|
logger.info(f"[DEBUG] 节点执行完成: {node_name}")
|
||||||
|
yield {
|
||||||
|
"type": "node_end",
|
||||||
|
"node_id": node_name,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"execution_id": execution_id,
|
||||||
|
"timestamp": data.get("timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
elif mode == "updates":
|
||||||
|
# 处理 state 更新
|
||||||
|
# data 是一个字典,key 是节点 ID,value 是 state 更新或 chunk
|
||||||
|
print("="*50)
|
||||||
|
print(data)
|
||||||
|
print("-"*50)
|
||||||
|
for node_id, update in data.items():
|
||||||
|
if isinstance(update, dict) and update.get("type") == "chunk":
|
||||||
|
# 这是流式 chunk,转发给客户端
|
||||||
|
chunk_count += 1
|
||||||
|
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
|
||||||
|
yield {
|
||||||
|
"type": "node_chunk",
|
||||||
|
"node_id": update.get("node_id"),
|
||||||
|
"chunk": update.get("content"),
|
||||||
|
"full_content": update.get("full_content")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
|
||||||
|
# 其他情况(state 更新)会被 LangGraph 自动合并到 state
|
||||||
|
|
||||||
|
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 计算耗时(即使失败也记录)
|
# 计算耗时(即使失败也记录)
|
||||||
|
|||||||
@@ -245,6 +245,9 @@ class BaseNode(ABC):
|
|||||||
final_result = item["result"]
|
final_result = item["result"]
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
# 字符串是 chunk
|
# 字符串是 chunk
|
||||||
|
# print("="*50)
|
||||||
|
# print(item)
|
||||||
|
# print("-"*50)
|
||||||
chunks.append(item)
|
chunks.append(item)
|
||||||
yield {
|
yield {
|
||||||
"type": "chunk",
|
"type": "chunk",
|
||||||
|
|||||||
@@ -30,11 +30,7 @@ class EndNode(BaseNode):
|
|||||||
|
|
||||||
# 获取配置的输出模板
|
# 获取配置的输出模板
|
||||||
output_template = self.config.get("output")
|
output_template = self.config.get("output")
|
||||||
# pool = self.get_variable_pool(state)
|
|
||||||
|
|
||||||
# print("="*20)
|
|
||||||
# print( pool.get("start.test"))
|
|
||||||
# print("="*20)
|
|
||||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||||
if output_template:
|
if output_template:
|
||||||
output = self._render_template(output_template, state)
|
output = self._render_template(output_template, state)
|
||||||
@@ -46,7 +42,45 @@ class EndNode(BaseNode):
|
|||||||
total_nodes = len(node_outputs)
|
total_nodes = len(node_outputs)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||||
print("="*20)
|
|
||||||
print(output)
|
|
||||||
print("="*20)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
async def execute_stream(self, state: WorkflowState):
|
||||||
|
"""流式执行 end 节点业务逻辑
|
||||||
|
|
||||||
|
当 end 节点前面是 LLM 节点时,流式输出其内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
文本片段(chunk)或完成标记
|
||||||
|
"""
|
||||||
|
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||||
|
|
||||||
|
# 获取配置的输出模板
|
||||||
|
output_template = self.config.get("output")
|
||||||
|
|
||||||
|
# 如果配置了输出模板,使用模板渲染
|
||||||
|
if output_template:
|
||||||
|
output = self._render_template(output_template, state)
|
||||||
|
|
||||||
|
# 检查输出中是否包含节点引用(如 {{llm_node.output}})
|
||||||
|
# 如果包含,则逐字符流式输出
|
||||||
|
if output:
|
||||||
|
# 逐字符流式输出
|
||||||
|
for char in output:
|
||||||
|
yield char
|
||||||
|
else:
|
||||||
|
output = "工作流已完成"
|
||||||
|
for char in output:
|
||||||
|
yield char
|
||||||
|
|
||||||
|
# 统计信息(用于日志)
|
||||||
|
node_outputs = state.get("node_outputs", {})
|
||||||
|
total_nodes = len(node_outputs)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
|
||||||
|
|
||||||
|
# yield 完成标记
|
||||||
|
yield {"__final__": True, "result": output}
|
||||||
|
|||||||
@@ -125,19 +125,22 @@ class LLMNode(BaseNode):
|
|||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
print("="*50)
|
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||||
print("stream",stream)
|
extra_params = {"streaming": stream} if stream else {}
|
||||||
print("="*50)
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=api_base,
|
base_url=api_base,
|
||||||
extra_params={"streaming": stream}
|
extra_params=extra_params
|
||||||
),
|
),
|
||||||
type=model_type
|
type=model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||||
|
|
||||||
return llm, prompt_or_messages
|
return llm, prompt_or_messages
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||||
@@ -201,47 +204,54 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# async def execute_stream(self, state: WorkflowState):
|
async def execute_stream(self, state: WorkflowState):
|
||||||
# """流式执行 LLM 调用
|
"""流式执行 LLM 调用
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# state: 工作流状态
|
state: 工作流状态
|
||||||
|
|
||||||
# Yields:
|
Yields:
|
||||||
# 文本片段(chunk)或完成标记
|
文本片段(chunk)或完成标记
|
||||||
# """
|
"""
|
||||||
# llm, prompt_or_messages = self._prepare_llm(state,True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
|
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||||
|
|
||||||
# # 累积完整响应
|
# 累积完整响应
|
||||||
# full_response = ""
|
full_response = ""
|
||||||
# last_chunk = None
|
last_chunk = None
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
# # 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
# async for chunk in llm.astream(prompt_or_messages):
|
# 注意:astream 方法本身就是流式的,不需要额外配置
|
||||||
# # 提取内容
|
async for chunk in llm.astream(prompt_or_messages):
|
||||||
# if hasattr(chunk, 'content'):
|
# 提取内容
|
||||||
# content = chunk.content
|
if hasattr(chunk, 'content'):
|
||||||
# else:
|
content = chunk.content
|
||||||
# content = str(chunk)
|
else:
|
||||||
|
content = str(chunk)
|
||||||
|
|
||||||
# full_response += content
|
# 只有当内容不为空时才处理
|
||||||
# last_chunk = chunk
|
if content:
|
||||||
# logger.info(f"节点 {self.node_id} LLM : {content}")
|
full_response += content
|
||||||
# # 流式返回每个文本片段
|
last_chunk = chunk
|
||||||
# yield content
|
chunk_count += 1
|
||||||
|
|
||||||
|
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
|
||||||
|
# 流式返回每个文本片段
|
||||||
|
yield content #AIMessage(content=content)
|
||||||
|
|
||||||
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
# # 构建完整的 AIMessage(包含元数据)
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
# if isinstance(last_chunk, AIMessage):
|
if isinstance(last_chunk, AIMessage):
|
||||||
# final_message = AIMessage(
|
final_message = AIMessage(
|
||||||
# content=full_response,
|
content=full_response,
|
||||||
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||||
# )
|
)
|
||||||
# else:
|
else:
|
||||||
# final_message = AIMessage(content=full_response)
|
final_message = AIMessage(content=full_response)
|
||||||
|
|
||||||
# # yield 完成标记
|
# yield 完成标记
|
||||||
# yield {"__final__": True, "result": final_message}
|
yield {"__final__": True, "result": final_message}
|
||||||
|
|||||||
Reference in New Issue
Block a user