Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 76 additions & 13 deletions sky/users/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import threading
import time
from typing import Generator, List, Optional, Set

import casbin
Expand All @@ -15,10 +16,10 @@
from sky import sky_logging
from sky.skylet import constants
from sky.users import rbac
from sky.utils import annotations
from sky.utils import common
from sky.utils import common_utils
from sky.utils.db import db_utils
from sky.utils.db import kv_cache

logging.getLogger('casbin.policy').setLevel(sky_logging.ERROR)
logging.getLogger('casbin.role').setLevel(sky_logging.ERROR)
Expand All @@ -32,6 +33,12 @@

_enforcer_instance: Optional['PermissionService'] = None

# KV cache constants for workspace permission checks.
_WORKSPACE_PERM_CACHE_PREFIX = 'perm:ws:'
_WORKSPACE_PERM_CACHE_KEY_SEP = ':'
# Long TTL as safety net; primary freshness is explicit invalidation on update.
_WORKSPACE_PERM_CACHE_TTL_SECONDS = 30 * 86400 # 30 days


class PermissionService:
"""Permission service for SkyPilot API Server."""
Expand Down Expand Up @@ -251,9 +258,11 @@ def delete_user(self, user_id: str) -> None:
return
enforcer.remove_grouping_policy(user_id, current_roles[0])
enforcer.save_policy()
self.invalidate_user_permission_cache(user_id)

def update_role(self, user_id: str, new_role: str) -> None:
"""Update user role relationship."""
role_changed = False
with _policy_lock():
self._load_policy_no_lock()
enforcer = self._ensure_enforcer()
Expand All @@ -267,10 +276,13 @@ def update_role(self, user_id: str, new_role: str) -> None:
logger.debug(f'User {user_id} already has role {new_role}')
return
enforcer.remove_grouping_policy(user_id, current_role)
role_changed = True

# Update user role
enforcer.add_grouping_policy(user_id, new_role)
enforcer.save_policy()
if role_changed:
self.invalidate_user_permission_cache(user_id)

def get_user_roles(self, user_id: str) -> List[str]:
"""Get all roles for a user.
Expand Down Expand Up @@ -345,15 +357,35 @@ def load_policy(self):
with _policy_lock():
self._load_policy_no_lock()

# Allow many cached (user, workspace) pairs so hot paths with many
# workspaces stay fast when batch get_accessible_workspace_names isn't used.
@annotations.lru_cache(scope='request', maxsize=256)
def _workspace_perm_cache_key(self, workspace_name: str,
user_id: str) -> str:
"""Build a KV cache key for a workspace permission entry."""
return (f'{_WORKSPACE_PERM_CACHE_PREFIX}'
f'{workspace_name}'
f'{_WORKSPACE_PERM_CACHE_KEY_SEP}'
f'{user_id}')

def invalidate_workspace_permission_cache(self,
workspace_name: str) -> None:
"""Invalidate all cached permission entries for a workspace."""
prefix = (f'{_WORKSPACE_PERM_CACHE_PREFIX}'
f'{workspace_name}'
f'{_WORKSPACE_PERM_CACHE_KEY_SEP}')
kv_cache.delete_cache_entries_by_prefix(prefix)

def invalidate_user_permission_cache(self, user_id: str) -> None:
"""Invalidate all cached permission entries for a user."""
kv_cache.delete_cache_entries_by_prefix_suffix(
prefix=_WORKSPACE_PERM_CACHE_PREFIX,
suffix=f'{_WORKSPACE_PERM_CACHE_KEY_SEP}{user_id}')

def check_workspace_permission(self, user_id: str,
workspace_name: str) -> bool:
"""Check workspace permission.

This method checks if a user has permission to access a specific
workspace.
workspace. Results are cached in a DB-backed KV cache so that all
server/executor processes share the same view.

For private workspaces, the user must have explicit permission.

Expand All @@ -364,18 +396,38 @@ def check_workspace_permission(self, user_id: str,
# When it is not on API server, we allow all users to access all
# workspaces, as the workspace check has been done on API server.
return True

# Check DB-backed KV cache (covers both admin and non-admin results).
cache_key = self._workspace_perm_cache_key(workspace_name, user_id)
cached = kv_cache.get_cache_entry(cache_key)
if cached is not None:
return cached == '1'

# Cache miss — compute the permission.
# Admin users have access to all workspaces.
role = self.get_user_roles(user_id)
if rbac.RoleName.ADMIN.value in role:
return True
# The Casbin model matcher already handles the wildcard '*' case:
# m = (g(r.sub, p.sub)|| p.sub == '*') && r.obj == p.obj &&
# r.act == p.act
# This means if there's a policy ('*', workspace_name, '*'), it will
# match any user
enforcer = self._ensure_enforcer()
result = enforcer.enforce(user_id, workspace_name, '*')
result = True
else:
# The Casbin model matcher already handles the wildcard '*' case:
# m = (g(r.sub, p.sub)|| p.sub == '*') && r.obj == p.obj &&
# r.act == p.act
# This means if there's a policy ('*', workspace_name, '*'), it
# will match any user
enforcer = self._ensure_enforcer()
result = enforcer.enforce(user_id, workspace_name, '*')

logger.debug(f'Workspace permission check: user={user_id}, '
f'workspace={workspace_name}, result={result}')

# Cache the result; failures are non-critical.
try:
kv_cache.add_or_update_cache_entry(
cache_key, '1' if result else '0',
time.time() + _WORKSPACE_PERM_CACHE_TTL_SECONDS)
except Exception as e: # pylint: disable=broad-except
logger.debug(f'Failed to cache workspace permission: {e}')

return result

def check_service_account_token_permission(self, user_id: str,
Expand Down Expand Up @@ -424,6 +476,9 @@ def add_workspace_policy(self, workspace_name: str,
f'workspace={workspace_name}')
enforcer.add_policy(user, workspace_name, '*')
enforcer.save_policy()
# Invalidate stale cached denials (e.g. from checks between a
# workspace deletion and its re-creation with the same name).
self.invalidate_workspace_permission_cache(workspace_name)

def update_workspace_policy(self, workspace_name: str,
users: List[str]) -> None:
Expand All @@ -446,13 +501,21 @@ def update_workspace_policy(self, workspace_name: str,
f'workspace={workspace_name}')
enforcer.add_policy(user, workspace_name, '*')
enforcer.save_policy()
# Invalidate cached permission entries after the policy is
# persisted so other processes re-compute permissions on next
# check.
self.invalidate_workspace_permission_cache(workspace_name)

def remove_workspace_policy(self, workspace_name: str) -> None:
"""Remove workspace policy."""
with _policy_lock():
enforcer = self._ensure_enforcer()
enforcer.remove_filtered_policy(1, workspace_name)
enforcer.save_policy()
# Invalidate cached permission entries after the policy is
# persisted so other processes re-compute permissions on next
# check.
self.invalidate_workspace_permission_cache(workspace_name)


@contextlib.contextmanager
Expand Down
52 changes: 52 additions & 0 deletions sky/utils/db/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,55 @@ def get_cache_entry(key: str) -> Optional[str]:
kv_cache_table.c.key == key).where(
kv_cache_table.c.expires_at > time.time()))
return result.scalar()


_LIKE_ESCAPE_CHAR = '\\'


def _escape_like(value: str) -> str:
"""Escape SQL LIKE wildcard characters (%, _) in a literal value."""
return (value.replace(_LIKE_ESCAPE_CHAR, _LIKE_ESCAPE_CHAR * 2).replace(
'%', f'{_LIKE_ESCAPE_CHAR}%').replace('_', f'{_LIKE_ESCAPE_CHAR}_'))


@metrics_lib.time_me
def delete_cache_entries_by_prefix(prefix: str) -> None:
"""Delete all cache entries whose key starts with the given prefix.

Any SQL LIKE wildcards (%, _) in *prefix* are escaped so they are
matched literally.

Args:
prefix: The literal prefix to match against cache keys.
"""
escaped = _escape_like(prefix)
engine = _db_manager.get_engine()
with orm.Session(engine) as session:
session.execute(
sqlalchemy.delete(kv_cache_table).where(
kv_cache_table.c.key.like(f'{escaped}%',
escape=_LIKE_ESCAPE_CHAR)))
session.commit()


@metrics_lib.time_me
def delete_cache_entries_by_prefix_suffix(prefix: str, suffix: str) -> None:
"""Delete all cache entries whose key starts with *prefix* and ends
with *suffix*, with any content in between.

Both *prefix* and *suffix* are treated as literal strings — any SQL
LIKE wildcards (%, _) they contain are escaped automatically.

Args:
prefix: Literal prefix to match against cache keys.
suffix: Literal suffix to match against cache keys.
"""
escaped_prefix = _escape_like(prefix)
escaped_suffix = _escape_like(suffix)
pattern = f'{escaped_prefix}%{escaped_suffix}'
engine = _db_manager.get_engine()
with orm.Session(engine) as session:
session.execute(
sqlalchemy.delete(kv_cache_table).where(
kv_cache_table.c.key.like(pattern, escape=_LIKE_ESCAPE_CHAR)))
session.commit()
109 changes: 109 additions & 0 deletions tests/unit_tests/test_sky/db/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,112 @@ def test_get_cache_entry_expired(isolated_database):
kv_cache.add_or_update_cache_entry('test_key', 'test_value',
time.time() + 3600)
assert kv_cache.get_cache_entry('test_key') == 'test_value'


def test_delete_cache_entries_by_prefix(isolated_database):
expires = time.time() + 3600
kv_cache.add_or_update_cache_entry('perm:ws:ws1:user1', '1', expires)
kv_cache.add_or_update_cache_entry('perm:ws:ws1:user2', '0', expires)
kv_cache.add_or_update_cache_entry('perm:ws:ws2:user1', '1', expires)
kv_cache.add_or_update_cache_entry('perm:ws:ws11:user1', '1', expires)
kv_cache.add_or_update_cache_entry('other:key', 'val', expires)

# Delete all entries for ws1
kv_cache.delete_cache_entries_by_prefix('perm:ws:ws1:')

# ws1 entries should be gone
assert kv_cache.get_cache_entry('perm:ws:ws1:user1') is None
assert kv_cache.get_cache_entry('perm:ws:ws1:user2') is None
# ws2 and other entries should remain
assert kv_cache.get_cache_entry('perm:ws:ws2:user1') == '1'
assert kv_cache.get_cache_entry('perm:ws:ws11:user1') == '1'
assert kv_cache.get_cache_entry('other:key') == 'val'


def test_delete_cache_entries_by_prefix_no_matches(isolated_database):
kv_cache.add_or_update_cache_entry('key1', 'value1', time.time() + 3600)
# Should not raise an error and should not delete unrelated entries
kv_cache.delete_cache_entries_by_prefix('nonexistent_prefix')
assert kv_cache.get_cache_entry('key1') == 'value1'


def test_delete_cache_entries_by_prefix_suffix(isolated_database):
expires = time.time() + 3600
kv_cache.add_or_update_cache_entry('perm:ws:ws1:user1', '1', expires)
kv_cache.add_or_update_cache_entry('perm:ws:ws2:user1', '1', expires)
kv_cache.add_or_update_cache_entry('perm:ws:ws1:user2', '0', expires)
kv_cache.add_or_update_cache_entry('other:key', 'val', expires)

# Delete all entries for user1 across all workspaces
kv_cache.delete_cache_entries_by_prefix_suffix('perm:ws:', ':user1')

# user1 entries should be gone
assert kv_cache.get_cache_entry('perm:ws:ws1:user1') is None
assert kv_cache.get_cache_entry('perm:ws:ws2:user1') is None
# user2 and other entries should remain
assert kv_cache.get_cache_entry('perm:ws:ws1:user2') == '0'
assert kv_cache.get_cache_entry('other:key') == 'val'


def test_delete_cache_entries_by_prefix_suffix_no_matches(isolated_database):
kv_cache.add_or_update_cache_entry('key1', 'value1', time.time() + 3600)
# Should not raise an error and should not delete unrelated entries
kv_cache.delete_cache_entries_by_prefix_suffix('perm:ws:', ':nonexistent')
assert kv_cache.get_cache_entry('key1') == 'value1'


def test_delete_by_prefix_escapes_percent_in_data(isolated_database):
"""Verify that '%' in the prefix is treated literally, not as wildcard."""
expires = time.time() + 3600
kv_cache.add_or_update_cache_entry('50%off:user1', '1', expires)
kv_cache.add_or_update_cache_entry('50Xoff:user1', '1', expires)
kv_cache.add_or_update_cache_entry('other:key', 'val', expires)

# Delete only entries starting with literal '50%off:'
kv_cache.delete_cache_entries_by_prefix('50%off:')

assert kv_cache.get_cache_entry('50%off:user1') is None
# '50Xoff:user1' must NOT be deleted — '%' should not act as wildcard
assert kv_cache.get_cache_entry('50Xoff:user1') == '1'
assert kv_cache.get_cache_entry('other:key') == 'val'


def test_delete_by_prefix_escapes_underscore_in_data(isolated_database):
"""Verify that '_' in the prefix is treated literally, not as wildcard."""
expires = time.time() + 3600
kv_cache.add_or_update_cache_entry('my_ws:user1', '1', expires)
kv_cache.add_or_update_cache_entry('myXws:user1', '1', expires)

kv_cache.delete_cache_entries_by_prefix('my_ws:')

assert kv_cache.get_cache_entry('my_ws:user1') is None
# 'myXws:user1' must NOT be deleted — '_' should not match 'X'
assert kv_cache.get_cache_entry('myXws:user1') == '1'


def test_delete_by_prefix_escapes_backslash_in_data(isolated_database):
r"""Verify that '\' in the prefix is treated literally, not as escape.

Without escaping, LIKE with ESCAPE '\' treats '\d' as matching
literal 'd', so 'team\dev:%' would match 'teamdev:*' instead
of 'team\dev:*'.
"""
expires = time.time() + 3600
kv_cache.add_or_update_cache_entry('team\\dev:user1', '1', expires)
kv_cache.add_or_update_cache_entry('teamdev:user1', '1', expires)

kv_cache.delete_cache_entries_by_prefix('team\\dev:')

assert kv_cache.get_cache_entry('team\\dev:user1') is None
# 'teamdev:user1' must NOT be deleted — backslash should not be
# swallowed as a LIKE escape character
assert kv_cache.get_cache_entry('teamdev:user1') == '1'


def test_escape_like_helper():
"""Test the _escape_like helper escapes %, _, and backslash."""
assert kv_cache._escape_like('normal') == 'normal'
assert kv_cache._escape_like('50%off') == '50\\%off'
assert kv_cache._escape_like('my_ws') == 'my\\_ws'
assert kv_cache._escape_like('a\\b') == 'a\\\\b'
assert kv_cache._escape_like('50%_x\\y') == '50\\%\\_x\\\\y'
Loading
Loading