Files
MemoryBear/api/app/repositories/tenant_repository.py
2026-03-27 14:36:19 +08:00

168 lines
5.5 KiB
Python

import uuid
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_, func
from typing import List, Optional
from app.models.tenant_model import Tenants
from app.models.user_model import User
from app.schemas.tenant_schema import TenantCreate, TenantUpdate
class TenantRepository:
"""租户数据访问层"""
def __init__(self, db: Session):
self.db = db
def create_tenant(self, tenant_data: TenantCreate) -> Tenants:
"""创建租户"""
db_tenant = Tenants(
name=tenant_data.name,
id=uuid.uuid4(),
description=tenant_data.description,
is_active=tenant_data.is_active
)
self.db.add(db_tenant)
self.db.flush()
return db_tenant
def get_tenant_by_id(self, tenant_id: uuid.UUID) -> Optional[Tenants]:
"""根据ID获取租户"""
return self.db.query(Tenants).filter(Tenants.id == tenant_id).first()
def get_tenant_by_name(self, name: str) -> Optional[Tenants]:
"""根据名称获取租户"""
return self.db.query(Tenants).filter(Tenants.name == name).first()
def get_tenants(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> List[Tenants]:
"""获取租户列表"""
query = self.db.query(Tenants)
if is_active is not None:
query = query.filter(Tenants.is_active == is_active)
if search:
query = query.filter(
or_(
Tenants.name.ilike(f"%{search}%"),
Tenants.description.ilike(f"%{search}%")
)
)
return query.offset(skip).limit(limit).all()
def count_tenants(
self,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> int:
"""统计租户数量"""
query = self.db.query(func.count(Tenants.id))
if is_active is not None:
query = query.filter(Tenants.is_active == is_active)
if search:
query = query.filter(
or_(
Tenants.name.ilike(f"%{search}%"),
Tenants.description.ilike(f"%{search}%")
)
)
return query.scalar()
def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]:
"""更新租户"""
db_tenant = self.get_tenant_by_id(tenant_id)
if not db_tenant:
return None
for field, value in tenant_data.dict(exclude_unset=True).items():
setattr(db_tenant, field, value)
self.db.flush()
return db_tenant
def delete_tenant(self, tenant_id: uuid.UUID) -> bool:
"""逻辑删除租户"""
db_tenant = self.get_tenant_by_id(tenant_id)
if not db_tenant:
return False
# 逻辑删除租户
db_tenant.is_active = False
return True
def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]:
"""获取租户下的所有用户"""
query = self.db.query(User).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.all()
def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]:
"""获取用户所属的租户"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user or not user.tenant_id:
return None
return self.get_tenant_by_id(user.tenant_id)
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""将用户分配给租户"""
user = self.db.query(User).filter(User.id == user_id).first()
if not user:
return False
# 验证租户存在
tenant = self.get_tenant_by_id(tenant_id)
if not tenant:
return False
user.tenant_id = tenant_id
self.db.flush()
return True
def count_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> int:
"""统计租户下的用户数量"""
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.scalar()
# 便利函数,保持向后兼容
def create_tenant(db: Session, tenant_data: TenantCreate) -> Tenants:
"""创建租户"""
return TenantRepository(db).create_tenant(tenant_data)
def get_tenant_by_id(db: Session, tenant_id: uuid.UUID) -> Optional[Tenants]:
"""根据ID获取租户"""
return TenantRepository(db).get_tenant_by_id(tenant_id)
def get_tenant_by_name(db: Session, name: str) -> Optional[Tenants]:
"""根据名称获取租户"""
return TenantRepository(db).get_tenant_by_name(name)
def get_tenants(db: Session, skip: int = 0, limit: int = 100) -> List[Tenants]:
"""获取租户列表"""
return TenantRepository(db).get_tenants(skip=skip, limit=limit)
def get_user_tenant(db: Session, user_id: uuid.UUID) -> Optional[Tenants]:
"""获取用户所属的租户"""
return TenantRepository(db).get_user_tenant(user_id)
def get_tenant_users(db: Session, tenant_id: uuid.UUID) -> List[User]:
"""获取租户下的所有用户"""
return TenantRepository(db).get_tenant_users(tenant_id)