nessai.flows.base
Base objects for implementing normalising flows.
Module Contents
Classes
Base class for all normalising flows. |
|
Base class for flow objects from glasflow.nflows. |
- class nessai.flows.base.BaseFlow(*args, **kwargs)
Bases:
torch.nn.Module
,abc.ABC
Base class for all normalising flows.
If implementing flows using distributions and transforms see NFlow.
- to(device)
Wrapper that stores the device before moving the flow
- abstract forward(x, context=None)
Apply the forward transformation and return samples in the latent space and the log-Jacobian determinant.
- Returns:
torch.Tensor
Tensor of samples in the latent space
torch.Tensor
Tensor of log determinants of the Jacobian of the forward transformation
- abstract inverse(z, context=None)
Apply the inverse transformation and return samples in the data space and the log-Jacobian determinant.
- Returns:
torch.Tensor
Tensor of samples in the data space
torch.Tensor
Tensor of log determinants of the Jacobian of the forward transformation
- abstract sample(n, context=None)
Generate n samples in the data space
- Returns:
torch.Tensor
Tensor of samples in the data space
- abstract log_prob(x, context=None)
Compute the log probability for a set of samples in the data space
- Returns:
torch.Tensor
Tensor of log probabilities of the samples
- abstract sample_latent_distribution(n, context=None)
Sample from the latent distribution.
- abstract base_distribution_log_prob(z, context=None)
Computes the log probability of samples in the latent for the base distribution in the flow.
- Returns:
torch.Tensor
Tensor of log probabilities of the latent samples
- abstract forward_and_log_prob(x, context=None)
Apply the forward transformation and compute the log probability of each sample
- Returns:
torch.Tensor
Tensor of samples in the latent space
torch.Tensor
Tensor of log probabilities of the samples
- abstract sample_and_log_prob(n, context=None)
Generates samples from the flow, together with their log probabilities in the data space
log p(x) = log p(z) + log|J|
.For flows, this is more efficient that calling
sample
andlog_prob
separately.- Returns:
torch.Tensor
Tensor of samples in the data space
torch.Tensor
Tensor of log probabilities of the samples
- finalise()
Finalise the flow after training.
Will be called after training the flow and loading the best weights. For example, can be used to finalise the Monte Carlo estimate of the normalising constant used in a LARS based flow.
By default does nothing and should be implemented by the user.
- end_iteration()
Update the model at the end of an iteration.
Will be called between training and validation.
By default does nothing and should be overridden by an class that inherit from this class.
- abstract freeze_transform()
Freeze the transform part of the flow.
Must be implemented by the child class.
- abstract unfreeze_transform()
Unfreeze the transform part of the flow.
Must be implemented by the child class.
- class nessai.flows.base.NFlow(transform, distribution)
Bases:
BaseFlow
Base class for flow objects from glasflow.nflows.
This replaces Flow from glasflow.nflows. It includes additional methods which are called in FlowModel.
- Parameters:
- transform:obj: glasflow.nflows.transforms.Transform
Object that applies the transformation, must have`forward` and inverse methods. See glasflow.nflows for more details.
- distribution:obj: glasflow.nflows.distributions.Distribution
Object the serves as the base distribution used when sampling and computing the log probability. Must have log_prob and sample methods. See glasflow.nflows for details
- forward(x, context=None)
Apply the forward transformation and return samples in the latent space and log-Jacobian determinant.
- inverse(z, context=None)
Apply the inverse transformation and return samples in the data space and log-Jacobian determinant (not log probability).
- sample(num_samples, context=None)
Produces N samples in the data space by drawing from the base distribution and the applying the inverse transform.
Does NOT need to be specified by the user
- log_prob(inputs, context=None)
Computes the log probability of the inputs samples by apply the transform.
Does NOT need to specified by the user
- sample_latent_distribution(n, context=None)
Sample from the latent distribution.
- base_distribution_log_prob(z, context=None)
Computes the log probability of samples in the latent for the base distribution in the flow.
- forward_and_log_prob(x, context=None)
Apply the forward transformation and compute the log probability of each sample
- Returns:
torch.Tensor
Tensor of samples in the latent space
torch.Tensor
Tensor of log probabilities of the samples
- sample_and_log_prob(N, context=None)
Generates samples from the flow, together with their log probabilities in the data space
log p(x) = log p(z) + log|J|
.For flows, this is more efficient that calling
sample
andlog_prob
separately.
- finalise()
Finalise the flow after training.
Checks if the base distribution or transform have finalise methods and calls them.
- end_iteration()
Update the model at the end of an iteration.
Will be called between training and validation.
- freeze_transform()
Freeze the transform part of the flow
- unfreeze_transform()
Unfreeze the transform part of the flow