nessai.flows.base

Base objects for implementing normalising flows.

Module Contents

Classes

BaseFlow

Base class for all normalising flows.

NFlow

Base class for flow objects from glasflow.nflows.

class nessai.flows.base.BaseFlow(*args, **kwargs)

Bases: torch.nn.Module, abc.ABC

Base class for all normalising flows.

If implementing flows using distributions and transforms see NFlow.

to(device)

Wrapper that stores the device before moving the flow

abstract forward(x, context=None)

Apply the forward transformation and return samples in the latent space and the log-Jacobian determinant.

Returns:
torch.Tensor

Tensor of samples in the latent space

torch.Tensor

Tensor of log determinants of the Jacobian of the forward transformation

abstract inverse(z, context=None)

Apply the inverse transformation and return samples in the data space and the log-Jacobian determinant.

Returns:
torch.Tensor

Tensor of samples in the data space

torch.Tensor

Tensor of log determinants of the Jacobian of the forward transformation

abstract sample(n, context=None)

Generate n samples in the data space

Returns:
torch.Tensor

Tensor of samples in the data space

abstract log_prob(x, context=None)

Compute the log probability for a set of samples in the data space

Returns:
torch.Tensor

Tensor of log probabilities of the samples

abstract sample_latent_distribution(n, context=None)

Sample from the latent distribution.

abstract base_distribution_log_prob(z, context=None)

Computes the log probability of samples in the latent for the base distribution in the flow.

Returns:
torch.Tensor

Tensor of log probabilities of the latent samples

abstract forward_and_log_prob(x, context=None)

Apply the forward transformation and compute the log probability of each sample

Returns:
torch.Tensor

Tensor of samples in the latent space

torch.Tensor

Tensor of log probabilities of the samples

abstract sample_and_log_prob(n, context=None)

Generates samples from the flow, together with their log probabilities in the data space log p(x) = log p(z) + log|J|.

For flows, this is more efficient that calling sample and log_prob separately.

Returns:
torch.Tensor

Tensor of samples in the data space

torch.Tensor

Tensor of log probabilities of the samples

finalise()

Finalise the flow after training.

Will be called after training the flow and loading the best weights. For example, can be used to finalise the Monte Carlo estimate of the normalising constant used in a LARS based flow.

By default does nothing and should be implemented by the user.

end_iteration()

Update the model at the end of an iteration.

Will be called between training and validation.

By default does nothing and should be overridden by an class that inherit from this class.

abstract freeze_transform()

Freeze the transform part of the flow.

Must be implemented by the child class.

abstract unfreeze_transform()

Unfreeze the transform part of the flow.

Must be implemented by the child class.

class nessai.flows.base.NFlow(transform, distribution)

Bases: BaseFlow

Base class for flow objects from glasflow.nflows.

This replaces Flow from glasflow.nflows. It includes additional methods which are called in FlowModel.

Parameters:
transform:obj: glasflow.nflows.transforms.Transform

Object that applies the transformation, must have`forward` and inverse methods. See glasflow.nflows for more details.

distribution:obj: glasflow.nflows.distributions.Distribution

Object the serves as the base distribution used when sampling and computing the log probability. Must have log_prob and sample methods. See glasflow.nflows for details

forward(x, context=None)

Apply the forward transformation and return samples in the latent space and log-Jacobian determinant.

inverse(z, context=None)

Apply the inverse transformation and return samples in the data space and log-Jacobian determinant (not log probability).

sample(num_samples, context=None)

Produces N samples in the data space by drawing from the base distribution and the applying the inverse transform.

Does NOT need to be specified by the user

log_prob(inputs, context=None)

Computes the log probability of the inputs samples by apply the transform.

Does NOT need to specified by the user

sample_latent_distribution(n, context=None)

Sample from the latent distribution.

base_distribution_log_prob(z, context=None)

Computes the log probability of samples in the latent for the base distribution in the flow.

forward_and_log_prob(x, context=None)

Apply the forward transformation and compute the log probability of each sample

Returns:
torch.Tensor

Tensor of samples in the latent space

torch.Tensor

Tensor of log probabilities of the samples

sample_and_log_prob(N, context=None)

Generates samples from the flow, together with their log probabilities in the data space log p(x) = log p(z) + log|J|.

For flows, this is more efficient that calling sample and log_prob separately.

finalise()

Finalise the flow after training.

Checks if the base distribution or transform have finalise methods and calls them.

end_iteration()

Update the model at the end of an iteration.

Will be called between training and validation.

freeze_transform()

Freeze the transform part of the flow

unfreeze_transform()

Unfreeze the transform part of the flow