Conversation
| momentum, _ = jax.flatten_util.ravel_pytree(momentum) | ||
| kinetic = 0.5 * jnp.dot(momentum, momentum) | ||
| hamiltonian = kinetic + state.log_prob | ||
| accept_prob = jnp.minimum(1.0, jnp.exp(hamiltonian - state.hamiltonian)) |
There was a problem hiding this comment.
As discussed, you can avoid the minimum and the exponential here. You can define
log_accept_ratio = hamiltonian - state.hamiltonian
See later for the accept/reject part.
| return revert_updates, state.params, state.hamiltonian | ||
|
|
||
| updates, new_params, new_hamiltonian = jax.lax.cond( | ||
| jax.random.uniform(uniform_key) < accept_prob, |
There was a problem hiding this comment.
Following the comment above, this line should become
jnp.log(jax.random.uniform(uniform_key)) < log_accept_ratio.
This is equivalent to what you have written but with one operation less. Alternatively, notice that -log(U) ~ Exponential(1)) if U~Uniform(0, 1). This means that you can also write
-jax.random.exponential(uniform_key)) < log_accept_ratio.
All of these should be equivalent. Please check that the lines I wrote are correct :-)
| """ | ||
|
|
||
| encoded_name: jnp.ndarray = convert_string_to_jnp_array("HMCState") | ||
| _encoded_which_params: Optional[Dict[str, List[Array]]] = None |
There was a problem hiding this comment.
I was expecting to see the stored _hamiltonian here too?
| **kwargs, | ||
| ) | ||
| state = state.replace( | ||
| opt_state=state.opt_state._replace(log_prob=aux["loss"]), |
There was a problem hiding this comment.
Should opt_state be added to the parameters of HMCState?
Add full-batch Hamiltonian Monte Carlo implementation.
Pull request type
Please check the type of change your PR introduces: