2626 WaveletDetailTuple2d ,
2727)
2828
29-
30- def _translate_boundary_strings (pywt_mode : BoundaryMode ) -> str :
29+ #: All the PyTorch boundary modes for :func:`torch.nn.functional.pad`
30+ PyTorchBoundaryMode = Literal ["replicate" , "constant" , "reflect" , "circular" ]
31+
32+ #: All the PyTorch boundary modes for :func:`torch.nn.functional.pad`
33+ #: plus `symmetric` for the custom ptwt boundary
34+ ExtendedPyTorchBoundaryMode = PyTorchBoundaryMode | Literal ["symmetric" ]
35+
36+ translation_dict : dict [BoundaryMode , ExtendedPyTorchBoundaryMode ] = {
37+ "constant" : "replicate" ,
38+ "zero" : "constant" ,
39+ "reflect" : "reflect" ,
40+ "periodic" : "circular" ,
41+ # pytorch does not support symmetric mode,
42+ # we have our own implementation.
43+ "symmetric" : "symmetric" ,
44+ }
45+
46+
47+ def _translate_boundary_strings (
48+ pywt_mode : BoundaryMode | None ,
49+ ) -> ExtendedPyTorchBoundaryMode :
3150 """Translate pywt mode strings to PyTorch mode strings.
3251
3352 We support ``constant``, ``zero``, ``reflect``,
@@ -38,15 +57,8 @@ def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
3857 Raises:
3958 ValueError: If the padding mode is not supported.
4059 """
41- translation_dict = {
42- "constant" : "replicate" ,
43- "zero" : "constant" ,
44- "reflect" : "reflect" ,
45- "periodic" : "circular" ,
46- # pytorch does not support symmetric mode,
47- # we have our own implementation.
48- "symmetric" : "symmetric" ,
49- }
60+ if pywt_mode is None :
61+ return translation_dict ["reflect" ]
5062 if pywt_mode in translation_dict :
5163 return translation_dict [pywt_mode ]
5264 else :
0 commit comments