Skip to content

Commit afc56b6

Browse files
committed
set up connections between volumes
1 parent 0aa399a commit afc56b6

1 file changed

Lines changed: 208 additions & 73 deletions

File tree

grudge/discretization.py

Lines changed: 208 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
.. autofunction:: make_discretization_collection
88
99
.. currentmodule:: grudge.discretization
10+
.. autoclass:: PartID
1011
"""
1112

1213
__copyright__ = """
@@ -34,10 +35,12 @@
3435
THE SOFTWARE.
3536
"""
3637

37-
from typing import Mapping, Optional, Union, TYPE_CHECKING, Any
38+
from typing import Sequence, Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any
3839

3940
from pytools import memoize_method, single_valued
4041

42+
from dataclasses import dataclass, replace
43+
4144
from grudge.dof_desc import (
4245
VTAG_ALL,
4346
DD_VOLUME_ALL,
@@ -71,6 +74,75 @@
7174
import mpi4py.MPI
7275

7376

77+
@dataclass(frozen=True)
78+
class PartID:
79+
"""Unique identifier for a piece of a partitioned mesh.
80+
81+
.. attribute:: volume_tag
82+
83+
The volume of the part.
84+
85+
.. attribute:: rank
86+
87+
The (optional) MPI rank of the part.
88+
89+
"""
90+
volume_tag: VolumeTag
91+
rank: Optional[int] = None
92+
93+
94+
# {{{ part ID normalization
95+
96+
def _normalize_mesh_part_ids(
97+
mesh: Mesh,
98+
self_volume_tag: VolumeTag,
99+
all_volume_tags: Sequence[VolumeTag],
100+
mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None):
101+
"""Convert a mesh's configuration-dependent "part ID" into a fixed type."""
102+
from numbers import Integral
103+
if mpi_communicator is not None:
104+
# Accept PartID or rank (assume intra-volume for the latter)
105+
def as_part_id(mesh_part_id):
106+
if isinstance(mesh_part_id, PartID):
107+
return mesh_part_id
108+
elif isinstance(mesh_part_id, Integral):
109+
return PartID(self_volume_tag, int(mesh_part_id))
110+
else:
111+
raise TypeError(f"Unable to convert {mesh_part_id} to PartID.")
112+
else:
113+
# Accept PartID or volume tag
114+
def as_part_id(mesh_part_id):
115+
if isinstance(mesh_part_id, PartID):
116+
return mesh_part_id
117+
elif mesh_part_id in all_volume_tags:
118+
return PartID(mesh_part_id)
119+
else:
120+
raise TypeError(f"Unable to convert {mesh_part_id} to PartID.")
121+
122+
facial_adjacency_groups = mesh.facial_adjacency_groups
123+
124+
new_facial_adjacency_groups = []
125+
126+
from meshmode.mesh import InterPartAdjacencyGroup
127+
for grp_list in facial_adjacency_groups:
128+
new_grp_list = []
129+
for fagrp in grp_list:
130+
if isinstance(fagrp, InterPartAdjacencyGroup):
131+
part_id = as_part_id(fagrp.part_id)
132+
new_fagrp = replace(
133+
fagrp,
134+
boundary_tag=BTAG_PARTITION(part_id),
135+
part_id=part_id)
136+
else:
137+
new_fagrp = fagrp
138+
new_grp_list.append(new_fagrp)
139+
new_facial_adjacency_groups.append(new_grp_list)
140+
141+
return mesh.copy(facial_adjacency_groups=new_facial_adjacency_groups)
142+
143+
# }}}
144+
145+
74146
# {{{ discr_tag_to_group_factory normalization
75147

76148
def _normalize_discr_tag_to_group_factory(
@@ -156,6 +228,9 @@ def __init__(self, array_context: ArrayContext,
156228
discr_tag_to_group_factory: Optional[
157229
Mapping[DiscretizationTag, ElementGroupFactory]] = None,
158230
mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None,
231+
inter_part_connections: Optional[
232+
Mapping[Tuple[PartID, PartID],
233+
DiscretizationConnection]] = None,
159234
) -> None:
160235
"""
161236
:arg discr_tag_to_group_factory: A mapping from discretization tags
@@ -206,6 +281,9 @@ def __init__(self, array_context: ArrayContext,
206281

207282
mesh = volume_discrs
208283

284+
mesh = _normalize_mesh_part_ids(
285+
mesh, VTAG_ALL, [VTAG_ALL], mpi_communicator=mpi_communicator)
286+
209287
discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory(
210288
dim=mesh.dim,
211289
discr_tag_to_group_factory=discr_tag_to_group_factory,
@@ -219,17 +297,32 @@ def __init__(self, array_context: ArrayContext,
219297

220298
del mesh
221299

300+
if inter_part_connections is not None:
301+
raise TypeError("may not pass inter_part_connections when "
302+
"DiscretizationCollection constructor is called in "
303+
"legacy mode")
304+
305+
self._inter_part_connections = \
306+
_set_up_inter_part_connections(
307+
array_context=self._setup_actx,
308+
mpi_communicator=mpi_communicator,
309+
volume_discrs=volume_discrs,
310+
base_group_factory=(
311+
discr_tag_to_group_factory[DISCR_TAG_BASE]))
312+
222313
# }}}
223314
else:
224315
assert discr_tag_to_group_factory is not None
225316
self._discr_tag_to_group_factory = discr_tag_to_group_factory
226317

227-
self._volume_discrs = volume_discrs
318+
if inter_part_connections is None:
319+
raise TypeError("inter_part_connections must be passed when "
320+
"DiscretizationCollection constructor is called in "
321+
"'modern' mode")
322+
323+
self._inter_part_connections = inter_part_connections
228324

229-
self._dist_boundary_connections = {
230-
vtag: self._set_up_distributed_communication(
231-
vtag, mpi_communicator, array_context)
232-
for vtag in self._volume_discrs.keys()}
325+
self._volume_discrs = volume_discrs
233326

234327
# }}}
235328

@@ -252,71 +345,6 @@ def is_management_rank(self):
252345
return self.mpi_communicator.Get_rank() \
253346
== self.get_management_rank_index()
254347

255-
# {{{ distributed
256-
257-
def _set_up_distributed_communication(
258-
self, vtag, mpi_communicator, array_context):
259-
from_dd = DOFDesc(VolumeDomainTag(vtag), DISCR_TAG_BASE)
260-
261-
boundary_connections = {}
262-
263-
from meshmode.distributed import get_connected_partitions
264-
connected_parts = get_connected_partitions(self._volume_discrs[vtag].mesh)
265-
266-
if connected_parts:
267-
if mpi_communicator is None:
268-
raise RuntimeError("must supply an MPI communicator when using a "
269-
"distributed mesh")
270-
271-
grp_factory = \
272-
self.group_factory_for_discretization_tag(DISCR_TAG_BASE)
273-
274-
local_boundary_connections = {}
275-
for i_remote_part in connected_parts:
276-
local_boundary_connections[i_remote_part] = self.connection_from_dds(
277-
from_dd, from_dd.trace(BTAG_PARTITION(i_remote_part)))
278-
279-
from meshmode.distributed import MPIBoundaryCommSetupHelper
280-
with MPIBoundaryCommSetupHelper(mpi_communicator, array_context,
281-
local_boundary_connections, grp_factory) as bdry_setup_helper:
282-
while True:
283-
conns = bdry_setup_helper.complete_some()
284-
if not conns:
285-
break
286-
for i_remote_part, conn in conns.items():
287-
boundary_connections[i_remote_part] = conn
288-
289-
return boundary_connections
290-
291-
def distributed_boundary_swap_connection(self, dd):
292-
"""Provides a mapping from the base volume discretization
293-
to the exterior boundary restriction on a parallel boundary
294-
partition described by *dd*. This connection is used to
295-
communicate across element boundaries in different parallel
296-
partitions during distributed runs.
297-
298-
:arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value
299-
convertible to one. The domain tag must be a subclass
300-
of :class:`grudge.dof_desc.BoundaryDomainTag` with an
301-
associated :class:`meshmode.mesh.BTAG_PARTITION`
302-
corresponding to a particular communication rank.
303-
"""
304-
if dd.discretization_tag is not DISCR_TAG_BASE:
305-
# FIXME
306-
raise NotImplementedError(
307-
"Distributed communication with discretization tag "
308-
f"{dd.discretization_tag} is not implemented."
309-
)
310-
311-
assert isinstance(dd.domain_tag, BoundaryDomainTag)
312-
assert isinstance(dd.domain_tag.tag, BTAG_PARTITION)
313-
314-
vtag = dd.domain_tag.volume_tag
315-
316-
return self._dist_boundary_connections[vtag][dd.domain_tag.tag.part_nr]
317-
318-
# }}}
319-
320348
# {{{ discr_from_dd
321349

322350
@memoize_method
@@ -772,6 +800,105 @@ def normal(self, dd):
772800
# }}}
773801

774802

803+
# {{{ distributed/multi-volume setup
804+
805+
def _set_up_inter_part_connections(
806+
array_context: ArrayContext,
807+
mpi_communicator: Optional["mpi4py.MPI.Intracomm"],
808+
volume_discrs: Mapping[VolumeTag, Discretization],
809+
base_group_factory: ElementGroupFactory,
810+
) -> Mapping[
811+
Tuple[PartID, PartID],
812+
DiscretizationConnection]:
813+
814+
from meshmode.distributed import (get_connected_parts,
815+
make_remote_group_infos, InterRankBoundaryInfo,
816+
MPIBoundaryCommSetupHelper)
817+
818+
rank = mpi_communicator.Get_rank() if mpi_communicator is not None else None
819+
820+
# Save boundary restrictions as they're created to avoid potentially creating
821+
# them twice in the loop below
822+
cached_part_bdry_restrictions: Mapping[
823+
Tuple[PartID, PartID],
824+
DiscretizationConnection] = {}
825+
826+
def get_part_bdry_restriction(self_part_id, other_part_id):
827+
cached_result = cached_part_bdry_restrictions.get(
828+
(self_part_id, other_part_id), None)
829+
if cached_result is not None:
830+
return cached_result
831+
return cached_part_bdry_restrictions.setdefault(
832+
(self_part_id, other_part_id),
833+
make_face_restriction(
834+
array_context, volume_discrs[self_part_id.volume_tag],
835+
base_group_factory,
836+
boundary_tag=BTAG_PARTITION(other_part_id)))
837+
838+
inter_part_conns: Mapping[
839+
Tuple[PartID, PartID],
840+
DiscretizationConnection] = {}
841+
842+
irbis = []
843+
844+
for vtag, volume_discr in volume_discrs.items():
845+
part_id = PartID(vtag, rank)
846+
connected_part_ids = get_connected_parts(volume_discr.mesh)
847+
for connected_part_id in connected_part_ids:
848+
bdry_restr = get_part_bdry_restriction(
849+
self_part_id=part_id, other_part_id=connected_part_id)
850+
851+
if connected_part_id.rank == rank:
852+
# {{{ rank-local interface between multiple volumes
853+
854+
connected_bdry_restr = get_part_bdry_restriction(
855+
self_part_id=connected_part_id, other_part_id=part_id)
856+
857+
from meshmode.discretization.connection import \
858+
make_partition_connection
859+
inter_part_conns[connected_part_id, part_id] = \
860+
make_partition_connection(
861+
array_context,
862+
local_bdry_conn=bdry_restr,
863+
remote_bdry_discr=connected_bdry_restr.to_discr,
864+
remote_group_infos=make_remote_group_infos(
865+
array_context, part_id, connected_bdry_restr))
866+
867+
# }}}
868+
else:
869+
# {{{ cross-rank interface
870+
871+
if mpi_communicator is None:
872+
raise RuntimeError("must supply an MPI communicator "
873+
"when using a distributed mesh")
874+
875+
irbis.append(
876+
InterRankBoundaryInfo(
877+
local_part_id=part_id,
878+
remote_part_id=connected_part_id,
879+
remote_rank=connected_part_id.rank,
880+
local_boundary_connection=bdry_restr))
881+
882+
# }}}
883+
884+
if irbis:
885+
assert mpi_communicator is not None
886+
887+
with MPIBoundaryCommSetupHelper(mpi_communicator, array_context,
888+
irbis, base_group_factory) as bdry_setup_helper:
889+
while True:
890+
conns = bdry_setup_helper.complete_some()
891+
if not conns:
892+
# We're done.
893+
break
894+
895+
inter_part_conns.update(conns)
896+
897+
return inter_part_conns
898+
899+
# }}}
900+
901+
775902
# {{{ modal group factory
776903

777904
def _generate_modal_group_factory(nodal_group_factory):
@@ -860,6 +987,8 @@ def make_discretization_collection(
860987

861988
del order
862989

990+
mpi_communicator = getattr(array_context, "mpi_communicator", None)
991+
863992
if any(
864993
isinstance(mesh_or_discr, Discretization)
865994
for mesh_or_discr in volumes.values()):
@@ -868,14 +997,20 @@ def make_discretization_collection(
868997
volume_discrs = {
869998
vtag: Discretization(
870999
array_context,
871-
mesh,
1000+
_normalize_mesh_part_ids(
1001+
mesh, vtag, volumes.keys(), mpi_communicator=mpi_communicator),
8721002
discr_tag_to_group_factory[DISCR_TAG_BASE])
8731003
for vtag, mesh in volumes.items()}
8741004

8751005
return DiscretizationCollection(
8761006
array_context=array_context,
8771007
volume_discrs=volume_discrs,
878-
discr_tag_to_group_factory=discr_tag_to_group_factory)
1008+
discr_tag_to_group_factory=discr_tag_to_group_factory,
1009+
inter_part_connections=_set_up_inter_part_connections(
1010+
array_context=array_context,
1011+
mpi_communicator=mpi_communicator,
1012+
volume_discrs=volume_discrs,
1013+
base_group_factory=discr_tag_to_group_factory[DISCR_TAG_BASE]))
8791014

8801015
# }}}
8811016

0 commit comments

Comments
 (0)