77.. autofunction:: make_discretization_collection
88
99.. currentmodule:: grudge.discretization
10+ .. autoclass:: PartID
1011"""
1112
1213__copyright__ = """
3435THE SOFTWARE.
3536"""
3637
37- from typing import Mapping , Optional , Union , TYPE_CHECKING , Any
38+ from typing import Sequence , Mapping , Optional , Union , Tuple , TYPE_CHECKING , Any
3839
3940from pytools import memoize_method , single_valued
4041
42+ from dataclasses import dataclass , replace
43+
4144from grudge .dof_desc import (
4245 VTAG_ALL ,
4346 DD_VOLUME_ALL ,
7174 import mpi4py .MPI
7275
7376
77+ @dataclass (frozen = True )
78+ class PartID :
79+ """Unique identifier for a piece of a partitioned mesh.
80+
81+ .. attribute:: volume_tag
82+
83+ The volume of the part.
84+
85+ .. attribute:: rank
86+
87+ The (optional) MPI rank of the part.
88+
89+ """
90+ volume_tag : VolumeTag
91+ rank : Optional [int ] = None
92+
93+
94+ # {{{ part ID normalization
95+
96+ def _normalize_mesh_part_ids (
97+ mesh : Mesh ,
98+ self_volume_tag : VolumeTag ,
99+ all_volume_tags : Sequence [VolumeTag ],
100+ mpi_communicator : Optional ["mpi4py.MPI.Intracomm" ] = None ):
101+ """Convert a mesh's configuration-dependent "part ID" into a fixed type."""
102+ from numbers import Integral
103+ if mpi_communicator is not None :
104+ # Accept PartID or rank (assume intra-volume for the latter)
105+ def as_part_id (mesh_part_id ):
106+ if isinstance (mesh_part_id , PartID ):
107+ return mesh_part_id
108+ elif isinstance (mesh_part_id , Integral ):
109+ return PartID (self_volume_tag , int (mesh_part_id ))
110+ else :
111+ raise TypeError (f"Unable to convert { mesh_part_id } to PartID." )
112+ else :
113+ # Accept PartID or volume tag
114+ def as_part_id (mesh_part_id ):
115+ if isinstance (mesh_part_id , PartID ):
116+ return mesh_part_id
117+ elif mesh_part_id in all_volume_tags :
118+ return PartID (mesh_part_id )
119+ else :
120+ raise TypeError (f"Unable to convert { mesh_part_id } to PartID." )
121+
122+ facial_adjacency_groups = mesh .facial_adjacency_groups
123+
124+ new_facial_adjacency_groups = []
125+
126+ from meshmode .mesh import InterPartAdjacencyGroup
127+ for grp_list in facial_adjacency_groups :
128+ new_grp_list = []
129+ for fagrp in grp_list :
130+ if isinstance (fagrp , InterPartAdjacencyGroup ):
131+ part_id = as_part_id (fagrp .part_id )
132+ new_fagrp = replace (
133+ fagrp ,
134+ boundary_tag = BTAG_PARTITION (part_id ),
135+ part_id = part_id )
136+ else :
137+ new_fagrp = fagrp
138+ new_grp_list .append (new_fagrp )
139+ new_facial_adjacency_groups .append (new_grp_list )
140+
141+ return mesh .copy (facial_adjacency_groups = new_facial_adjacency_groups )
142+
143+ # }}}
144+
145+
74146# {{{ discr_tag_to_group_factory normalization
75147
76148def _normalize_discr_tag_to_group_factory (
@@ -156,6 +228,9 @@ def __init__(self, array_context: ArrayContext,
156228 discr_tag_to_group_factory : Optional [
157229 Mapping [DiscretizationTag , ElementGroupFactory ]] = None ,
158230 mpi_communicator : Optional ["mpi4py.MPI.Intracomm" ] = None ,
231+ inter_part_connections : Optional [
232+ Mapping [Tuple [PartID , PartID ],
233+ DiscretizationConnection ]] = None ,
159234 ) -> None :
160235 """
161236 :arg discr_tag_to_group_factory: A mapping from discretization tags
@@ -206,6 +281,9 @@ def __init__(self, array_context: ArrayContext,
206281
207282 mesh = volume_discrs
208283
284+ mesh = _normalize_mesh_part_ids (
285+ mesh , VTAG_ALL , [VTAG_ALL ], mpi_communicator = mpi_communicator )
286+
209287 discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory (
210288 dim = mesh .dim ,
211289 discr_tag_to_group_factory = discr_tag_to_group_factory ,
@@ -219,17 +297,32 @@ def __init__(self, array_context: ArrayContext,
219297
220298 del mesh
221299
300+ if inter_part_connections is not None :
301+ raise TypeError ("may not pass inter_part_connections when "
302+ "DiscretizationCollection constructor is called in "
303+ "legacy mode" )
304+
305+ self ._inter_part_connections = \
306+ _set_up_inter_part_connections (
307+ array_context = self ._setup_actx ,
308+ mpi_communicator = mpi_communicator ,
309+ volume_discrs = volume_discrs ,
310+ base_group_factory = (
311+ discr_tag_to_group_factory [DISCR_TAG_BASE ]))
312+
222313 # }}}
223314 else :
224315 assert discr_tag_to_group_factory is not None
225316 self ._discr_tag_to_group_factory = discr_tag_to_group_factory
226317
227- self ._volume_discrs = volume_discrs
318+ if inter_part_connections is None :
319+ raise TypeError ("inter_part_connections must be passed when "
320+ "DiscretizationCollection constructor is called in "
321+ "'modern' mode" )
322+
323+ self ._inter_part_connections = inter_part_connections
228324
229- self ._dist_boundary_connections = {
230- vtag : self ._set_up_distributed_communication (
231- vtag , mpi_communicator , array_context )
232- for vtag in self ._volume_discrs .keys ()}
325+ self ._volume_discrs = volume_discrs
233326
234327 # }}}
235328
@@ -252,71 +345,6 @@ def is_management_rank(self):
252345 return self .mpi_communicator .Get_rank () \
253346 == self .get_management_rank_index ()
254347
255- # {{{ distributed
256-
257- def _set_up_distributed_communication (
258- self , vtag , mpi_communicator , array_context ):
259- from_dd = DOFDesc (VolumeDomainTag (vtag ), DISCR_TAG_BASE )
260-
261- boundary_connections = {}
262-
263- from meshmode .distributed import get_connected_partitions
264- connected_parts = get_connected_partitions (self ._volume_discrs [vtag ].mesh )
265-
266- if connected_parts :
267- if mpi_communicator is None :
268- raise RuntimeError ("must supply an MPI communicator when using a "
269- "distributed mesh" )
270-
271- grp_factory = \
272- self .group_factory_for_discretization_tag (DISCR_TAG_BASE )
273-
274- local_boundary_connections = {}
275- for i_remote_part in connected_parts :
276- local_boundary_connections [i_remote_part ] = self .connection_from_dds (
277- from_dd , from_dd .trace (BTAG_PARTITION (i_remote_part )))
278-
279- from meshmode .distributed import MPIBoundaryCommSetupHelper
280- with MPIBoundaryCommSetupHelper (mpi_communicator , array_context ,
281- local_boundary_connections , grp_factory ) as bdry_setup_helper :
282- while True :
283- conns = bdry_setup_helper .complete_some ()
284- if not conns :
285- break
286- for i_remote_part , conn in conns .items ():
287- boundary_connections [i_remote_part ] = conn
288-
289- return boundary_connections
290-
291- def distributed_boundary_swap_connection (self , dd ):
292- """Provides a mapping from the base volume discretization
293- to the exterior boundary restriction on a parallel boundary
294- partition described by *dd*. This connection is used to
295- communicate across element boundaries in different parallel
296- partitions during distributed runs.
297-
298- :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value
299- convertible to one. The domain tag must be a subclass
300- of :class:`grudge.dof_desc.BoundaryDomainTag` with an
301- associated :class:`meshmode.mesh.BTAG_PARTITION`
302- corresponding to a particular communication rank.
303- """
304- if dd .discretization_tag is not DISCR_TAG_BASE :
305- # FIXME
306- raise NotImplementedError (
307- "Distributed communication with discretization tag "
308- f"{ dd .discretization_tag } is not implemented."
309- )
310-
311- assert isinstance (dd .domain_tag , BoundaryDomainTag )
312- assert isinstance (dd .domain_tag .tag , BTAG_PARTITION )
313-
314- vtag = dd .domain_tag .volume_tag
315-
316- return self ._dist_boundary_connections [vtag ][dd .domain_tag .tag .part_nr ]
317-
318- # }}}
319-
320348 # {{{ discr_from_dd
321349
322350 @memoize_method
@@ -772,6 +800,105 @@ def normal(self, dd):
772800 # }}}
773801
774802
803+ # {{{ distributed/multi-volume setup
804+
805+ def _set_up_inter_part_connections (
806+ array_context : ArrayContext ,
807+ mpi_communicator : Optional ["mpi4py.MPI.Intracomm" ],
808+ volume_discrs : Mapping [VolumeTag , Discretization ],
809+ base_group_factory : ElementGroupFactory ,
810+ ) -> Mapping [
811+ Tuple [PartID , PartID ],
812+ DiscretizationConnection ]:
813+
814+ from meshmode .distributed import (get_connected_parts ,
815+ make_remote_group_infos , InterRankBoundaryInfo ,
816+ MPIBoundaryCommSetupHelper )
817+
818+ rank = mpi_communicator .Get_rank () if mpi_communicator is not None else None
819+
820+ # Save boundary restrictions as they're created to avoid potentially creating
821+ # them twice in the loop below
822+ cached_part_bdry_restrictions : Mapping [
823+ Tuple [PartID , PartID ],
824+ DiscretizationConnection ] = {}
825+
826+ def get_part_bdry_restriction (self_part_id , other_part_id ):
827+ cached_result = cached_part_bdry_restrictions .get (
828+ (self_part_id , other_part_id ), None )
829+ if cached_result is not None :
830+ return cached_result
831+ return cached_part_bdry_restrictions .setdefault (
832+ (self_part_id , other_part_id ),
833+ make_face_restriction (
834+ array_context , volume_discrs [self_part_id .volume_tag ],
835+ base_group_factory ,
836+ boundary_tag = BTAG_PARTITION (other_part_id )))
837+
838+ inter_part_conns : Mapping [
839+ Tuple [PartID , PartID ],
840+ DiscretizationConnection ] = {}
841+
842+ irbis = []
843+
844+ for vtag , volume_discr in volume_discrs .items ():
845+ part_id = PartID (vtag , rank )
846+ connected_part_ids = get_connected_parts (volume_discr .mesh )
847+ for connected_part_id in connected_part_ids :
848+ bdry_restr = get_part_bdry_restriction (
849+ self_part_id = part_id , other_part_id = connected_part_id )
850+
851+ if connected_part_id .rank == rank :
852+ # {{{ rank-local interface between multiple volumes
853+
854+ connected_bdry_restr = get_part_bdry_restriction (
855+ self_part_id = connected_part_id , other_part_id = part_id )
856+
857+ from meshmode .discretization .connection import \
858+ make_partition_connection
859+ inter_part_conns [connected_part_id , part_id ] = \
860+ make_partition_connection (
861+ array_context ,
862+ local_bdry_conn = bdry_restr ,
863+ remote_bdry_discr = connected_bdry_restr .to_discr ,
864+ remote_group_infos = make_remote_group_infos (
865+ array_context , part_id , connected_bdry_restr ))
866+
867+ # }}}
868+ else :
869+ # {{{ cross-rank interface
870+
871+ if mpi_communicator is None :
872+ raise RuntimeError ("must supply an MPI communicator "
873+ "when using a distributed mesh" )
874+
875+ irbis .append (
876+ InterRankBoundaryInfo (
877+ local_part_id = part_id ,
878+ remote_part_id = connected_part_id ,
879+ remote_rank = connected_part_id .rank ,
880+ local_boundary_connection = bdry_restr ))
881+
882+ # }}}
883+
884+ if irbis :
885+ assert mpi_communicator is not None
886+
887+ with MPIBoundaryCommSetupHelper (mpi_communicator , array_context ,
888+ irbis , base_group_factory ) as bdry_setup_helper :
889+ while True :
890+ conns = bdry_setup_helper .complete_some ()
891+ if not conns :
892+ # We're done.
893+ break
894+
895+ inter_part_conns .update (conns )
896+
897+ return inter_part_conns
898+
899+ # }}}
900+
901+
775902# {{{ modal group factory
776903
777904def _generate_modal_group_factory (nodal_group_factory ):
@@ -860,6 +987,8 @@ def make_discretization_collection(
860987
861988 del order
862989
990+ mpi_communicator = getattr (array_context , "mpi_communicator" , None )
991+
863992 if any (
864993 isinstance (mesh_or_discr , Discretization )
865994 for mesh_or_discr in volumes .values ()):
@@ -868,14 +997,20 @@ def make_discretization_collection(
868997 volume_discrs = {
869998 vtag : Discretization (
870999 array_context ,
871- mesh ,
1000+ _normalize_mesh_part_ids (
1001+ mesh , vtag , volumes .keys (), mpi_communicator = mpi_communicator ),
8721002 discr_tag_to_group_factory [DISCR_TAG_BASE ])
8731003 for vtag , mesh in volumes .items ()}
8741004
8751005 return DiscretizationCollection (
8761006 array_context = array_context ,
8771007 volume_discrs = volume_discrs ,
878- discr_tag_to_group_factory = discr_tag_to_group_factory )
1008+ discr_tag_to_group_factory = discr_tag_to_group_factory ,
1009+ inter_part_connections = _set_up_inter_part_connections (
1010+ array_context = array_context ,
1011+ mpi_communicator = mpi_communicator ,
1012+ volume_discrs = volume_discrs ,
1013+ base_group_factory = discr_tag_to_group_factory [DISCR_TAG_BASE ]))
8791014
8801015# }}}
8811016
0 commit comments