220 lines
7.9 KiB
Python
220 lines
7.9 KiB
Python
|
||
from typing import Any, Dict, List, Union
|
||
from langchain_core.embeddings import Embeddings
|
||
|
||
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||
from app.models.models_model import ModelProvider
|
||
|
||
|
||
class RedBearEmbeddings(Embeddings):
|
||
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
|
||
|
||
def __init__(self, config: RedBearModelConfig):
|
||
self._config = config
|
||
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
|
||
|
||
if self._is_volcano:
|
||
# 火山引擎使用 Ark SDK
|
||
self._client = self._create_volcano_client(config)
|
||
self._model = None
|
||
else:
|
||
# 其他 provider 使用 LangChain
|
||
self._model = self._create_model(config)
|
||
self._client = None
|
||
|
||
@staticmethod
|
||
def _create_model(config: RedBearModelConfig) -> Embeddings:
|
||
"""根据配置创建 LangChain 模型"""
|
||
embedding_class = get_provider_embedding_class(config.provider)
|
||
provider = config.provider.lower()
|
||
# Embedding models only need connection params, never LLM-specific ones
|
||
# (e.g. enable_thinking, model_kwargs) — build params directly.
|
||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||
import httpx
|
||
params = {
|
||
"model": config.model_name,
|
||
"base_url": config.base_url,
|
||
"api_key": config.api_key,
|
||
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
|
||
"max_retries": config.max_retries,
|
||
"check_embedding_ctx_length": False,
|
||
"encoding_format": "float"
|
||
}
|
||
elif provider == ModelProvider.DASHSCOPE:
|
||
params = {
|
||
"model": config.model_name,
|
||
"dashscope_api_key": config.api_key,
|
||
"max_retries": config.max_retries,
|
||
}
|
||
elif provider == ModelProvider.OLLAMA:
|
||
params = {
|
||
"model": config.model_name,
|
||
"base_url": config.base_url,
|
||
}
|
||
elif provider == ModelProvider.BEDROCK:
|
||
params = RedBearModelFactory.get_model_params(config)
|
||
else:
|
||
params = RedBearModelFactory.get_model_params(config)
|
||
return embedding_class(**params)
|
||
|
||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||
"""创建火山引擎客户端"""
|
||
from volcenginesdkarkruntime import Ark
|
||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||
|
||
# ==================== LangChain 标准接口 ====================
|
||
|
||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||
"""批量文本向量化(LangChain 标准接口)"""
|
||
if self._is_volcano:
|
||
# 火山引擎多模态 Embedding
|
||
contents = [{"type": "text", "text": text} for text in texts]
|
||
response = self._client.multimodal_embeddings.create(
|
||
model=self._config.model_name,
|
||
input=contents,
|
||
encoding_format="float"
|
||
)
|
||
return [response.data.embedding]
|
||
else:
|
||
# 其他 provider
|
||
return self._model.embed_documents(texts)
|
||
|
||
def embed_query(self, text: str) -> List[float]:
|
||
"""单个文本向量化(LangChain 标准接口)"""
|
||
if self._is_volcano:
|
||
# 火山引擎多模态 Embedding
|
||
result = self.embed_documents([text])
|
||
return result[0] if result else []
|
||
else:
|
||
# 其他 provider
|
||
return self._model.embed_query(text)
|
||
|
||
# ==================== 多模态扩展方法 ====================
|
||
|
||
def embed_multimodal(
|
||
self,
|
||
contents: List[Dict[str, Any]],
|
||
**kwargs
|
||
) -> List[List[float]]:
|
||
"""
|
||
多模态向量化(仅火山引擎支持)
|
||
|
||
Args:
|
||
contents: 内容列表,格式:
|
||
- 文本: {"type": "text", "text": "..."}
|
||
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
|
||
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
向量列表
|
||
"""
|
||
if not self._is_volcano:
|
||
raise NotImplementedError(
|
||
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
|
||
)
|
||
|
||
response = self._client.multimodal_embeddings.create(
|
||
model=self._config.model_name,
|
||
input=contents,
|
||
**kwargs
|
||
)
|
||
return [response.data.embedding]
|
||
|
||
async def aembed_multimodal(
|
||
self,
|
||
contents: List[Dict[str, Any]],
|
||
**kwargs
|
||
) -> List[List[float]]:
|
||
"""异步多模态向量化"""
|
||
# 火山引擎 SDK 暂不支持异步,使用同步方法
|
||
return self.embed_multimodal(contents, **kwargs)
|
||
|
||
def embed_text(self, text: str, **kwargs) -> List[float]:
|
||
"""文本向量化(便捷方法)"""
|
||
if self._is_volcano:
|
||
result = self.embed_multimodal(
|
||
[{"type": "text", "text": text}],
|
||
**kwargs
|
||
)
|
||
return result[0] if result else []
|
||
else:
|
||
return self.embed_query(text)
|
||
|
||
def embed_image(self, image_url: str, **kwargs) -> List[float]:
|
||
"""图片向量化(仅火山引擎支持)"""
|
||
if not self._is_volcano:
|
||
raise NotImplementedError(
|
||
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||
)
|
||
|
||
result = self.embed_multimodal(
|
||
[{"type": "image_url", "image_url": {"url": image_url}}],
|
||
**kwargs
|
||
)
|
||
return result[0] if result else []
|
||
|
||
def embed_video(self, video_url: str, **kwargs) -> List[float]:
|
||
"""视频向量化(仅火山引擎支持)"""
|
||
if not self._is_volcano:
|
||
raise NotImplementedError(
|
||
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||
)
|
||
|
||
result = self.embed_multimodal(
|
||
[{"type": "video_url", "video_url": {"url": video_url}}],
|
||
**kwargs
|
||
)
|
||
return result[0] if result else []
|
||
|
||
def embed_batch(
|
||
self,
|
||
items: List[Union[str, Dict[str, Any]]],
|
||
**kwargs
|
||
) -> List[List[float]]:
|
||
"""
|
||
批量向量化(支持混合类型)
|
||
|
||
Args:
|
||
items: 可以是字符串列表或内容字典列表
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
向量列表
|
||
"""
|
||
# 如果全是字符串,使用标准方法
|
||
if all(isinstance(item, str) for item in items):
|
||
return self.embed_documents(items)
|
||
|
||
# 如果包含字典,需要多模态支持
|
||
if not self._is_volcano:
|
||
raise NotImplementedError(
|
||
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||
)
|
||
|
||
# 标准化输入格式
|
||
contents = []
|
||
for item in items:
|
||
if isinstance(item, str):
|
||
contents.append({"type": "text", "text": item})
|
||
elif isinstance(item, dict):
|
||
contents.append(item)
|
||
else:
|
||
raise ValueError(f"不支持的输入类型: {type(item)}")
|
||
|
||
return self.embed_multimodal(contents, **kwargs)
|
||
|
||
# ==================== 工具方法 ====================
|
||
|
||
def is_multimodal_supported(self) -> bool:
|
||
"""检查是否支持多模态"""
|
||
return self._is_volcano
|
||
|
||
def get_provider(self) -> str:
|
||
"""获取 provider"""
|
||
return self._config.provider
|
||
|
||
|
||
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
|
||
RedBearMultimodalEmbeddings = RedBearEmbeddings
|