Skip to content

Commit a5dc691

Browse files
committed
fix: use numbers.Real for checking type
np.float32 are not float for isinstance, let's use a more generic checking.
1 parent 2040f73 commit a5dc691

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

src/modopt/base/types.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
"""
88

9+
import numbers
910
import numpy as np
1011
from modopt.interface.errors import warn
1112

@@ -68,14 +69,14 @@ def check_float(input_obj):
6869
check_int : related function
6970
7071
"""
71-
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
72+
if not isinstance(input_obj, (int, numbers.Real, list, tuple, np.ndarray)):
7273
raise TypeError("Invalid input type.")
7374
if isinstance(input_obj, int):
7475
input_obj = float(input_obj)
7576
elif isinstance(input_obj, (list, tuple)):
7677
input_obj = np.array(input_obj, dtype=float)
7778
elif isinstance(input_obj, np.ndarray) and (
78-
not np.issubdtype(input_obj.dtype, np.floating)
79+
not np.issubdtype(input_obj.dtype, numbers.Real)
7980
):
8081
input_obj = input_obj.astype(float)
8182

@@ -117,9 +118,9 @@ def check_int(input_obj):
117118
check_float : related function
118119
119120
"""
120-
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
121+
if not isinstance(input_obj, (int, numbers.Real, list, tuple, np.ndarray)):
121122
raise TypeError("Invalid input type.")
122-
if isinstance(input_obj, float):
123+
if isinstance(input_obj, numbers.Real):
123124
input_obj = int(input_obj)
124125
elif isinstance(input_obj, (list, tuple)):
125126
input_obj = np.array(input_obj, dtype=int)

src/modopt/opt/algorithms/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from inspect import getmro
44

55
import numpy as np
6+
import numbers
67
from tqdm.auto import tqdm
78

89
from modopt.base import backend
@@ -192,7 +193,7 @@ def _check_param(self, param_val):
192193
For invalid input type
193194
194195
"""
195-
if not isinstance(param_val, float):
196+
if not isinstance(param_val, numbers.Real):
196197
raise TypeError("Algorithm parameter must be a float value.")
197198

198199
def _check_param_update(self, param_update):

src/modopt/signal/positivity.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
"""
99

10+
import numbers
1011
import numpy as np
1112

1213

@@ -93,13 +94,13 @@ def positive(input_data, ragged=False):
9394
[1, 2, 3]])
9495
9596
"""
96-
if not isinstance(input_data, (int, float, list, tuple, np.ndarray)):
97+
if not isinstance(input_data, (int, numbers.Real, list, tuple, np.ndarray)):
9798
raise TypeError(
9899
"Invalid data type, input must be `int`, `float`, `list`, "
99100
+ "`tuple` or `np.ndarray`.",
100101
)
101102

102-
if isinstance(input_data, (int, float)):
103+
if isinstance(input_data, (int, numbers.Real)):
103104
return pos_thresh(input_data)
104105

105106
if ragged:

0 commit comments

Comments
 (0)