feat(workspace): add workspace models configuration update endpoint
- Add PUT endpoint to update workspace LLM, embedding, and rerank model configurations - Create WorkspaceModelsUpdate schema for model configuration update requests - Create WorkspaceModelsConfig schema for model configuration responses with proper validation - Implement update_workspace_models_configs service method to persist model configuration changes - Update workspace_models_configs GET endpoint to return validated WorkspaceModelsConfig response - Reorganize imports across controller, schema, and service files for consistency and readability - Add proper logging for model configuration updates with user and workspace context
This commit is contained in:
@@ -1,25 +1,38 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import List, Optional
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
from app.dependencies import (
|
||||||
from app.models.user_model import User
|
cur_workspace_access_guard,
|
||||||
|
get_current_superuser,
|
||||||
|
get_current_tenant,
|
||||||
|
get_current_user,
|
||||||
|
workspace_access_guard,
|
||||||
|
)
|
||||||
from app.models.tenant_model import Tenants
|
from app.models.tenant_model import Tenants
|
||||||
from app.models.workspace_model import Workspace, InviteStatus
|
from app.models.user_model import User
|
||||||
|
from app.models.workspace_model import InviteStatus, Workspace
|
||||||
|
from app.schemas import knowledge_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.schemas.workspace_schema import (
|
from app.schemas.workspace_schema import (
|
||||||
WorkspaceCreate, WorkspaceUpdate, WorkspaceResponse,
|
InviteAcceptRequest,
|
||||||
WorkspaceInviteCreate, WorkspaceInviteResponse,
|
InviteValidateResponse,
|
||||||
InviteValidateResponse, InviteAcceptRequest,
|
WorkspaceCreate,
|
||||||
WorkspaceMemberUpdate, WorkspaceMemberItem
|
WorkspaceInviteCreate,
|
||||||
|
WorkspaceInviteResponse,
|
||||||
|
WorkspaceMemberItem,
|
||||||
|
WorkspaceMemberUpdate,
|
||||||
|
WorkspaceModelsConfig,
|
||||||
|
WorkspaceModelsUpdate,
|
||||||
|
WorkspaceResponse,
|
||||||
|
WorkspaceUpdate,
|
||||||
)
|
)
|
||||||
from app.schemas import knowledge_schema
|
from app.services import document_service, knowledge_service, workspace_service
|
||||||
from app.services import workspace_service
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from app.core.logging_config import get_api_logger
|
from sqlalchemy.orm import Session
|
||||||
from app.services import knowledge_service, document_service
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
# 需要认证的路由器
|
# 需要认证的路由器
|
||||||
@@ -338,5 +351,30 @@ def workspace_models_configs(
|
|||||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||||
)
|
)
|
||||||
return success(data=configs, msg="模型配置获取成功")
|
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/workspace_models", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_workspace_models_configs(
|
||||||
|
models_update: WorkspaceModelsUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的模型配置")
|
||||||
|
|
||||||
|
updated_workspace = workspace_service.update_workspace_models_configs(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
models_update=models_update,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||||
|
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||||
|
)
|
||||||
|
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
import email
|
|
||||||
from pydantic import BaseModel, Field, EmailStr, field_serializer, computed_field, ConfigDict
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import email
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
from app.models.workspace_model import WorkspaceRole, InviteStatus
|
|
||||||
|
from app.models.workspace_model import InviteStatus, WorkspaceRole
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
EmailStr,
|
||||||
|
Field,
|
||||||
|
computed_field,
|
||||||
|
field_serializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceBase(BaseModel):
|
class WorkspaceBase(BaseModel):
|
||||||
@@ -170,3 +178,19 @@ class InviteValidateResponse(BaseModel):
|
|||||||
|
|
||||||
class InviteAcceptRequest(BaseModel):
|
class InviteAcceptRequest(BaseModel):
|
||||||
token: str = Field(..., description="邀请令牌")
|
token: str = Field(..., description="邀请令牌")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceModelsUpdate(BaseModel):
|
||||||
|
"""工作空间模型配置更新请求"""
|
||||||
|
llm: Optional[uuid.UUID] = Field(default=None, description="LLM模型ID")
|
||||||
|
embedding: Optional[uuid.UUID] = Field(default=None, description="嵌入模型ID")
|
||||||
|
rerank: Optional[uuid.UUID] = Field(default=None, description="重排序模型ID")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceModelsConfig(BaseModel):
|
||||||
|
"""工作空间模型配置响应"""
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
llm: Optional[str] = Field(default=None, description="LLM模型ID")
|
||||||
|
embedding: Optional[str] = Field(default=None, description="嵌入模型ID")
|
||||||
|
rerank: Optional[str] = Field(default=None, description="重排序模型ID")
|
||||||
|
|||||||
@@ -1,36 +1,38 @@
|
|||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import List, Optional
|
|
||||||
import uuid
|
|
||||||
import secrets
|
|
||||||
import hashlib
|
|
||||||
import datetime
|
import datetime
|
||||||
from fastapi import HTTPException, status
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from os import getenv
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||||
from app.models.tenant_model import Tenants
|
from app.core.logging_config import get_business_logger
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.app_model import App
|
from app.models.workspace_model import (
|
||||||
from app.models.end_user_model import EndUser
|
InviteStatus,
|
||||||
from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember
|
Workspace,
|
||||||
from app.schemas.workspace_schema import (
|
WorkspaceMember,
|
||||||
WorkspaceCreate,
|
WorkspaceRole,
|
||||||
WorkspaceUpdate,
|
|
||||||
WorkspaceInviteCreate,
|
|
||||||
WorkspaceInviteResponse,
|
|
||||||
InviteValidateResponse,
|
|
||||||
InviteAcceptRequest,
|
|
||||||
WorkspaceMemberUpdate
|
|
||||||
)
|
)
|
||||||
from app.repositories import workspace_repository
|
from app.repositories import workspace_repository
|
||||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||||
from app.core.logging_config import get_business_logger
|
from app.schemas.workspace_schema import (
|
||||||
from app.core.config import settings
|
InviteAcceptRequest,
|
||||||
from app.services import user_service
|
InviteValidateResponse,
|
||||||
from os import getenv
|
WorkspaceCreate,
|
||||||
|
WorkspaceInviteCreate,
|
||||||
|
WorkspaceInviteResponse,
|
||||||
|
WorkspaceMemberUpdate,
|
||||||
|
WorkspaceModelsUpdate,
|
||||||
|
WorkspaceUpdate,
|
||||||
|
)
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
# 获取业务逻辑专用日志器
|
# 获取业务逻辑专用日志器
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
import os #
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
def switch_workspace(
|
def switch_workspace(
|
||||||
db: Session,
|
db: Session,
|
||||||
@@ -134,10 +136,9 @@ def create_workspace(
|
|||||||
f"{db_workspace.id} 创建知识库"
|
f"{db_workspace.id} 创建知识库"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
import os
|
|
||||||
from app.schemas.knowledge_schema import KnowledgeCreate
|
|
||||||
from app.models.knowledge_model import KnowledgeType, PermissionType
|
from app.models.knowledge_model import KnowledgeType, PermissionType
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
|
from app.schemas.knowledge_schema import KnowledgeCreate
|
||||||
|
|
||||||
# 创建知识库数据
|
# 创建知识库数据
|
||||||
knowledge_data = KnowledgeCreate(
|
knowledge_data = KnowledgeCreate(
|
||||||
@@ -232,7 +233,7 @@ def get_workspace_members(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 权限检查:工作空间成员或超级管理员可以查看成员列表
|
# 权限检查:工作空间成员或超级管理员可以查看成员列表
|
||||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
from app.core.permissions import Action, Resource, Subject, permission_service
|
||||||
member = workspace_repository.get_member_in_workspace(
|
member = workspace_repository.get_member_in_workspace(
|
||||||
db=db, user_id=user.id, workspace_id=workspace_id
|
db=db, user_id=user.id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
@@ -287,7 +288,7 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 使用统一权限服务检查访问权限
|
# 使用统一权限服务检查访问权限
|
||||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
from app.core.permissions import Action, Resource, Subject, permission_service
|
||||||
|
|
||||||
# 获取用户的工作空间成员关系
|
# 获取用户的工作空间成员关系
|
||||||
member = workspace_repository.get_member_in_workspace(
|
member = workspace_repository.get_member_in_workspace(
|
||||||
@@ -325,7 +326,7 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 使用统一权限服务检查管理权限
|
# 使用统一权限服务检查管理权限
|
||||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
from app.core.permissions import Action, Resource, Subject, permission_service
|
||||||
|
|
||||||
# 获取用户的工作空间成员关系
|
# 获取用户的工作空间成员关系
|
||||||
member = workspace_repository.get_member_in_workspace(
|
member = workspace_repository.get_member_in_workspace(
|
||||||
@@ -802,3 +803,54 @@ def get_workspace_models_configs(
|
|||||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def update_workspace_models_configs(
|
||||||
|
db: Session,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
models_update: WorkspaceModelsUpdate,
|
||||||
|
user: User,
|
||||||
|
) -> Workspace:
|
||||||
|
"""更新工作空间的模型配置(llm, embedding, rerank)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
models_update: 模型配置更新对象
|
||||||
|
user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Workspace: 更新后的工作空间对象
|
||||||
|
"""
|
||||||
|
business_logger.info(f"用户 {user.username} 请求更新工作空间 {workspace_id} 的模型配置")
|
||||||
|
|
||||||
|
# 检查用户是否有管理员权限
|
||||||
|
db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if models_update.llm is not None:
|
||||||
|
db_workspace.llm = str(models_update.llm) if models_update.llm else None
|
||||||
|
business_logger.debug(f"更新LLM配置: {models_update.llm}")
|
||||||
|
|
||||||
|
if models_update.embedding is not None:
|
||||||
|
db_workspace.embedding = str(models_update.embedding) if models_update.embedding else None
|
||||||
|
business_logger.debug(f"更新嵌入模型配置: {models_update.embedding}")
|
||||||
|
|
||||||
|
if models_update.rerank is not None:
|
||||||
|
db_workspace.rerank = str(models_update.rerank) if models_update.rerank else None
|
||||||
|
business_logger.debug(f"更新重排序模型配置: {models_update.rerank}")
|
||||||
|
|
||||||
|
db.add(db_workspace)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_workspace)
|
||||||
|
|
||||||
|
business_logger.info(
|
||||||
|
f"工作空间模型配置更新成功: workspace_id={workspace_id}, "
|
||||||
|
f"llm={db_workspace.llm}, embedding={db_workspace.embedding}, rerank={db_workspace.rerank}"
|
||||||
|
)
|
||||||
|
return db_workspace
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(f"工作空间模型配置更新失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
db.rollback()
|
||||||
|
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
Reference in New Issue
Block a user