Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 88 additions & 13 deletions sqladmin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
Expand Down Expand Up @@ -1141,6 +1142,83 @@ class UserAdmin(ModelView, model=User):
]
return ", ".join(field_names)

def _get_joined_entities(self, stmt: Select) -> Set[Any]:
"""Extract all joined entity classes from a Select statement.

Args:
stmt: The SQLAlchemy select statement to inspect

Returns:
Set of joined entity classes
"""
joined_entities: Set[Any] = set()

froms = stmt.get_final_froms()
if not froms:
return joined_entities

for from_clause in froms:
# Check if this is a Join object (has left and right attributes)
if not (hasattr(from_clause, "left") and hasattr(from_clause, "right")):
continue

# This is a Join object - recursively extract entities
left = from_clause.left # type: ignore[attr-defined]
while hasattr(left, "left") and hasattr(left, "right"):
# Left side is also a Join
right_entity = getattr(left.right, "entity_namespace", None) # type: ignore[attr-defined]
if right_entity and isinstance(right_entity, type):
joined_entities.add(right_entity)
left = left.left # type: ignore[attr-defined]

# Add the right side of the current join
right_entity = getattr(from_clause.right, "entity_namespace", None) # type: ignore[attr-defined]
if right_entity and isinstance(right_entity, type):
joined_entities.add(right_entity)

return joined_entities

def _join_relationship_paths(
self,
stmt: Select,
field_path: str,
joined_paths: Set[str],
) -> Tuple[Select, Any]:
"""Join relationship paths and return the statement and target model.

This helper method navigates through relationship paths (e.g., 'user.profile')
and joins each relationship only once, tracking which paths have been joined
to avoid duplicate JOINs.

Args:
stmt: The SQLAlchemy select statement to modify
field_path: The field path (e.g., 'user.profile.role')
joined_paths: Set tracking which relationship paths have been joined

Returns:
Tuple of (modified statement, target model class)
"""
model = self.model
parts = field_path.split(".")

# Get already joined entities from the statement
joined_entities = self._get_joined_entities(stmt)

current_path = ""
for part in parts[:-1]:
current_path = f"{current_path}.{part}" if current_path else part
next_model = getattr(model, part).mapper.class_

# Check if this path is already tracked OR if the entity is already joined
if current_path not in joined_paths and next_model not in joined_entities:
stmt = stmt.join(next_model)
joined_paths.add(current_path)
joined_entities.add(next_model)

model = next_model

return stmt, model

def search_query(self, stmt: Select, term: str) -> Select:
"""Specify the search query given the SQLAlchemy statement
and term to search for.
Expand All @@ -1152,15 +1230,13 @@ def search_query(self, stmt: Select, term: str) -> Select:
"""

expressions = []
joined_paths: Set[str] = set()

for field in self._search_fields:
model = self.model
stmt, model = self._join_relationship_paths(stmt, field, joined_paths)
parts = field.split(".")
for part in parts[:-1]:
model = getattr(model, part).mapper.class_
stmt = stmt.join(model)

field = getattr(model, parts[-1])
expressions.append(cast(field, String).ilike(f"%{term}%"))
field_attr = getattr(model, parts[-1])
expressions.append(cast(field_attr, String).ilike(f"%{term}%"))

return stmt.filter(or_(*expressions))

Expand Down Expand Up @@ -1224,13 +1300,12 @@ def sort_query(self, stmt: Select, request: Request) -> Select:
else:
sort_fields = self._get_default_sort()

for sort_field, is_desc in sort_fields:
model = self.model
joined_paths: Set[str] = set()

parts = self._get_prop_name(sort_field).split(".")
for part in parts[:-1]:
model = getattr(model, part).mapper.class_
stmt = stmt.join(model)
for sort_field, is_desc in sort_fields:
field_path = self._get_prop_name(sort_field)
stmt, model = self._join_relationship_paths(stmt, field_path, joined_paths)
parts = field_path.split(".")

if is_desc:
stmt = stmt.order_by(desc(getattr(model, parts[-1])))
Expand Down
30 changes: 30 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,36 @@ class AddressAdmin(ModelView, model=Address):
assert "lower(CAST(profiles.role AS VARCHAR))" in str(stmt)


def test_sort_multi_fields() -> None:
class AddressAdmin(ModelView, model=Address):
column_sortable_list = [Address.id, User.id, User.name]

query = select(Address)
request = Request({"type": "http", "query_string": b"sortBy=user.id&sort=asc"})
stmt = AddressAdmin().sort_query(query, request)

stmt_str = str(stmt)
assert "ORDER BY users.id ASC" in stmt_str
assert stmt_str.count("JOIN") == 1


def test_sort_then_search_no_duplicate_joins() -> None:
class AddressAdmin(ModelView, model=Address):
column_searchable_list = ["user.name"]
column_sortable_list = [User.id]

query = select(Address)
request = Request({"type": "http", "query_string": b"sortBy=user.id&sort=asc"})

stmt = AddressAdmin().sort_query(query, request)
stmt_after_sort = str(stmt)
assert stmt_after_sort.count("JOIN") == 1

stmt = AddressAdmin().search_query(stmt, "test")
stmt_after_search = str(stmt)
assert stmt_after_search.count("JOIN") == 1


def test_expose_decorator(client: TestClient) -> None:
class UserAdmin(ModelView, model=User):
@expose("/profile/{pk}")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_views/test_view_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ class UserAdmin(ModelView, model=User):
User.status,
]
column_labels = {User.email: "Email"}
column_searchable_list = [User.name]
column_sortable_list = [User.id]
column_searchable_list = [User.name, User.status]
column_sortable_list = [User.id, User.name]
column_export_list = [User.name, User.status]
column_formatters = {
User.addresses_formattable: lambda m, a: [
Expand Down
57 changes: 56 additions & 1 deletion tests/test_views/test_view_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class UserAdmin(ModelView, model=User):
User.status,
]
column_labels = {User.email: "Email"}
column_searchable_list = [User.name]
column_searchable_list = [User.name, User.id]
column_sortable_list = [User.id]
column_export_list = [User.name, User.status]
column_formatters = {
Expand Down Expand Up @@ -854,3 +854,58 @@ def test_export_bad_type_is_404(client: TestClient) -> None:
def test_export_permission(client: TestClient) -> None:
response = client.get("/admin/movie/export/csv")
assert response.status_code == 403


def test_search_multi_fields_no_duplicate_joins(client: TestClient) -> None:
class AddressAdmin(ModelView, model=Address):
column_searchable_list = [User.id, User.name]

admin.add_view(AddressAdmin)

with session_maker() as session:
user = User(name="Alice")
address = Address(user=user)
session.add_all([user, address])
session.commit()

response = client.get("/admin/address/list?search=Alice")
assert response.status_code == 200


def test_sort_multi_fields_no_duplicate_joins(client: TestClient) -> None:
class AddressAdmin(ModelView, model=Address):
column_sortable_list = [Address.id, User.id, User.name]

admin.add_view(AddressAdmin)

with session_maker() as session:
user1 = User(name="Bob")
user2 = User(name="Alice")
address1 = Address(user=user1)
address2 = Address(user=user2)
session.add_all([user1, user2, address1, address2])
session.commit()

response = client.get("/admin/address/list?sortBy=user.name&sort=asc")
assert response.status_code == 200


def test_sort_and_search_together_no_duplicate_joins(client: TestClient) -> None:
class AddressAdmin(ModelView, model=Address):
column_searchable_list = [User.name, User.id]
column_sortable_list = [Address.id, User.id, User.name]

admin.add_view(AddressAdmin)

with session_maker() as session:
user1 = User(name="Alice")
user2 = User(name="Bob")
user3 = User(name="Charlie")
address1 = Address(user=user1)
address2 = Address(user=user2)
address3 = Address(user=user3)
session.add_all([user1, user2, user3, address1, address2, address3])
session.commit()

response = client.get("/admin/address/list?sortBy=user.name&sort=asc&search=o")
assert response.status_code == 200