Skip to content

Commit a3bb25c

Browse files
committed
Type point_calculus
1 parent 40ae9a3 commit a3bb25c

File tree

2 files changed

+52
-16
lines changed

2 files changed

+52
-16
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,14 @@ dependencies = [
4040

4141
[dependency-groups]
4242
dev = [
43+
{include-group = "type"},
4344
{include-group = "doc"},
4445
{include-group = "test"},
4546
{include-group = "lint"},
4647
]
48+
type = [
49+
"optype>=0.14"
50+
]
4751
lint = [
4852
"pylint",
4953
# https://github.com/astral-sh/ruff/issues/16943

sumpy/point_calculus.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2323
THE SOFTWARE.
2424
"""
25+
from typing import TYPE_CHECKING, Literal, TypeAlias
2526

2627
import numpy as np
2728
import numpy.linalg as la
@@ -30,13 +31,22 @@
3031
from pytools import memoize_method
3132

3233

34+
if TYPE_CHECKING:
35+
from collections.abc import Callable, Sequence
36+
37+
from optype.numpy import Array1D, Array2D
38+
39+
3340
__doc__ = """
3441
.. autoclass:: CalculusPatch
3542
3643
.. autofunction:: frequency_domain_maxwell
3744
"""
3845

3946

47+
NodesKind: TypeAlias = Literal["chebyshev", "equispaced", "legendre"]
48+
49+
4050
class CalculusPatch:
4151
"""Sets up a grid of points on which derivatives can be calculated. Useful
4252
to verify that an evaluated potential actually solves a PDE.
@@ -64,7 +74,16 @@ class CalculusPatch:
6474
.. automethod:: plot_nodes
6575
.. automethod:: plot
6676
"""
67-
def __init__(self, center, h=1e-1, order=4, nodes="chebyshev"):
77+
dim: int
78+
center: Array1D[np.floating]
79+
points: Array2D[np.floating]
80+
npoints: int
81+
82+
def __init__(self,
83+
center: Array1D[np.floating],
84+
h: float = 1e-1,
85+
order: int = 4,
86+
nodes: NodesKind = "chebyshev"):
6887
self.center = center
6988

7089
npoints = order + 1
@@ -119,7 +138,7 @@ def _zero_eval_vec_1d(self):
119138
# The zeroth coefficient--all others involve x=0.
120139
return self._vandermonde_1d()[0]
121140

122-
def basis(self):
141+
def basis(self) -> Sequence[Callable[[Array2D[np.floating]], Array1D[np.floating]]]:
123142
"""
124143
:returns: a :class:`list` containing functions that realize
125144
a high-order interpolation basis on the :py:attr:`points`.
@@ -129,7 +148,7 @@ def basis(self):
129148

130149
from pytools import indices_in_shape
131150

132-
def eval_basis(ind, x):
151+
def eval_basis(ind: tuple[int, ...], x: Array2D[np.floating]):
133152
result = 1
134153
for i in range(self.dim):
135154
coord = (x[i] - self.center[i])/(self.h/2)
@@ -172,7 +191,11 @@ def _diff_mat_1d(self, nderivs):
172191
deriv_coeffs_mat = la.solve(vandermonde.T, n_diff_mat.T).T
173192
return vandermonde.dot(deriv_coeffs_mat)
174193

175-
def diff(self, axis, f_values, nderivs=1):
194+
def diff(self,
195+
axis: int,
196+
f_values: Array1D[np.inexact],
197+
nderivs: int = 1
198+
) -> Array1D[np.inexact] | Literal[0]:
176199
"""Return the derivative along *axis* of *f_values*.
177200
178201
:arg f_values: an array of shape ``(npoints_total,)``
@@ -197,16 +220,16 @@ def diff(self, axis, f_values, nderivs=1):
197220
self._diff_mat_1d(nderivs),
198221
f_values.reshape(*self._pshape)).reshape(-1)
199222

200-
def dx(self, f_values):
223+
def dx(self, f_values: Array1D[np.inexact]):
201224
return self.diff(0, f_values)
202225

203-
def dy(self, f_values):
226+
def dy(self, f_values: Array1D[np.inexact]):
204227
return self.diff(1, f_values)
205228

206-
def dz(self, f_values):
229+
def dz(self, f_values: Array1D[np.inexact]):
207230
return self.diff(2, f_values)
208231

209-
def laplace(self, f_values):
232+
def laplace(self, f_values: Array1D[np.inexact]):
210233
"""Return the Laplacian of *f_values*.
211234
212235
:arg f_values: an array of shape ``(npoints_total,)``
@@ -215,18 +238,22 @@ def laplace(self, f_values):
215238

216239
return sum(self.diff(iaxis, f_values, 2) for iaxis in range(self.dim))
217240

218-
def div(self, arg):
241+
def div(self,
242+
arg: obj_array.ObjectArray1D[Array1D[np.inexact]]
243+
) -> Array1D[np.inexact] | int:
219244
r"""
220245
:arg arg: an object array containing
221246
:class:`numpy.ndarray`\ s with shape ``(npoints_total,)``.
222247
"""
223-
result = 0
248+
result: Array1D[np.inexact] | int = 0
224249
for i, arg_i in enumerate(arg):
225250
result = result + self.diff(i, arg_i)
226251

227252
return result
228253

229-
def curl(self, arg):
254+
def curl(self,
255+
arg: obj_array.ObjectArray1D[Array1D[np.inexact]]
256+
) -> obj_array.ObjectArray1D[Array1D[np.inexact]]:
230257
r"""Take the curl of the vector quantity *arg*.
231258
232259
:arg arg: an object array of shape ``(3,)`` containing
@@ -254,15 +281,15 @@ def eval_at_center(self, f_values):
254281
return f_values
255282

256283
@property
257-
def x(self):
284+
def x(self) -> Array1D[np.floating]:
258285
return self.points[0]
259286

260287
@property
261-
def y(self):
288+
def y(self) -> Array1D[np.floating]:
262289
return self.points[1]
263290

264291
@property
265-
def z(self):
292+
def z(self) -> Array1D[np.floating]:
266293
return self.points[2]
267294

268295
def norm(self, arg, p):
@@ -284,15 +311,20 @@ def plot_nodes(self):
284311
self._points_shaped[1].reshape(-1),
285312
"o")
286313

287-
def plot(self, f):
314+
def plot(self, f: Array1D[np.floating]):
288315
f = f.reshape(*self._pshape)
289316

290317
import matplotlib.pyplot as plt
291318
plt.gca().set_aspect("equal")
292319
plt.contourf(self._points_1d, self._points_1d, f)
293320

294321

295-
def frequency_domain_maxwell(cpatch, e, h, k):
322+
def frequency_domain_maxwell(
323+
cpatch: CalculusPatch,
324+
e: obj_array.ObjectArray1D[Array1D[np.complexfloating]],
325+
h: obj_array.ObjectArray1D[Array1D[np.complexfloating]],
326+
k: complex
327+
):
296328
mu = 1
297329
epsilon = 1
298330
c = 1/np.sqrt(mu*epsilon)

0 commit comments

Comments
 (0)