-
Notifications
You must be signed in to change notification settings - Fork 206
Description
🐛 Bug Description
When the condition is more than 1-dimensional (e.g. an image instead of a vector), FMPE trains successfully but fails to sample because of a shape mismatch.
🔄 Steps to Reproduce
This is a bit of a stupid example, but the parameter sets the mean, and x is a matrix where each entry is sampled from a normal with that mean.
import torch
from sbi.inference import FMPE
from sbi.utils import BoxUniform
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets import posterior_flow_nn
prior = BoxUniform(low=torch.Tensor([-1.0]), high=torch.Tensor([1.0]))
def simulator(theta):
mean = torch.ones(10,10)
mean = mean.unsqueeze(0).expand(theta.shape[0], -1, -1)
print(mean.shape)
mean = mean + theta.unsqueeze(-1)#mean is a 10x10 matrix with all elements equal to theta
x = mean + 0.1*torch.randn_like(mean)
return x
theta = prior.sample((1000,))
x = simulator(theta)
embedding_net = CNNEmbedding(
input_shape = (10,10),
in_channels=1,
output_dim = 5,
kernel_size=3
)
flow_estimator = posterior_flow_nn(
model='mlp',
embedding_net=embedding_net,
)
inference = FMPE(prior=prior, vf_estimator=flow_estimator)
_ = inference.append_simulations(theta, x).train(max_num_epochs=2)
posterior = inference.build_posterior()
x_o = 0.1*torch.randn(1,10,10) + 0.5
samples = posterior.sample((10,), x=x_o)This fails with the error RuntimeError: shape '[10, 1]' is invalid for input of size 100
✅ Expected Behavior
We get posterior samples
...
📌 Additional Context
I've come across this problem some time ago, and traced it down to this line in ZukoNeuralODE
sbi/sbi/samplers/ode_solvers/zuko_ode.py
Lines 95 to 98 in 06f13a8
| return NormalizingFlow( | |
| transform=transform, | |
| base=DiagNormal(self.mean_base, self.std_base).expand(condition.shape[:-1]), | |
| ) |
We are expanding the batch size to using condition.shape[:-1] as the condition batch size. But of course, if the condition is 2D, then the batch size is condition.shape[:-2] etc. My proposed solution just passes the condition_shape of the ConditionalVectorFieldEstimator to ZukoNeuralODE so that we can appropriately expand to the batch size here. I can make a quick PR to do this if that sounds right.