Skip to content

Conversation

@copybara-service
Copy link

[graphcast] Prepare for jax_pmap_shmap_merge=True.

This change:

  • Inlines jax.device_put_sharded to take an axis_name so we can make sure the resulting axis name in the NamedSharding matches what the new pmap expects. _device_put_sharded falls back to the original jax.device_put_sharded when jax_pmap_shmap_merge=False.
  • Strips sharding metadata from RNG keys after pmapped split_rng_fn calls. There is a small performance penalty and not ideal, but is less intrusive of a change than making sure all user predictor_fns pmap with the same devices.

See https://docs.jax.dev/en/latest/migrate_pmap.html for more information.

This change:
- Inlines `jax.device_put_sharded` to take an `axis_name` so we can make sure the resulting axis name in the `NamedSharding` matches what the new pmap expects. `_device_put_sharded` falls back to the original `jax.device_put_sharded` when `jax_pmap_shmap_merge=False`.
- Strips sharding metadata from RNG keys after pmapped `split_rng_fn` calls. There is a small performance penalty and not ideal, but is less intrusive of a change than making sure all user `predictor_fn`s `pmap` with the same devices.

See https://docs.jax.dev/en/latest/migrate_pmap.html for more information.

PiperOrigin-RevId: 846566054
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant