168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
import uuid
|
|
from sqlalchemy.orm import Session, joinedload
|
|
from sqlalchemy import and_, or_, func
|
|
from typing import List, Optional
|
|
|
|
from app.models.tenant_model import Tenants
|
|
from app.models.user_model import User
|
|
from app.schemas.tenant_schema import TenantCreate, TenantUpdate
|
|
|
|
|
|
class TenantRepository:
|
|
"""租户数据访问层"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
|
|
def create_tenant(self, tenant_data: TenantCreate) -> Tenants:
|
|
"""创建租户"""
|
|
db_tenant = Tenants(
|
|
name=tenant_data.name,
|
|
id=uuid.uuid4(),
|
|
description=tenant_data.description,
|
|
is_active=tenant_data.is_active
|
|
)
|
|
self.db.add(db_tenant)
|
|
self.db.flush()
|
|
return db_tenant
|
|
|
|
def get_tenant_by_id(self, tenant_id: uuid.UUID) -> Optional[Tenants]:
|
|
"""根据ID获取租户"""
|
|
return self.db.query(Tenants).filter(Tenants.id == tenant_id).first()
|
|
|
|
def get_tenant_by_name(self, name: str) -> Optional[Tenants]:
|
|
"""根据名称获取租户"""
|
|
return self.db.query(Tenants).filter(Tenants.name == name).first()
|
|
|
|
def get_tenants(
|
|
self,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
is_active: Optional[bool] = None,
|
|
search: Optional[str] = None
|
|
) -> List[Tenants]:
|
|
"""获取租户列表"""
|
|
query = self.db.query(Tenants)
|
|
|
|
if is_active is not None:
|
|
query = query.filter(Tenants.is_active == is_active)
|
|
|
|
if search:
|
|
query = query.filter(
|
|
or_(
|
|
Tenants.name.ilike(f"%{search}%"),
|
|
Tenants.description.ilike(f"%{search}%")
|
|
)
|
|
)
|
|
|
|
return query.offset(skip).limit(limit).all()
|
|
|
|
def count_tenants(
|
|
self,
|
|
is_active: Optional[bool] = None,
|
|
search: Optional[str] = None
|
|
) -> int:
|
|
"""统计租户数量"""
|
|
query = self.db.query(func.count(Tenants.id))
|
|
|
|
if is_active is not None:
|
|
query = query.filter(Tenants.is_active == is_active)
|
|
|
|
if search:
|
|
query = query.filter(
|
|
or_(
|
|
Tenants.name.ilike(f"%{search}%"),
|
|
Tenants.description.ilike(f"%{search}%")
|
|
)
|
|
)
|
|
|
|
return query.scalar()
|
|
|
|
def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]:
|
|
"""更新租户"""
|
|
db_tenant = self.get_tenant_by_id(tenant_id)
|
|
if not db_tenant:
|
|
return None
|
|
|
|
for field, value in tenant_data.dict(exclude_unset=True).items():
|
|
setattr(db_tenant, field, value)
|
|
|
|
self.db.flush()
|
|
return db_tenant
|
|
|
|
def delete_tenant(self, tenant_id: uuid.UUID) -> bool:
|
|
"""逻辑删除租户"""
|
|
db_tenant = self.get_tenant_by_id(tenant_id)
|
|
if not db_tenant:
|
|
return False
|
|
|
|
# 逻辑删除租户
|
|
db_tenant.is_active = False
|
|
return True
|
|
|
|
def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]:
|
|
"""获取租户下的所有用户"""
|
|
query = self.db.query(User).filter(User.tenant_id == tenant_id)
|
|
|
|
if is_active is not None:
|
|
query = query.filter(User.is_active == is_active)
|
|
|
|
return query.all()
|
|
|
|
def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]:
|
|
"""获取用户所属的租户"""
|
|
user = self.db.query(User).filter(User.id == user_id).first()
|
|
if not user or not user.tenant_id:
|
|
return None
|
|
|
|
return self.get_tenant_by_id(user.tenant_id)
|
|
|
|
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
|
"""将用户分配给租户"""
|
|
user = self.db.query(User).filter(User.id == user_id).first()
|
|
if not user:
|
|
return False
|
|
|
|
# 验证租户存在
|
|
tenant = self.get_tenant_by_id(tenant_id)
|
|
if not tenant:
|
|
return False
|
|
|
|
user.tenant_id = tenant_id
|
|
self.db.flush()
|
|
return True
|
|
|
|
def count_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> int:
|
|
"""统计租户下的用户数量"""
|
|
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
|
|
|
|
if is_active is not None:
|
|
query = query.filter(User.is_active == is_active)
|
|
|
|
return query.scalar()
|
|
|
|
|
|
# 便利函数,保持向后兼容
|
|
def create_tenant(db: Session, tenant_data: TenantCreate) -> Tenants:
|
|
"""创建租户"""
|
|
return TenantRepository(db).create_tenant(tenant_data)
|
|
|
|
def get_tenant_by_id(db: Session, tenant_id: uuid.UUID) -> Optional[Tenants]:
|
|
"""根据ID获取租户"""
|
|
return TenantRepository(db).get_tenant_by_id(tenant_id)
|
|
|
|
def get_tenant_by_name(db: Session, name: str) -> Optional[Tenants]:
|
|
"""根据名称获取租户"""
|
|
return TenantRepository(db).get_tenant_by_name(name)
|
|
|
|
def get_tenants(db: Session, skip: int = 0, limit: int = 100) -> List[Tenants]:
|
|
"""获取租户列表"""
|
|
return TenantRepository(db).get_tenants(skip=skip, limit=limit)
|
|
|
|
def get_user_tenant(db: Session, user_id: uuid.UUID) -> Optional[Tenants]:
|
|
"""获取用户所属的租户"""
|
|
return TenantRepository(db).get_user_tenant(user_id)
|
|
|
|
def get_tenant_users(db: Session, tenant_id: uuid.UUID) -> List[User]:
|
|
"""获取租户下的所有用户"""
|
|
return TenantRepository(db).get_tenant_users(tenant_id) |