nessai.flows.utils
Various utilities for implementing normalising flows.
Module Contents
Functions
|
SiLU (Sigmoid-weighted Linear Unit) activation function. |
Get the base distribution for a flow. |
|
|
Get the number of neurons. |
|
Get a natively implemented flow class. |
|
Get the class to use for the normalizing flow from a string. |
|
Setup the flow form a configuration dictionary. |
|
Reset parameters of a given module in place. |
|
Resets permutations and linear transforms for a given module in place. |
|
Function for creating linear transforms. |
|
Create a pre transform. |
- nessai.flows.utils.silu(x)
SiLU (Sigmoid-weighted Linear Unit) activation function.
Also known as swish.
Elfwing et al 2017: https://arxiv.org/abs/1702.03118v3
- nessai.flows.utils.get_base_distribution(n_inputs: int, distribution: str | Type[glasflow.nflows.distributions.Distribution], **kwargs) glasflow.nflows.distributions.Distribution
Get the base distribution for a flow.
Includes special configuration for certain distributions.
- Parameters:
- n_inputsint
Number of inputs to the distribution.
- distributionUnion[str, Type[glasflow.nflows.distribution.Distribution]]
Distribution class or name of known distribution
- kwargsAny
Keyword arguments used when creating an instance of distribution.
- nessai.flows.utils.get_n_neurons(n_neurons: int | None = None, n_inputs: int | None = None, default: int = 8) int
Get the number of neurons.
- Parameters:
- n_neuronsOptional[int]
Number of neurons.
- n_inputs: Optional[int]
Number of inputs.
- defaultint
Default value if
n_neurons
andn_inputs
are not given.
- Returns:
- int
Number of neurons.
- Raises:
- ValueError
Raised if the number of inputs could not be converted to an integer.
Notes
If
n_inputs
is also specified then the options forn_neurons
are either a value that can be converted to anint
or one of the following:'auto'
or'double'
: uses twice the number of inputs'equal'
: uses the number of inputs'half'
: uses half the number of inputsNone
: falls back to'auto'
- nessai.flows.utils.get_native_flow_class(name)
Get a natively implemented flow class.
- nessai.flows.utils.get_flow_class(name: str)
Get the class to use for the normalizing flow from a string.
- nessai.flows.utils.configure_model(config)
Setup the flow form a configuration dictionary.
- nessai.flows.utils.reset_weights(module)
Reset parameters of a given module in place.
Uses the
reset_parameters
method fromtorch.nn.Module
Also checks the following modules from glasflow.nflows
glasflow.nflows.transforms.normalization.BatchNorm
- Parameters:
- module
torch.nn.Module
Module to reset
- module
- nessai.flows.utils.reset_permutations(module)
Resets permutations and linear transforms for a given module in place.
Resets using the original initialisation method. This needed since they do not have a
reset_parameters
method.- Parameters:
- module
torch.nn.Module
Module to reset
- module
- nessai.flows.utils.create_linear_transform(linear_transform, features)
Function for creating linear transforms.
- Parameters:
- linear_transform{‘permutation’, ‘lu’, ‘svd’}
Linear transform to use.
- featresint
Number of features.
- nessai.flows.utils.create_pre_transform(pre_transform, features, **kwargs)
Create a pre transform.
- Parameters:
- pre_transformstr, {logit, batch_norm}
Name of the transform
- featuresint
Number of input features
- kwargs
Keyword arguments passed to the transform class.