Skip to content

Cast rate to float when constructing Poisson to ensure consistent log prob behavior #2181

@dylanhmorris

Description

@dylanhmorris

Very excited about #2176. Wanted to propose a tweak to address a potential edge case. As written, this will cast real value arrays to integer arrays when the Poisson distribution instance was initialized with an integer rate.

return (
xlogy(jnp.astype(value, jnp.result_type(self.rate)), self.rate)
- gammaln(value + 1)
- self.rate
)

That in turns means that (for example) Poisson(2).log_prob(value) and Poisson(2.0).log_prob(value) can behave differently for the same value.

I think we should avoid this by casting all rates to floats at construction. Rationale below. Let me know what you think, @fehiepsi, @Qazalbash, and @tillahoffmann.

Rationale

jnp.astype (mostly) casts floats to ints as the floor, e.g.

import jax.numpy as jnp
jnp.astype(1.99999, "int32")
# Array(1, dtype=int32)

Even though the Poisson is supported only on the integers, the log_prob method evaluates the PMF expression for non-integer values (as opposed to returning -inf). For example:

import numpyro.distributions as dist
my_poisson = dist.Poisson(2.5)
my_poisson.log_prob(5.2)
# Array(-2.8675885, dtype=float32)

Notably, my_poisson.log_prob(0.9999) and my_poisson.log_prob(1) give very similar results.

Now suppose we construct a Poisson with an integer rate, e.g. Poisson(2). Depending on whether that rate was actually an int (2) or was a float (2., 2.0, etc.) evaluating log_prob can give very different results for the same near-integer value.

int_rate = dist.Poisson(2)
float_rate = dist.Poisson(2.)
int_rate.log_prob(2)
# Array(-1.306853, dtype=float32)
float_rate.log_prob(2)
# Array(-1.306853, dtype=float32)
int_rate.log_prob(1.999999)
# Array(-2.0000002, dtype=float32)
float_rate.log_prob(1.999999)
# Array(-1.3068538, dtype=float32)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions