Skip to content

Commit 6c3b62c

Browse files
authored
Refactor convolutional transforms (#130)
1 parent f580cb3 commit 6c3b62c

5 files changed

Lines changed: 100 additions & 76 deletions

File tree

src/ptwt/_util.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,3 +850,87 @@ def _get_default_axes(n: int) -> tuple[int, ...]:
850850
if n < 1:
851851
raise ValueError(f"only natural number dimensions are allowed. given: {n}")
852852
return tuple(range(-n, 0))
853+
854+
855+
def _preprocess_deconstruction(
856+
data: torch.Tensor,
857+
wavelet: Union[Wavelet, str],
858+
*,
859+
ndim: int,
860+
axes: AxisHint = None,
861+
) -> tuple[torch.Tensor, list[int], torch.Tensor, torch.Tensor, torch.Tensor]:
862+
data, ds = _preprocess_tensor(data, ndim=ndim, axes=axes)
863+
dec_lo, dec_hi, _, _ = _get_filter_tensors(
864+
wavelet, flip=True, device=data.device, dtype=data.dtype
865+
)
866+
dec_filt = _construct_nd_filt(dec_lo, dec_hi, n=ndim)
867+
return data, ds, dec_lo, dec_hi, dec_filt
868+
869+
870+
def _construct_nd_filt(lo: torch.Tensor, hi: torch.Tensor, n: int) -> torch.Tensor:
871+
if n == 1:
872+
return _construct_1d_filt(lo, hi)
873+
elif n == 2:
874+
return _construct_2d_filt(lo, hi)
875+
elif n == 3:
876+
return _construct_3d_filt(lo, hi)
877+
else:
878+
raise NotImplementedError()
879+
880+
881+
def _construct_1d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
882+
"""Construct one-dimensional filters."""
883+
return torch.stack([lo, hi], 0)
884+
885+
886+
def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
887+
"""Construct two-dimensional filters using outer products.
888+
889+
Args:
890+
lo (torch.Tensor): Low-pass input filter.
891+
hi (torch.Tensor): High-pass input filter
892+
893+
Returns:
894+
Stacked 2d-filters of dimension
895+
896+
[2^2, 1, height, width].
897+
898+
The four filters are ordered ll, lh, hl, hh.
899+
900+
"""
901+
ll = _outer(lo, lo)
902+
lh = _outer(hi, lo)
903+
hl = _outer(lo, hi)
904+
hh = _outer(hi, hi)
905+
filt = torch.stack([ll, lh, hl, hh], 0)
906+
filt = filt.unsqueeze(1)
907+
return filt
908+
909+
910+
def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
911+
"""Construct three-dimensional filters using outer products.
912+
913+
Args:
914+
lo (torch.Tensor): Low-pass input filter.
915+
hi (torch.Tensor): High-pass input filter
916+
917+
Returns:
918+
Stacked 3d filters of dimension::
919+
920+
[2^3, 1, length, height, width].
921+
922+
The four filters are ordered ll, lh, hl, hh.
923+
"""
924+
dim_size = lo.shape[-1]
925+
size = [dim_size] * 3
926+
lll = _outer(lo, _outer(lo, lo)).reshape(size)
927+
llh = _outer(lo, _outer(lo, hi)).reshape(size)
928+
lhl = _outer(lo, _outer(hi, lo)).reshape(size)
929+
lhh = _outer(lo, _outer(hi, hi)).reshape(size)
930+
hll = _outer(hi, _outer(lo, lo)).reshape(size)
931+
hlh = _outer(hi, _outer(lo, hi)).reshape(size)
932+
hhl = _outer(hi, _outer(hi, lo)).reshape(size)
933+
hhh = _outer(hi, _outer(hi, hi)).reshape(size)
934+
filt = torch.stack([lll, llh, lhl, lhh, hll, hlh, hhl, hhh], 0)
935+
filt = filt.unsqueeze(1)
936+
return filt

src/ptwt/conv_transform.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
_postprocess_coeffs,
2323
_postprocess_tensor,
2424
_preprocess_coeffs,
25-
_preprocess_tensor,
25+
_preprocess_deconstruction,
2626
_translate_boundary_strings,
2727
)
2828
from .constants import BoundaryMode, Wavelet, WaveletCoeff1d
@@ -91,7 +91,7 @@ def wavedec(
9191
9292
Args:
9393
data (torch.Tensor): The input time series to transform.
94-
By default the last axis is transformed.
94+
By default, the last axis is transformed.
9595
wavelet (Wavelet or str): A pywt wavelet compatible object or
9696
the name of a pywt wavelet.
9797
Please consider the output from ``pywt.wavelist(kind='discrete')``
@@ -122,22 +122,19 @@ def wavedec(
122122
>>> # compute the forward fwt coefficients
123123
>>> ptwt.wavedec(data, 'haar', mode='zero', level=2)
124124
"""
125-
data, ds = _preprocess_tensor(data, ndim=1, axes=axis)
126-
127-
dec_lo, dec_hi, _, _ = _get_filter_tensors(
128-
wavelet, flip=True, device=data.device, dtype=data.dtype
125+
data, ds, dec_lo, dec_hi, dec_filt = _preprocess_deconstruction(
126+
data, wavelet, axes=axis, ndim=1
129127
)
130-
filt_len = dec_lo.shape[-1]
131-
filt = torch.stack([dec_lo, dec_hi], 0)
132128

133129
if level is None:
130+
filt_len = dec_lo.shape[-1]
134131
level = pywt.dwt_max_level(data.shape[-1], filt_len)
135132

136133
result_list = []
137134
res_lo = data
138135
for _ in range(level):
139136
res_lo = _fwt_pad(res_lo, wavelet, mode=mode)
140-
res = torch.nn.functional.conv1d(res_lo, filt, stride=2)
137+
res = torch.nn.functional.conv1d(res_lo, dec_filt, stride=2)
141138
res_lo, res_hi = torch.split(res, 1, 1)
142139
result_list.append(res_hi.squeeze(1))
143140
result_list.append(res_lo.squeeze(1))

src/ptwt/conv_transform_2.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,22 @@
1515
AxisHint,
1616
_adjust_padding_at_reconstruction,
1717
_check_same_device_dtype,
18+
_construct_2d_filt,
1819
_get_filter_tensors,
1920
_get_padding_n,
2021
_group_for_symmetric,
21-
_outer,
2222
_pad_symmetric,
2323
_postprocess_coeffs,
2424
_postprocess_tensor,
2525
_preprocess_coeffs,
26-
_preprocess_tensor,
26+
_preprocess_deconstruction,
2727
_translate_boundary_strings,
2828
)
2929
from .constants import BoundaryMode, Wavelet, WaveletCoeff2d, WaveletDetailTuple2d
3030

3131
__all__ = ["wavedec2", "waverec2"]
3232

3333

34-
def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
35-
"""Construct two-dimensional filters using outer products.
36-
37-
Args:
38-
lo (torch.Tensor): Low-pass input filter.
39-
hi (torch.Tensor): High-pass input filter
40-
41-
Returns:
42-
Stacked 2d-filters of dimension
43-
44-
[2^2, 1, height, width].
45-
46-
The four filters are ordered ll, lh, hl, hh.
47-
48-
"""
49-
ll = _outer(lo, lo)
50-
lh = _outer(hi, lo)
51-
hl = _outer(lo, hi)
52-
hh = _outer(hi, hi)
53-
filt = torch.stack([ll, lh, hl, hh], 0)
54-
filt = filt.unsqueeze(1)
55-
return filt
56-
57-
5834
def _fwt_pad2(
5935
data: torch.Tensor,
6036
wavelet: Union[Wavelet, str],
@@ -154,11 +130,9 @@ def wavedec2(
154130
>>> coefficients = ptwt.wavedec2(data, "haar", level=2, mode="zero")
155131
156132
"""
157-
data, ds = _preprocess_tensor(data, ndim=2, axes=axes)
158-
dec_lo, dec_hi, _, _ = _get_filter_tensors(
159-
wavelet, flip=True, device=data.device, dtype=data.dtype
133+
data, ds, dec_lo, dec_hi, dec_filt = _preprocess_deconstruction(
134+
data, wavelet, axes=axes, ndim=2
160135
)
161-
dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi)
162136

163137
if level is None:
164138
level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet)

src/ptwt/conv_transform_3.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,22 @@
1515
_adjust_padding_at_reconstruction,
1616
_as_wavelet,
1717
_check_same_device_dtype,
18+
_construct_3d_filt,
1819
_get_filter_tensors,
1920
_get_padding_n,
2021
_group_for_symmetric,
21-
_outer,
2222
_pad_symmetric,
2323
_postprocess_coeffs,
2424
_postprocess_tensor,
2525
_preprocess_coeffs,
26-
_preprocess_tensor,
26+
_preprocess_deconstruction,
2727
_translate_boundary_strings,
2828
)
2929
from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict
3030

3131
__all__ = ["wavedec3", "waverec3"]
3232

3333

34-
def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
35-
"""Construct three-dimensional filters using outer products.
36-
37-
Args:
38-
lo (torch.Tensor): Low-pass input filter.
39-
hi (torch.Tensor): High-pass input filter
40-
41-
Returns:
42-
Stacked 3d filters of dimension::
43-
44-
[2^3, 1, length, height, width].
45-
46-
The four filters are ordered ll, lh, hl, hh.
47-
"""
48-
dim_size = lo.shape[-1]
49-
size = [dim_size] * 3
50-
lll = _outer(lo, _outer(lo, lo)).reshape(size)
51-
llh = _outer(lo, _outer(lo, hi)).reshape(size)
52-
lhl = _outer(lo, _outer(hi, lo)).reshape(size)
53-
lhh = _outer(lo, _outer(hi, hi)).reshape(size)
54-
hll = _outer(hi, _outer(lo, lo)).reshape(size)
55-
hlh = _outer(hi, _outer(lo, hi)).reshape(size)
56-
hhl = _outer(hi, _outer(hi, lo)).reshape(size)
57-
hhh = _outer(hi, _outer(hi, hi)).reshape(size)
58-
filt = torch.stack([lll, llh, lhl, lhh, hll, hlh, hhl, hhh], 0)
59-
filt = filt.unsqueeze(1)
60-
return filt
61-
62-
6334
def _fwt_pad3(
6435
data: torch.Tensor,
6536
wavelet: Union[Wavelet, str],
@@ -137,13 +108,10 @@ def wavedec3(
137108
>>> data = torch.randn(5, 16, 16, 16)
138109
>>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect")
139110
"""
140-
data, ds = _preprocess_tensor(data, ndim=3, axes=axes)
141-
142111
wavelet = _as_wavelet(wavelet)
143-
dec_lo, dec_hi, _, _ = _get_filter_tensors(
144-
wavelet, flip=True, device=data.device, dtype=data.dtype
112+
data, ds, dec_lo, dec_hi, dec_filt = _preprocess_deconstruction(
113+
data, wavelet, axes=axes, ndim=3
145114
)
146-
dec_filt = _construct_3d_filt(lo=dec_lo, hi=dec_hi)
147115

148116
if level is None:
149117
level = pywt.dwtn_max_level(

src/ptwt/matmul_transform_2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AxisHint,
1616
_as_wavelet,
1717
_check_same_device_dtype,
18+
_construct_2d_filt,
1819
_deprecated_alias,
1920
_ensure_axes,
2021
_get_filter_tensors,
@@ -32,7 +33,7 @@
3233
WaveletCoeff2d,
3334
WaveletDetailTuple2d,
3435
)
35-
from .conv_transform_2 import _construct_2d_filt, _fwt_pad2
36+
from .conv_transform_2 import _fwt_pad2
3637
from .matmul_transform import (
3738
BaseMatrixWaveDec,
3839
construct_boundary_a,

0 commit comments

Comments
 (0)