Skip to content

Commit 5ab175f

Browse files
committed
np.ndarray output types
1 parent 3699305 commit 5ab175f

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

attune/_discrete_tune.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
def __repr__(self):
3636
return f"DiscreteTune({repr(self.ranges)}, {repr(self.default)})"
3737

38-
def __call__(self, ind_value, *, ind_units=None):
38+
def __call__(self, ind_value, *, ind_units=None) -> np.ndarray:
3939
"""Evaluate the DiscreteTune at specific independent value(s).
4040
4141
Paramters
@@ -54,20 +54,15 @@ def __call__(self, ind_value, *, ind_units=None):
5454
"""
5555
if ind_units is not None and self._ind_units is not None:
5656
ind_value = wt.units.convert(ind_value, ind_units, self._ind_units)
57-
if isinstance(ind_value, np.ndarray):
58-
out = np.full(
59-
ind_value.shape,
60-
self.default,
61-
dtype=f"U{max([len(s) for s in self.ranges.keys()])}",
62-
)
63-
for key, (imin, imax) in self.ranges.items():
64-
out[(ind_value >= imin) & (ind_value <= imax)] = key
65-
return out
66-
else:
67-
for key, (imin, imax) in self.ranges.items():
68-
if imin <= ind_value <= imax:
69-
return key
70-
return self.default
57+
ind_value = np.asarray(ind_value)
58+
out = np.full(
59+
ind_value.shape,
60+
self.default,
61+
dtype=f"U{max([len(s) for s in self.ranges.keys()])}",
62+
)
63+
for key, (imin, imax) in self.ranges.items():
64+
out[(ind_value >= imin) & (ind_value <= imax)] = key
65+
return out
7166

7267
def __eq__(self, other):
7368
return self.ranges == other.ranges and self.default == other.default

attune/_tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __repr__(self):
5555
return f"Tune({repr(self.independent)}, {repr(self.dependent)})"
5656
return f"Tune({repr(self.independent)}, {repr(self.dependent)}, dep_units={repr(self.dep_units)})"
5757

58-
def __call__(self, ind_value, *, ind_units=None, dep_units=None):
58+
def __call__(self, ind_value, *, ind_units=None, dep_units=None) -> np.ndarray:
5959
if ind_units is not None and self._ind_units is not None:
6060
ind_value = wt.units.convert(ind_value, ind_units, self._ind_units)
6161
ret = self._interp(ind_value)

0 commit comments

Comments
 (0)