feat(model): add volcano model

This commit is contained in:
Timebomb2018
2026-03-25 11:45:49 +08:00
parent 6348304b7d
commit 54cff5861a
13 changed files with 1122 additions and 33 deletions

View File

@@ -1,23 +1,190 @@
from typing import Any, Dict, List, Optional, TypeVar, Callable
from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
from app.models.models_model import ModelProvider
class RedBearEmbeddings(Embeddings):
"""Embedding → 完全符合 LangChain Embeddings"""
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
if self._is_volcano:
# 火山引擎使用 Ark SDK
self._client = self._create_volcano_client(config)
self._model = None
else:
# 其他 provider 使用 LangChain
self._model = self._create_model(config)
self._client = None
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建模型"""
"""根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""
from volcenginesdkarkruntime import Ark
return Ark(api_key=config.api_key, base_url=config.base_url)
# ==================== LangChain 标准接口 ====================
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._model.embed_documents(texts)
"""批量文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
contents = [{"type": "text", "text": text} for text in texts]
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
encoding_format="float"
)
return [response.data.embedding]
else:
# 其他 provider
return self._model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return self._model.embed_query(text)
"""单个文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
result = self.embed_documents([text])
return result[0] if result else []
else:
# 其他 provider
return self._model.embed_query(text)
# ==================== 多模态扩展方法 ====================
def embed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""
多模态向量化(仅火山引擎支持)
Args:
contents: 内容列表,格式:
- 文本: {"type": "text", "text": "..."}
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
**kwargs: 其他参数
Returns:
向量列表
"""
if not self._is_volcano:
raise NotImplementedError(
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
)
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
**kwargs
)
return [item.embedding for item in response.data]
async def aembed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""异步多模态向量化"""
# 火山引擎 SDK 暂不支持异步,使用同步方法
return self.embed_multimodal(contents, **kwargs)
def embed_text(self, text: str, **kwargs) -> List[float]:
"""文本向量化(便捷方法)"""
if self._is_volcano:
result = self.embed_multimodal(
[{"type": "text", "text": text}],
**kwargs
)
return result[0] if result else []
else:
return self.embed_query(text)
def embed_image(self, image_url: str, **kwargs) -> List[float]:
"""图片向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "image_url", "image_url": {"url": image_url}}],
**kwargs
)
return result[0] if result else []
def embed_video(self, video_url: str, **kwargs) -> List[float]:
"""视频向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "video_url", "video_url": {"url": video_url}}],
**kwargs
)
return result[0] if result else []
def embed_batch(
self,
items: List[Union[str, Dict[str, Any]]],
**kwargs
) -> List[List[float]]:
"""
批量向量化(支持混合类型)
Args:
items: 可以是字符串列表或内容字典列表
**kwargs: 其他参数
Returns:
向量列表
"""
# 如果全是字符串,使用标准方法
if all(isinstance(item, str) for item in items):
return self.embed_documents(items)
# 如果包含字典,需要多模态支持
if not self._is_volcano:
raise NotImplementedError(
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
# 标准化输入格式
contents = []
for item in items:
if isinstance(item, str):
contents.append({"type": "text", "text": item})
elif isinstance(item, dict):
contents.append(item)
else:
raise ValueError(f"不支持的输入类型: {type(item)}")
return self.embed_multimodal(contents, **kwargs)
# ==================== 工具方法 ====================
def is_multimodal_supported(self) -> bool:
"""检查是否支持多模态"""
return self._is_volcano
def get_provider(self) -> str:
"""获取 provider"""
return self._config.provider
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
RedBearMultimodalEmbeddings = RedBearEmbeddings