Very excited about #2176. Wanted to propose a tweak to address a potential edge case. As written, this will cast real value arrays to integer arrays when the Poisson distribution instance was initialized with an integer rate.
|
return ( |
|
xlogy(jnp.astype(value, jnp.result_type(self.rate)), self.rate) |
|
- gammaln(value + 1) |
|
- self.rate |
|
) |
That in turns means that (for example) Poisson(2).log_prob(value) and Poisson(2.0).log_prob(value) can behave differently for the same value.
I think we should avoid this by casting all rates to floats at construction. Rationale below. Let me know what you think, @fehiepsi, @Qazalbash, and @tillahoffmann.
Rationale
jnp.astype (mostly) casts floats to ints as the floor, e.g.
import jax.numpy as jnp
jnp.astype(1.99999, "int32")
# Array(1, dtype=int32)
Even though the Poisson is supported only on the integers, the log_prob method evaluates the PMF expression for non-integer values (as opposed to returning -inf). For example:
import numpyro.distributions as dist
my_poisson = dist.Poisson(2.5)
my_poisson.log_prob(5.2)
# Array(-2.8675885, dtype=float32)
Notably, my_poisson.log_prob(0.9999) and my_poisson.log_prob(1) give very similar results.
Now suppose we construct a Poisson with an integer rate, e.g. Poisson(2). Depending on whether that rate was actually an int (2) or was a float (2., 2.0, etc.) evaluating log_prob can give very different results for the same near-integer value.
int_rate = dist.Poisson(2)
float_rate = dist.Poisson(2.)
int_rate.log_prob(2)
# Array(-1.306853, dtype=float32)
float_rate.log_prob(2)
# Array(-1.306853, dtype=float32)
int_rate.log_prob(1.999999)
# Array(-2.0000002, dtype=float32)
float_rate.log_prob(1.999999)
# Array(-1.3068538, dtype=float32)
Very excited about #2176. Wanted to propose a tweak to address a potential edge case. As written, this will cast real
valuearrays to integer arrays when the Poisson distribution instance was initialized with an integer rate.numpyro/numpyro/distributions/discrete.py
Lines 807 to 811 in e708f34
That in turns means that (for example)
Poisson(2).log_prob(value)andPoisson(2.0).log_prob(value)can behave differently for the samevalue.I think we should avoid this by casting all rates to floats at construction. Rationale below. Let me know what you think, @fehiepsi, @Qazalbash, and @tillahoffmann.
Rationale
jnp.astype(mostly) casts floats to ints as the floor, e.g.Even though the Poisson is supported only on the integers, the
log_probmethod evaluates the PMF expression for non-integervalues (as opposed to returning-inf). For example:Notably,
my_poisson.log_prob(0.9999)andmy_poisson.log_prob(1)give very similar results.Now suppose we construct a Poisson with an integer
rate, e.g.Poisson(2). Depending on whether thatratewas actually an int (2) or was a float (2.,2.0, etc.) evaluatinglog_probcan give very different results for the same near-integer value.