Skip to content

FMPE fails to sample with 2D conditions #1702

@gmoss13

Description

@gmoss13

🐛 Bug Description

When the condition is more than 1-dimensional (e.g. an image instead of a vector), FMPE trains successfully but fails to sample because of a shape mismatch.

🔄 Steps to Reproduce

This is a bit of a stupid example, but the parameter sets the mean, and x is a matrix where each entry is sampled from a normal with that mean.

import torch
from sbi.inference import FMPE
from sbi.utils import BoxUniform
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets import posterior_flow_nn


prior = BoxUniform(low=torch.Tensor([-1.0]), high=torch.Tensor([1.0]))
def simulator(theta):
    mean = torch.ones(10,10)
    mean = mean.unsqueeze(0).expand(theta.shape[0], -1, -1)
    print(mean.shape)
    mean = mean + theta.unsqueeze(-1)#mean is a 10x10 matrix with all elements equal to theta

    x = mean + 0.1*torch.randn_like(mean)
    return x

theta = prior.sample((1000,))
x = simulator(theta)

embedding_net = CNNEmbedding(
    input_shape = (10,10),
    in_channels=1,
    output_dim = 5,
    kernel_size=3
)
flow_estimator = posterior_flow_nn(
    model='mlp',
    embedding_net=embedding_net,
)

inference = FMPE(prior=prior, vf_estimator=flow_estimator)
_ = inference.append_simulations(theta, x).train(max_num_epochs=2)
posterior = inference.build_posterior()

x_o = 0.1*torch.randn(1,10,10) + 0.5

samples = posterior.sample((10,), x=x_o)

This fails with the error RuntimeError: shape '[10, 1]' is invalid for input of size 100

✅ Expected Behavior

We get posterior samples

...

📌 Additional Context

I've come across this problem some time ago, and traced it down to this line in ZukoNeuralODE

return NormalizingFlow(
transform=transform,
base=DiagNormal(self.mean_base, self.std_base).expand(condition.shape[:-1]),
)

We are expanding the batch size to using condition.shape[:-1] as the condition batch size. But of course, if the condition is 2D, then the batch size is condition.shape[:-2] etc. My proposed solution just passes the condition_shape of the ConditionalVectorFieldEstimator to ZukoNeuralODE so that we can appropriately expand to the batch size here. I can make a quick PR to do this if that sounds right.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions