77# SPDX-License-Identifier: LGPL-3.0-or-later
88
99import numpy
10+ from itertools import chain
11+ from collections import defaultdict
1012
1113from FIAT import polynomial_set , functional
1214from FIAT .reference_element import compute_unflattening_map
@@ -117,36 +119,37 @@ def to_riesz(self, poly_set):
117119 riesz_shape = (num_nodes , * tshape , num_exp )
118120 mat = numpy .zeros (riesz_shape , "d" )
119121
120- pts = set ()
121- dpts = set ()
122- Qs_to_ells = dict ()
123- for i , ell in enumerate (self .nodes ):
124- if len (ell .deriv_dict ) > 0 :
125- dpts .update (ell .deriv_dict .keys ())
126- continue
127- if isinstance (ell , functional .IntegralMoment ):
128- Q = ell .Q
129- else :
130- Q = None
131- pts .update (ell .pt_dict .keys ())
132- if Q in Qs_to_ells :
122+ def map_quadratures_to_points (nodes , deriv = False ):
123+ Qs_to_ells = defaultdict (list )
124+ for i , ell in enumerate (nodes ):
125+ if deriv and len (ell .deriv_dict ) == 0 :
126+ continue
127+ elif not deriv and len (ell .pt_dict ) == 0 :
128+ continue
129+ if isinstance (ell , (functional .IntegralMoment , functional .IntegralMomentOfDerivative )):
130+ Q = ell .Q
131+ else :
132+ Q = None
133133 Qs_to_ells [Q ].append (i )
134- else :
135- Qs_to_ells [Q ] = [i ]
136-
137- Qs_to_pts = {}
138- if len (pts ) > 0 :
139- Qs_to_pts [None ] = tuple (sorted (pts ))
140- for Q in Qs_to_ells :
141- if Q is not None :
142- cur_pts = tuple (map (tuple , Q .pts ))
134+ pts = set ()
135+ Qs_to_pts = {}
136+ for Q in Qs_to_ells :
137+ if Q is None :
138+ if deriv :
139+ cur_pts = chain .from_iterable (nodes [i ].deriv_dict .keys () for i in Qs_to_ells [None ])
140+ else :
141+ cur_pts = chain .from_iterable (nodes [i ].pt_dict .keys () for i in Qs_to_ells [None ])
142+ cur_pts = tuple (set (cur_pts ))
143+ else :
144+ cur_pts = tuple (map (tuple , Q .pts ))
143145 Qs_to_pts [Q ] = cur_pts
144146 pts .update (cur_pts )
147+ pts = list (sorted (pts ))
148+ return Qs_to_ells , Qs_to_pts , pts
145149
146150 # Now tabulate the function values
147- pts = list ( sorted ( pts ) )
151+ Qs_to_ells , Qs_to_pts , pts = map_quadratures_to_points ( self . nodes )
148152 expansion_values = numpy .transpose (es .tabulate (ed , pts ))
149-
150153 for Q in Qs_to_ells :
151154 ells = Qs_to_ells [Q ]
152155 cur_pts = Qs_to_pts [Q ]
@@ -171,25 +174,35 @@ def to_riesz(self, poly_set):
171174 # Tabulate the derivative values that are needed
172175 max_deriv_order = max (ell .max_deriv_order for ell in self .nodes )
173176 if max_deriv_order > 0 :
174- dpts = list ( sorted ( dpts ) )
177+ Qs_to_ells , Qs_to_pts , pts = map_quadratures_to_points ( self . nodes , deriv = True )
175178 # It's easiest/most efficient to get derivatives of the
176179 # expansion set through the polynomial set interface.
177180 # This is creating a short-lived set to do just this.
178181 coeffs = numpy .eye (num_exp )
179182 expansion = polynomial_set .PolynomialSet (self .ref_el , ed , ed , es , coeffs )
180- dexpansion_values = expansion .tabulate (dpts , max_deriv_order )
181-
182- ells = [k for k , ell in enumerate (self .nodes ) if len (ell .deriv_dict ) > 0 ]
183- wshape = (len (ells ), * tshape , len (dpts ))
184- dwts = {alpha : numpy .zeros (wshape , "d" ) for alpha in dexpansion_values if sum (alpha ) > 0 }
185- for i , k in enumerate (ells ):
186- ell = self .nodes [k ]
187- for pt , wac_list in ell .deriv_dict .items ():
188- j = dpts .index (pt )
189- for (w , alpha , c ) in wac_list :
190- dwts [alpha ][i ][c ][j ] = w
191- for alpha in dwts :
192- mat [ells ] += numpy .dot (dwts [alpha ], dexpansion_values [alpha ].T )
183+ dexpansion_values = expansion .tabulate (pts , max_deriv_order )
184+ for Q in Qs_to_ells :
185+ ells = Qs_to_ells [Q ]
186+ cur_pts = Qs_to_pts [Q ]
187+ indices = list (map (pts .index , cur_pts ))
188+ wshape = (len (ells ), * tshape , len (cur_pts ))
189+ dwts = {alpha : numpy .zeros (wshape , "d" ) for alpha in dexpansion_values if sum (alpha ) > 0 }
190+ if Q is None :
191+ for i , k in enumerate (ells ):
192+ ell = self .nodes [k ]
193+ for pt , wac_list in ell .deriv_dict .items ():
194+ j = cur_pts .index (pt )
195+ for (w , alpha , c ) in wac_list :
196+ dwts [alpha ][i ][c ][j ] = w
197+ else :
198+ for i , k in enumerate (ells ):
199+ ell = self .nodes [k ]
200+ for alpha in ell .weights :
201+ dwts [alpha ][i ][ell .comp ][:] = ell .weights [alpha ]
202+ for alpha in dwts :
203+ wts = dwts [alpha ]
204+ expansion_values = dexpansion_values [alpha ].T
205+ mat [ells ] += numpy .dot (wts , expansion_values [indices ])
193206 return mat
194207
195208 def get_indices (self , restriction_domain , take_closure = True ):
0 commit comments