Skip to content

Commit d2204e8

Browse files
authored
DualSet: BLAS3 optimization for IntegralMomentOfDerivative (#210)
* DualSet: BLAS3 optimization for IntegralMomentOfDerivative
1 parent a087dbd commit d2204e8

File tree

3 files changed

+53
-39
lines changed

3 files changed

+53
-39
lines changed

FIAT/dual_set.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# SPDX-License-Identifier: LGPL-3.0-or-later
88

99
import numpy
10+
from itertools import chain
11+
from collections import defaultdict
1012

1113
from FIAT import polynomial_set, functional
1214
from FIAT.reference_element import compute_unflattening_map
@@ -117,36 +119,37 @@ def to_riesz(self, poly_set):
117119
riesz_shape = (num_nodes, *tshape, num_exp)
118120
mat = numpy.zeros(riesz_shape, "d")
119121

120-
pts = set()
121-
dpts = set()
122-
Qs_to_ells = dict()
123-
for i, ell in enumerate(self.nodes):
124-
if len(ell.deriv_dict) > 0:
125-
dpts.update(ell.deriv_dict.keys())
126-
continue
127-
if isinstance(ell, functional.IntegralMoment):
128-
Q = ell.Q
129-
else:
130-
Q = None
131-
pts.update(ell.pt_dict.keys())
132-
if Q in Qs_to_ells:
122+
def map_quadratures_to_points(nodes, deriv=False):
123+
Qs_to_ells = defaultdict(list)
124+
for i, ell in enumerate(nodes):
125+
if deriv and len(ell.deriv_dict) == 0:
126+
continue
127+
elif not deriv and len(ell.pt_dict) == 0:
128+
continue
129+
if isinstance(ell, (functional.IntegralMoment, functional.IntegralMomentOfDerivative)):
130+
Q = ell.Q
131+
else:
132+
Q = None
133133
Qs_to_ells[Q].append(i)
134-
else:
135-
Qs_to_ells[Q] = [i]
136-
137-
Qs_to_pts = {}
138-
if len(pts) > 0:
139-
Qs_to_pts[None] = tuple(sorted(pts))
140-
for Q in Qs_to_ells:
141-
if Q is not None:
142-
cur_pts = tuple(map(tuple, Q.pts))
134+
pts = set()
135+
Qs_to_pts = {}
136+
for Q in Qs_to_ells:
137+
if Q is None:
138+
if deriv:
139+
cur_pts = chain.from_iterable(nodes[i].deriv_dict.keys() for i in Qs_to_ells[None])
140+
else:
141+
cur_pts = chain.from_iterable(nodes[i].pt_dict.keys() for i in Qs_to_ells[None])
142+
cur_pts = tuple(set(cur_pts))
143+
else:
144+
cur_pts = tuple(map(tuple, Q.pts))
143145
Qs_to_pts[Q] = cur_pts
144146
pts.update(cur_pts)
147+
pts = list(sorted(pts))
148+
return Qs_to_ells, Qs_to_pts, pts
145149

146150
# Now tabulate the function values
147-
pts = list(sorted(pts))
151+
Qs_to_ells, Qs_to_pts, pts = map_quadratures_to_points(self.nodes)
148152
expansion_values = numpy.transpose(es.tabulate(ed, pts))
149-
150153
for Q in Qs_to_ells:
151154
ells = Qs_to_ells[Q]
152155
cur_pts = Qs_to_pts[Q]
@@ -171,25 +174,35 @@ def to_riesz(self, poly_set):
171174
# Tabulate the derivative values that are needed
172175
max_deriv_order = max(ell.max_deriv_order for ell in self.nodes)
173176
if max_deriv_order > 0:
174-
dpts = list(sorted(dpts))
177+
Qs_to_ells, Qs_to_pts, pts = map_quadratures_to_points(self.nodes, deriv=True)
175178
# It's easiest/most efficient to get derivatives of the
176179
# expansion set through the polynomial set interface.
177180
# This is creating a short-lived set to do just this.
178181
coeffs = numpy.eye(num_exp)
179182
expansion = polynomial_set.PolynomialSet(self.ref_el, ed, ed, es, coeffs)
180-
dexpansion_values = expansion.tabulate(dpts, max_deriv_order)
181-
182-
ells = [k for k, ell in enumerate(self.nodes) if len(ell.deriv_dict) > 0]
183-
wshape = (len(ells), *tshape, len(dpts))
184-
dwts = {alpha: numpy.zeros(wshape, "d") for alpha in dexpansion_values if sum(alpha) > 0}
185-
for i, k in enumerate(ells):
186-
ell = self.nodes[k]
187-
for pt, wac_list in ell.deriv_dict.items():
188-
j = dpts.index(pt)
189-
for (w, alpha, c) in wac_list:
190-
dwts[alpha][i][c][j] = w
191-
for alpha in dwts:
192-
mat[ells] += numpy.dot(dwts[alpha], dexpansion_values[alpha].T)
183+
dexpansion_values = expansion.tabulate(pts, max_deriv_order)
184+
for Q in Qs_to_ells:
185+
ells = Qs_to_ells[Q]
186+
cur_pts = Qs_to_pts[Q]
187+
indices = list(map(pts.index, cur_pts))
188+
wshape = (len(ells), *tshape, len(cur_pts))
189+
dwts = {alpha: numpy.zeros(wshape, "d") for alpha in dexpansion_values if sum(alpha) > 0}
190+
if Q is None:
191+
for i, k in enumerate(ells):
192+
ell = self.nodes[k]
193+
for pt, wac_list in ell.deriv_dict.items():
194+
j = cur_pts.index(pt)
195+
for (w, alpha, c) in wac_list:
196+
dwts[alpha][i][c][j] = w
197+
else:
198+
for i, k in enumerate(ells):
199+
ell = self.nodes[k]
200+
for alpha in ell.weights:
201+
dwts[alpha][i][ell.comp][:] = ell.weights[alpha]
202+
for alpha in dwts:
203+
wts = dwts[alpha]
204+
expansion_values = dexpansion_values[alpha].T
205+
mat[ells] += numpy.dot(wts, expansion_values[indices])
193206
return mat
194207

195208
def get_indices(self, restriction_domain, take_closure=True):

FIAT/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ def __init__(self, ref_el, Q, f_at_qpts, *directions, comp=(), shp=(), nm=""):
347347

348348
points = Q.get_points()
349349
weights = numpy.multiply(f_at_qpts, Q.get_weights())
350+
self.weights = {alpha: weights*tau[alpha] for alpha in tau}
351+
350352
dpt_dict = {tuple(pt): [(wt*tau[alpha], alpha, comp) for alpha in tau]
351353
for pt, wt in zip(points, weights)}
352354

FIAT/wuxu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ def __init__(self, ref_el, degree):
168168
Q = FacetQuadratureRule(ref_el, 1, e, Q_ref, avg=True)
169169
cur = len(nodes)
170170
nodes.append(IntegralMomentOfDerivative(ref_el, Q, f, n, n))
171-
172171
entity_ids[1][e].extend(range(cur, len(nodes)))
173172

174173
super().__init__(nodes, ref_el, entity_ids)

0 commit comments

Comments
 (0)