Merge branch 'refs/heads/feature/agent-tool_xjn' into feature/20260105_xjn

This commit is contained in:
Timebomb2018
2026-03-25 11:48:42 +08:00
13 changed files with 1122 additions and 33 deletions

View File

@@ -0,0 +1,164 @@
"""
图片和视频生成服务
提供统一的生成接口,支持多种 Provider
"""
from typing import Dict, Any, Optional
from sqlalchemy.orm import Session
import uuid
from app.core.models import RedBearModelConfig, RedBearImageGenerator, RedBearVideoGenerator
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelType
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.services.model_service import ModelApiKeyService
class GenerationService:
"""生成服务"""
def __init__(self, db: Session):
self.db = db
async def generate_image(
self,
model_config_id: str,
prompt: str,
size: Optional[str] = "1024x1024",
n: int = 1,
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
model_config_id: 模型配置ID
prompt: 提示词
size: 图片尺寸
n: 生成数量
**kwargs: 其他参数
Returns:
生成结果
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.IMAGE:
raise BusinessException(
f"模型类型错误,期望 {ModelType.IMAGE},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成图片
generator = RedBearImageGenerator(config)
result = await generator.agenerate(prompt, size, n, **kwargs)
return result
async def generate_video(
self,
model_config_id: str,
prompt: str,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
model_config_id: 模型配置ID
prompt: 提示词
duration: 视频时长(秒)
**kwargs: 其他参数
Returns:
生成结果包含任务ID
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
if model_config.type != ModelType.VIDEO:
raise BusinessException(
f"模型类型错误,期望 {ModelType.VIDEO},实际 {model_config.type}",
code=BizCode.INVALID_PARAMETER
)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 生成视频
generator = RedBearVideoGenerator(config)
result = await generator.agenerate(prompt, duration, **kwargs)
return result
async def get_video_task_status(
self,
model_config_id: str,
task_id: str
) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
model_config_id: 模型配置ID
task_id: 任务ID
Returns:
任务状态信息
"""
# 获取模型配置
model_config = ModelConfigRepository.get_by_id(self.db, uuid.UUID(model_config_id))
if not model_config:
raise BusinessException("模型配置不存在", code=BizCode.NOT_FOUND)
# 获取 API Key
api_key_info = ModelApiKeyService.get_available_api_key(self.db, uuid.UUID(model_config_id))
if not api_key_info:
raise BusinessException("没有可用的 API Key", code=BizCode.NOT_FOUND)
# 创建配置
config = RedBearModelConfig(
model_name=api_key_info.model_name,
provider=api_key_info.provider,
api_key=api_key_info.api_key,
base_url=api_key_info.api_base,
extra_params=api_key_info.config or {}
)
# 查询任务状态
generator = RedBearVideoGenerator(config)
result = await generator.aget_task_status(task_id)
return result

View File

@@ -154,10 +154,17 @@ class ModelConfigService:
}
elif model_type_lower == "embedding":
# Embedding 模型验证(在线程中运行同步方法)
# Embedding 模型验证
# 统一使用 RedBearEmbeddings自动支持火山引擎多模态
embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"]
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
# 火山引擎使用 embed_batch其他使用 embed_documents
if provider.lower() == "volcano":
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
else:
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time
return {
@@ -193,6 +200,56 @@ class ModelConfigService:
},
"error": None
}
elif model_type_lower == "image":
# 图片生成模型验证
from app.core.models.generation import RedBearImageGenerator
generator = RedBearImageGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda",
size="2K"
)
elapsed_time = time.time() - start_time
logger.info(f"成功生成图片,结果: {result}")
return {
"valid": True,
"message": "图片生成模型配置验证成功",
"response": f"成功生成图片,结果: {result}",
"elapsed_time": elapsed_time,
"usage": {
"prompt_length": len("a cute panda"),
"image_count": 1
},
"error": None
}
elif model_type_lower == "video":
# 视频生成模型验证
from app.core.models.generation import RedBearVideoGenerator
generator = RedBearVideoGenerator(model_config)
result = await generator.agenerate(
prompt="a cute panda playing in bamboo forest",
duration=5
)
elapsed_time = time.time() - start_time
# 视频生成是异步任务返回任务ID
task_id = result.get("task_id") if isinstance(result, dict) else None
return {
"valid": True,
"message": "视频生成模型配置验证成功",
"response": f"成功创建视频生成任务任务ID: {task_id}",
"elapsed_time": elapsed_time,
"usage": {
"prompt_length": len("a cute panda playing in bamboo forest"),
"task_id": task_id
},
"error": None
}
else:
return {

View File

@@ -294,6 +294,7 @@ PROVIDER_STRATEGIES = {
"bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy,
"volcano": OpenAIFormatStrategy,
}