@@ -850,3 +850,87 @@ def _get_default_axes(n: int) -> tuple[int, ...]:
850850 if n < 1 :
851851 raise ValueError (f"only natural number dimensions are allowed. given: { n } " )
852852 return tuple (range (- n , 0 ))
853+
854+
855+ def _preprocess_deconstruction (
856+ data : torch .Tensor ,
857+ wavelet : Union [Wavelet , str ],
858+ * ,
859+ ndim : int ,
860+ axes : AxisHint = None ,
861+ ) -> tuple [torch .Tensor , list [int ], torch .Tensor , torch .Tensor , torch .Tensor ]:
862+ data , ds = _preprocess_tensor (data , ndim = ndim , axes = axes )
863+ dec_lo , dec_hi , _ , _ = _get_filter_tensors (
864+ wavelet , flip = True , device = data .device , dtype = data .dtype
865+ )
866+ dec_filt = _construct_nd_filt (dec_lo , dec_hi , n = ndim )
867+ return data , ds , dec_lo , dec_hi , dec_filt
868+
869+
870+ def _construct_nd_filt (lo : torch .Tensor , hi : torch .Tensor , n : int ) -> torch .Tensor :
871+ if n == 1 :
872+ return _construct_1d_filt (lo , hi )
873+ elif n == 2 :
874+ return _construct_2d_filt (lo , hi )
875+ elif n == 3 :
876+ return _construct_3d_filt (lo , hi )
877+ else :
878+ raise NotImplementedError ()
879+
880+
881+ def _construct_1d_filt (lo : torch .Tensor , hi : torch .Tensor ) -> torch .Tensor :
882+ """Construct one-dimensional filters."""
883+ return torch .stack ([lo , hi ], 0 )
884+
885+
886+ def _construct_2d_filt (lo : torch .Tensor , hi : torch .Tensor ) -> torch .Tensor :
887+ """Construct two-dimensional filters using outer products.
888+
889+ Args:
890+ lo (torch.Tensor): Low-pass input filter.
891+ hi (torch.Tensor): High-pass input filter
892+
893+ Returns:
894+ Stacked 2d-filters of dimension
895+
896+ [2^2, 1, height, width].
897+
898+ The four filters are ordered ll, lh, hl, hh.
899+
900+ """
901+ ll = _outer (lo , lo )
902+ lh = _outer (hi , lo )
903+ hl = _outer (lo , hi )
904+ hh = _outer (hi , hi )
905+ filt = torch .stack ([ll , lh , hl , hh ], 0 )
906+ filt = filt .unsqueeze (1 )
907+ return filt
908+
909+
910+ def _construct_3d_filt (lo : torch .Tensor , hi : torch .Tensor ) -> torch .Tensor :
911+ """Construct three-dimensional filters using outer products.
912+
913+ Args:
914+ lo (torch.Tensor): Low-pass input filter.
915+ hi (torch.Tensor): High-pass input filter
916+
917+ Returns:
918+ Stacked 3d filters of dimension::
919+
920+ [2^3, 1, length, height, width].
921+
922+ The four filters are ordered ll, lh, hl, hh.
923+ """
924+ dim_size = lo .shape [- 1 ]
925+ size = [dim_size ] * 3
926+ lll = _outer (lo , _outer (lo , lo )).reshape (size )
927+ llh = _outer (lo , _outer (lo , hi )).reshape (size )
928+ lhl = _outer (lo , _outer (hi , lo )).reshape (size )
929+ lhh = _outer (lo , _outer (hi , hi )).reshape (size )
930+ hll = _outer (hi , _outer (lo , lo )).reshape (size )
931+ hlh = _outer (hi , _outer (lo , hi )).reshape (size )
932+ hhl = _outer (hi , _outer (hi , lo )).reshape (size )
933+ hhh = _outer (hi , _outer (hi , hi )).reshape (size )
934+ filt = torch .stack ([lll , llh , lhl , lhh , hll , hlh , hhl , hhh ], 0 )
935+ filt = filt .unsqueeze (1 )
936+ return filt
0 commit comments