feat(agent): add input variable validation
This commit is contained in:
@@ -21,6 +21,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedBearModelConfig(BaseModel):
|
||||
"""模型配置基类"""
|
||||
model_name: str
|
||||
@@ -32,17 +33,18 @@ class RedBearModelConfig(BaseModel):
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
||||
concurrency: int = 5 # 并发限流
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
@@ -87,7 +89,7 @@ class RedBearModelFactory:
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
@@ -104,7 +106,7 @@ class RedBearModelFactory:
|
||||
# region 从 base_url 或 extra_params 获取
|
||||
from botocore.config import Config as BotoConfig
|
||||
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||
|
||||
|
||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||
# Configure with increased connection pool
|
||||
@@ -112,16 +114,16 @@ class RedBearModelFactory:
|
||||
max_pool_connections=max_pool_connections,
|
||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||
)
|
||||
|
||||
|
||||
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||
model_id = normalize_bedrock_model_id(config.model_name)
|
||||
|
||||
|
||||
params = {
|
||||
"model_id": model_id,
|
||||
"config": boto_config,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
|
||||
# 解析 API key (格式: access_key_id:secret_access_key)
|
||||
if config.api_key and ":" in config.api_key:
|
||||
access_key_id, secret_access_key = config.api_key.split(":", 1)
|
||||
@@ -129,51 +131,52 @@ class RedBearModelFactory:
|
||||
params["aws_secret_access_key"] = secret_access_key
|
||||
elif config.api_key:
|
||||
params["aws_access_key_id"] = config.api_key
|
||||
|
||||
|
||||
# 设置 region
|
||||
if config.base_url:
|
||||
params["region_name"] = config.base_url
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
|
||||
|
||||
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
@@ -183,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
return OllamaEmbeddings
|
||||
@@ -201,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_rerank_class(provider: str):
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -16,7 +16,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -65,7 +65,7 @@ class AgentNode(BaseNode):
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
return draft_service, release, message
|
||||
|
||||
|
||||
Reference in New Issue
Block a user