Skip to content

Commit dfab348

Browse files
committed
more tutorials
1 parent e72934e commit dfab348

6 files changed

Lines changed: 201 additions & 8 deletions

File tree

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ jaxion
106106
pages/tutorial3
107107
pages/tutorial4
108108
pages/tutorial5
109+
pages/tutorial6
109110

110111
.. toctree::
111112
:maxdepth: 1

docs/pages/tutorial1.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ It will also often be useful to import `jax.numpy <https://docs.jax.dev/en/lates
2525
2626
to set up initial conditions, which as JAX arrays.
2727

28+
By default, JAX runs in single-precision mode, which is sufficient for many Jaxion simulations.
29+
But if you'd like to run in double-precision mode, you can set:
30+
31+
.. code-block:: python
32+
33+
jax.config.update("jax_enable_x64", True)
34+
2835
The basic steps are:
2936

3037
(1) set simulation parameters,
@@ -59,6 +66,7 @@ Let's look at the parameters for this example:
5966
"time": {
6067
"start": 0.0,
6168
"end": 1.0,
69+
"safety_factor": 1.0
6270
},
6371
"output": {
6472
"path": f"./checkpoints",
@@ -85,6 +93,8 @@ with a base resolution of 32 grid cells per side,
8593
and a resolution multiplier of 2 (meaning the effective resolution is 64 grid cells per side).
8694

8795
In the ``"time"`` section, we set the simulation to start at time 0.0 and end at time 1.0 (in code units (kpc / (km/s))).
96+
The time step in a Jaxion simulation is, by default, the quantum kinetically-limited time step,
97+
dt = (m_per_hbar / 6.0) * (dx * dx), which can be scaled by the ``safety_factor`` parameter (here set to 1.0).
8898

8999
In the ``"output"`` section, we specify the output path for saving simulation checkpoints,
90100
the number of checkpoints to save (100), and enable saving.

docs/pages/tutorial3.rst

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,29 @@
1-
Tutorial 3: Custom Callbacks
2-
============================
1+
Tutorial 3: Sponge Boundary Conditions
2+
======================================
33

4-
TODO XXX
4+
This tutorial describes how to add sponge boundary conditions.
5+
By default, Jaxion uses periodic boundary conditions.
6+
Sponge boundary conditions can be used to absorb outgoing waves at the domain boundaries.
7+
8+
An example of sponge boundary conditions is provided in `examples/tidal_stripping <https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping>`_
9+
10+
It is created by adding an external potential to the simulation.
11+
We need to set ``params["physics"]["external_potential"] = True``.
12+
13+
.. code-block:: python
14+
15+
# V_0 is the depth of the sponge potential ~ G * M_tot / box_size
16+
# r_N, r_p, and r_s define the sponge region
17+
r_N = 0.5 * box_size
18+
r_p = (7 / 8) * r_N
19+
r_s = 0.5 * (r_N + r_p)
20+
delta = r_N - r_p
21+
V_sponge = (
22+
-0.5j
23+
* V_0
24+
* (2 + jnp.tanh((R - r_s) / delta) - jnp.tanh(r_s / delta))
25+
* jnp.heaviside(R - r_p, 0.0)
26+
)
27+
sim.state["V_ext"] = V_sponge
28+
29+
Since the sponge potential is imaginary, it will absorb outgoing waves in the sponge region.

docs/pages/tutorial4.rst

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,39 @@
1-
Tutorial 4: Custom Fields
2-
=========================
1+
Tutorial 4: Custom Callbacks
2+
============================
33

4-
TODO XXX
4+
This tutorial describes how to create custom callbacks in Jaxion simulations.
5+
The callback function is called at each time step during the simulation.
6+
It can be used to perform custom calculations or save info on the simulation state at each step.
7+
8+
An example of a custom callback is provided in `examples/black_hole_accretion <https://github.com/JaxionProject/jaxion/tree/main/examples/black_hole_accretion>`_
9+
where the callback is used to save the mass of a black hole at each time step.
10+
11+
It is simply done by creating new simulation state variables and defining a callback function:
12+
13+
.. code-block:: python
14+
15+
# add callback to record info about state
16+
n_buffer = sim.nt + 1
17+
sim.state["tt"] = jnp.full((n_buffer,), jnp.nan)
18+
sim.state["m_bh"] = jnp.full((n_buffer,), jnp.nan)
19+
sim.state["tt"] = sim.state["tt"].at[0].set(0.0)
20+
sim.state["m_bh"] = sim.state["m_bh"].at[0].set(M_bh)
21+
sim.callback = callback
22+
23+
def callback(i, state):
24+
# record the black hole mass at end of timestep i
25+
state["tt"] = state["tt"].at[i + 1].set(state["t"])
26+
state["m_bh"] = state["m_bh"].at[i + 1].set(state["mass"][0])
27+
return state
28+
29+
The callback function takes two arguments:
30+
- `i`: the current time step index
31+
- `state`: the current simulation state
32+
The function should return the updated state.
33+
34+
In this example, the black hole mass looks something like this:
35+
36+
.. figure:: ../../examples/black_hole_accretion/callback.png
37+
:width: 300px
38+
:align: center
39+
:alt: black_hole_accretion callback

docs/pages/tutorial5.rst

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,58 @@
1-
Tutorial 5: Multiple GPUs
1+
Tutorial 5: Custom Fields
22
=========================
33

4-
TODO XXX
4+
This tutorial describes how to create custom quantum dark matter fields in Jaxion simulations.
5+
An example is provided in `examples/two_field <https://github.com/JaxionProject/jaxion/tree/main/examples/two_field>`_
6+
7+
First, ``params["physics"]["quantum"]`` should be set to ``False`` since we will be adding out own fields.
8+
9+
For example, to create a two-field simulation, we can add the simulation states:
10+
``sim.state["psi1"]`` and ``sim.state["psi2"]``.
11+
12+
Then, we need to define custom kick, drift, and total density functions, and attach them to the simulation object:
13+
14+
.. code-block:: python
15+
16+
def custom_density(state):
17+
return jnp.abs(state["psi1"]) ** 2 + jnp.abs(state["psi2"]) ** 2
18+
19+
def custom_kick(state, V, dt):
20+
state["psi1"] = jnp.exp(-1j * m1_per_hbar * dt * V) * state["psi1"]
21+
state["psi2"] = jnp.exp(-1j * m2_per_hbar * dt * V) * state["psi2"]
22+
23+
return state
24+
25+
def custom_drift(state, k_sq, dt):
26+
psi1_hat = jd.fft.pfft3d(state["psi1"])
27+
psi1_hat = jnp.exp(dt * (-1.0j * k_sq / m1_per_hbar / 2.0)) * psi1_hat
28+
state["psi1"] = jd.fft.pifft3d(psi1_hat)
29+
30+
psi2_hat = jd.fft.pfft3d(state["psi2"])
31+
psi2_hat = jnp.exp(dt * (-1.0j * k_sq / m2_per_hbar / 2.0)) * psi2_hat
32+
state["psi2"] = jd.fft.pifft3d(psi2_hat)
33+
34+
return state
35+
36+
sim.custom_density = custom_density
37+
sim.custom_kick = custom_kick
38+
sim.custom_drift = custom_drift
39+
40+
In this example, the custom density function is used to compute the total density from both fields.
41+
The custom kick and drift functions are used to update the fields during the simulation.
42+
43+
The simulation can then be run as usual with ``sim.run()``.
44+
45+
It is also possible to add custom plotting for these fields by defining a custom plotting function and attaching it to the simulation object:
46+
``custom_plot(state, checkpoint_dir, i, params)``.
47+
48+
The Two-Field example evolves two quantum fields with different masses that interact gravitationally, and looks something like this:
49+
50+
.. figure:: ../../examples/two_field/rho1_070.png
51+
:width: 300px
52+
:align: center
53+
:alt: two_field field 1
54+
55+
.. figure:: ../../examples/two_field/rho2_070.png
56+
:width: 300px
57+
:align: center
58+
:alt: two_field field 2

docs/pages/tutorial6.rst

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
Tutorial 6: Multiple GPUs
2+
=========================
3+
4+
This tutorial describes how to run Jaxion simulations on distributed GPUs.
5+
An example is provided in `examples/soliton_binary_merger <https://github.com/JaxionProject/jaxion/tree/main/examples/soliton_binary_merger>`_
6+
7+
First, ensure that JAX is initialized for distributed GPU use, by calling:
8+
9+
.. code-block:: python
10+
11+
import jax
12+
13+
jax.distributed.initialize()
14+
15+
On some machines, you may need to specify additional parameters to the ``initialize()`` function.
16+
For most SLURM clusters, you will not need to specify anything.
17+
18+
If you'd like your Python script to print info, it should by guarded by:
19+
20+
.. code-block:: python
21+
22+
if jax.process_index() == 0:
23+
print("Using distributed GPU mode")
24+
25+
to prevent multiple processes from printing simultaneously.
26+
27+
28+
Sharding
29+
--------
30+
31+
We need to set up sharding for the simulation state arrays.
32+
Sharding splits the arrays across multiple devices for distributed computation.
33+
This can be done as follows:
34+
35+
.. code-block:: python
36+
37+
from jax.experimental import mesh_utils
38+
from jax.sharding import Mesh, PartitionSpec, NamedSharding
39+
40+
# Create mesh and sharding for distributed computation
41+
n_devices = jax.device_count()
42+
devices = mesh_utils.create_device_mesh((1, n_devices))
43+
mesh = Mesh(devices, axis_names=("x", "y"))
44+
sharding = NamedSharding(mesh, PartitionSpec("x", "y"))
45+
46+
In the example above, we create a virtual device mesh with one row and ``n_devices`` columns.
47+
Arrays (2D and 3D) are split along the "y" axis across the devices.
48+
49+
When creating the simulation object, we need to pass the sharding to it:
50+
51+
.. code-block:: python
52+
53+
sim = jaxion.Simulation(params, sharding=sharding)
54+
55+
And that's it! Jaxion will now run the simulation on multiple GPUs.
56+
57+
If you grab arrays from the simulation, such as the grid (``sim.grid``) or spectral grid (``sim.kgrid``),
58+
these will be sharded arrays too.
59+
60+
61+
Slurm Example
62+
-------------
63+
64+
An example SLURM submission script for running our example on the Flatiron Rusty cluster
65+
is provided below:
66+
67+
.. literalinclude:: ../../examples/soliton_binary_merger/sbatch_rusty_distributed.sh
68+
:language: bash

0 commit comments

Comments
 (0)