nessai.flows.utils
==================

.. py:module:: nessai.flows.utils

.. autoapi-nested-parse::

   Various utilities for implementing normalising flows.

   ..
       !! processed by numpydoc !!


Functions
---------

.. autoapisummary::

   nessai.flows.utils.silu
   nessai.flows.utils.get_base_distribution
   nessai.flows.utils.get_n_neurons
   nessai.flows.utils.get_native_flow_class
   nessai.flows.utils.get_flow_class
   nessai.flows.utils.get_activation_function
   nessai.flows.utils.configure_model
   nessai.flows.utils.reset_weights
   nessai.flows.utils.reset_permutations
   nessai.flows.utils.create_linear_transform
   nessai.flows.utils.create_pre_transform


Module Contents
---------------

.. py:function:: silu(x)

   
   SiLU (Sigmoid-weighted Linear Unit) activation function.

   Also known as swish.

   Elfwing et al 2017: https://arxiv.org/abs/1702.03118v3















   ..
       !! processed by numpydoc !!

.. py:function:: get_base_distribution(n_inputs: int, distribution: Union[str, Type[glasflow.nflows.distributions.Distribution]], **kwargs) -> glasflow.nflows.distributions.Distribution

   
   Get the base distribution for a flow.

   Includes special configuration for certain distributions.

   :Parameters:

       **n_inputs** : int
           Number of inputs to the distribution.

       **distribution** : Union[str, Type[glasflow.nflows.distribution.Distribution]]
           Distribution class or name of known distribution

       **kwargs** : Any
           Keyword arguments used when creating an instance of distribution.














   ..
       !! processed by numpydoc !!

.. py:function:: get_n_neurons(n_neurons: Optional[int] = None, n_inputs: Optional[int] = None, default: int = 8) -> int

   
   Get the number of neurons.


   :Parameters:

       **n_neurons** : Optional[int]
           Number of neurons.

       **n_inputs: Optional[int]**
           Number of inputs.

       **default** : int
           Default value if :code:`n_neurons` and :code:`n_inputs` are not given.



   :Returns:

       int
           Number of neurons.




   :Raises:

       ValueError
           Raised if the number of inputs could not be converted to an integer.




   .. rubric:: Notes

   If :code:`n_inputs` is also specified then the options for
   :code:`n_neurons` are either a value that can be converted to an
   :code:`int` or one of the following:

       - :code:`'auto'` or :code:`'double'`: uses twice the number of inputs
       - :code:`'equal'`: uses the number of inputs
       - :code:`'half'`: uses half the number of inputs
       - :code:`None`: falls back to :code:`'auto'`



   ..
       !! processed by numpydoc !!

.. py:function:: get_native_flow_class(name)

   
   Get a natively implemented flow class.
















   ..
       !! processed by numpydoc !!

.. py:function:: get_flow_class(name: str)

   
   Get the class to use for the normalizing flow from a string.
















   ..
       !! processed by numpydoc !!

.. py:function:: get_activation_function(name: str) -> Callable

   
   Get the activation function from a string.
















   ..
       !! processed by numpydoc !!

.. py:function:: configure_model(config)

   
   Setup the flow form a configuration dictionary.
















   ..
       !! processed by numpydoc !!

.. py:function:: reset_weights(module)

   
   Reset parameters of a given module in place.

   Uses the ``reset_parameters`` method from ``torch.nn.Module``

   Also checks the following modules from glasflow.nflows

   - glasflow.nflows.transforms.normalization.BatchNorm

   :Parameters:

       **module** : :obj:`torch.nn.Module`
           Module to reset














   ..
       !! processed by numpydoc !!

.. py:function:: reset_permutations(module)

   
   Resets permutations and linear transforms for a given module in place.

   Resets using the original initialisation method. This needed since they
   do not have a ``reset_parameters`` method.

   :Parameters:

       **module** : :obj:`torch.nn.Module`
           Module to reset














   ..
       !! processed by numpydoc !!

.. py:function:: create_linear_transform(linear_transform, features)

   
   Function for creating linear transforms.


   :Parameters:

       **linear_transform** : {'permutation', 'lu', 'svd'}
           Linear transform to use.

       **featres** : int
           Number of features.














   ..
       !! processed by numpydoc !!

.. py:function:: create_pre_transform(pre_transform, features, **kwargs)

   
   Create a pre transform.


   :Parameters:

       **pre_transform** : str, {logit, batch_norm}
           Name of the transform

       **features** : int
           Number of input features

       **kwargs**
           Keyword arguments passed to the transform class.














   ..
       !! processed by numpydoc !!

