[fix] model support stream
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Iterator, AsyncIterator, List, Optional
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.outputs import LLMResult, GenerationChunk
|
||||
|
||||
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
|
||||
from app.models.models_model import ModelType
|
||||
@@ -10,21 +10,36 @@ from app.models.models_model import ModelType
|
||||
|
||||
class RedBearLLM(BaseLLM):
|
||||
"""
|
||||
RedBear LLM 模型包装器 - 完全动态代理实现
|
||||
RedBear LLM Model Wrapper
|
||||
|
||||
这个包装器自动将所有方法调用委托给内部模型,
|
||||
同时提供优雅的回退机制和错误处理。
|
||||
This wrapper provides a unified interface to access different LLM providers,
|
||||
while maintaining all LangChain functionality, including streaming output.
|
||||
|
||||
Features:
|
||||
- Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.)
|
||||
- Full streaming output support
|
||||
- Elegant error handling and fallback mechanism
|
||||
- Automatic proxying of all underlying model methods and attributes
|
||||
"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
|
||||
self._model = self._create_model(config, type)
|
||||
def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM):
|
||||
"""Initialize RedBear LLM wrapper
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
type: Model type (LLM or CHAT)
|
||||
"""
|
||||
super().__init__()
|
||||
self._config = config
|
||||
self._model = self._create_model(config, type)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""返回LLM类型标识符"""
|
||||
return self._model._llm_type
|
||||
"""Return LLM type identifier"""
|
||||
return getattr(self._model, '_llm_type', 'redbear_llm')
|
||||
|
||||
# ==================== Core Methods (Required by BaseLLM) ====================
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""同步生成文本"""
|
||||
"""Synchronous text generation (required by BaseLLM)"""
|
||||
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
async def _agenerate(
|
||||
@@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""异步生成文本"""
|
||||
"""Asynchronous text generation (required by BaseLLM)"""
|
||||
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
# 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为
|
||||
# ==================== Advanced Methods (Support Message Lists) ====================
|
||||
|
||||
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
"""Synchronous model invocation
|
||||
|
||||
Supports various input formats including strings and message lists.
|
||||
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
try:
|
||||
return self._model.invoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'invoke' in str(e):
|
||||
# Underlying model doesn't support invoke, fallback to parent implementation
|
||||
return super().invoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
# Other exceptions are raised directly
|
||||
raise
|
||||
|
||||
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
"""Asynchronous model invocation
|
||||
|
||||
Supports various input formats including strings and message lists.
|
||||
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
try:
|
||||
return await self._model.ainvoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'ainvoke' in str(e):
|
||||
# Underlying model doesn't support ainvoke, fallback to parent implementation
|
||||
return await super().ainvoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
# Other exceptions are raised directly
|
||||
raise
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
动态代理:将所有未定义的属性和方法调用委托给内部模型
|
||||
# ==================== Streaming Methods (Critical) ====================
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[dict] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Synchronous streaming model invocation
|
||||
|
||||
这是最优雅的包装器实现方式,完全避免了方法重复定义
|
||||
"""
|
||||
# 处理特殊属性以避免递归
|
||||
if name in ('__isabstractmethod__', '__dict__', '__class__'):
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
stop: List of stop words
|
||||
**kwargs: Additional arguments
|
||||
|
||||
# 检查内部模型是否有该属性(使用安全的方式避免递归)
|
||||
Yields:
|
||||
GenerationChunk: Generated text chunks
|
||||
"""
|
||||
try:
|
||||
yield from self._model.stream(input, config=config, stop=stop, **kwargs)
|
||||
except AttributeError as e:
|
||||
if 'stream' in str(e):
|
||||
# Underlying model doesn't support stream, fallback to parent implementation
|
||||
yield from super().stream(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[dict] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
"""Asynchronous streaming model invocation
|
||||
|
||||
This is the core method for streaming output. It directly proxies to the
|
||||
underlying model's astream method, maintaining generator characteristics
|
||||
to ensure each chunk is delivered in real-time.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
stop: List of stop words
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Yields:
|
||||
GenerationChunk: Generated text chunks
|
||||
"""
|
||||
try:
|
||||
async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs):
|
||||
yield chunk
|
||||
except AttributeError as e:
|
||||
if 'astream' in str(e):
|
||||
# Underlying model doesn't support astream, fallback to parent implementation
|
||||
async for chunk in super().astream(input, config=config, stop=stop, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# ==================== Dynamic Proxy ====================
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamic proxy: delegate undefined attributes and method calls to internal model
|
||||
|
||||
This method allows RedBearLLM to transparently access all attributes and methods
|
||||
of the underlying model without explicitly defining each one.
|
||||
|
||||
Args:
|
||||
name: Attribute or method name
|
||||
|
||||
Returns:
|
||||
Attribute value or method
|
||||
|
||||
Raises:
|
||||
AttributeError: If attribute doesn't exist
|
||||
"""
|
||||
# Avoid recursion: raise error directly for special attributes
|
||||
if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'):
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
# Try to get attribute from internal model
|
||||
try:
|
||||
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
|
||||
attr = object.__getattribute__(self._model, name)
|
||||
|
||||
# 如果是方法,返回一个包装器来处理调用
|
||||
# If it's callable (a method)
|
||||
if callable(attr):
|
||||
# 流式方法直接返回,不包装(保持生成器特性)
|
||||
if name in ('_stream', '_astream', 'stream', 'astream'):
|
||||
# Streaming methods are returned directly to maintain generator characteristics
|
||||
# Note: Although we've explicitly implemented stream/astream,
|
||||
# this is kept to handle internal methods like _stream/_astream
|
||||
if name in ('_stream', '_astream'):
|
||||
return attr
|
||||
|
||||
# 非流式方法使用包装器处理异常
|
||||
# Wrap other methods for easier debugging and error handling
|
||||
def method_wrapper(*args, **kwargs):
|
||||
return attr(*args, **kwargs)
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception:
|
||||
# Can add logging or error handling here
|
||||
raise
|
||||
|
||||
# 保持方法的元信息
|
||||
# Preserve method metadata
|
||||
method_wrapper.__name__ = name
|
||||
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
|
||||
return method_wrapper
|
||||
|
||||
# 如果是普通属性,直接返回
|
||||
# If it's a regular attribute, return directly
|
||||
return attr
|
||||
|
||||
except AttributeError:
|
||||
# 内部模型没有该属性,尝试回退实现
|
||||
# Internal model doesn't have this attribute either
|
||||
pass
|
||||
|
||||
# 检查是否有回退方法(使用安全的方式避免递归)
|
||||
# Check if there's a fallback method
|
||||
fallback_name = f'_fallback_{name}'
|
||||
try:
|
||||
fallback_method = object.__getattribute__(self, fallback_name)
|
||||
return fallback_method
|
||||
return object.__getattribute__(self, fallback_name)
|
||||
except AttributeError:
|
||||
# 没有回退方法,抛出适当的错误
|
||||
pass
|
||||
|
||||
# 如果都没有,抛出适当的错误
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
# Nothing found, raise error
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'. "
|
||||
f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute."
|
||||
)
|
||||
|
||||
# ==================== Helper Methods ====================
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
|
||||
"""创建内部模型实例"""
|
||||
"""Create internal model instance
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
type: Model type
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
"""
|
||||
llm_class = get_provider_llm_class(config, type)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return llm_class(**model_params)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_config(self) -> RedBearModelConfig:
|
||||
"""Get model configuration
|
||||
|
||||
Returns:
|
||||
Model configuration object
|
||||
"""
|
||||
return self._config
|
||||
|
||||
def get_underlying_model(self) -> BaseLLM:
|
||||
"""Get underlying model instance
|
||||
|
||||
Returns:
|
||||
Underlying model instance
|
||||
"""
|
||||
return self._model
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of the object"""
|
||||
return (
|
||||
f"RedBearLLM("
|
||||
f"provider={self._config.provider}, "
|
||||
f"model={self._config.model_name}, "
|
||||
f"type={type(self._model).__name__}"
|
||||
f")"
|
||||
)
|
||||
Reference in New Issue
Block a user