nessai.flows.utils#
Various utilities for implementing normalising flows.
Functions#
|
SiLU (Sigmoid-weighted Linear Unit) activation function. |
Get the base distribution for a flow. |
|
|
Get the number of neurons. |
|
Get a natively implemented flow class. |
|
Get the class to use for the normalizing flow from a string. |
|
Get the activation function from a string. |
|
Setup the flow form a configuration dictionary. |
|
Reset parameters of a given module in place. |
|
Resets permutations and linear transforms for a given module in place. |
|
Function for creating linear transforms. |
|
Create a pre transform. |
Module Contents#
- nessai.flows.utils.silu(x)#
SiLU (Sigmoid-weighted Linear Unit) activation function.
Also known as swish.
Elfwing et al 2017: https://arxiv.org/abs/1702.03118v3
- nessai.flows.utils.get_base_distribution(n_inputs: int, distribution: str | Type[glasflow.nflows.distributions.Distribution], **kwargs) glasflow.nflows.distributions.Distribution#
Get the base distribution for a flow.
Includes special configuration for certain distributions.
- Parameters:
- n_inputsint
Number of inputs to the distribution.
- distributionUnion[str, Type[glasflow.nflows.distribution.Distribution]]
Distribution class or name of known distribution
- kwargsAny
Keyword arguments used when creating an instance of distribution.
- nessai.flows.utils.get_n_neurons(n_neurons: int | None = None, n_inputs: int | None = None, default: int = 8) int#
Get the number of neurons.
- Parameters:
- n_neuronsOptional[int]
Number of neurons.
- n_inputs: Optional[int]
Number of inputs.
- defaultint
Default value if
n_neuronsandn_inputsare not given.
- Returns:
- int
Number of neurons.
- Raises:
- ValueError
Raised if the number of inputs could not be converted to an integer.
Notes
If
n_inputsis also specified then the options forn_neuronsare either a value that can be converted to anintor one of the following:'auto'or'double': uses twice the number of inputs'equal': uses the number of inputs'half': uses half the number of inputsNone: falls back to'auto'
- nessai.flows.utils.get_native_flow_class(name)#
Get a natively implemented flow class.
- nessai.flows.utils.get_flow_class(name: str)#
Get the class to use for the normalizing flow from a string.
- nessai.flows.utils.get_activation_function(name: str) Callable#
Get the activation function from a string.
- nessai.flows.utils.configure_model(config)#
Setup the flow form a configuration dictionary.
- nessai.flows.utils.reset_weights(module)#
Reset parameters of a given module in place.
Uses the
reset_parametersmethod fromtorch.nn.ModuleAlso checks the following modules from glasflow.nflows
glasflow.nflows.transforms.normalization.BatchNorm
- Parameters:
- module
torch.nn.Module Module to reset
- module
- nessai.flows.utils.reset_permutations(module)#
Resets permutations and linear transforms for a given module in place.
Resets using the original initialisation method. This needed since they do not have a
reset_parametersmethod.- Parameters:
- module
torch.nn.Module Module to reset
- module
- nessai.flows.utils.create_linear_transform(linear_transform, features)#
Function for creating linear transforms.
- Parameters:
- linear_transform{‘permutation’, ‘lu’, ‘svd’}
Linear transform to use.
- featresint
Number of features.
- nessai.flows.utils.create_pre_transform(pre_transform, features, **kwargs)#
Create a pre transform.
- Parameters:
- pre_transformstr, {logit, batch_norm}
Name of the transform
- featuresint
Number of input features
- kwargs
Keyword arguments passed to the transform class.