Merge pull request #680 from SuanmoSuanyangTechnology/feature/agent-tool_xjn

feat(model)
This commit is contained in:
Mark
2026-03-25 18:55:10 +08:00
committed by GitHub
21 changed files with 1189 additions and 67 deletions

View File

@@ -574,8 +574,12 @@ async def get_file_url(
# For local storage, generate signed URL with expiration # For local storage, generate signed URL with expiration
url = generate_signed_url(str(file_id), expires) url = generate_signed_url(str(file_id), expires)
else: else:
# For remote storage (OSS/S3), get presigned URL # For remote storage (OSS/S3), get presigned URL with forced download
url = await storage_service.get_file_url(file_key, expires=expires) url = await storage_service.get_file_url(
file_key,
expires=expires,
file_name=file_metadata.file_name,
)
url = _match_scheme(request, url) url = _match_scheme(request, url)
api_logger.info(f"Generated file URL: file_id={file_id}") api_logger.info(f"Generated file URL: file_id={file_id}")
@@ -786,7 +790,7 @@ async def permanent_download_file(
# For remote storage, redirect to presigned URL with long expiration # For remote storage, redirect to presigned URL with long expiration
try: try:
# Use a very long expiration (7 days max for most cloud providers) # Use a very long expiration (7 days max for most cloud providers)
presigned_url = await storage_service.get_file_url(file_key, expires=604800) presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
presigned_url = _match_scheme(request, presigned_url) presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e: except Exception as e:

View File

@@ -91,9 +91,11 @@ async def get_mcp_servers(
try: try:
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers=api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=api.builder_headers(api.headers), headers=headers,
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -173,6 +175,7 @@ async def get_operational_mcp_servers(
url = f'{api.mcp_base_url}/operational' url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
try: try:
cookies = api.get_cookies(access_token=token, cookies_required=True) cookies = api.get_cookies(access_token=token, cookies_required=True)
@@ -260,7 +263,9 @@ async def create_mcp_market_config(
api.login(create_data.token) api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token) cookies = api.get_cookies(create_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {create_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
@@ -290,9 +295,11 @@ async def create_mcp_market_config(
'search': "" 'search': ""
} }
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=api.builder_headers(api.headers), headers=headers,
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -393,7 +400,9 @@ async def update_mcp_market_config(
api.login(update_data.token) api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token) cookies = api.get_cookies(update_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {update_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")

View File

@@ -669,6 +669,7 @@ async def config_query(
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": release.config.get("variables"), "variables": release.config.get("variables"),
"memory": release.config.get("memory", {}).get("enabled"),
"features": release.config.get("features") "features": release.config.get("features")
} }
elif release.app.type == AppType.MULTI_AGENT: elif release.app.type == AppType.MULTI_AGENT:

View File

@@ -2,6 +2,7 @@
OpenAI Embedder 客户端实现 OpenAI Embedder 客户端实现
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
自动支持火山引擎的多模态 Embedding。
""" """
from typing import List from typing import List
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
) )
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings from app.core.models.embedding import RedBearEmbeddings
from app.models.models_model import ModelProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
- 批量文本嵌入 - 批量文本嵌入
- 自动重试机制 - 自动重试机制
- 错误处理 - 错误处理
- 火山引擎多模态 Embedding自动识别
""" """
def __init__(self, model_config: RedBearModelConfig): def __init__(self, model_config: RedBearModelConfig):
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
""" """
super().__init__(model_config) super().__init__(model_config)
# 初始化 RedBearEmbeddings 模型 # 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
self.model = RedBearEmbeddings( self.model = RedBearEmbeddings(
RedBearModelConfig( RedBearModelConfig(
model_name=self.model_name, model_name=self.model_name,
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
timeout=self.timeout, timeout=self.timeout,
) )
) )
self.is_multimodal = self.model.is_multimodal_supported()
logger.info("OpenAI Embedder 客户端初始化完成") logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
async def response( async def response(
self, self,
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
return [] return []
# 生成嵌入向量 # 生成嵌入向量
embeddings = await self.model.aembed_documents(texts) if self.is_multimodal:
# 火山引擎多模态 Embedding
embeddings = await self.model.aembed_multimodal(
[{"type": "text", "text": text} for text in texts]
)
else:
# 普通 Embedding
embeddings = await self.model.aembed_documents(texts)
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
return embeddings return embeddings

View File

@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
from .llm import RedBearLLM from .llm import RedBearLLM
from .embedding import RedBearEmbeddings from .embedding import RedBearEmbeddings
from .rerank import RedBearRerank from .rerank import RedBearRerank
from .generation import RedBearImageGenerator, RedBearVideoGenerator
__all__ = [ __all__ = [
"RedBearModelConfig", "RedBearModelConfig",
@@ -9,5 +10,7 @@ __all__ = [
"RedBearEmbeddings", "RedBearEmbeddings",
"RedBearRerank", "RedBearRerank",
"RedBearModelFactory", "RedBearModelFactory",
"get_provider_llm_class" "get_provider_llm_class",
"RedBearImageGenerator",
"RedBearVideoGenerator"
] ]

View File

@@ -67,7 +67,7 @@ class RedBearModelFactory:
**config.extra_params **config.extra_params
} }
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
# 使用 httpx.Timeout 对象来设置详细的超时配置 # 使用 httpx.Timeout 对象来设置详细的超时配置
# 这样可以分别控制连接超时和读取超时 # 这样可以分别控制连接超时和读取超时
import httpx import httpx
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式 # dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni: if provider == ModelProvider.DASHSCOPE and config.is_omni:
return ChatOpenAI return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
if type == ModelType.LLM: if type == ModelType.LLM:
return OpenAI return OpenAI
elif type == ModelType.CHAT: elif type == ModelType.CHAT:
return ChatOpenAI return ChatOpenAI
else:
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
return ChatTongyi return ChatTongyi
elif provider == ModelProvider.OLLAMA: elif provider == ModelProvider.OLLAMA:

View File

@@ -1,23 +1,190 @@
from typing import Any, Dict, List, Optional, TypeVar, Callable from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
from app.models.models_model import ModelProvider
class RedBearEmbeddings(Embeddings): class RedBearEmbeddings(Embeddings):
"""Embedding → 完全符合 LangChain Embeddings""" """统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
def __init__(self, config: RedBearModelConfig): def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config self._config = config
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
if self._is_volcano:
# 火山引擎使用 Ark SDK
self._client = self._create_volcano_client(config)
self._model = None
else:
# 其他 provider 使用 LangChain
self._model = self._create_model(config)
self._client = None
def _create_model(self, config: RedBearModelConfig) -> Embeddings: def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建模型""" """根据配置创建 LangChain 模型"""
embedding_class = get_provider_embedding_class(config.provider) embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config) model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params) return embedding_class(**model_params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""
from volcenginesdkarkruntime import Ark
return Ark(api_key=config.api_key, base_url=config.base_url)
# ==================== LangChain 标准接口 ====================
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._model.embed_documents(texts) """批量文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
contents = [{"type": "text", "text": text} for text in texts]
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
encoding_format="float"
)
return [response.data.embedding]
else:
# 其他 provider
return self._model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
return self._model.embed_query(text) """单个文本向量化LangChain 标准接口)"""
if self._is_volcano:
# 火山引擎多模态 Embedding
result = self.embed_documents([text])
return result[0] if result else []
else:
# 其他 provider
return self._model.embed_query(text)
# ==================== 多模态扩展方法 ====================
def embed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""
多模态向量化(仅火山引擎支持)
Args:
contents: 内容列表,格式:
- 文本: {"type": "text", "text": "..."}
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
**kwargs: 其他参数
Returns:
向量列表
"""
if not self._is_volcano:
raise NotImplementedError(
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
)
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
**kwargs
)
return [response.data.embedding]
async def aembed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""异步多模态向量化"""
# 火山引擎 SDK 暂不支持异步,使用同步方法
return self.embed_multimodal(contents, **kwargs)
def embed_text(self, text: str, **kwargs) -> List[float]:
"""文本向量化(便捷方法)"""
if self._is_volcano:
result = self.embed_multimodal(
[{"type": "text", "text": text}],
**kwargs
)
return result[0] if result else []
else:
return self.embed_query(text)
def embed_image(self, image_url: str, **kwargs) -> List[float]:
"""图片向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "image_url", "image_url": {"url": image_url}}],
**kwargs
)
return result[0] if result else []
def embed_video(self, video_url: str, **kwargs) -> List[float]:
"""视频向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "video_url", "video_url": {"url": video_url}}],
**kwargs
)
return result[0] if result else []
def embed_batch(
self,
items: List[Union[str, Dict[str, Any]]],
**kwargs
) -> List[List[float]]:
"""
批量向量化(支持混合类型)
Args:
items: 可以是字符串列表或内容字典列表
**kwargs: 其他参数
Returns:
向量列表
"""
# 如果全是字符串,使用标准方法
if all(isinstance(item, str) for item in items):
return self.embed_documents(items)
# 如果包含字典,需要多模态支持
if not self._is_volcano:
raise NotImplementedError(
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
# 标准化输入格式
contents = []
for item in items:
if isinstance(item, str):
contents.append({"type": "text", "text": item})
elif isinstance(item, dict):
contents.append(item)
else:
raise ValueError(f"不支持的输入类型: {type(item)}")
return self.embed_multimodal(contents, **kwargs)
# ==================== 工具方法 ====================
def is_multimodal_supported(self) -> bool:
"""检查是否支持多模态"""
return self._is_volcano
def get_provider(self) -> str:
"""获取 provider"""
return self._config.provider
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
RedBearMultimodalEmbeddings = RedBearEmbeddings

View File

@@ -0,0 +1,344 @@
"""
图片和视频生成模型封装
支持的 Provider:
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
- OpenAI: 使用 openai SDK
"""
from typing import Any, Dict, Optional
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.images.images import (
SequentialImageGenerationOptions,
ContentGenerationTool,
OptimizePromptOptions
)
from app.core.models.base import RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelProvider
class RedBearImageGenerator:
"""图片生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
# elif provider == ModelProvider.OPENAI:
# from openai import OpenAI
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的图片生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
sequential_image_generation: Optional[str] = None,
sequential_image_generation_options: Optional[Dict] = None,
tools: Optional[list] = None,
optimize_prompt_options: Optional[Dict] = None,
stream: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
prompt: 提示词
image: 参考图片URL或URL列表图文生图/多图融合)
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080"至少3686400像素
output_format: 输出格式,如 "png", "jpg"
response_format: 返回格式,"url""b64_json"
watermark: 是否添加水印
sequential_image_generation: 组图生成模式,"auto""disabled"
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
stream: 是否使用流式生成
**kwargs: 其他参数
Returns:
生成结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {
"model": self._config.model_name,
"prompt": prompt,
"size": size,
"output_format": output_format,
"response_format": response_format,
"watermark": watermark,
}
if image is not None:
params["image"] = image
if sequential_image_generation:
params["sequential_image_generation"] = sequential_image_generation
if sequential_image_generation_options:
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
**sequential_image_generation_options
)
if tools:
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
if optimize_prompt_options:
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
if stream:
params["stream"] = True
params.update(kwargs)
response = self._client.images.generate(**params)
# elif provider == ModelProvider.OPENAI:
# response = self._client.images.generate(
# model=self._config.model_name,
# prompt=prompt,
# size=size,
# n=n,
# **kwargs
# )
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
**kwargs
) -> Dict[str, Any]:
"""异步生成图片"""
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
class RedBearVideoGenerator:
"""视频生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的视频生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image_url: Optional[str] = None,
first_frame_url: Optional[str] = None,
last_frame_url: Optional[str] = None,
reference_images: Optional[list] = None,
draft_task_id: Optional[str] = None,
duration: Optional[int] = None,
frames: Optional[int] = None,
ratio: Optional[str] = None,
resolution: Optional[str] = None,
generate_audio: bool = False,
watermark: bool = False,
camera_fixed: bool = False,
seed: Optional[int] = None,
return_last_frame: bool = False,
service_tier: str = "default",
execution_expires_after: Optional[int] = None,
draft: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
prompt: 提示词
image_url: 首帧图片URL图生视频-基于首帧)
first_frame_url: 首帧图片URL图生视频-基于首尾帧)
last_frame_url: 尾帧图片URL图生视频-基于首尾帧)
reference_images: 参考图片URL列表图生视频-基于参考图)
draft_task_id: Draft任务ID基于Draft生成正式视频
duration: 视频时长与frames二选一
frames: 视频帧数与duration二选一
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
resolution: 视频分辨率,如 "720p", "1080p"
generate_audio: 是否生成音频
watermark: 是否添加水印
camera_fixed: 是否固定镜头
seed: 随机种子
return_last_frame: 是否返回最后一帧
service_tier: 服务层级,"default""flex"(离线推理)
execution_expires_after: 任务过期时间(秒)
draft: 是否生成样片
**kwargs: 其他参数
Returns:
生成结果包含任务ID需要轮询获取结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
content = [{"type": "text", "text": prompt}]
if draft_task_id:
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
else:
if image_url:
content.append({"type": "image_url", "image_url": {"url": image_url}})
if first_frame_url:
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
if last_frame_url:
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
if reference_images:
for ref_url in reference_images:
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
if duration:
params["duration"] = duration
if frames:
params["frames"] = frames
if ratio:
params["ratio"] = ratio
if resolution:
params["resolution"] = resolution
if generate_audio:
params["generate_audio"] = generate_audio
if camera_fixed:
params["camera_fixed"] = camera_fixed
if seed is not None:
params["seed"] = seed
if return_last_frame:
params["return_last_frame"] = return_last_frame
if service_tier != "default":
params["service_tier"] = service_tier
if execution_expires_after:
params["execution_expires_after"] = execution_expires_after
if draft:
params["draft"] = draft
params.update(kwargs)
response = self._client.content_generation.tasks.create(**params)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image_url: Optional[str] = None,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""异步生成视频"""
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
response = self._client.content_generation.tasks.get(task_id=task_id)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
"""异步查询任务状态"""
return self.get_task_status(task_id)
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
"""
查询视频生成任务列表
Args:
page_size: 每页数量
status: 任务状态筛选,如 "succeeded", "failed", "pending"
**kwargs: 其他参数
Returns:
任务列表
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {"page_size": page_size}
if status:
params["status"] = status
params.update(kwargs)
response = self._client.content_generation.tasks.list(**params)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def delete_task(self, task_id: str) -> None:
"""
删除或取消视频生成任务
Args:
task_id: 任务ID
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
self._client.content_generation.tasks.delete(task_id=task_id)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)

View File

@@ -0,0 +1,334 @@
provider: volcano
models:
# Doubao-Seed 2.0 系列
- name: doubao-seed-2-0-pro-260215
type: chat
provider: volcano
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-lite-260215
type: chat
provider: volcano
description: 面向高频企业场景兼顾性能与成本的均衡型模型综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-mini-260215
type: chat
provider: volcano
description: 面向低时延、高并发与成本敏感场景提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解适合成本和速度优先的轻量级任务。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-code-preview-260215
type: chat
provider: volcano
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
# Doubao-Seed 1.x 系列
- name: doubao-seed-1-8-251228
type: chat
provider: volcano
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面视觉基础能力显著提升可低帧率理解超长视频视频运动理解、复杂空间理解及文档结构化解析能力也有所优化还原生支持智能上下文管理用户可配置上下文策略。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-251015
type: chat
provider: volcano
description: Doubao-Seed-1.6全新多模态深度思考模型同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-lite-251015
type: chat
provider: volcano
description: 更高性价比常见任务的最佳选择支持minimal、low、medium、high 四种reasoning_effort思考深度
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-flash-250828
type: chat
provider: volcano
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型TPOT低至10ms 同时支持文本和视觉理解文本理解能力超过上一代lite视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-code-preview-251028
type: chat
provider: volcano
description: 面向Agentic编程任务进行了深度优化。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
- name: doubao-seed-1-6-vision-250815
type: chat
provider: volcano
description: 全新Doubao-Seed-1.6系列视觉深度思考模型视觉理解能力显著增强并支持image_process视觉工具
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
# Doubao 1.5 系列
- name: doubao-1-5-vision-pro-32k-250115
type: chat
provider: volcano
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
- name: doubao-1-5-pro-32k-250115
type: chat
provider: volcano
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-1-5-lite-32k-250115
type: chat
provider: volcano
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
# Doubao-Seedance 视频生成系列
- name: doubao-seedance-1-5-pro-251215
type: video
provider: volcano
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-250528
type: video
provider: volcano
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-fast-251015
type: video
provider: volcano
description: 一款价格触底、效能封顶的全面模型在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-lite-i2v-250428
type: video
provider: volcano
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
- 图生视频
logo: volcano
- name: doubao-seedance-1-0-lite-t2v-250428
type: video
provider: volcano
description: 基于文本提示词生成视频
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 视频生成
- 文生视频
logo: volcano
# Doubao-Seedream 图像生成系列
- name: doubao-seedream-5-0-260128
type: image
provider: volcano
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-5-251128
type: image
provider: volcano
description: 字节跳动最新推出的图像多模态模型整合了文生图、图生图、组图输出等能力融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-0-250828
type: image
provider: volcano
description: 基于领先架构的SOTA级多模态图像创作模型其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-3-0-t2i-250415
type: image
provider: volcano
description: 一款支持原生高分辨率的中英双语图像生成基础模型综合能力媲美GPT-4o处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 图像生成
- 文生图
logo: volcano
# Doubao 翻译系列
- name: doubao-seed-translation-250915
type: chat
provider: volcano
description: 通用多语言翻译模型支持30余种语言互译支持 4K 上下文窗口,输出长度支持最大 3K tokens
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 翻译模型
logo: volcano
# Doubao Embedding 系列
- name: doubao-embedding-vision-251215
type: embedding
provider: volcano
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 向量模型
- 多模态模型
logo: volcano

View File

@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
class ElasticSearchVector(BaseVector): class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
super().__init__(index_name.lower()) super().__init__(index_name.lower())
# self.embeddings = XinferenceEmbeddings(
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port # 初始化 Embedding 模型(自动支持火山引擎多模态)
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
# )
# Remove debug printing to avoid leaking sensitive information
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
self.embeddings = RedBearEmbeddings(RedBearModelConfig( self.embeddings = RedBearEmbeddings(RedBearModelConfig(
model_name=embedding_config.model_name, model_name=embedding_config.model_name,
provider=embedding_config.provider, provider=embedding_config.provider,
api_key=embedding_config.api_key, api_key=embedding_config.api_key,
base_url=embedding_config.api_base base_url=embedding_config.api_base
)) ))
# self.reranker = XinferenceRerank( self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
# model_uid="bge-reranker-large"
# )
# Remove debug printing to avoid leaking sensitive information
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
self.reranker = RedBearRerank(RedBearModelConfig( self.reranker = RedBearRerank(RedBearModelConfig(
model_name=reranker_config.model_name, model_name=reranker_config.model_name,
provider=reranker_config.provider, provider=reranker_config.provider,
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
def add_chunks(self, chunks: list[DocumentChunk], **kwargs): def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# 实现 Elasticsearch 保存向量 # 实现 Elasticsearch 保存向量
texts = [chunk.page_content for chunk in chunks] texts = [chunk.page_content for chunk in chunks]
embeddings = self.embeddings.embed_documents(list(texts)) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
embeddings = self.embeddings.embed_batch(texts)
else:
embeddings = self.embeddings.embed_documents(list(texts))
self.create(chunks, embeddings, **kwargs) self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
updated count. updated count.
""" """
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
chunk.vector = self.embeddings.embed_query(chunk.page_content) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
chunk.vector = self.embeddings.embed_text(chunk.page_content)
else:
chunk.vector = self.embeddings.embed_query(chunk.page_content)
body = { body = {
"script": { "script": {
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
"""Search the nearest neighbors to a vector.""" """Search the nearest neighbors to a vector."""
query_vector = self.embeddings.embed_query(query) if self.is_multimodal_embedding:
# 火山引擎多模态 Embedding
query_vector = self.embeddings.embed_text(query)
else:
query_vector = self.embeddings.embed_query(query)
top_k = kwargs.get("top_k", 1024) top_k = kwargs.get("top_k", 1024)
score_threshold = float(kwargs.get("score_threshold") or 0.3) score_threshold = float(kwargs.get("score_threshold") or 0.3)
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"

View File

@@ -109,17 +109,13 @@ class StorageBackend(ABC):
pass pass
@abstractmethod @abstractmethod
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
""" self,
Get an access URL for the file. file_key: str,
expires: int = 3600,
Args: file_name: Optional[str] = None
file_key: Unique identifier for the file in the storage system. ) -> str:
expires: URL validity period in seconds (default: 1 hour). """Get an access URL for the file."""
Returns:
URL for accessing the file.
"""
pass pass
async def get_permanent_url(self, file_key: str) -> Optional[str]: async def get_permanent_url(self, file_key: str) -> Optional[str]:

View File

@@ -210,7 +210,12 @@ class LocalStorage(StorageBackend):
cause=e, cause=e,
) )
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None
) -> str:
""" """
Get an access URL for the file. Get an access URL for the file.
@@ -220,6 +225,7 @@ class LocalStorage(StorageBackend):
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (not used for local storage). expires: URL validity period in seconds (not used for local storage).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A relative URL path for accessing the file. A relative URL path for accessing the file.

View File

@@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK.
import io import io
import logging import logging
import urllib.parse
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
import oss2 import oss2
@@ -242,24 +243,33 @@ class OSSStorage(StorageBackend):
logger.error(f"Failed to check file existence in OSS {file_key}: {e}") logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
return False return False
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
url = self.bucket.sign_url("GET", file_key, expires) params = {}
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -6,6 +6,7 @@ using the boto3 SDK.
""" """
import io import io
import urllib.parse
import logging import logging
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
@@ -352,31 +353,37 @@ class S3Storage(StorageBackend):
logger.error(f"Failed to check file existence in S3 {file_key}: {e}") logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
return False return False
async def get_url(self, file_key: str, expires: int = 3600) -> str: async def get_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
params = {"Bucket": self.bucket_name, "Key": file_key}
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.client.generate_presigned_url( url = self.client.generate_presigned_url(
"get_object", "get_object",
Params={ Params=params,
"Bucket": self.bucket_name,
"Key": file_key,
},
ExpiresIn=expires, ExpiresIn=expires,
) )
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -27,9 +27,9 @@ class ModelType(StrEnum):
RERANK = "rerank" RERANK = "rerank"
# TTS = "tts" # TTS = "tts"
# SPEECH2TEXT = "speech2text" # SPEECH2TEXT = "speech2text"
# IMAGE = "image" IMAGE = "image"
# AUDIO = "audio" # AUDIO = "audio"
# VISION = "vision" VIDEO = "video"
class ModelProvider(StrEnum): class ModelProvider(StrEnum):
@@ -46,6 +46,7 @@ class ModelProvider(StrEnum):
XINFERENCE = "xinference" XINFERENCE = "xinference"
GPUSTACK = "gpustack" GPUSTACK = "gpustack"
BEDROCK = "bedrock" BEDROCK = "bedrock"
VOLCANO = "volcano"
COMPOSITE = "composite" COMPOSITE = "composite"

View File

@@ -439,7 +439,6 @@ class ModelConfigRepository:
ModelConfig.is_public ModelConfig.is_public
), ),
ModelConfig.provider == provider, ModelConfig.provider == provider,
ModelConfig.is_active,
~ModelConfig.is_composite ~ModelConfig.is_composite
) )
).all() ).all()

View File

@@ -325,27 +325,30 @@ class FileStorageService:
) )
raise raise
async def get_file_url(self, file_key: str, expires: int = 3600) -> str: async def get_file_url(
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get an access URL for a file. Get an access URL for a file.
Args: Args:
file_key: The file key. file_key: The file key.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
URL for accessing the file. URL for accessing the file.
""" """
logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s") logger.debug(f"Getting file URL: file_key={file_key}, expires={expires}s")
try: try:
url = await self.storage.get_url(file_key, expires) url = await self.storage.get_url(file_key, expires, file_name=file_name)
logger.debug(f"File URL generated: file_key={file_key}") logger.debug(f"File URL generated: file_key={file_key}")
return url return url
except Exception as e: except Exception as e:
logger.error( logger.error(f"Error getting file URL: file_key={file_key}, error={str(e)}")
f"Error getting file URL: file_key={file_key}, error={str(e)}"
)
raise raise

View File

@@ -0,0 +1,162 @@
"""
图片和视频生成服务
提供统一的生成接口,支持多种 Provider
"""
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
import uuid
from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelType
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.services.model_service import ModelApiKeyService
class GenerationService:
"""生成服务"""
def __init__(self, db: Session):
self.db = db
async def generate_image(
self,
model_config_id: str,
prompt: str,
size: Optional[str] = "2k",
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
model_config_id: 模型配置ID
prompt: 提示词
size: 图片尺寸
**kwargs: 其他参数
Returns:
生成结果
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.IMAGE:
raise BusinessException(
f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成图片
generator = RedBearImageGenerator(config)
result = await generator.agenerate(prompt, size, **kwargs)
return result
async def generate_video(
self,
model_config_id: str,
prompt: str,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
model_config_id: 模型配置ID
prompt: 提示词
duration: 视频时长(秒)
**kwargs: 其他参数
Returns:
生成结果包含任务ID
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.VIDEO:
raise BusinessException(
f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成视频
generator = RedBearVideoGenerator(config)
result = await generator.agenerate(prompt, duration, **kwargs)
return result
async def get_video_task_status(
self,
model_config_id: str,
task_id: str
) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
model_config_id: 模型配置ID
task_id: 任务ID
Returns:
任务状态信息
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 查询任务状态
generator = RedBearVideoGenerator(config)
result = await generator.aget_task_status(task_id)
return result

View File

@@ -154,10 +154,17 @@ class ModelConfigService:
} }
elif model_type_lower == "embedding": elif model_type_lower == "embedding":
# Embedding 模型验证(在线程中运行同步方法) # Embedding 模型验证
# 统一使用 RedBearEmbeddings自动支持火山引擎多模态
embedding = RedBearEmbeddings(model_config) embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"] test_texts = [test_message, "测试文本"]
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
# 火山引擎使用 embed_batch其他使用 embed_documents
if provider.lower() == "volcano":
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
else:
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
return { return {
@@ -193,6 +200,56 @@ class ModelConfigService:
}, },
"error": None "error": None
} }
elif model_type_lower == "image":
# 图片生成模型验证
from app.core.models.generation import RedBearImageGenerator
generator = RedBearImageGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda",
size="2K"
)
elapsed_time = time.time() - start_time
logger.info(f"成功生成图片,结果: {result}")
return {
"valid": True,
"message": "图片生成模型配置验证成功",
"response": f"成功生成图片,结果: {result}",
"elapsed_time": elapsed_time,
"usage": {
"prompt_length": len("a cute panda"),
"image_count": 1
},
"error": None
}
elif model_type_lower == "video":
# 视频生成模型验证
from app.core.models.generation import RedBearVideoGenerator
generator = RedBearVideoGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda playing in bamboo forest",
duration=5
)
elapsed_time = time.time() - start_time
# 视频生成是异步任务返回任务ID
task_id = result.get("task_id") if isinstance(result, dict) else None
return {
"valid": True,
"message": "视频生成模型配置验证成功",
"response": f"成功创建视频生成任务任务ID: {task_id}",
"elapsed_time": elapsed_time,
"usage": {
"prompt_length": len("a cute panda playing in bamboo forest"),
"task_id": task_id
},
"error": None
}
else: else:
return { return {

View File

@@ -294,6 +294,7 @@ PROVIDER_STRATEGIES = {
"bedrock": BedrockFormatStrategy, "bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy, "anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy, "openai": OpenAIFormatStrategy,
"volcano": OpenAIFormatStrategy,
} }

View File

@@ -147,6 +147,7 @@ dependencies = [
"modelscope>=1.34.0", "modelscope>=1.34.0",
"python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'", "python-magic>=0.4.14; sys_platform == 'linux' or sys_platform == 'darwin'",
"python-magic-bin>=0.4.14; sys_platform=='win32'", "python-magic-bin>=0.4.14; sys_platform=='win32'",
"volcengine-python-sdk[ark]==5.0.19"
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]