77import warnings
88from collections .abc import Callable , Sequence
99from functools import partial
10- from typing import Any , Literal , Optional , Union , cast , overload
10+ from typing import Any , Literal , Optional , TypeAlias , Union , cast , overload
1111
1212import numpy as np
1313import pywt
4343 "symmetric" : "symmetric" ,
4444}
4545
46+ #: A hint for axes
47+ AxisHint : TypeAlias = int | Sequence [int ] | None
48+
4649
4750def _translate_boundary_strings (
4851 pywt_mode : BoundaryMode | None ,
@@ -452,7 +455,7 @@ def _coeff_tree_map(
452455def _preprocess_coeffs (
453456 coeffs : list [torch .Tensor ],
454457 ndim : Literal [1 ],
455- axes : int ,
458+ axes : AxisHint = ... ,
456459 add_channel_dim : bool = False ,
457460) -> tuple [list [torch .Tensor ], list [int ]]: ...
458461
@@ -462,7 +465,7 @@ def _preprocess_coeffs(
462465def _preprocess_coeffs (
463466 coeffs : WaveletCoeff2d ,
464467 ndim : Literal [2 ],
465- axes : tuple [ int , int ] ,
468+ axes : AxisHint = ... ,
466469 add_channel_dim : bool = False ,
467470) -> tuple [WaveletCoeff2d , list [int ]]: ...
468471
@@ -472,7 +475,7 @@ def _preprocess_coeffs(
472475def _preprocess_coeffs (
473476 coeffs : WaveletCoeffNd ,
474477 ndim : int ,
475- axes : tuple [ int , ...] ,
478+ axes : AxisHint = ...,
476479 add_channel_dim : bool = False ,
477480) -> tuple [WaveletCoeffNd , list [int ]]: ...
478481
@@ -482,7 +485,7 @@ def _preprocess_coeffs(
482485def _preprocess_coeffs (
483486 coeffs : list [torch .Tensor ],
484487 ndim : int ,
485- axes : Union [ tuple [ int , ...], int ] ,
488+ axes : AxisHint = ...,
486489 add_channel_dim : bool = False ,
487490) -> tuple [list [torch .Tensor ], list [int ]]: ...
488491
@@ -494,7 +497,7 @@ def _preprocess_coeffs(
494497 WaveletCoeffNd ,
495498 ],
496499 ndim : int ,
497- axes : Union [ tuple [ int , ...], int ] ,
500+ axes : AxisHint = None ,
498501 add_channel_dim : bool = False ,
499502) -> tuple [
500503 Union [
@@ -509,7 +512,7 @@ def _preprocess_coeffs(
509512 For each coefficient tensor in `coeffs` the transformed axes
510513 as specified by `axes` are moved to be the last.
511514 Adds a batch dim if a coefficient tensor has none.
512- If it has has multiple batch dimensions, they are folded into a single
515+ If it has multiple batch dimensions, they are folded into a single
513516 batch dimension.
514517
515518 Args:
@@ -520,7 +523,7 @@ def _preprocess_coeffs(
520523 :data:`ptwt.constants.WaveletCoeffNd` (Nd case).
521524 ndim (int): The number of axes :math:`N` on which the transformation
522525 was applied.
523- axes (int or tuple of ints) : Axes on which the transform was calculated.
526+ axes : Axes on which the transform was calculated.
524527 add_channel_dim (bool): If True, ensures that all returned coefficients
525528 have at least `:math:`N + 2` axes by potentially adding a new axis at dim 1.
526529 Defaults to False.
@@ -536,23 +539,18 @@ def _preprocess_coeffs(
536539 ValueError: If the input dtype is unsupported or `ndim` does not
537540 fit to the passed `axes` or `coeffs` dimensions.
538541 """
539- if isinstance ( axes , int ) :
540- axes = ( axes , )
542+ if ndim <= 0 :
543+ raise ValueError ( "Number of dimensions must be positive" )
541544
542545 torch_dtype = _check_if_tensor (coeffs [0 ]).dtype
543546 if not _is_dtype_supported (torch_dtype ):
544547 raise ValueError (f"Input dtype { torch_dtype } not supported" )
545548
546- if ndim <= 0 :
547- raise ValueError ("Number of dimensions must be positive" )
548-
549- if tuple (axes ) != tuple (range (- ndim , 0 )):
550- if len (axes ) != ndim :
551- raise ValueError (f"{ ndim } D transforms work with { ndim } axes." )
552- else :
553- # for all tensors in `coeffs`: swap the axes
554- swap_fn = partial (_swap_axes , axes = axes )
555- coeffs = _coeff_tree_map (coeffs , swap_fn )
549+ axes = _ensure_axes (axes = axes , dim = ndim )
550+ if axes != _get_default_axes (ndim ):
551+ # for all tensors in `coeffs`: swap the axes
552+ swap_fn = partial (_swap_axes , axes = axes )
553+ coeffs = _coeff_tree_map (coeffs , swap_fn )
556554
557555 # Fold axes for the wavelets
558556 ds = list (coeffs [0 ].shape )
@@ -578,7 +576,7 @@ def _postprocess_coeffs(
578576 coeffs : list [torch .Tensor ],
579577 ndim : Literal [1 ],
580578 ds : list [int ],
581- axes : int ,
579+ axes : AxisHint = ... ,
582580) -> list [torch .Tensor ]: ...
583581
584582
@@ -588,7 +586,7 @@ def _postprocess_coeffs(
588586 coeffs : WaveletCoeff2d ,
589587 ndim : Literal [2 ],
590588 ds : list [int ],
591- axes : tuple [ int , int ] ,
589+ axes : AxisHint = ... ,
592590) -> WaveletCoeff2d : ...
593591
594592
@@ -598,7 +596,7 @@ def _postprocess_coeffs(
598596 coeffs : WaveletCoeffNd ,
599597 ndim : int ,
600598 ds : list [int ],
601- axes : tuple [ int , ...] ,
599+ axes : AxisHint = ...,
602600) -> WaveletCoeffNd : ...
603601
604602
@@ -608,7 +606,7 @@ def _postprocess_coeffs(
608606 coeffs : list [torch .Tensor ],
609607 ndim : int ,
610608 ds : list [int ],
611- axes : Union [ tuple [ int , ...], int ] ,
609+ axes : AxisHint = ...,
612610) -> list [torch .Tensor ]: ...
613611
614612
@@ -620,7 +618,7 @@ def _postprocess_coeffs(
620618 ],
621619 ndim : int ,
622620 ds : list [int ],
623- axes : Union [ tuple [ int , ...], int ] ,
621+ axes : AxisHint = None ,
624622) -> Union [
625623 list [torch .Tensor ],
626624 WaveletCoeff2d ,
@@ -645,7 +643,7 @@ def _postprocess_coeffs(
645643 applied.
646644 ds (list of ints): The shape of the original first coefficient before
647645 preprocessing, i.e. of ``coeffs[0]``.
648- axes (int or tuple of ints) : Axes on which the transform was calculated.
646+ axes : Axes on which the transform was calculated.
649647
650648 Returns:
651649 The result of undoing the preprocessing operations on `coeffs`.
@@ -654,12 +652,11 @@ def _postprocess_coeffs(
654652 ValueError: If `ndim` does not fit to the passed `axes`
655653 or `coeffs` dimensions.
656654 """
657- if isinstance (axes , int ):
658- axes = (axes ,)
659-
660655 if ndim <= 0 :
661656 raise ValueError ("Number of dimensions must be positive" )
662657
658+ axes = _ensure_axes (axes = axes , dim = ndim )
659+
663660 # Fold axes for the wavelets
664661 if len (ds ) < ndim :
665662 raise ValueError (f"At least { ndim } input dimensions required." )
@@ -671,21 +668,19 @@ def _postprocess_coeffs(
671668 unfold_axes_fn = partial (_unfold_axes , ds = ds , keep_no = ndim )
672669 coeffs = _coeff_tree_map (coeffs , unfold_axes_fn )
673670
674- if tuple (axes ) != tuple (range (- ndim , 0 )):
675- if len (axes ) != ndim :
676- raise ValueError (f"{ ndim } D transforms work with { ndim } axes." )
677- else :
678- # for all tensors in `coeffs`: undo axes swapping
679- undo_swap_fn = partial (_undo_swap_axes , axes = axes )
680- coeffs = _coeff_tree_map (coeffs , undo_swap_fn )
671+ if axes != _get_default_axes (ndim ):
672+ # for all tensors in `coeffs`: undo axes swapping
673+ undo_swap_fn = partial (_undo_swap_axes , axes = axes )
674+ coeffs = _coeff_tree_map (coeffs , undo_swap_fn )
681675
682676 return coeffs
683677
684678
685679def _preprocess_tensor (
686680 data : torch .Tensor ,
687681 ndim : int ,
688- axes : Union [tuple [int , ...], int ],
682+ * ,
683+ axes : AxisHint = None ,
689684 add_channel_dim : bool = True ,
690685) -> tuple [torch .Tensor , list [int ]]:
691686 """Preprocess input tensor dimensions.
@@ -699,7 +694,7 @@ def _preprocess_tensor(
699694 data (torch.Tensor): An input tensor with at least `ndim` axes.
700695 ndim (int): The number of axes :math:`N` on which the transformation is
701696 applied.
702- axes (int or tuple of ints) : Axes on which the transform is calculated.
697+ axes : Axes on which the transform is calculated.
703698 add_channel_dim (bool): If True, ensures that the return has at
704699 least :math:`N + 2` axes by potentially adding a new axis at dim 1.
705700 Defaults to True.
@@ -720,7 +715,7 @@ def _preprocess_tensor(
720715
721716
722717def _postprocess_tensor (
723- data : torch .Tensor , ndim : int , ds : list [int ], axes : Union [ tuple [ int , ...], int ]
718+ data : torch .Tensor , ndim : int , ds : list [int ], axes : AxisHint | None = None
724719) -> torch .Tensor :
725720 """Postprocess input tensor dimensions.
726721
@@ -737,7 +732,7 @@ def _postprocess_tensor(
737732 applied.
738733 ds (list of ints): The shape of the original input tensor before
739734 preprocessing.
740- axes (int or tuple of ints) : Axes on which the transform was calculated.
735+ axes : Axes on which the transform was calculated.
741736
742737 Returns:
743738 The result of undoing the preprocessing operations on `data`.
@@ -817,3 +812,41 @@ def _get_padding_n(
817812 for i in range (1 , n + 1 ):
818813 rv .extend (_get_pad (data .shape [- i ], wavelet_length ))
819814 return tuple (rv )
815+
816+
817+ def _ensure_axes (* , axes : AxisHint = None , dim : int ) -> tuple [int , ...]:
818+ if axes is None :
819+ return _get_default_axes (dim )
820+ if isinstance (axes , int ):
821+ if dim != 1 :
822+ raise ValueError (f"tried passing single axis to { dim } D transform" )
823+ return (axes ,)
824+ if len (axes ) != dim :
825+ raise ValueError (f"tried passing { len (axes )} D axes { axes } to { dim } D transform" )
826+ _check_axes_argument (axes )
827+ return tuple (axes )
828+
829+
830+ def _get_default_axes (n : int ) -> tuple [int , ...]:
831+ """Get the default axes for a transformation.
832+
833+ Args:
834+ n: The number of dimensions of the convolution
835+
836+ Returns:
837+ A sequence of the default axes
838+
839+ Raises:
840+ ValueError: If the dimension is not a natural number
841+
842+ Examples:
843+ >>> _get_default_axes(1)
844+ (-1,)
845+ >>> _get_default_axes(2)
846+ (-2, -1)
847+ >>> _get_default_axes(3)
848+ (-3, -2, -1)
849+ """
850+ if n < 1 :
851+ raise ValueError (f"only natural number dimensions are allowed. given: { n } " )
852+ return tuple (range (- n , 0 ))
0 commit comments