From caf9b8a9dae1c410fe6f08a361cd70c97767f817 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Wed, 24 Dec 2025 18:21:36 +0800 Subject: [PATCH] feat(workspace): add workspace models configuration update endpoint - Add PUT endpoint to update workspace LLM, embedding, and rerank model configurations - Create WorkspaceModelsUpdate schema for model configuration update requests - Create WorkspaceModelsConfig schema for model configuration responses with proper validation - Implement update_workspace_models_configs service method to persist model configuration changes - Update workspace_models_configs GET endpoint to return validated WorkspaceModelsConfig response - Reorganize imports across controller, schema, and service files for consistency and readability - Add proper logging for model configuration updates with user and workspace context --- api/app/controllers/workspace_controller.py | 68 +++++++++--- api/app/schemas/workspace_schema.py | 32 +++++- api/app/services/workspace_service.py | 112 ++++++++++++++------ 3 files changed, 163 insertions(+), 49 deletions(-) diff --git a/api/app/controllers/workspace_controller.py b/api/app/controllers/workspace_controller.py index f4390568..eb6065e0 100644 --- a/api/app/controllers/workspace_controller.py +++ b/api/app/controllers/workspace_controller.py @@ -1,25 +1,38 @@ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session -from typing import List, Optional import uuid +from typing import List, Optional +from app.core.logging_config import get_api_logger from app.core.response_utils import success from app.db import get_db -from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard -from app.models.user_model import User +from app.dependencies import ( + cur_workspace_access_guard, + get_current_superuser, + get_current_tenant, + get_current_user, + workspace_access_guard, +) from app.models.tenant_model import Tenants -from app.models.workspace_model import Workspace, InviteStatus +from app.models.user_model import User +from app.models.workspace_model import InviteStatus, Workspace +from app.schemas import knowledge_schema from app.schemas.response_schema import ApiResponse from app.schemas.workspace_schema import ( - WorkspaceCreate, WorkspaceUpdate, WorkspaceResponse, - WorkspaceInviteCreate, WorkspaceInviteResponse, - InviteValidateResponse, InviteAcceptRequest, - WorkspaceMemberUpdate, WorkspaceMemberItem + InviteAcceptRequest, + InviteValidateResponse, + WorkspaceCreate, + WorkspaceInviteCreate, + WorkspaceInviteResponse, + WorkspaceMemberItem, + WorkspaceMemberUpdate, + WorkspaceModelsConfig, + WorkspaceModelsUpdate, + WorkspaceResponse, + WorkspaceUpdate, ) -from app.schemas import knowledge_schema -from app.services import workspace_service -from app.core.logging_config import get_api_logger -from app.services import knowledge_service, document_service +from app.services import document_service, knowledge_service, workspace_service +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.orm import Session + # 获取API专用日志器 api_logger = get_api_logger() # 需要认证的路由器 @@ -338,5 +351,30 @@ def workspace_models_configs( f"成功获取工作空间 {workspace_id} 的模型配置: " f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}" ) - return success(data=configs, msg="模型配置获取成功") + return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功") + + +@router.put("/workspace_models", response_model=ApiResponse) +@cur_workspace_access_guard() +def update_workspace_models_configs( + models_update: WorkspaceModelsUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """更新当前工作空间的模型配置(llm, embedding, rerank)""" + workspace_id = current_user.current_workspace_id + api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的模型配置") + + updated_workspace = workspace_service.update_workspace_models_configs( + db=db, + workspace_id=workspace_id, + models_update=models_update, + user=current_user + ) + + api_logger.info( + f"成功更新工作空间 {workspace_id} 的模型配置: " + f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}" + ) + return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功") diff --git a/api/app/schemas/workspace_schema.py b/api/app/schemas/workspace_schema.py index eb3e31e2..1fdfe426 100644 --- a/api/app/schemas/workspace_schema.py +++ b/api/app/schemas/workspace_schema.py @@ -1,9 +1,17 @@ -import email -from pydantic import BaseModel, Field, EmailStr, field_serializer, computed_field, ConfigDict import datetime +import email import uuid -from typing import Literal -from app.models.workspace_model import WorkspaceRole, InviteStatus +from typing import Literal, Optional + +from app.models.workspace_model import InviteStatus, WorkspaceRole +from pydantic import ( + BaseModel, + ConfigDict, + EmailStr, + Field, + computed_field, + field_serializer, +) class WorkspaceBase(BaseModel): @@ -170,3 +178,19 @@ class InviteValidateResponse(BaseModel): class InviteAcceptRequest(BaseModel): token: str = Field(..., description="邀请令牌") + + +class WorkspaceModelsUpdate(BaseModel): + """工作空间模型配置更新请求""" + llm: Optional[uuid.UUID] = Field(default=None, description="LLM模型ID") + embedding: Optional[uuid.UUID] = Field(default=None, description="嵌入模型ID") + rerank: Optional[uuid.UUID] = Field(default=None, description="重排序模型ID") + + +class WorkspaceModelsConfig(BaseModel): + """工作空间模型配置响应""" + model_config = ConfigDict(from_attributes=True) + + llm: Optional[str] = Field(default=None, description="LLM模型ID") + embedding: Optional[str] = Field(default=None, description="嵌入模型ID") + rerank: Optional[str] = Field(default=None, description="重排序模型ID") diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 04ee647c..bb6c53dc 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -1,36 +1,38 @@ -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid -import secrets -import hashlib import datetime -from fastapi import HTTPException, status +import hashlib +import secrets +import uuid +from os import getenv +from typing import List, Optional + +from app.core.config import settings from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, PermissionDeniedException -from app.models.tenant_model import Tenants +from app.core.logging_config import get_business_logger from app.models.user_model import User -from app.models.app_model import App -from app.models.end_user_model import EndUser -from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember -from app.schemas.workspace_schema import ( - WorkspaceCreate, - WorkspaceUpdate, - WorkspaceInviteCreate, - WorkspaceInviteResponse, - InviteValidateResponse, - InviteAcceptRequest, - WorkspaceMemberUpdate +from app.models.workspace_model import ( + InviteStatus, + Workspace, + WorkspaceMember, + WorkspaceRole, ) from app.repositories import workspace_repository from app.repositories.workspace_invite_repository import WorkspaceInviteRepository -from app.core.logging_config import get_business_logger -from app.core.config import settings -from app.services import user_service -from os import getenv +from app.schemas.workspace_schema import ( + InviteAcceptRequest, + InviteValidateResponse, + WorkspaceCreate, + WorkspaceInviteCreate, + WorkspaceInviteResponse, + WorkspaceMemberUpdate, + WorkspaceModelsUpdate, + WorkspaceUpdate, +) +from dotenv import load_dotenv +from sqlalchemy.orm import Session + # 获取业务逻辑专用日志器 business_logger = get_business_logger() -import os # -from dotenv import load_dotenv load_dotenv() def switch_workspace( db: Session, @@ -134,10 +136,9 @@ def create_workspace( f"{db_workspace.id} 创建知识库" ) try: - import os - from app.schemas.knowledge_schema import KnowledgeCreate from app.models.knowledge_model import KnowledgeType, PermissionType from app.repositories import knowledge_repository + from app.schemas.knowledge_schema import KnowledgeCreate # 创建知识库数据 knowledge_data = KnowledgeCreate( @@ -232,7 +233,7 @@ def get_workspace_members( ) # 权限检查:工作空间成员或超级管理员可以查看成员列表 - from app.core.permissions import permission_service, Subject, Resource, Action + from app.core.permissions import Action, Resource, Subject, permission_service member = workspace_repository.get_member_in_workspace( db=db, user_id=user.id, workspace_id=workspace_id ) @@ -287,7 +288,7 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use ) # 使用统一权限服务检查访问权限 - from app.core.permissions import permission_service, Subject, Resource, Action + from app.core.permissions import Action, Resource, Subject, permission_service # 获取用户的工作空间成员关系 member = workspace_repository.get_member_in_workspace( @@ -325,7 +326,7 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user ) # 使用统一权限服务检查管理权限 - from app.core.permissions import permission_service, Subject, Resource, Action + from app.core.permissions import Action, Resource, Subject, permission_service # 获取用户的工作空间成员关系 member = workspace_repository.get_member_in_workspace( @@ -801,4 +802,55 @@ def get_workspace_models_configs( f"成功获取工作空间 {workspace_id} 的模型配置: " f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}" ) - return configs \ No newline at end of file + return configs + + +def update_workspace_models_configs( + db: Session, + workspace_id: uuid.UUID, + models_update: WorkspaceModelsUpdate, + user: User, +) -> Workspace: + """更新工作空间的模型配置(llm, embedding, rerank) + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + models_update: 模型配置更新对象 + user: 当前用户 + + Returns: + Workspace: 更新后的工作空间对象 + """ + business_logger.info(f"用户 {user.username} 请求更新工作空间 {workspace_id} 的模型配置") + + # 检查用户是否有管理员权限 + db_workspace = _check_workspace_admin_permission(db, workspace_id, user) + + try: + if models_update.llm is not None: + db_workspace.llm = str(models_update.llm) if models_update.llm else None + business_logger.debug(f"更新LLM配置: {models_update.llm}") + + if models_update.embedding is not None: + db_workspace.embedding = str(models_update.embedding) if models_update.embedding else None + business_logger.debug(f"更新嵌入模型配置: {models_update.embedding}") + + if models_update.rerank is not None: + db_workspace.rerank = str(models_update.rerank) if models_update.rerank else None + business_logger.debug(f"更新重排序模型配置: {models_update.rerank}") + + db.add(db_workspace) + db.commit() + db.refresh(db_workspace) + + business_logger.info( + f"工作空间模型配置更新成功: workspace_id={workspace_id}, " + f"llm={db_workspace.llm}, embedding={db_workspace.embedding}, rerank={db_workspace.rerank}" + ) + return db_workspace + + except Exception as e: + business_logger.error(f"工作空间模型配置更新失败: workspace_id={workspace_id} - {str(e)}") + db.rollback() + raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR) \ No newline at end of file