[fix] model support stream

This commit is contained in:
Mark
2025-12-20 16:03:41 +08:00
parent 6c04c99073
commit d8fcea8564
5 changed files with 377 additions and 134 deletions

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Iterator, AsyncIterator, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import LLMResult
from langchain_core.outputs import LLMResult, GenerationChunk
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
from app.models.models_model import ModelType
@@ -10,21 +10,36 @@ from app.models.models_model import ModelType
class RedBearLLM(BaseLLM):
"""
RedBear LLM 模型包装器 - 完全动态代理实现
RedBear LLM Model Wrapper
这个包装器自动将所有方法调用委托给内部模型,
同时提供优雅的回退机制和错误处理。
This wrapper provides a unified interface to access different LLM providers,
while maintaining all LangChain functionality, including streaming output.
Features:
- Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.)
- Full streaming output support
- Elegant error handling and fallback mechanism
- Automatic proxying of all underlying model methods and attributes
"""
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
self._model = self._create_model(config, type)
def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM):
"""Initialize RedBear LLM wrapper
Args:
config: Model configuration
type: Model type (LLM or CHAT)
"""
super().__init__()
self._config = config
self._model = self._create_model(config, type)
@property
def _llm_type(self) -> str:
"""返回LLM类型标识符"""
return self._model._llm_type
"""Return LLM type identifier"""
return getattr(self._model, '_llm_type', 'redbear_llm')
# ==================== Core Methods (Required by BaseLLM) ====================
def _generate(
self,
prompts: List[str],
@@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
"""同步生成文本"""
"""Synchronous text generation (required by BaseLLM)"""
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
async def _agenerate(
@@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
"""异步生成文本"""
"""Asynchronous text generation (required by BaseLLM)"""
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
# 关键:覆盖 invoke/ainvoke直接委托到底层模型避免 BaseLLM 的字符串化行为
# ==================== Advanced Methods (Support Message Lists) ====================
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
"""Synchronous model invocation
Supports various input formats including strings and message lists.
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
**kwargs: Additional arguments
Returns:
Model response
"""
try:
return self._model.invoke(input, config=config, **kwargs)
except AttributeError as e:
# 只在属性错误时回退(说明底层模型不支持该方法)
if 'invoke' in str(e):
# Underlying model doesn't support invoke, fallback to parent implementation
return super().invoke(input, config=config, **kwargs)
# 其他 AttributeError 直接抛出
raise
except Exception:
# 其他所有异常(包括 ValidationException直接抛出不回退
# Other exceptions are raised directly
raise
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
"""Asynchronous model invocation
Supports various input formats including strings and message lists.
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
**kwargs: Additional arguments
Returns:
Model response
"""
try:
return await self._model.ainvoke(input, config=config, **kwargs)
except AttributeError as e:
# 只在属性错误时回退(说明底层模型不支持该方法)
if 'ainvoke' in str(e):
# Underlying model doesn't support ainvoke, fallback to parent implementation
return await super().ainvoke(input, config=config, **kwargs)
# 其他 AttributeError 直接抛出
raise
except Exception:
# 其他所有异常(包括 ValidationException直接抛出不回退
# Other exceptions are raised directly
raise
def __getattr__(self, name):
"""
动态代理:将所有未定义的属性和方法调用委托给内部模型
# ==================== Streaming Methods (Critical) ====================
def stream(
self,
input: Any,
config: Optional[dict] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> Iterator[GenerationChunk]:
"""Synchronous streaming model invocation
这是最优雅的包装器实现方式,完全避免了方法重复定义
"""
# 处理特殊属性以避免递归
if name in ('__isabstractmethod__', '__dict__', '__class__'):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
stop: List of stop words
**kwargs: Additional arguments
# 检查内部模型是否有该属性(使用安全的方式避免递归)
Yields:
GenerationChunk: Generated text chunks
"""
try:
yield from self._model.stream(input, config=config, stop=stop, **kwargs)
except AttributeError as e:
if 'stream' in str(e):
# Underlying model doesn't support stream, fallback to parent implementation
yield from super().stream(input, config=config, stop=stop, **kwargs)
else:
raise
except Exception:
raise
async def astream(
self,
input: Any,
config: Optional[dict] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> AsyncIterator[GenerationChunk]:
"""Asynchronous streaming model invocation
This is the core method for streaming output. It directly proxies to the
underlying model's astream method, maintaining generator characteristics
to ensure each chunk is delivered in real-time.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
stop: List of stop words
**kwargs: Additional arguments
Yields:
GenerationChunk: Generated text chunks
"""
try:
async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs):
yield chunk
except AttributeError as e:
if 'astream' in str(e):
# Underlying model doesn't support astream, fallback to parent implementation
async for chunk in super().astream(input, config=config, stop=stop, **kwargs):
yield chunk
else:
raise
except Exception:
raise
# ==================== Dynamic Proxy ====================
def __getattr__(self, name: str) -> Any:
"""Dynamic proxy: delegate undefined attributes and method calls to internal model
This method allows RedBearLLM to transparently access all attributes and methods
of the underlying model without explicitly defining each one.
Args:
name: Attribute or method name
Returns:
Attribute value or method
Raises:
AttributeError: If attribute doesn't exist
"""
# Avoid recursion: raise error directly for special attributes
if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# Try to get attribute from internal model
try:
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
attr = object.__getattribute__(self._model, name)
# 如果是方法,返回一个包装器来处理调用
# If it's callable (a method)
if callable(attr):
# 流式方法直接返回,不包装(保持生成器特性)
if name in ('_stream', '_astream', 'stream', 'astream'):
# Streaming methods are returned directly to maintain generator characteristics
# Note: Although we've explicitly implemented stream/astream,
# this is kept to handle internal methods like _stream/_astream
if name in ('_stream', '_astream'):
return attr
# 非流式方法使用包装器处理异常
# Wrap other methods for easier debugging and error handling
def method_wrapper(*args, **kwargs):
return attr(*args, **kwargs)
try:
return attr(*args, **kwargs)
except Exception:
# Can add logging or error handling here
raise
# 保持方法的元信息
# Preserve method metadata
method_wrapper.__name__ = name
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
return method_wrapper
# 如果是普通属性,直接返回
# If it's a regular attribute, return directly
return attr
except AttributeError:
# 内部模型没有该属性,尝试回退实现
# Internal model doesn't have this attribute either
pass
# 检查是否有回退方法(使用安全的方式避免递归)
# Check if there's a fallback method
fallback_name = f'_fallback_{name}'
try:
fallback_method = object.__getattribute__(self, fallback_name)
return fallback_method
return object.__getattribute__(self, fallback_name)
except AttributeError:
# 没有回退方法,抛出适当的错误
pass
# 如果都没有,抛出适当的错误
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# Nothing found, raise error
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'. "
f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute."
)
# ==================== Helper Methods ====================
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
"""创建内部模型实例"""
"""Create internal model instance
Args:
config: Model configuration
type: Model type
Returns:
Created model instance
"""
llm_class = get_provider_llm_class(config, type)
model_params = RedBearModelFactory.get_model_params(config)
return llm_class(**model_params)
def get_config(self) -> RedBearModelConfig:
"""Get model configuration
Returns:
Model configuration object
"""
return self._config
def get_underlying_model(self) -> BaseLLM:
"""Get underlying model instance
Returns:
Underlying model instance
"""
return self._model
def __repr__(self) -> str:
"""Return string representation of the object"""
return (
f"RedBearLLM("
f"provider={self._config.provider}, "
f"model={self._config.model_name}, "
f"type={type(self._model).__name__}"
f")"
)

View File

@@ -125,17 +125,22 @@ class WorkflowExecutor:
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
async def node_func(state: WorkflowState, inst=node_instance):
async for item in inst.run_stream(state):
yield item
workflow.add_node(node_id, node_func)
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
async def node_func(state: WorkflowState, inst=node_instance):
return await inst.run(state)
workflow.add_node(node_id, node_func)
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
@@ -283,13 +288,9 @@ class WorkflowExecutor:
):
"""执行工作流(流式)
使用 stream_mode="updates" 来获取每个节点的 state 更新。
节点的 generator 会 yield 多个值:
- 中间的 chunk 事件(带 type="chunk"
- 最后的 state 更新(纯字典,包含 node_outputs 等)
LangGraph 会将所有 yield 的值收集起来,并将它们合并到 state 中。
我们需要过滤出 chunk 事件并转发,同时确保 state 更新被正确处理。
使用多个 stream_mode 来获取:
1. "updates" - 节点的 state 更新和流式 chunk
2. "debug" - 节点执行的详细信息(开始/完成时间
Args:
input_data: 输入数据
@@ -297,7 +298,7 @@ class WorkflowExecutor:
Yields:
流式事件
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
@@ -310,34 +311,73 @@ class WorkflowExecutor:
# 3. 执行工作流
try:
async for mode, event in graph.astream(
chunk_count = 0
async for event in graph.astream(
initial_state,
stream_mode=["updates","messages"],
stream_mode=["updates", "debug"],
):
# print("刚才跑的节点:", event[0])
# # 通过图结构就能算出“接下来是谁”
# print("接下来可能跑:", graph.get_next(event[0]))
# print("="*50)
# # print("mode",mode)
# print("event",event)
# print("="*50)
# event 是一个字典key 是节点 IDvalue 是 state 更新或 chunk
for node_id, update in event.items():
print("="*50)
print("node_id",node_id)
print("update",update)
mode, data = event
print("="*50)
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
if mode == "debug":
# 处理调试信息(节点执行状态)
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if event_type == "task":
# 节点开始执行
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点开始执行: {node_name}")
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content")
"type": "node_start",
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
# 其他情况state 更新)会被 LangGraph 自动合并到 state不需要我们处理
print(event)
yield event
elif event_type == "task_result":
# 节点执行完成
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] 节点执行完成: {node_name}")
yield {
"type": "node_end",
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
elif mode == "updates":
# 处理 state 更新
# data 是一个字典key 是节点 IDvalue 是 state 更新或 chunk
print("="*50)
print(data)
print("-"*50)
for node_id, update in data.items():
if isinstance(update, dict) and update.get("type") == "chunk":
# 这是流式 chunk转发给客户端
chunk_count += 1
logger.debug(f"[UPDATE] 收到 chunk #{chunk_count} from {node_id}: {update.get('content')[:50]}...")
yield {
"type": "node_chunk",
"node_id": update.get("node_id"),
"chunk": update.get("content"),
"full_content": update.get("full_content")
}
else:
logger.debug(f"[UPDATE] 收到 state 更新 from {node_id}")
# 其他情况state 更新)会被 LangGraph 自动合并到 state
logger.info(f"工作流执行完成(流式),总 chunks: {chunk_count}")
except Exception as e:
# 计算耗时(即使失败也记录)

View File

@@ -245,6 +245,9 @@ class BaseNode(ABC):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
# print("="*50)
# print(item)
# print("-"*50)
chunks.append(item)
yield {
"type": "chunk",

View File

@@ -30,11 +30,7 @@ class EndNode(BaseNode):
# 获取配置的输出模板
output_template = self.config.get("output")
# pool = self.get_variable_pool(state)
# print("="*20)
# print( pool.get("start.test"))
# print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
@@ -46,7 +42,45 @@ class EndNode(BaseNode):
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
当 end 节点前面是 LLM 节点时,流式输出其内容。
Args:
state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
# 如果配置了输出模板,使用模板渲染
if output_template:
output = self._render_template(output_template, state)
# 检查输出中是否包含节点引用(如 {{llm_node.output}}
# 如果包含,则逐字符流式输出
if output:
# 逐字符流式输出
for char in output:
yield char
else:
output = "工作流已完成"
for char in output:
yield char
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行 {total_nodes} 个节点")
# yield 完成标记
yield {"__final__": True, "result": output}

View File

@@ -125,19 +125,22 @@ class LLMNode(BaseNode):
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
print("="*50)
print("stream",stream)
print("="*50)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params={"streaming": stream}
extra_params=extra_params
),
type=model_type
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
@@ -201,47 +204,54 @@ class LLMNode(BaseNode):
}
return None
# async def execute_stream(self, state: WorkflowState):
# """流式执行 LLM 调用
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
# Args:
# state: 工作流状态
Args:
state: 工作流状态
# Yields:
# 文本片段chunk或完成标记
# """
# llm, prompt_or_messages = self._prepare_llm(state,True)
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state, True)
# logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# # 累积完整响应
# full_response = ""
# last_chunk = None
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# # 调用 LLM流式支持字符串或消息列表
# async for chunk in llm.astream(prompt_or_messages):
# # 提取内容
# if hasattr(chunk, 'content'):
# content = chunk.content
# else:
# content = str(chunk)
# 调用 LLM流式支持字符串或消息列表
# 注意astream 方法本身就是流式的,不需要额外配置
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
# full_response += content
# last_chunk = chunk
# logger.info(f"节点 {self.node_id} LLM : {content}")
# # 流式返回每个文本片段
# yield content
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# logger.debug(f"节点 {self.node_id} LLM chunk #{chunk_count}: {content[:50]}...")
# 流式返回每个文本片段
yield content #AIMessage(content=content)
# logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# # 构建完整的 AIMessage包含元数据
# if isinstance(last_chunk, AIMessage):
# final_message = AIMessage(
# content=full_response,
# response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
# )
# else:
# final_message = AIMessage(content=full_response)
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# # yield 完成标记
# yield {"__final__": True, "result": final_message}
# yield 完成标记
yield {"__final__": True, "result": final_message}