[modify] multi agent model parameter
This commit is contained in:
@@ -3,11 +3,14 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text, ForeignKey
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text, ForeignKey, TypeDecorator
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
from app.schemas import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
class OrchestrationMode(StrEnum):
|
class OrchestrationMode(StrEnum):
|
||||||
"""图标类型枚举"""
|
"""图标类型枚举"""
|
||||||
@@ -21,6 +24,28 @@ class AggregationStrategy(StrEnum):
|
|||||||
VOTE = "vote"
|
VOTE = "vote"
|
||||||
PRIORITY = "priority"
|
PRIORITY = "priority"
|
||||||
|
|
||||||
|
class PydanticType(TypeDecorator):
|
||||||
|
impl = JSON
|
||||||
|
|
||||||
|
def __init__(self, pydantic_model: type[BaseModel]):
|
||||||
|
super().__init__()
|
||||||
|
self.model = pydantic_model
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
# 入库:Model -> dict
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, self.model):
|
||||||
|
return value.dict()
|
||||||
|
return value # 已经是 dict 也放行
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
# 出库:dict -> Model
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
# return self.model.parse_obj(value) # pydantic v1
|
||||||
|
return self.model.model_validate(value) # pydantic v2
|
||||||
|
|
||||||
class MultiAgentConfig(Base):
|
class MultiAgentConfig(Base):
|
||||||
"""多 Agent 配置表"""
|
"""多 Agent 配置表"""
|
||||||
__tablename__ = "multi_agent_configs"
|
__tablename__ = "multi_agent_configs"
|
||||||
@@ -36,7 +61,7 @@ class MultiAgentConfig(Base):
|
|||||||
|
|
||||||
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id", name="multi_agent_configs_default_model_config_id_fkey"), nullable=True, index=True, comment="默认模型配置ID")
|
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id", name="multi_agent_configs_default_model_config_id_fkey"), nullable=True, index=True, comment="默认模型配置ID")
|
||||||
# 结构化配置(直接存储 JSON)
|
# 结构化配置(直接存储 JSON)
|
||||||
model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
model_parameters = Column(PydanticType(ModelParameters), nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
||||||
# 协作模式
|
# 协作模式
|
||||||
orchestration_mode = Column(
|
orchestration_mode = Column(
|
||||||
String(20),
|
String(20),
|
||||||
|
|||||||
@@ -366,10 +366,26 @@ class MasterAgentRouter:
|
|||||||
# if self.model_parameters:
|
# if self.model_parameters:
|
||||||
# temperature = self.model_parameters["temperature"]
|
# temperature = self.model_parameters["temperature"]
|
||||||
# max_tokens = self.model_parameters["max_tokens"]
|
# max_tokens = self.model_parameters["max_tokens"]
|
||||||
|
if self.model_parameters:
|
||||||
|
if hasattr(self.model_parameters, 'temperature'):
|
||||||
|
# Pydantic 模型
|
||||||
|
temperature = self.model_parameters.temperature
|
||||||
|
max_tokens = getattr(self.model_parameters, 'max_tokens', 1000)
|
||||||
|
elif isinstance(self.model_parameters, dict):
|
||||||
|
# 字典
|
||||||
|
temperature = self.model_parameters.get("temperature", 0.3)
|
||||||
|
max_tokens = self.model_parameters.get("max_tokens", 1000)
|
||||||
|
else:
|
||||||
|
temperature = 0.3
|
||||||
|
max_tokens = 1000
|
||||||
|
else:
|
||||||
|
temperature = 0.3
|
||||||
|
max_tokens = 1000
|
||||||
|
# extra_params = {"temperature": self.model_parameters.get("temperature", 0.3),
|
||||||
|
# "max_tokens":self.model_parameters.get("max_tokens", 1000)
|
||||||
|
# }
|
||||||
|
extra_params = {"temperature": temperature, "max_tokens": max_tokens}
|
||||||
|
|
||||||
extra_params = {"temperature": self.model_parameters.get("temperature", 0.3),
|
|
||||||
"max_tokens":self.model_parameters.get("max_tokens", 1000)
|
|
||||||
}
|
|
||||||
# 创建 RedBearModelConfig
|
# 创建 RedBearModelConfig
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
model_name=api_key_config.model_name,
|
model_name=api_key_config.model_name,
|
||||||
|
|||||||
@@ -293,12 +293,13 @@ class MultiAgentService:
|
|||||||
# 处理 model_parameters(可能是 None、字典或 Pydantic 模型)
|
# 处理 model_parameters(可能是 None、字典或 Pydantic 模型)
|
||||||
if data.model_parameters is None:
|
if data.model_parameters is None:
|
||||||
model_parameters_data = None
|
model_parameters_data = None
|
||||||
elif isinstance(data.model_parameters, dict):
|
# elif isinstance(data.model_parameters, dict):
|
||||||
# 过滤掉值为 None 的字段
|
# # 过滤掉值为 None 的字段
|
||||||
model_parameters_data = {k: v for k, v in data.model_parameters.items() if v is not None}
|
# model_parameters_data = {k: v for k, v in data.model_parameters.items() if v is not None}
|
||||||
else:
|
else:
|
||||||
# 过滤掉值为 None 的字段
|
# 过滤掉值为 None 的字段
|
||||||
model_parameters_data = {k: v for k, v in data.model_parameters.model_dump().items() if v is not None}
|
# model_parameters_data = {k: v for k, v in data.model_parameters.model_dump().items() if v is not None}
|
||||||
|
model_parameters_data = data.model_parameters
|
||||||
|
|
||||||
config = MultiAgentConfig(
|
config = MultiAgentConfig(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user