diff --git a/firedrake/preconditioners/asm.py b/firedrake/preconditioners/asm.py index de537cf123..0aa51b4c9f 100644 --- a/firedrake/preconditioners/asm.py +++ b/firedrake/preconditioners/asm.py @@ -6,6 +6,7 @@ from firedrake.dmhooks import get_function_space from firedrake.mesh import DistributedMeshOverlapType from firedrake.logging import warning +from firedrake.exceptions import NonUniqueMeshSequenceError from tinyasm import _tinyasm as tinyasm from mpi4py import MPI import numpy @@ -152,52 +153,30 @@ class ASMStarPC(ASMPatchPC): _prefix = "pc_star_" def get_patches(self, V): - mesh = V._mesh - if len(set(mesh)) == 1: - mesh_unique = mesh.unique() - else: + try: + mesh = V.mesh().unique() + except NonUniqueMeshSequenceError: raise NotImplementedError("Not implemented for general mixed meshes") - mesh_dm = mesh_unique.topology_dm - if mesh_unique.cell_set._extruded: + mesh_dm = mesh.topology_dm + if mesh.cell_set._extruded: warning("applying ASMStarPC on an extruded mesh") # Obtain the topological entities to use to construct the stars opts = PETSc.Options(self.prefix) depth = opts.getInt("construct_dim", default=0) + validate_overlap(mesh, depth, "star") + + use_coloring = opts.getBool("use_coloring", default=False) ordering = opts.getString("mat_ordering_type", default="natural") - validate_overlap(mesh_unique, depth, "star") # Accessing .indices causes the allocation of a global array, # so we need to cache these for efficiency V_local_ises_indices = tuple(iset.indices for iset in V.dof_dset.local_ises) # Build index sets for the patches - ises = [] - (start, end) = mesh_dm.getDepthStratum(depth) - for seed in range(start, end): - # Only build patches over owned DoFs - if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1: - continue - - # Create point list from mesh DM - pt_array, _ = mesh_dm.getTransitiveClosure(seed, useCone=False) - pt_array = order_points(mesh_dm, pt_array, ordering, self.prefix) - - # Get DoF indices for patch - indices = [] - for (i, W) in enumerate(V): - section = W.dm.getDefaultSection() - for p in pt_array.tolist(): - dof = section.getDof(p) - if dof <= 0: - continue - off = section.getOffset(p) - # Local indices within W - W_indices = slice(off*W.block_size, W.block_size * (off + dof)) - indices.extend(V_local_ises_indices[i][W_indices]) - iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF) - ises.append(iset) - + colors = get_colors(mesh, use_coloring, depth, distance=1) + ises = [build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, self.prefix, color) + for color in colors] return ises @@ -213,13 +192,12 @@ class ASMVankaPC(ASMPatchPC): _prefix = "pc_vanka_" def get_patches(self, V): - mesh = V._mesh - if len(set(mesh)) == 1: - mesh_unique = mesh.unique() - else: + try: + mesh = V.mesh().unique() + except NonUniqueMeshSequenceError: raise NotImplementedError("Not implemented for general mixed meshes") - mesh_dm = mesh_unique.topology_dm - if mesh_unique.layers: + mesh_dm = mesh.topology_dm + if mesh.layers: warning("applying ASMVankaPC on an extruded mesh") # Obtain the topological entities to use to construct the stars @@ -228,62 +206,33 @@ def get_patches(self, V): height = opts.getInt("construct_codim", default=-1) if (depth == -1 and height == -1) or (depth != -1 and height != -1): raise ValueError(f"Must set exactly one of {self.prefix}construct_dim or {self.prefix}construct_codim") + if depth == -1: + depth = mesh_dm.getDimension() - height + validate_overlap(mesh, depth, "vanka") exclude_subspaces = list(map(int, opts.getString("exclude_subspaces", default="-1").split(","))) + include_subspaces = [i for i in range(len(V)) if i not in exclude_subspaces] include_type = opts.getString("include_type", default="star").lower() if include_type not in ["star", "entity"]: raise ValueError(f"{self.prefix}include_type must be either 'star' or 'entity', not {include_type}") include_star = include_type == "star" + use_coloring = opts.getBool("use_coloring", default=False) ordering = opts.getString("mat_ordering_type", default="natural") + + def splitting(V): + return (tuple(V[i] for i in include_subspaces), tuple(V[i] for i in exclude_subspaces)) + + Z = splitting(V) # Accessing .indices causes the allocation of a global array, # so we need to cache these for efficiency V_local_ises_indices = tuple(iset.indices for iset in V.dof_dset.local_ises) + Z_local_ises_indices = splitting(V_local_ises_indices) # Build index sets for the patches - ises = [] - if depth != -1: - (start, end) = mesh_dm.getDepthStratum(depth) - patch_dim = depth - else: - (start, end) = mesh_dm.getHeightStratum(height) - patch_dim = mesh_dm.getDimension() - height - validate_overlap(mesh_unique, patch_dim, "vanka") - - for seed in range(start, end): - # Only build patches over owned DoFs - if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1: - continue - - # Create point list from mesh DM - star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False) - star = order_points(mesh_dm, star, ordering, self.prefix) - pt_array = [] - for pt in reversed(star): - closure, _ = mesh_dm.getTransitiveClosure(pt, useCone=True) - pt_array.extend(closure) - # Grab unique points with stable ordering - pt_array = list(reversed(dict.fromkeys(pt_array))) - - # Get DoF indices for patch - indices = [] - for (i, W) in enumerate(V): - section = W.dm.getDefaultSection() - if i in exclude_subspaces: - loop_list = star if include_star else [seed] - else: - loop_list = pt_array - for p in loop_list: - dof = section.getDof(p) - if dof <= 0: - continue - off = section.getOffset(p) - # Local indices within W - W_indices = slice(off*W.block_size, W.block_size * (off + dof)) - indices.extend(V_local_ises_indices[i][W_indices]) - iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF) - ises.append(iset) - + colors = get_colors(mesh, use_coloring, depth, distance=2) + ises = [build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, self.prefix, + include_star, color) for color in colors] return ises @@ -309,13 +258,12 @@ class ASMLinesmoothPC(ASMPatchPC): _prefix = "pc_linesmooth_" def get_patches(self, V): - mesh = V._mesh - if len(set(mesh)) == 1: - mesh_unique = mesh.unique() - else: + try: + mesh = V.mesh().unique() + except NonUniqueMeshSequenceError: raise NotImplementedError("Not implemented for general mixed meshes") - assert mesh_unique.cell_set._extruded - dm = mesh_unique.topology_dm + assert mesh.cell_set._extruded + dm = mesh.topology_dm section = V.dm.getDefaultSection() # Obtain the codimensions to loop over from options, if present opts = PETSc.Options(self.prefix) @@ -419,14 +367,13 @@ class ASMExtrudedStarPC(ASMStarPC): _prefix = 'pc_star_' def get_patches(self, V): - mesh = V.mesh() - if len(set(mesh)) == 1: - mesh_unique = mesh.unique() - else: + try: + mesh = V.mesh().unique() + except NonUniqueMeshSequenceError: raise NotImplementedError("Not implemented for general mixed meshes") - mesh_dm = mesh_unique.topology_dm - nlayers = mesh_unique.layers - if not mesh_unique.cell_set._extruded: + mesh_dm = mesh.topology_dm + nlayers = mesh.layers + if not mesh.cell_set._extruded: return super(ASMExtrudedStarPC, self).get_patches(V) periodic = mesh.extruded_periodic @@ -434,6 +381,7 @@ def get_patches(self, V): opts = PETSc.Options(self.prefix) depth = opts.getInt("construct_dim", default=0) ordering = opts.getString("mat_ordering_type", default="natural") + use_coloring = opts.getBool("use_coloring", default=False) # Accessing .indices causes the allocation of a global array, # so we need to cache these for efficiency @@ -475,59 +423,57 @@ def get_patches(self, V): else: continue - validate_overlap(mesh_unique, base_depth, "star") - start, end = mesh_dm.getDepthStratum(base_depth) - for seed in range(start, end): - # Only build patches over owned DoFs - if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1: - continue + validate_overlap(mesh, base_depth, "star") - # Create point list from mesh DM - points, _ = mesh_dm.getTransitiveClosure(seed, useCone=False) - points = order_points(mesh_dm, points, ordering, self.prefix) + num_layer_seeds = nlayers-1 if (periodic or interval_depth) else nlayers + num_layer_colors = 2 if use_coloring else num_layer_seeds + + colors = get_colors(mesh, use_coloring, base_depth, distance=1) + for color in colors: + points = get_star_points(mesh_dm, ordering, self.prefix, color) + if len(points) == 0: + continue + points = numpy.asarray(points) points -= pstart # offset by chart start - num_seeds = nlayers - if periodic or interval_depth: - num_seeds -= 1 - for layer_seed in range(num_seeds): + for layer_color in range(num_layer_colors): indices = [] - # Get DoF indices for patch - for i, W in enumerate(V): - iset = V_ises[i] - for layer_dim, layer_shift in layer_entities: - layer = layer_seed - layer_shift - if periodic: - # Handle periodic case - layer = layer % (nlayers-1) - elif layer < 0 or (layer + layer_dim) >= nlayers: - # We are out of bounds - continue - - for p in points: - # How to walk up one layer - blayer_offset = basemeshlayeroffsets[i][p] - if blayer_offset <= 0: - # In this case we don't have any dofs on - # this entity. + for layer_seed in range(layer_color, num_layer_seeds, num_layer_colors): + # Get DoF indices for patch + for i, W in enumerate(V): + iset = V_ises[i] + for layer_dim, layer_shift in layer_entities: + layer = layer_seed - layer_shift + if periodic: + # Handle periodic case + layer = layer % (nlayers-1) + elif layer < 0 or (layer + layer_dim) >= nlayers: + # We are out of bounds continue - # Offset in the global array for the bottom of - # the column - off = basemeshoff[i][p] - # Number of dofs in the interior of the - # vertical interval cell on top of this base - # entity - dof = basemeshdof[i][p] - # Hard-code taking the star - if layer_dim == 0: - begin = off + layer * blayer_offset - end = off + layer * blayer_offset + dof - else: - begin = off + layer * blayer_offset + dof - end = off + (layer + 1) * blayer_offset - zlice = slice(W.block_size * begin, W.block_size * end) - indices.extend(iset[zlice]) + for p in points: + # How to walk up one layer + blayer_offset = basemeshlayeroffsets[i][p] + if blayer_offset <= 0: + # In this case we don't have any dofs on + # this entity. + continue + # Offset in the global array for the bottom of + # the column + off = basemeshoff[i][p] + # Number of dofs in the interior of the + # vertical interval cell on top of this base + # entity + dof = basemeshdof[i][p] + # Hard-code taking the star + if layer_dim == 0: + begin = off + layer * blayer_offset + end = off + layer * blayer_offset + dof + else: + begin = off + layer * blayer_offset + dof + end = off + (layer + 1) * blayer_offset + zlice = slice(W.block_size * begin, W.block_size * end) + indices.extend(iset[zlice]) iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF) ises.append(iset) return ises @@ -554,3 +500,97 @@ def validate_overlap(mesh, patch_dim, patch_type): if overlap_depth < patch_depth: warning(f"Mesh overlap depth of {overlap_depth} does not support {patch_type}-patches. " "Did you forget to set overlap_type in your mesh's distribution_parameters?") + + +def get_colors(mesh, use_coloring, depth, distance=1): + """For a given entity dimension (depth), constructs a coloring of the + entities if use_coloring=True, otherwise returns all entities visible by + this process. + """ + mesh_dm = mesh.topology_dm + point_subset = None + if use_coloring: + colors = mesh_dm.createColoring(depth=depth, distance=distance) + if point_subset is not None: + colors = tuple(numpy.intersect1d(point_subset, color.indices) for color in colors) + else: + if point_subset is None: + colors = range(*mesh_dm.getDepthStratum(depth)) + else: + colors = point_subset + return colors + + +def get_entity_dofs(V, V_local_ises_indices, points): + """Extract degrees of freedom associated to mesh entities (points of the DMPlex).""" + indices = [] + for (i, W) in enumerate(V): + section = W.dm.getDefaultSection() + for p in points: + dof = section.getDof(p) + if dof <= 0: + continue + off = section.getOffset(p) + # Local indices within W + W_slice = slice(off*W.block_size, W.block_size * (off + dof)) + indices.extend(V_local_ises_indices[i][W_slice]) + return indices + + +def get_star_points(mesh_dm, ordering, prefix, seed_points): + """Get DMPlex points in the star of each point in seed_points.""" + if isinstance(seed_points, PETSc.IS): + seed_points = seed_points.indices + elif numpy.isscalar(seed_points): + seed_points = (seed_points,) + points = [] + for seed in seed_points: + # Only build patches over owned DoFs + if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1: + continue + # Create point list from mesh DM + star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False) + star = order_points(mesh_dm, star, ordering, prefix) + points.extend(star) + return points + + +def build_star_indices(V, V_local_ises_indices, mesh_dm, ordering, prefix, seed_points): + """Build index sets for star patches.""" + points = get_star_points(mesh_dm, ordering, prefix, seed_points) + indices = get_entity_dofs(V, V_local_ises_indices, points) + iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF) + return iset + + +def build_vanka_indices(Z, Z_local_ises_indices, mesh_dm, ordering, prefix, include_star, seed_points): + """Build index sets for Vanka patches.""" + if isinstance(seed_points, PETSc.IS): + seed_points = seed_points.indices + elif numpy.isscalar(seed_points): + seed_points = (seed_points,) + V_points = [] + Q_points = [] + for seed in seed_points: + # Only build patches over owned DoFs + if mesh_dm.getLabelValue("pyop2_ghost", seed) != -1: + continue + # Create point list from mesh DM + star, _ = mesh_dm.getTransitiveClosure(seed, useCone=False) + star = order_points(mesh_dm, star, ordering, prefix) + if include_star: + Q_points.extend(star) + else: + Q_points.append(seed) + + closure = [] + for s in reversed(star): + cs, _ = mesh_dm.getTransitiveClosure(s, useCone=True) + closure.extend(cs) + # Grab unique points with stable ordering + V_points.extend(reversed(dict.fromkeys(closure))) + + indices = get_entity_dofs(Z[0], Z_local_ises_indices[0], V_points) + indices.extend(get_entity_dofs(Z[1], Z_local_ises_indices[1], Q_points)) + iset = PETSc.IS().createGeneral(indices, comm=PETSc.COMM_SELF) + return iset diff --git a/tests/firedrake/regression/test_star_pc.py b/tests/firedrake/regression/test_star_pc.py index d6e1383163..9f00caa979 100644 --- a/tests/firedrake/regression/test_star_pc.py +++ b/tests/firedrake/regression/test_star_pc.py @@ -11,7 +11,7 @@ def problem_type(request): return request.param -@pytest.fixture(params=["petscasm", "tinyasm"]) +@pytest.fixture(params=["petscasm", "tinyasm", "petscasm-coloring"]) def backend(request): return request.param @@ -25,6 +25,10 @@ def filter_warnings(caller): def test_star_equivalence(problem_type, backend): distribution_parameters = {"partition": True, "overlap_type": (DistributedMeshOverlapType.VERTEX, 1)} + use_coloring = False + if backend.endswith("coloring"): + backend, _ = backend.split("-") + use_coloring = True if problem_type == "scalar": base = UnitSquareMesh(10, 10, distribution_parameters=distribution_parameters) @@ -169,6 +173,7 @@ def test_star_equivalence(problem_type, backend): "mg_coarse_mat_type": "aij", "mg_coarse_pc_type": "lu"} + star_params["mg_levels_pc_star_use_coloring"] = use_coloring star_params["mg_levels_pc_star_backend"] = backend star_params["mg_levels_pc_star_mat_ordering_type"] = "rcm" nvproblem = NonlinearVariationalProblem(a, u, bcs=bcs) @@ -182,6 +187,10 @@ def test_star_equivalence(problem_type, backend): comp_its = comp_solver.snes.getLinearSolveIterations() assert star_its == comp_its + if use_coloring: + l = len(mh) - 1 + star_patches = len(star_solver.snes.ksp.pc.getMGSmoother(l).pc.getPythonContext().asmpc.getASMSubKSP()) + assert star_patches < 12 def test_vanka_equivalence(problem_type):