Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/8525.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Apply the Purger pattern to `DomainRepository` for domain-related deletions, improving code consistency and maintainability
32 changes: 32 additions & 0 deletions src/ai/backend/manager/repositories/domain/purgers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import override

import sqlalchemy as sa

from ai.backend.manager.models.domain import DomainRow
from ai.backend.manager.models.kernel.row import KernelRow
from ai.backend.manager.repositories.base.purger import BatchPurgerSpec


@dataclass
class DomainKernelBatchPurgerSpec(BatchPurgerSpec[KernelRow]):
"""PurgerSpec for deleting all kernels belonging to a domain."""

domain_name: str

@override
def build_subquery(self) -> sa.sql.Select[tuple[KernelRow]]:
return sa.select(KernelRow).where(KernelRow.domain_name == self.domain_name)


@dataclass
class DomainBatchPurgerSpec(BatchPurgerSpec[DomainRow]):
"""PurgerSpec for deleting a domain."""

domain_name: str

@override
def build_subquery(self) -> sa.sql.Select[tuple[DomainRow]]:
return sa.select(DomainRow).where(DomainRow.name == self.domain_name)
74 changes: 44 additions & 30 deletions src/ai/backend/manager/repositories/domain/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import cast

import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession

from ai.backend.common.exception import BackendAIError, DomainNotFound, InvalidAPIParameters
Expand All @@ -26,21 +25,24 @@
InvalidDomainConfiguration,
)
from ai.backend.manager.models.domain import DomainRow, domains, get_domains
from ai.backend.manager.models.group import ProjectType, groups
from ai.backend.manager.models.kernel import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
kernels,
)
from ai.backend.manager.models.group import GroupRow, ProjectType, groups
from ai.backend.manager.models.kernel import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES
from ai.backend.manager.models.kernel.row import KernelRow
from ai.backend.manager.models.rbac import SystemScope
from ai.backend.manager.models.rbac.context import ClientContext
from ai.backend.manager.models.rbac.permission_defs import DomainPermission, ScalingGroupPermission
from ai.backend.manager.models.resource_policy import keypair_resource_policies
from ai.backend.manager.models.scaling_group import ScalingGroupForDomainRow, get_scaling_groups
from ai.backend.manager.models.user import users
from ai.backend.manager.models.user import UserRow
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.repositories.base.creator import Creator, execute_creator
from ai.backend.manager.repositories.base.purger import BatchPurger, execute_batch_purger
from ai.backend.manager.repositories.base.updater import Updater, execute_updater
from ai.backend.manager.repositories.domain.creators import DomainCreatorSpec
from ai.backend.manager.repositories.domain.purgers import (
DomainBatchPurgerSpec,
DomainKernelBatchPurgerSpec,
)
from ai.backend.manager.repositories.permission_controller.role_manager import RoleManager

domain_repository_resilience = Resilience(
Expand Down Expand Up @@ -123,28 +125,30 @@ async def purge_domain(self, domain_name: str) -> None:
Permanently deletes a domain after validation checks.
Validates domain purge permissions and prerequisites.
"""
async with self._db.begin() as conn:
async with self._db.begin_session() as session:
# Validate prerequisites
if await self._domain_has_active_kernels(conn, domain_name):
if await self._domain_has_active_kernels(session, domain_name):
raise DomainHasActiveKernels(
"Domain has some active kernels. Terminate them first."
)

user_count = await self._get_domain_user_count(conn, domain_name)
user_count = await self._get_domain_user_count(session, domain_name)
if user_count > 0:
raise DomainHasUsers("There are users bound to the domain. Remove users first.")

group_count = await self._get_domain_group_count(conn, domain_name)
group_count = await self._get_domain_group_count(session, domain_name)
if group_count > 0:
raise DomainHasGroups("There are groups bound to the domain. Remove groups first.")

# Clean up kernels
await self._delete_kernels(conn, domain_name)
await self._delete_kernels(session, domain_name)

# Delete domain
delete_query = sa.delete(domains).where(domains.c.name == domain_name)
result = await conn.execute(delete_query)
if result.rowcount == 0:
result = await execute_batch_purger(
session,
BatchPurger(spec=DomainBatchPurgerSpec(domain_name=domain_name), batch_size=1),
)
if result.deleted_count == 0:
raise DomainDeletionFailed(f"Failed to delete domain: {domain_name}")

@domain_repository_resilience.apply()
Expand Down Expand Up @@ -241,42 +245,52 @@ async def _create_model_store_group(self, db_session: SASession, domain_name: st
})
await db_session.execute(model_store_insert_query)

async def _delete_kernels(self, conn: SAConnection, domain_name: str) -> int:
async def _delete_kernels(self, session: SASession, domain_name: str) -> int:
"""
Private method to delete all kernels for a domain.
"""
delete_query = sa.delete(kernels).where(kernels.c.domain_name == domain_name)
result = await conn.execute(delete_query)
return result.rowcount
result = await execute_batch_purger(
session,
BatchPurger(spec=DomainKernelBatchPurgerSpec(domain_name=domain_name)),
)
return result.deleted_count

async def _domain_has_active_kernels(self, conn: SAConnection, domain_name: str) -> bool:
async def _domain_has_active_kernels(self, session: SASession, domain_name: str) -> bool:
"""
Private method to check if domain has active kernels.
"""
query = (
sa.select(sa.func.count())
.select_from(kernels)
.select_from(KernelRow)
.where(
(kernels.c.domain_name == domain_name)
& (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))
(KernelRow.domain_name == domain_name)
& (KernelRow.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES))
)
)
active_kernel_count = await conn.scalar(query)
active_kernel_count = await session.scalar(query)
return (active_kernel_count or 0) > 0

async def _get_domain_user_count(self, conn: SAConnection, domain_name: str) -> int:
async def _get_domain_user_count(self, session: SASession, domain_name: str) -> int:
"""
Private method to get user count for a domain.
"""
query = sa.select(sa.func.count()).where(users.c.domain_name == domain_name)
return await conn.scalar(query) or 0
query = (
sa.select(sa.func.count())
.select_from(UserRow)
.where(UserRow.domain_name == domain_name)
)
return await session.scalar(query) or 0

async def _get_domain_group_count(self, conn: SAConnection, domain_name: str) -> int:
async def _get_domain_group_count(self, session: SASession, domain_name: str) -> int:
"""
Private method to get group count for a domain.
"""
query = sa.select(sa.func.count()).where(groups.c.domain_name == domain_name)
return await conn.scalar(query) or 0
query = (
sa.select(sa.func.count())
.select_from(GroupRow)
.where(GroupRow.domain_name == domain_name)
)
return await session.scalar(query) or 0

@domain_repository_resilience.apply()
async def create_domain_node_with_permissions(
Expand Down
Loading
Loading