feat(model): add volcano model
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
OpenAI Embedder 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||
自动支持火山引擎的多模态 Embedding。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
- 批量文本嵌入
|
||||
- 自动重试机制
|
||||
- 错误处理
|
||||
- 火山引擎多模态 Embedding(自动识别)
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
"""
|
||||
super().__init__(model_config)
|
||||
|
||||
# 初始化 RedBearEmbeddings 模型
|
||||
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
self.model = RedBearEmbeddings(
|
||||
RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
)
|
||||
self.is_multimodal = self.model.is_multimodal_supported()
|
||||
|
||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
||||
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
|
||||
|
||||
async def response(
|
||||
self,
|
||||
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
return []
|
||||
|
||||
# 生成嵌入向量
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
if self.is_multimodal:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = await self.model.aembed_multimodal(
|
||||
[{"type": "text", "text": text} for text in texts]
|
||||
)
|
||||
else:
|
||||
# 普通 Embedding
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
|
||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
|
||||
from .llm import RedBearLLM
|
||||
from .embedding import RedBearEmbeddings
|
||||
from .rerank import RedBearRerank
|
||||
from .generation import RedBearImageGenerator, RedBearVideoGenerator
|
||||
|
||||
__all__ = [
|
||||
"RedBearModelConfig",
|
||||
@@ -9,5 +10,7 @@ __all__ = [
|
||||
"RedBearEmbeddings",
|
||||
"RedBearRerank",
|
||||
"RedBearModelFactory",
|
||||
"get_provider_llm_class"
|
||||
"get_provider_llm_class",
|
||||
"RedBearImageGenerator",
|
||||
"RedBearVideoGenerator"
|
||||
]
|
||||
@@ -67,7 +67,7 @@ class RedBearModelFactory:
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
# 这样可以分别控制连接超时和读取超时
|
||||
import httpx
|
||||
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
return ChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
|
||||
if type == ModelType.LLM:
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
return ChatOpenAI
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
|
||||
@@ -1,23 +1,190 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Callable
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
|
||||
class RedBearEmbeddings(Embeddings):
|
||||
"""Embedding → 完全符合 LangChain Embeddings"""
|
||||
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._model = self._create_model(config)
|
||||
self._config = config
|
||||
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
|
||||
|
||||
if self._is_volcano:
|
||||
# 火山引擎使用 Ark SDK
|
||||
self._client = self._create_volcano_client(config)
|
||||
self._model = None
|
||||
else:
|
||||
# 其他 provider 使用 LangChain
|
||||
self._model = self._create_model(config)
|
||||
self._client = None
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建模型"""
|
||||
"""根据配置创建 LangChain 模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
|
||||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||
"""创建火山引擎客户端"""
|
||||
from volcenginesdkarkruntime import Ark
|
||||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
# ==================== LangChain 标准接口 ====================
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._model.embed_documents(texts)
|
||||
"""批量文本向量化(LangChain 标准接口)"""
|
||||
if self._is_volcano:
|
||||
# 火山引擎多模态 Embedding
|
||||
contents = [{"type": "text", "text": text} for text in texts]
|
||||
response = self._client.multimodal_embeddings.create(
|
||||
model=self._config.model_name,
|
||||
input=contents,
|
||||
encoding_format="float"
|
||||
)
|
||||
return [response.data.embedding]
|
||||
else:
|
||||
# 其他 provider
|
||||
return self._model.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._model.embed_query(text)
|
||||
"""单个文本向量化(LangChain 标准接口)"""
|
||||
if self._is_volcano:
|
||||
# 火山引擎多模态 Embedding
|
||||
result = self.embed_documents([text])
|
||||
return result[0] if result else []
|
||||
else:
|
||||
# 其他 provider
|
||||
return self._model.embed_query(text)
|
||||
|
||||
# ==================== 多模态扩展方法 ====================
|
||||
|
||||
def embed_multimodal(
|
||||
self,
|
||||
contents: List[Dict[str, Any]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
多模态向量化(仅火山引擎支持)
|
||||
|
||||
Args:
|
||||
contents: 内容列表,格式:
|
||||
- 文本: {"type": "text", "text": "..."}
|
||||
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
|
||||
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
response = self._client.multimodal_embeddings.create(
|
||||
model=self._config.model_name,
|
||||
input=contents,
|
||||
**kwargs
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
async def aembed_multimodal(
|
||||
self,
|
||||
contents: List[Dict[str, Any]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""异步多模态向量化"""
|
||||
# 火山引擎 SDK 暂不支持异步,使用同步方法
|
||||
return self.embed_multimodal(contents, **kwargs)
|
||||
|
||||
def embed_text(self, text: str, **kwargs) -> List[float]:
|
||||
"""文本向量化(便捷方法)"""
|
||||
if self._is_volcano:
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "text", "text": text}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
else:
|
||||
return self.embed_query(text)
|
||||
|
||||
def embed_image(self, image_url: str, **kwargs) -> List[float]:
|
||||
"""图片向量化(仅火山引擎支持)"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "image_url", "image_url": {"url": image_url}}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
|
||||
def embed_video(self, video_url: str, **kwargs) -> List[float]:
|
||||
"""视频向量化(仅火山引擎支持)"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "video_url", "video_url": {"url": video_url}}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
items: List[Union[str, Dict[str, Any]]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
批量向量化(支持混合类型)
|
||||
|
||||
Args:
|
||||
items: 可以是字符串列表或内容字典列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
# 如果全是字符串,使用标准方法
|
||||
if all(isinstance(item, str) for item in items):
|
||||
return self.embed_documents(items)
|
||||
|
||||
# 如果包含字典,需要多模态支持
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
# 标准化输入格式
|
||||
contents = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
contents.append({"type": "text", "text": item})
|
||||
elif isinstance(item, dict):
|
||||
contents.append(item)
|
||||
else:
|
||||
raise ValueError(f"不支持的输入类型: {type(item)}")
|
||||
|
||||
return self.embed_multimodal(contents, **kwargs)
|
||||
|
||||
# ==================== 工具方法 ====================
|
||||
|
||||
def is_multimodal_supported(self) -> bool:
|
||||
"""检查是否支持多模态"""
|
||||
return self._is_volcano
|
||||
|
||||
def get_provider(self) -> str:
|
||||
"""获取 provider"""
|
||||
return self._config.provider
|
||||
|
||||
|
||||
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
|
||||
RedBearMultimodalEmbeddings = RedBearEmbeddings
|
||||
|
||||
345
api/app/core/models/generation.py
Normal file
345
api/app/core/models/generation.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
图片和视频生成模型封装
|
||||
|
||||
支持的 Provider:
|
||||
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
|
||||
- OpenAI: 使用 openai SDK
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from volcenginesdkarkruntime import Ark
|
||||
from volcenginesdkarkruntime.types.images.images import (
|
||||
SequentialImageGenerationOptions,
|
||||
ContentGenerationTool,
|
||||
OptimizePromptOptions
|
||||
)
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
|
||||
class RedBearImageGenerator:
|
||||
"""图片生成模型封装"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._config = config
|
||||
self._client = self._create_client(config)
|
||||
|
||||
def _create_client(self, config: RedBearModelConfig):
|
||||
"""根据 provider 创建客户端"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||
# elif provider == ModelProvider.OPENAI:
|
||||
# from openai import OpenAI
|
||||
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的图片生成提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
image: Optional[Any] = None,
|
||||
size: Optional[str] = "2K",
|
||||
output_format: str = "png",
|
||||
response_format: str = "url",
|
||||
watermark: bool = False,
|
||||
sequential_image_generation: Optional[str] = None,
|
||||
sequential_image_generation_options: Optional[Dict] = None,
|
||||
tools: Optional[list] = None,
|
||||
optimize_prompt_options: Optional[Dict] = None,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成图片
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
image: 参考图片URL或URL列表(图文生图/多图融合)
|
||||
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素)
|
||||
n: 生成数量
|
||||
output_format: 输出格式,如 "png", "jpg"
|
||||
response_format: 返回格式,"url" 或 "b64_json"
|
||||
watermark: 是否添加水印
|
||||
sequential_image_generation: 组图生成模式,"auto" 或 "disabled"
|
||||
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
|
||||
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
|
||||
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
|
||||
stream: 是否使用流式生成
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
params = {
|
||||
"model": self._config.model_name,
|
||||
"prompt": prompt,
|
||||
"size": size,
|
||||
"output_format": output_format,
|
||||
"response_format": response_format,
|
||||
"watermark": watermark,
|
||||
}
|
||||
|
||||
if image is not None:
|
||||
params["image"] = image
|
||||
|
||||
if sequential_image_generation:
|
||||
params["sequential_image_generation"] = sequential_image_generation
|
||||
if sequential_image_generation_options:
|
||||
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
|
||||
**sequential_image_generation_options
|
||||
)
|
||||
|
||||
if tools:
|
||||
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
|
||||
|
||||
if optimize_prompt_options:
|
||||
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
|
||||
|
||||
if stream:
|
||||
params["stream"] = True
|
||||
|
||||
params.update(kwargs)
|
||||
response = self._client.images.generate(**params)
|
||||
|
||||
# elif provider == ModelProvider.OPENAI:
|
||||
# response = self._client.images.generate(
|
||||
# model=self._config.model_name,
|
||||
# prompt=prompt,
|
||||
# size=size,
|
||||
# n=n,
|
||||
# **kwargs
|
||||
# )
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompt: str,
|
||||
image: Optional[Any] = None,
|
||||
size: Optional[str] = "2K",
|
||||
output_format: str = "png",
|
||||
response_format: str = "url",
|
||||
watermark: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""异步生成图片"""
|
||||
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
|
||||
|
||||
|
||||
class RedBearVideoGenerator:
|
||||
"""视频生成模型封装"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._config = config
|
||||
self._client = self._create_client(config)
|
||||
|
||||
def _create_client(self, config: RedBearModelConfig):
|
||||
"""根据 provider 创建客户端"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的视频生成提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
image_url: Optional[str] = None,
|
||||
first_frame_url: Optional[str] = None,
|
||||
last_frame_url: Optional[str] = None,
|
||||
reference_images: Optional[list] = None,
|
||||
draft_task_id: Optional[str] = None,
|
||||
duration: Optional[int] = None,
|
||||
frames: Optional[int] = None,
|
||||
ratio: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
generate_audio: bool = False,
|
||||
watermark: bool = False,
|
||||
camera_fixed: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
return_last_frame: bool = False,
|
||||
service_tier: str = "default",
|
||||
execution_expires_after: Optional[int] = None,
|
||||
draft: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成视频
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
image_url: 首帧图片URL(图生视频-基于首帧)
|
||||
first_frame_url: 首帧图片URL(图生视频-基于首尾帧)
|
||||
last_frame_url: 尾帧图片URL(图生视频-基于首尾帧)
|
||||
reference_images: 参考图片URL列表(图生视频-基于参考图)
|
||||
draft_task_id: Draft任务ID(基于Draft生成正式视频)
|
||||
duration: 视频时长(秒),与frames二选一
|
||||
frames: 视频帧数,与duration二选一
|
||||
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
|
||||
resolution: 视频分辨率,如 "720p", "1080p"
|
||||
generate_audio: 是否生成音频
|
||||
watermark: 是否添加水印
|
||||
camera_fixed: 是否固定镜头
|
||||
seed: 随机种子
|
||||
return_last_frame: 是否返回最后一帧
|
||||
service_tier: 服务层级,"default" 或 "flex"(离线推理)
|
||||
execution_expires_after: 任务过期时间(秒)
|
||||
draft: 是否生成样片
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果(包含任务ID,需要轮询获取结果)
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
if draft_task_id:
|
||||
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
|
||||
else:
|
||||
if image_url:
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
|
||||
if first_frame_url:
|
||||
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
|
||||
if last_frame_url:
|
||||
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
|
||||
|
||||
if reference_images:
|
||||
for ref_url in reference_images:
|
||||
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
|
||||
|
||||
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
|
||||
|
||||
if duration:
|
||||
params["duration"] = duration
|
||||
if frames:
|
||||
params["frames"] = frames
|
||||
if ratio:
|
||||
params["ratio"] = ratio
|
||||
if resolution:
|
||||
params["resolution"] = resolution
|
||||
if generate_audio:
|
||||
params["generate_audio"] = generate_audio
|
||||
if camera_fixed:
|
||||
params["camera_fixed"] = camera_fixed
|
||||
if seed is not None:
|
||||
params["seed"] = seed
|
||||
if return_last_frame:
|
||||
params["return_last_frame"] = return_last_frame
|
||||
if service_tier != "default":
|
||||
params["service_tier"] = service_tier
|
||||
if execution_expires_after:
|
||||
params["execution_expires_after"] = execution_expires_after
|
||||
if draft:
|
||||
params["draft"] = draft
|
||||
|
||||
params.update(kwargs)
|
||||
response = self._client.content_generation.tasks.create(**params)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompt: str,
|
||||
image_url: Optional[str] = None,
|
||||
duration: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""异步生成视频"""
|
||||
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
|
||||
|
||||
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
查询视频生成任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务状态信息
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
response = self._client.content_generation.tasks.get(task_id=task_id)
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""异步查询任务状态"""
|
||||
return self.get_task_status(task_id)
|
||||
|
||||
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
查询视频生成任务列表
|
||||
|
||||
Args:
|
||||
page_size: 每页数量
|
||||
status: 任务状态筛选,如 "succeeded", "failed", "pending"
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
params = {"page_size": page_size}
|
||||
if status:
|
||||
params["status"] = status
|
||||
params.update(kwargs)
|
||||
response = self._client.content_generation.tasks.list(**params)
|
||||
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
|
||||
def delete_task(self, task_id: str) -> None:
|
||||
"""
|
||||
删除或取消视频生成任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
"""
|
||||
provider = self._config.provider.lower()
|
||||
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
self._client.content_generation.tasks.delete(task_id=task_id)
|
||||
else:
|
||||
raise BusinessException(
|
||||
f"不支持的提供商: {provider}",
|
||||
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||
)
|
||||
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
@@ -0,0 +1,334 @@
|
||||
provider: volcano
|
||||
models:
|
||||
# Doubao-Seed 2.0 系列
|
||||
- name: doubao-seed-2-0-pro-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-2-0-lite-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-2-0-mini-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-2-0-code-preview-260215
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao-Seed 1.x 系列
|
||||
- name: doubao-seed-1-8-251228
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-251015
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-lite-251015
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-flash-250828
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-code-preview-251028
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 面向Agentic编程任务进行了深度优化。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seed-1-6-vision-250815
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao 1.5 系列
|
||||
- name: doubao-1-5-vision-pro-32k-250115
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-1-5-pro-32k-250115
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-1-5-lite-32k-250115
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao-Seedance 视频生成系列
|
||||
- name: doubao-seedance-1-5-pro-251215
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-pro-250528
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-pro-fast-251015
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-lite-i2v-250428
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
- 图生视频
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedance-1-0-lite-t2v-250428
|
||||
type: video
|
||||
provider: volcano
|
||||
description: 基于文本提示词生成视频
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 视频生成
|
||||
- 文生视频
|
||||
logo: volcano
|
||||
|
||||
# Doubao-Seedream 图像生成系列
|
||||
- name: doubao-seedream-5-0-260128
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedream-4-5-251128
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedream-4-0-250828
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
logo: volcano
|
||||
|
||||
- name: doubao-seedream-3-0-t2i-250415
|
||||
type: image
|
||||
provider: volcano
|
||||
description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 图像生成
|
||||
- 文生图
|
||||
logo: volcano
|
||||
|
||||
# Doubao 翻译系列
|
||||
- name: doubao-seed-translation-250915
|
||||
type: chat
|
||||
provider: volcano
|
||||
description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 翻译模型
|
||||
logo: volcano
|
||||
|
||||
# Doubao Embedding 系列
|
||||
- name: doubao-embedding-vision-251215
|
||||
type: embedding
|
||||
provider: volcano
|
||||
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 向量模型
|
||||
- 多模态模型
|
||||
logo: volcano
|
||||
@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
|
||||
class ElasticSearchVector(BaseVector):
|
||||
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||
super().__init__(index_name.lower())
|
||||
# self.embeddings = XinferenceEmbeddings(
|
||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port
|
||||
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
|
||||
# )
|
||||
# Remove debug printing to avoid leaking sensitive information
|
||||
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
|
||||
|
||||
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
||||
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
|
||||
model_name=embedding_config.model_name,
|
||||
provider=embedding_config.provider,
|
||||
api_key=embedding_config.api_key,
|
||||
base_url=embedding_config.api_base
|
||||
))
|
||||
# self.reranker = XinferenceRerank(
|
||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
|
||||
# model_uid="bge-reranker-large"
|
||||
# )
|
||||
# Remove debug printing to avoid leaking sensitive information
|
||||
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
|
||||
self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
|
||||
|
||||
self.reranker = RedBearRerank(RedBearModelConfig(
|
||||
model_name=reranker_config.model_name,
|
||||
provider=reranker_config.provider,
|
||||
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
|
||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||
# 实现 Elasticsearch 保存向量
|
||||
texts = [chunk.page_content for chunk in chunks]
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = self.embeddings.embed_batch(texts)
|
||||
else:
|
||||
embeddings = self.embeddings.embed_documents(list(texts))
|
||||
self.create(chunks, embeddings, **kwargs)
|
||||
|
||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
|
||||
updated count.
|
||||
"""
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||
else:
|
||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||
|
||||
body = {
|
||||
"script": {
|
||||
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
|
||||
"""Search the nearest neighbors to a vector."""
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
if self.is_multimodal_embedding:
|
||||
# 火山引擎多模态 Embedding
|
||||
query_vector = self.embeddings.embed_text(query)
|
||||
else:
|
||||
query_vector = self.embeddings.embed_query(query)
|
||||
top_k = kwargs.get("top_k", 1024)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.3)
|
||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||
|
||||
@@ -26,9 +26,9 @@ class ModelType(StrEnum):
|
||||
RERANK = "rerank"
|
||||
# TTS = "tts"
|
||||
# SPEECH2TEXT = "speech2text"
|
||||
# IMAGE = "image"
|
||||
IMAGE = "image"
|
||||
# AUDIO = "audio"
|
||||
# VISION = "vision"
|
||||
VIDEO = "video"
|
||||
|
||||
|
||||
class ModelProvider(StrEnum):
|
||||
@@ -45,6 +45,7 @@ class ModelProvider(StrEnum):
|
||||
XINFERENCE = "xinference"
|
||||
GPUSTACK = "gpustack"
|
||||
BEDROCK = "bedrock"
|
||||
VOLCANO = "volcano"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
|
||||
|
||||
@@ -435,7 +435,6 @@ class ModelConfigRepository:
|
||||
ModelConfig.is_public
|
||||
),
|
||||
ModelConfig.provider == provider,
|
||||
ModelConfig.is_active,
|
||||
~ModelConfig.is_composite
|
||||
)
|
||||
).all()
|
||||
|
||||
164
api/app/services/generation_service.py
Normal file
164
api/app/services/generation_service.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
图片和视频生成服务
|
||||
|
||||
提供统一的生成接口,支持多种 Provider
|
||||
"""
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
|
||||
class GenerationService:
|
||||
"""生成服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
model_config_id: str,
|
||||
prompt: str,
|
||||
size: Optional[str] = "1024x1024",
|
||||
n: int = 1,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成图片
|
||||
|
||||
Args:
|
||||
model_config_id: 模型配置ID
|
||||
prompt: 提示词
|
||||
size: 图片尺寸
|
||||
n: 生成数量
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果
|
||||
"""
|
||||
# 获取模型配置
|
||||
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||
|
||||
if model_config.type != ModelType.IMAGE:
|
||||
raise BusinessException(
|
||||
f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
|
||||
code=BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||
if not api_key_info:
|
||||
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 创建配置
|
||||
config = RedBearModelConfig(
|
||||
model_name=api_key_info.model_name,
|
||||
provider=api_key_info.provider,
|
||||
api_key=api_key_info.api_key,
|
||||
base_url=api_key_info.api_base,
|
||||
extra_params=api_key_info.config or {}
|
||||
)
|
||||
|
||||
# 生成图片
|
||||
generator = RedBearImageGenerator(config)
|
||||
result = await generator.agenerate(prompt, size, n, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
model_config_id: str,
|
||||
prompt: str,
|
||||
duration: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成视频
|
||||
|
||||
Args:
|
||||
model_config_id: 模型配置ID
|
||||
prompt: 提示词
|
||||
duration: 视频时长(秒)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
生成结果(包含任务ID)
|
||||
"""
|
||||
# 获取模型配置
|
||||
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||
|
||||
if model_config.type != ModelType.VIDEO:
|
||||
raise BusinessException(
|
||||
f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
|
||||
code=BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||
if not api_key_info:
|
||||
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 创建配置
|
||||
config = RedBearModelConfig(
|
||||
model_name=api_key_info.model_name,
|
||||
provider=api_key_info.provider,
|
||||
api_key=api_key_info.api_key,
|
||||
base_url=api_key_info.api_base,
|
||||
extra_params=api_key_info.config or {}
|
||||
)
|
||||
|
||||
# 生成视频
|
||||
generator = RedBearVideoGenerator(config)
|
||||
result = await generator.agenerate(prompt, duration, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
async def get_video_task_status(
|
||||
self,
|
||||
model_config_id: str,
|
||||
task_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
查询视频生成任务状态
|
||||
|
||||
Args:
|
||||
model_config_id: 模型配置ID
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务状态信息
|
||||
"""
|
||||
# 获取模型配置
|
||||
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
|
||||
if not api_key_info:
|
||||
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
|
||||
|
||||
# 创建配置
|
||||
config = RedBearModelConfig(
|
||||
model_name=api_key_info.model_name,
|
||||
provider=api_key_info.provider,
|
||||
api_key=api_key_info.api_key,
|
||||
base_url=api_key_info.api_base,
|
||||
extra_params=api_key_info.config or {}
|
||||
)
|
||||
|
||||
# 查询任务状态
|
||||
generator = RedBearVideoGenerator(config)
|
||||
result = await generator.aget_task_status(task_id)
|
||||
|
||||
return result
|
||||
@@ -154,10 +154,17 @@ class ModelConfigService:
|
||||
}
|
||||
|
||||
elif model_type_lower == "embedding":
|
||||
# Embedding 模型验证(在线程中运行同步方法)
|
||||
# Embedding 模型验证
|
||||
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
|
||||
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||
if provider.lower() == "volcano":
|
||||
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||
else:
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -193,6 +200,56 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
elif model_type_lower == "image":
|
||||
# 图片生成模型验证
|
||||
from app.core.models.generation import RedBearImageGenerator
|
||||
|
||||
generator = RedBearImageGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda",
|
||||
size="2K"
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"成功生成图片,结果: {result}")
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "图片生成模型配置验证成功",
|
||||
"response": f"成功生成图片,结果: {result}",
|
||||
"elapsed_time": elapsed_time,
|
||||
"usage": {
|
||||
"prompt_length": len("a cute panda"),
|
||||
"image_count": 1
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
elif model_type_lower == "video":
|
||||
# 视频生成模型验证
|
||||
from app.core.models.generation import RedBearVideoGenerator
|
||||
|
||||
generator = RedBearVideoGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda playing in bamboo forest",
|
||||
duration=5
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 视频生成是异步任务,返回任务ID
|
||||
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "视频生成模型配置验证成功",
|
||||
"response": f"成功创建视频生成任务,任务ID: {task_id}",
|
||||
"elapsed_time": elapsed_time,
|
||||
"usage": {
|
||||
"prompt_length": len("a cute panda playing in bamboo forest"),
|
||||
"task_id": task_id
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
|
||||
@@ -297,6 +297,7 @@ PROVIDER_STRATEGIES = {
|
||||
"bedrock": BedrockFormatStrategy,
|
||||
"anthropic": BedrockFormatStrategy,
|
||||
"openai": OpenAIFormatStrategy,
|
||||
"volcano": OpenAIFormatStrategy,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ dependencies = [
|
||||
"modelscope>=1.34.0",
|
||||
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
|
||||
"python-magic-bin>=0.4.14; sys_platform=='win32'",
|
||||
"volcengine-python-sdk[ark]==5.0.19"
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
Reference in New Issue
Block a user