feat: Add base project structure with API and web components
This commit is contained in:
13
api/app/core/models/__init__.py
Normal file
13
api/app/core/models/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFactory
|
||||
from .llm import RedBearLLM
|
||||
from .embedding import RedBearEmbeddings
|
||||
from .rerank import RedBearRerank
|
||||
|
||||
__all__ = [
|
||||
"RedBearModelConfig",
|
||||
"RedBearLLM",
|
||||
"RedBearEmbeddings",
|
||||
"RedBearRerank",
|
||||
"RedBearModelFactory",
|
||||
"get_provider_llm_class"
|
||||
]
|
||||
167
api/app/core/models/base.py
Normal file
167
api/app/core/models/base.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from __future__ import annotations
|
||||
import asyncio, httpx, time, os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Callable
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseLLM, BaseLanguageModel
|
||||
from langchain_core.outputs import LLMResult, Generation
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class RedBearModelConfig(BaseModel):
|
||||
"""模型配置基类"""
|
||||
model_name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
|
||||
@classmethod
|
||||
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
# 这样可以分别控制连接超时和读取超时
|
||||
import httpx
|
||||
timeout_config = httpx.Timeout(
|
||||
timeout=config.timeout, # 总超时时间
|
||||
connect=60.0, # 连接超时:60秒(足够建立 TCP 连接)
|
||||
read=config.timeout, # 读取超时:使用配置的超时时间
|
||||
write=60.0, # 写入超时:60秒
|
||||
pool=10.0, # 连接池超时:10秒
|
||||
)
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
# 只支持: model, dashscope_api_key, max_retries, client
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
# Bedrock 使用 AWS 凭证
|
||||
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
||||
# region 从 base_url 或 extra_params 获取
|
||||
params = {
|
||||
"model_id": config.model_name,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
# 解析 API key (格式: access_key_id:secret_access_key)
|
||||
if config.api_key and ":" in config.api_key:
|
||||
access_key_id, secret_access_key = config.api_key.split(":", 1)
|
||||
params["aws_access_key_id"] = access_key_id
|
||||
params["aws_secret_access_key"] = secret_access_key
|
||||
elif config.api_key:
|
||||
params["aws_access_key_id"] = config.api_key
|
||||
|
||||
# 设置 region
|
||||
if config.base_url:
|
||||
params["region_name"] = config.base_url
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@classmethod
|
||||
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
|
||||
return ChatBedrock
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
return OllamaEmbeddings
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
return BedrockEmbeddings
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_rerank_class(provider: str):
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
23
api/app/core/models/embedding.py
Normal file
23
api/app/core/models/embedding.py
Normal file
@@ -0,0 +1,23 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Callable
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
|
||||
|
||||
class RedBearEmbeddings(Embeddings):
|
||||
"""Embedding → 完全符合 LangChain Embeddings"""
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._model = self._create_model(config)
|
||||
self._config = config
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._model.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._model.embed_query(text)
|
||||
16
api/app/core/models/factory.py
Normal file
16
api/app/core/models/factory.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# from typing import Optional
|
||||
# from app.core.model_client import RedBearEmbeddings, RedBearLLM, RedBearRerank, ModelConfig
|
||||
|
||||
|
||||
# class RedBearModelFactory:
|
||||
# @staticmethod
|
||||
# def llm(model: str, api_key: str, base_url: Optional[str] = None) -> RedBearLLM:
|
||||
# return RedBearLLM(ModelConfig(model_name=model, api_key=api_key, base_url=base_url))
|
||||
|
||||
# @staticmethod
|
||||
# def embeddings(model: str, api_key: str, base_url: Optional[str] = None) -> RedBearEmbeddings:
|
||||
# return RedBearEmbeddings(ModelConfig(model_name=model, api_key=api_key, base_url=base_url))
|
||||
|
||||
# @staticmethod
|
||||
# def reranker(model: str, api_key: str, base_url: Optional[str] = None) -> RedBearRerank:
|
||||
# return RedBearRerank(ModelConfig(model_name=model, api_key=api_key, base_url=base_url))
|
||||
133
api/app/core/models/llm.py
Normal file
133
api/app/core/models/llm.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
|
||||
from app.models.models_model import ModelType
|
||||
|
||||
|
||||
class RedBearLLM(BaseLLM):
|
||||
"""
|
||||
RedBear LLM 模型包装器 - 完全动态代理实现
|
||||
|
||||
这个包装器自动将所有方法调用委托给内部模型,
|
||||
同时提供优雅的回退机制和错误处理。
|
||||
"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
|
||||
self._model = self._create_model(config, type)
|
||||
self._config = config
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""返回LLM类型标识符"""
|
||||
return self._model._llm_type
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""同步生成文本"""
|
||||
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""异步生成文本"""
|
||||
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
# 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为
|
||||
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
try:
|
||||
return self._model.invoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'invoke' in str(e):
|
||||
return super().invoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
raise
|
||||
|
||||
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
try:
|
||||
return await self._model.ainvoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'ainvoke' in str(e):
|
||||
return await super().ainvoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
raise
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
动态代理:将所有未定义的属性和方法调用委托给内部模型
|
||||
|
||||
这是最优雅的包装器实现方式,完全避免了方法重复定义
|
||||
"""
|
||||
# 处理特殊属性以避免递归
|
||||
if name in ('__isabstractmethod__', '__dict__', '__class__'):
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
# 检查内部模型是否有该属性(使用安全的方式避免递归)
|
||||
try:
|
||||
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
|
||||
attr = object.__getattribute__(self._model, name)
|
||||
|
||||
# 如果是方法,返回一个包装器来处理调用
|
||||
if callable(attr):
|
||||
# 流式方法直接返回,不包装(保持生成器特性)
|
||||
if name in ('_stream', '_astream', 'stream', 'astream'):
|
||||
return attr
|
||||
|
||||
# 非流式方法使用包装器处理异常
|
||||
def method_wrapper(*args, **kwargs):
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
# 保持方法的元信息
|
||||
method_wrapper.__name__ = name
|
||||
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
|
||||
return method_wrapper
|
||||
|
||||
# 如果是普通属性,直接返回
|
||||
return attr
|
||||
|
||||
except AttributeError:
|
||||
# 内部模型没有该属性,尝试回退实现
|
||||
pass
|
||||
|
||||
# 检查是否有回退方法(使用安全的方式避免递归)
|
||||
fallback_name = f'_fallback_{name}'
|
||||
try:
|
||||
fallback_method = object.__getattribute__(self, fallback_name)
|
||||
return fallback_method
|
||||
except AttributeError:
|
||||
# 没有回退方法,抛出适当的错误
|
||||
pass
|
||||
|
||||
# 如果都没有,抛出适当的错误
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
|
||||
"""创建内部模型实例"""
|
||||
llm_class = get_provider_llm_class(config, type)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return llm_class(**model_params)
|
||||
|
||||
|
||||
|
||||
35
api/app/core/models/rerank copy.py
Normal file
35
api/app/core/models/rerank copy.py
Normal file
@@ -0,0 +1,35 @@
|
||||
|
||||
# from typing import Any, Dict, List, Optional
|
||||
# from langchain_core.runnables import RunnableSerializable
|
||||
|
||||
# from app.core.models.base import RedBearModelConfig
|
||||
|
||||
# class RedBearRerank(RunnableSerializable[str, List[float]]):
|
||||
# """ Rerank → 作为 Runnable 插入任意 LCEL 链"""
|
||||
# def __init__(self, config: RedBearModelConfig):
|
||||
# super().__init__(self, config)
|
||||
|
||||
# def invoke(self, input: Dict[str, Any], config: Optional[Dict] = None) -> List[float]:
|
||||
# query, docs = input["query"], input["documents"]
|
||||
# url = (self.config.base_url or "https://api.cohere.ai/v1") + "/rerank"
|
||||
# body = {
|
||||
# "query": query,
|
||||
# "documents": docs,
|
||||
# "model": self.config.model_name,
|
||||
# "top_n": len(docs),
|
||||
# }
|
||||
# js = self._sync_post(url, body)
|
||||
# scores = [0.0] * len(docs)
|
||||
# for item in js["results"]:
|
||||
# scores[item["index"]] = item["relevance_score"]
|
||||
# return scores
|
||||
|
||||
# async def ainvoke(self, input: Dict[str, Any], config: Optional[Dict] = None) -> List[float]:
|
||||
# query, docs = input["query"], input["documents"]
|
||||
# url = (self.config.base_url or "https://api.cohere.ai/v1") + "/rerank"
|
||||
# body = {"query": query, "documents": docs, "model": self.config.model_name, "top_n": len(docs)}
|
||||
# js = await self._async_post(url, body)
|
||||
# scores = [0.0] * len(docs)
|
||||
# for item in js["results"]:
|
||||
# scores[item["index"]] = item["relevance_score"]
|
||||
# return scores
|
||||
80
api/app/core/models/rerank.py
Normal file
80
api/app/core/models/rerank.py
Normal file
@@ -0,0 +1,80 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Type, Union
|
||||
from copy import deepcopy
|
||||
from urllib.parse import urlparse
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory
|
||||
from app.models import ModelProvider
|
||||
|
||||
class RedBearRerank(BaseDocumentCompressor):
|
||||
""" Rerank → 作为 Runnable 插入任意 LCEL 链"""
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._model = self._create_model(config)
|
||||
self._config = config
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig):
|
||||
"""创建内部模型实例"""
|
||||
model_class = get_provider_rerank_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_rerank_model_params(config)
|
||||
print(model_params)
|
||||
return model_class(**model_params)
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Compress documents using Jina's Rerank API.
|
||||
|
||||
Args:
|
||||
documents: A sequence of documents to compress.
|
||||
query: The query to use for compressing the documents.
|
||||
callbacks: Callbacks to run during the compression process.
|
||||
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
compressed = []
|
||||
for res in self.rerank(documents, query):
|
||||
doc = documents[res["index"]]
|
||||
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
|
||||
doc_copy.metadata["relevance_score"] = res["relevance_score"]
|
||||
compressed.append(doc_copy)
|
||||
return compressed
|
||||
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
documents: Sequence[Union[str, Document, dict]],
|
||||
query: str,
|
||||
*,
|
||||
top_n: Optional[int] = -1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
provider = self._config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
import langchain_community.document_compressors.jina_rerank as jina_mod
|
||||
# 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank
|
||||
def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]:
|
||||
if not base_url:
|
||||
return None
|
||||
url = base_url.rstrip('/')
|
||||
if url.endswith("/v1/rerank"):
|
||||
return url
|
||||
if url.endswith("/v1"):
|
||||
return url + "/rerank"
|
||||
return url + "/v1/rerank"
|
||||
|
||||
jina_base = _normalize_jina_base(self._config.base_url)
|
||||
if jina_base:
|
||||
# 设置完整的 rerank 端点,例如 http://host:port/v1/rerank
|
||||
jina_mod.JINA_API_URL = jina_base
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
model_instance : JinaRerank = self._model
|
||||
return model_instance.rerank(documents = documents, query = query, top_n=top_n)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型提供商: {provider}")
|
||||
|
||||
Reference in New Issue
Block a user