[modify] multi agent model parameter
This commit is contained in:
@@ -3,11 +3,14 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text, ForeignKey
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text, ForeignKey, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import Base
|
||||
from app.schemas import ModelParameters
|
||||
|
||||
|
||||
class OrchestrationMode(StrEnum):
|
||||
"""图标类型枚举"""
|
||||
@@ -21,6 +24,28 @@ class AggregationStrategy(StrEnum):
|
||||
VOTE = "vote"
|
||||
PRIORITY = "priority"
|
||||
|
||||
class PydanticType(TypeDecorator):
|
||||
impl = JSON
|
||||
|
||||
def __init__(self, pydantic_model: type[BaseModel]):
|
||||
super().__init__()
|
||||
self.model = pydantic_model
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
# 入库:Model -> dict
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.model):
|
||||
return value.dict()
|
||||
return value # 已经是 dict 也放行
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
# 出库:dict -> Model
|
||||
if value is None:
|
||||
return None
|
||||
# return self.model.parse_obj(value) # pydantic v1
|
||||
return self.model.model_validate(value) # pydantic v2
|
||||
|
||||
class MultiAgentConfig(Base):
|
||||
"""多 Agent 配置表"""
|
||||
__tablename__ = "multi_agent_configs"
|
||||
@@ -36,7 +61,7 @@ class MultiAgentConfig(Base):
|
||||
|
||||
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id", name="multi_agent_configs_default_model_config_id_fkey"), nullable=True, index=True, comment="默认模型配置ID")
|
||||
# 结构化配置(直接存储 JSON)
|
||||
model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
||||
model_parameters = Column(PydanticType(ModelParameters), nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
||||
# 协作模式
|
||||
orchestration_mode = Column(
|
||||
String(20),
|
||||
|
||||
@@ -366,10 +366,26 @@ class MasterAgentRouter:
|
||||
# if self.model_parameters:
|
||||
# temperature = self.model_parameters["temperature"]
|
||||
# max_tokens = self.model_parameters["max_tokens"]
|
||||
|
||||
extra_params = {"temperature": self.model_parameters.get("temperature", 0.3),
|
||||
"max_tokens":self.model_parameters.get("max_tokens", 1000)
|
||||
}
|
||||
if self.model_parameters:
|
||||
if hasattr(self.model_parameters, 'temperature'):
|
||||
# Pydantic 模型
|
||||
temperature = self.model_parameters.temperature
|
||||
max_tokens = getattr(self.model_parameters, 'max_tokens', 1000)
|
||||
elif isinstance(self.model_parameters, dict):
|
||||
# 字典
|
||||
temperature = self.model_parameters.get("temperature", 0.3)
|
||||
max_tokens = self.model_parameters.get("max_tokens", 1000)
|
||||
else:
|
||||
temperature = 0.3
|
||||
max_tokens = 1000
|
||||
else:
|
||||
temperature = 0.3
|
||||
max_tokens = 1000
|
||||
# extra_params = {"temperature": self.model_parameters.get("temperature", 0.3),
|
||||
# "max_tokens":self.model_parameters.get("max_tokens", 1000)
|
||||
# }
|
||||
extra_params = {"temperature": temperature, "max_tokens": max_tokens}
|
||||
|
||||
# 创建 RedBearModelConfig
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=api_key_config.model_name,
|
||||
|
||||
@@ -293,12 +293,13 @@ class MultiAgentService:
|
||||
# 处理 model_parameters(可能是 None、字典或 Pydantic 模型)
|
||||
if data.model_parameters is None:
|
||||
model_parameters_data = None
|
||||
elif isinstance(data.model_parameters, dict):
|
||||
# 过滤掉值为 None 的字段
|
||||
model_parameters_data = {k: v for k, v in data.model_parameters.items() if v is not None}
|
||||
# elif isinstance(data.model_parameters, dict):
|
||||
# # 过滤掉值为 None 的字段
|
||||
# model_parameters_data = {k: v for k, v in data.model_parameters.items() if v is not None}
|
||||
else:
|
||||
# 过滤掉值为 None 的字段
|
||||
model_parameters_data = {k: v for k, v in data.model_parameters.model_dump().items() if v is not None}
|
||||
# model_parameters_data = {k: v for k, v in data.model_parameters.model_dump().items() if v is not None}
|
||||
model_parameters_data = data.model_parameters
|
||||
|
||||
config = MultiAgentConfig(
|
||||
app_id=app_id,
|
||||
|
||||
Reference in New Issue
Block a user