feat(agent, memory): add agent-perceived memory writing
This commit is contained in:
@@ -69,7 +69,8 @@ class ModelConfigService:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||
if not model:
|
||||
@@ -77,21 +78,22 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
|
||||
ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表"""
|
||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||
|
||||
@staticmethod
|
||||
async def validate_model_config(
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -158,13 +160,13 @@ class ModelConfigService:
|
||||
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
|
||||
|
||||
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||
if provider.lower() == "volcano":
|
||||
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||
else:
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -200,11 +202,11 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "image":
|
||||
# 图片生成模型验证
|
||||
from app.core.models.generation import RedBearImageGenerator
|
||||
|
||||
|
||||
generator = RedBearImageGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda",
|
||||
@@ -212,7 +214,7 @@ class ModelConfigService:
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"成功生成图片,结果: {result}")
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "图片生成模型配置验证成功",
|
||||
@@ -224,21 +226,21 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "video":
|
||||
# 视频生成模型验证
|
||||
from app.core.models.generation import RedBearVideoGenerator
|
||||
|
||||
|
||||
generator = RedBearVideoGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda playing in bamboo forest",
|
||||
duration=5
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 视频生成是异步任务,返回任务ID
|
||||
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "视频生成模型配置验证成功",
|
||||
@@ -265,7 +267,6 @@ class ModelConfigService:
|
||||
# 提取详细的错误信息
|
||||
error_message = str(e)
|
||||
error_type = type(e).__name__
|
||||
print("=========error_message:",error_message.lower())
|
||||
# 特殊处理常见的错误类型
|
||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||
# 区域/国家限制(适用于所有提供商)
|
||||
@@ -354,14 +355,16 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""更新模型配置"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
@@ -370,25 +373,27 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
# 检查 API Key 关联的模型配置类型
|
||||
for model_config in api_key.model_configs:
|
||||
# chat 和 llm 类型可以兼容
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
@@ -399,7 +404,7 @@ class ModelConfigService:
|
||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||
# BizCode.INVALID_PARAMETER
|
||||
# )
|
||||
|
||||
|
||||
# 创建组合模型
|
||||
model_config_data = {
|
||||
"tenant_id": tenant_id,
|
||||
@@ -418,49 +423,51 @@ class ModelConfigService:
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush()
|
||||
|
||||
|
||||
# 关联 API Keys
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""更新组合模型"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
for model_config in api_key.model_configs:
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = existing_model.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
# 更新基本信息
|
||||
existing_model.name = model_data.name
|
||||
# existing_model.type = model_data.type
|
||||
@@ -471,14 +478,14 @@ class ModelConfigService:
|
||||
existing_model.is_public = model_data.is_public
|
||||
if "load_balance_strategy" in model_data.model_fields_set:
|
||||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||||
|
||||
|
||||
# 更新 API Keys 关联
|
||||
existing_model.api_keys.clear()
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
existing_model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_model)
|
||||
return existing_model
|
||||
@@ -532,7 +539,7 @@ class ModelApiKeyService:
|
||||
"""根据provider为多个ModelConfig创建API Key"""
|
||||
created_keys = []
|
||||
failed_models = [] # 记录验证失败的模型
|
||||
|
||||
|
||||
for model_config_id in data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
@@ -540,10 +547,10 @@ class ModelApiKeyService:
|
||||
|
||||
data.is_omni = model_config.is_omni
|
||||
data.capability = model_config.capability
|
||||
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
|
||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -553,7 +560,7 @@ class ModelApiKeyService:
|
||||
ModelApiKey.model_name == model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
|
||||
if existing_key:
|
||||
# 如果已存在,重新激活并更新
|
||||
if existing_key.is_active:
|
||||
@@ -566,14 +573,14 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
created_keys.append(existing_key)
|
||||
continue
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -589,7 +596,7 @@ class ModelApiKeyService:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
failed_models.append(model_name)
|
||||
continue
|
||||
|
||||
|
||||
# 创建API Key
|
||||
api_key_data = ModelApiKeyCreate(
|
||||
model_config_ids=[model_config_id],
|
||||
@@ -606,12 +613,12 @@ class ModelApiKeyService:
|
||||
)
|
||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||
created_keys.append(api_key_obj)
|
||||
|
||||
|
||||
if created_keys:
|
||||
db.commit()
|
||||
for key in created_keys:
|
||||
db.refresh(key)
|
||||
|
||||
|
||||
return created_keys, failed_models
|
||||
|
||||
@staticmethod
|
||||
@@ -626,7 +633,7 @@ class ModelApiKeyService:
|
||||
api_key_data.is_omni = model_config.is_omni
|
||||
if api_key_data.capability is None:
|
||||
api_key_data.capability = model_config.capability
|
||||
|
||||
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -650,15 +657,15 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_key)
|
||||
return existing_key
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -691,7 +698,7 @@ class ModelApiKeyService:
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
if existing_api_key.model_configs:
|
||||
model_config = existing_api_key.model_configs[0]
|
||||
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||
@@ -729,15 +736,15 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
return None
|
||||
|
||||
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
|
||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
|
||||
|
||||
# 否则返回第一个
|
||||
return api_keys[0]
|
||||
|
||||
@@ -760,20 +767,19 @@ class ModelApiKeyService:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
|
||||
class ModelBaseService:
|
||||
"""基础模型服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||
models = ModelBaseRepository.get_list(db, query)
|
||||
|
||||
|
||||
provider_groups = {}
|
||||
for m in models:
|
||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||
if tenant_id:
|
||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||
|
||||
|
||||
provider = m.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = {
|
||||
@@ -781,7 +787,7 @@ class ModelBaseService:
|
||||
"models": []
|
||||
}
|
||||
provider_groups[provider]["models"].append(model_dict)
|
||||
|
||||
|
||||
return list(provider_groups.values())
|
||||
|
||||
@staticmethod
|
||||
@@ -823,10 +829,10 @@ class ModelBaseService:
|
||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
model_config_data = {
|
||||
"model_id": model_base_id,
|
||||
"tenant_id": tenant_id,
|
||||
|
||||
Reference in New Issue
Block a user