Source code for mlreco.models.layers.common.activation_normalization_factories

"""
Contains factories for activation functions and normalization layers.
"""

[docs]def activations_dict(): import MinkowskiEngine as ME from . import nonlinearities activations = { 'relu': ME.MinkowskiReLU, 'lrelu': nonlinearities.MinkowskiLeakyReLU, 'prelu': ME.MinkowskiPReLU, 'selu': ME.MinkowskiSELU, 'celu': ME.MinkowskiCELU, 'mish': nonlinearities.MinkowskiMish, 'elu': nonlinearities.MinkowskiELU, 'tanh': ME.MinkowskiTanh, 'sigmoid': ME.MinkowskiSigmoid } return activations
[docs]def activations_construct(name, **kwargs): activations = activations_dict() if name not in activations: raise Exception("Unknown activation function name provided") return activations[name](**kwargs)
[docs]def normalizations_dict(): import MinkowskiEngine as ME from . import normalizations from .blocks import Identity norm_layers = { 'none': Identity, 'batch_norm': ME.MinkowskiBatchNorm, 'instance_norm': ME.MinkowskiInstanceNorm, 'pixel_norm': normalizations.MinkowskiPixelNorm } return norm_layers
[docs]def normalizations_construct(name, *args, **kwargs): norm_layers = normalizations_dict() if name not in norm_layers: raise Exception("Unknown normalization layer name provided") return norm_layers[name](*args, **kwargs)