Skip to content

Commit b062a56

Browse files
committed
fix NaN in fp32 sampling
1 parent 3db131b commit b062a56

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

MODEL_CARD.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ params, transformer = make_transformer(
4848

4949
## Speeds, Sizes, Times
5050
- Both models contain ~13.8M parameters
51-
- Generating 29,000 crystal samples on a single A100 GPU takes ~1,058 seconds (~37 ms per sample)
51+
- Generating 45,000 crystal samples on a single A100 GPU takes ~440 seconds (~10 ms per sample)

crystalformer/src/sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def sample_x(key, h_x, Kx, top_p, temperature, batchsize):
5555
k = sample_top_p(key_k, x_logit, top_p, temperature)
5656
loc = loc.reshape(batchsize, Kx)[jnp.arange(batchsize), k]
5757
kappa = kappa.reshape(batchsize, Kx)[jnp.arange(batchsize), k]
58+
kappa = jnp.clip(kappa, a_min=1e-6) # to avoid numerical issue
5859
x = sample_von_mises(key_x, loc, kappa/temperature, (batchsize,))
5960
x = (x+ jnp.pi)/(2.0*jnp.pi) # wrap into [0, 1]
6061
return key, x

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188
else:
189189

190190
print("\n========== Start sampling ==========")
191-
jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice
191+
# jax.config.update("jax_enable_x64", True) # to get off compilation warning, and to prevent sample nan lattice
192192
#FYI, the error was [Compiling module extracted] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
193193

194194
if args.formula is not None:

0 commit comments

Comments
 (0)