Skip to content

Commit f580cb3

Browse files
authored
Refactor axis handling (#129)
1 parent 18db39a commit f580cb3

11 files changed

Lines changed: 152 additions & 141 deletions

src/ptwt/_util.py

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from collections.abc import Callable, Sequence
99
from functools import partial
10-
from typing import Any, Literal, Optional, Union, cast, overload
10+
from typing import Any, Literal, Optional, TypeAlias, Union, cast, overload
1111

1212
import numpy as np
1313
import pywt
@@ -43,6 +43,9 @@
4343
"symmetric": "symmetric",
4444
}
4545

46+
#: A hint for axes
47+
AxisHint: TypeAlias = int | Sequence[int] | None
48+
4649

4750
def _translate_boundary_strings(
4851
pywt_mode: BoundaryMode | None,
@@ -452,7 +455,7 @@ def _coeff_tree_map(
452455
def _preprocess_coeffs(
453456
coeffs: list[torch.Tensor],
454457
ndim: Literal[1],
455-
axes: int,
458+
axes: AxisHint = ...,
456459
add_channel_dim: bool = False,
457460
) -> tuple[list[torch.Tensor], list[int]]: ...
458461

@@ -462,7 +465,7 @@ def _preprocess_coeffs(
462465
def _preprocess_coeffs(
463466
coeffs: WaveletCoeff2d,
464467
ndim: Literal[2],
465-
axes: tuple[int, int],
468+
axes: AxisHint = ...,
466469
add_channel_dim: bool = False,
467470
) -> tuple[WaveletCoeff2d, list[int]]: ...
468471

@@ -472,7 +475,7 @@ def _preprocess_coeffs(
472475
def _preprocess_coeffs(
473476
coeffs: WaveletCoeffNd,
474477
ndim: int,
475-
axes: tuple[int, ...],
478+
axes: AxisHint = ...,
476479
add_channel_dim: bool = False,
477480
) -> tuple[WaveletCoeffNd, list[int]]: ...
478481

@@ -482,7 +485,7 @@ def _preprocess_coeffs(
482485
def _preprocess_coeffs(
483486
coeffs: list[torch.Tensor],
484487
ndim: int,
485-
axes: Union[tuple[int, ...], int],
488+
axes: AxisHint = ...,
486489
add_channel_dim: bool = False,
487490
) -> tuple[list[torch.Tensor], list[int]]: ...
488491

@@ -494,7 +497,7 @@ def _preprocess_coeffs(
494497
WaveletCoeffNd,
495498
],
496499
ndim: int,
497-
axes: Union[tuple[int, ...], int],
500+
axes: AxisHint = None,
498501
add_channel_dim: bool = False,
499502
) -> tuple[
500503
Union[
@@ -509,7 +512,7 @@ def _preprocess_coeffs(
509512
For each coefficient tensor in `coeffs` the transformed axes
510513
as specified by `axes` are moved to be the last.
511514
Adds a batch dim if a coefficient tensor has none.
512-
If it has has multiple batch dimensions, they are folded into a single
515+
If it has multiple batch dimensions, they are folded into a single
513516
batch dimension.
514517
515518
Args:
@@ -520,7 +523,7 @@ def _preprocess_coeffs(
520523
:data:`ptwt.constants.WaveletCoeffNd` (Nd case).
521524
ndim (int): The number of axes :math:`N` on which the transformation
522525
was applied.
523-
axes (int or tuple of ints): Axes on which the transform was calculated.
526+
axes : Axes on which the transform was calculated.
524527
add_channel_dim (bool): If True, ensures that all returned coefficients
525528
have at least `:math:`N + 2` axes by potentially adding a new axis at dim 1.
526529
Defaults to False.
@@ -536,23 +539,18 @@ def _preprocess_coeffs(
536539
ValueError: If the input dtype is unsupported or `ndim` does not
537540
fit to the passed `axes` or `coeffs` dimensions.
538541
"""
539-
if isinstance(axes, int):
540-
axes = (axes,)
542+
if ndim <= 0:
543+
raise ValueError("Number of dimensions must be positive")
541544

542545
torch_dtype = _check_if_tensor(coeffs[0]).dtype
543546
if not _is_dtype_supported(torch_dtype):
544547
raise ValueError(f"Input dtype {torch_dtype} not supported")
545548

546-
if ndim <= 0:
547-
raise ValueError("Number of dimensions must be positive")
548-
549-
if tuple(axes) != tuple(range(-ndim, 0)):
550-
if len(axes) != ndim:
551-
raise ValueError(f"{ndim}D transforms work with {ndim} axes.")
552-
else:
553-
# for all tensors in `coeffs`: swap the axes
554-
swap_fn = partial(_swap_axes, axes=axes)
555-
coeffs = _coeff_tree_map(coeffs, swap_fn)
549+
axes = _ensure_axes(axes=axes, dim=ndim)
550+
if axes != _get_default_axes(ndim):
551+
# for all tensors in `coeffs`: swap the axes
552+
swap_fn = partial(_swap_axes, axes=axes)
553+
coeffs = _coeff_tree_map(coeffs, swap_fn)
556554

557555
# Fold axes for the wavelets
558556
ds = list(coeffs[0].shape)
@@ -578,7 +576,7 @@ def _postprocess_coeffs(
578576
coeffs: list[torch.Tensor],
579577
ndim: Literal[1],
580578
ds: list[int],
581-
axes: int,
579+
axes: AxisHint = ...,
582580
) -> list[torch.Tensor]: ...
583581

584582

@@ -588,7 +586,7 @@ def _postprocess_coeffs(
588586
coeffs: WaveletCoeff2d,
589587
ndim: Literal[2],
590588
ds: list[int],
591-
axes: tuple[int, int],
589+
axes: AxisHint = ...,
592590
) -> WaveletCoeff2d: ...
593591

594592

@@ -598,7 +596,7 @@ def _postprocess_coeffs(
598596
coeffs: WaveletCoeffNd,
599597
ndim: int,
600598
ds: list[int],
601-
axes: tuple[int, ...],
599+
axes: AxisHint = ...,
602600
) -> WaveletCoeffNd: ...
603601

604602

@@ -608,7 +606,7 @@ def _postprocess_coeffs(
608606
coeffs: list[torch.Tensor],
609607
ndim: int,
610608
ds: list[int],
611-
axes: Union[tuple[int, ...], int],
609+
axes: AxisHint = ...,
612610
) -> list[torch.Tensor]: ...
613611

614612

@@ -620,7 +618,7 @@ def _postprocess_coeffs(
620618
],
621619
ndim: int,
622620
ds: list[int],
623-
axes: Union[tuple[int, ...], int],
621+
axes: AxisHint = None,
624622
) -> Union[
625623
list[torch.Tensor],
626624
WaveletCoeff2d,
@@ -645,7 +643,7 @@ def _postprocess_coeffs(
645643
applied.
646644
ds (list of ints): The shape of the original first coefficient before
647645
preprocessing, i.e. of ``coeffs[0]``.
648-
axes (int or tuple of ints): Axes on which the transform was calculated.
646+
axes : Axes on which the transform was calculated.
649647
650648
Returns:
651649
The result of undoing the preprocessing operations on `coeffs`.
@@ -654,12 +652,11 @@ def _postprocess_coeffs(
654652
ValueError: If `ndim` does not fit to the passed `axes`
655653
or `coeffs` dimensions.
656654
"""
657-
if isinstance(axes, int):
658-
axes = (axes,)
659-
660655
if ndim <= 0:
661656
raise ValueError("Number of dimensions must be positive")
662657

658+
axes = _ensure_axes(axes=axes, dim=ndim)
659+
663660
# Fold axes for the wavelets
664661
if len(ds) < ndim:
665662
raise ValueError(f"At least {ndim} input dimensions required.")
@@ -671,21 +668,19 @@ def _postprocess_coeffs(
671668
unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=ndim)
672669
coeffs = _coeff_tree_map(coeffs, unfold_axes_fn)
673670

674-
if tuple(axes) != tuple(range(-ndim, 0)):
675-
if len(axes) != ndim:
676-
raise ValueError(f"{ndim}D transforms work with {ndim} axes.")
677-
else:
678-
# for all tensors in `coeffs`: undo axes swapping
679-
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
680-
coeffs = _coeff_tree_map(coeffs, undo_swap_fn)
671+
if axes != _get_default_axes(ndim):
672+
# for all tensors in `coeffs`: undo axes swapping
673+
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
674+
coeffs = _coeff_tree_map(coeffs, undo_swap_fn)
681675

682676
return coeffs
683677

684678

685679
def _preprocess_tensor(
686680
data: torch.Tensor,
687681
ndim: int,
688-
axes: Union[tuple[int, ...], int],
682+
*,
683+
axes: AxisHint = None,
689684
add_channel_dim: bool = True,
690685
) -> tuple[torch.Tensor, list[int]]:
691686
"""Preprocess input tensor dimensions.
@@ -699,7 +694,7 @@ def _preprocess_tensor(
699694
data (torch.Tensor): An input tensor with at least `ndim` axes.
700695
ndim (int): The number of axes :math:`N` on which the transformation is
701696
applied.
702-
axes (int or tuple of ints): Axes on which the transform is calculated.
697+
axes : Axes on which the transform is calculated.
703698
add_channel_dim (bool): If True, ensures that the return has at
704699
least :math:`N + 2` axes by potentially adding a new axis at dim 1.
705700
Defaults to True.
@@ -720,7 +715,7 @@ def _preprocess_tensor(
720715

721716

722717
def _postprocess_tensor(
723-
data: torch.Tensor, ndim: int, ds: list[int], axes: Union[tuple[int, ...], int]
718+
data: torch.Tensor, ndim: int, ds: list[int], axes: AxisHint | None = None
724719
) -> torch.Tensor:
725720
"""Postprocess input tensor dimensions.
726721
@@ -737,7 +732,7 @@ def _postprocess_tensor(
737732
applied.
738733
ds (list of ints): The shape of the original input tensor before
739734
preprocessing.
740-
axes (int or tuple of ints): Axes on which the transform was calculated.
735+
axes : Axes on which the transform was calculated.
741736
742737
Returns:
743738
The result of undoing the preprocessing operations on `data`.
@@ -817,3 +812,41 @@ def _get_padding_n(
817812
for i in range(1, n + 1):
818813
rv.extend(_get_pad(data.shape[-i], wavelet_length))
819814
return tuple(rv)
815+
816+
817+
def _ensure_axes(*, axes: AxisHint = None, dim: int) -> tuple[int, ...]:
818+
if axes is None:
819+
return _get_default_axes(dim)
820+
if isinstance(axes, int):
821+
if dim != 1:
822+
raise ValueError(f"tried passing single axis to {dim}D transform")
823+
return (axes,)
824+
if len(axes) != dim:
825+
raise ValueError(f"tried passing {len(axes)}D axes {axes} to {dim}D transform")
826+
_check_axes_argument(axes)
827+
return tuple(axes)
828+
829+
830+
def _get_default_axes(n: int) -> tuple[int, ...]:
831+
"""Get the default axes for a transformation.
832+
833+
Args:
834+
n: The number of dimensions of the convolution
835+
836+
Returns:
837+
A sequence of the default axes
838+
839+
Raises:
840+
ValueError: If the dimension is not a natural number
841+
842+
Examples:
843+
>>> _get_default_axes(1)
844+
(-1,)
845+
>>> _get_default_axes(2)
846+
(-2, -1)
847+
>>> _get_default_axes(3)
848+
(-3, -2, -1)
849+
"""
850+
if n < 1:
851+
raise ValueError(f"only natural number dimensions are allowed. given: {n}")
852+
return tuple(range(-n, 0))

src/ptwt/conv_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313

1414
from ._util import (
15+
AxisHint,
1516
_adjust_padding_at_reconstruction,
1617
_check_same_device_dtype,
1718
_get_filter_tensors,
@@ -146,7 +147,7 @@ def wavedec(
146147

147148

148149
def waverec(
149-
coeffs: WaveletCoeff1d, wavelet: Union[Wavelet, str], axis: int = -1
150+
coeffs: WaveletCoeff1d, wavelet: Union[Wavelet, str], *, axis: AxisHint = None
150151
) -> torch.Tensor:
151152
"""Reconstruct a 1d signal from wavelet coefficients.
152153
@@ -157,8 +158,7 @@ def waverec(
157158
the name of a pywt wavelet.
158159
Refer to the output from ``pywt.wavelist(kind='discrete')``
159160
for possible choices.
160-
axis (int): Compute the transform over this axis of the `data` tensor.
161-
Defaults to -1.
161+
axis : Compute the transform over this axis. If none, the last is used.
162162
163163
Returns:
164164
The reconstructed signal tensor.

src/ptwt/conv_transform_2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313

1414
from ._util import (
15+
AxisHint,
1516
_adjust_padding_at_reconstruction,
1617
_check_same_device_dtype,
1718
_get_filter_tensors,
@@ -185,7 +186,8 @@ def wavedec2(
185186
def waverec2(
186187
coeffs: WaveletCoeff2d,
187188
wavelet: Union[Wavelet, str],
188-
axes: tuple[int, int] = (-2, -1),
189+
*,
190+
axes: AxisHint = None,
189191
) -> torch.Tensor:
190192
"""Reconstruct a 2d signal from wavelet coefficients.
191193
@@ -196,8 +198,7 @@ def waverec2(
196198
the name of a pywt wavelet.
197199
Refer to the output from ``pywt.wavelist(kind='discrete')``
198200
for possible choices.
199-
axes (tuple[int, int]): Compute the transform over these axes of the `data`
200-
tensor. Defaults to (-2, -1).
201+
axes : Compute the transform over these axes. If none, the last 2 are used.
201202
202203
Returns:
203204
The reconstructed signal tensor.

src/ptwt/conv_transform_3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from ._util import (
14+
AxisHint,
1415
_adjust_padding_at_reconstruction,
1516
_as_wavelet,
1617
_check_same_device_dtype,
@@ -179,7 +180,8 @@ def wavedec3(
179180
def waverec3(
180181
coeffs: WaveletCoeffNd,
181182
wavelet: Union[Wavelet, str],
182-
axes: tuple[int, int, int] = (-3, -2, -1),
183+
*,
184+
axes: AxisHint = None,
183185
) -> torch.Tensor:
184186
"""Reconstruct a 3d signal from wavelet coefficients.
185187
@@ -190,8 +192,8 @@ def waverec3(
190192
the name of a pywt wavelet.
191193
Refer to the output from ``pywt.wavelist(kind='discrete')``
192194
for possible choices.
193-
axes (tuple[int, int, int]): Compute the transform over these axes of the `data`
194-
tensor. Defaults to (-3, -2, -1).
195+
axes : Compute the transform over these axes. If none, the last 3 are used.
196+
195197
196198
Returns:
197199
The reconstructed signal tensor.

0 commit comments

Comments
 (0)