Skip to content

Commit eb1c071

Browse files
authored
Improve typing for boundaries (#126)
1 parent c81ba5a commit eb1c071

3 files changed

Lines changed: 23 additions & 16 deletions

File tree

src/ptwt/_util.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,27 @@
2626
WaveletDetailTuple2d,
2727
)
2828

29-
30-
def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
29+
#: All the PyTorch boundary modes for :func:`torch.nn.functional.pad`
30+
PyTorchBoundaryMode = Literal["replicate", "constant", "reflect", "circular"]
31+
32+
#: All the PyTorch boundary modes for :func:`torch.nn.functional.pad`
33+
#: plus `symmetric` for the custom ptwt boundary
34+
ExtendedPyTorchBoundaryMode = PyTorchBoundaryMode | Literal["symmetric"]
35+
36+
translation_dict: dict[BoundaryMode, ExtendedPyTorchBoundaryMode] = {
37+
"constant": "replicate",
38+
"zero": "constant",
39+
"reflect": "reflect",
40+
"periodic": "circular",
41+
# pytorch does not support symmetric mode,
42+
# we have our own implementation.
43+
"symmetric": "symmetric",
44+
}
45+
46+
47+
def _translate_boundary_strings(
48+
pywt_mode: BoundaryMode | None,
49+
) -> ExtendedPyTorchBoundaryMode:
3150
"""Translate pywt mode strings to PyTorch mode strings.
3251
3352
We support ``constant``, ``zero``, ``reflect``,
@@ -38,15 +57,8 @@ def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
3857
Raises:
3958
ValueError: If the padding mode is not supported.
4059
"""
41-
translation_dict = {
42-
"constant": "replicate",
43-
"zero": "constant",
44-
"reflect": "reflect",
45-
"periodic": "circular",
46-
# pytorch does not support symmetric mode,
47-
# we have our own implementation.
48-
"symmetric": "symmetric",
49-
}
60+
if pywt_mode is None:
61+
return translation_dict["reflect"]
5062
if pywt_mode in translation_dict:
5163
return translation_dict[pywt_mode]
5264
else:

src/ptwt/conv_transform.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ def _fwt_pad(
5555
Returns:
5656
A PyTorch tensor with the padded input data
5757
"""
58-
# convert pywt to pytorch convention.
59-
if mode is None:
60-
mode = "reflect"
6158
pytorch_mode = _translate_boundary_strings(mode)
6259

6360
if padding is None:

src/ptwt/conv_transform_2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,6 @@ def _fwt_pad2(
8484
The padded output tensor.
8585
8686
"""
87-
if mode is None:
88-
mode = "reflect"
8987
pytorch_mode = _translate_boundary_strings(mode)
9088

9189
if padding is None:

0 commit comments

Comments
 (0)