"""
Contains factories for activation functions and normalization layers.
"""
[docs]def activations_dict():
import MinkowskiEngine as ME
from . import nonlinearities
activations = {
'relu': ME.MinkowskiReLU,
'lrelu': nonlinearities.MinkowskiLeakyReLU,
'prelu': ME.MinkowskiPReLU,
'selu': ME.MinkowskiSELU,
'celu': ME.MinkowskiCELU,
'mish': nonlinearities.MinkowskiMish,
'elu': nonlinearities.MinkowskiELU,
'tanh': ME.MinkowskiTanh,
'sigmoid': ME.MinkowskiSigmoid
}
return activations
[docs]def activations_construct(name, **kwargs):
activations = activations_dict()
if name not in activations:
raise Exception("Unknown activation function name provided")
return activations[name](**kwargs)
[docs]def normalizations_dict():
import MinkowskiEngine as ME
from . import normalizations
from .blocks import Identity
norm_layers = {
'none': Identity,
'batch_norm': ME.MinkowskiBatchNorm,
'instance_norm': ME.MinkowskiInstanceNorm,
'pixel_norm': normalizations.MinkowskiPixelNorm
}
return norm_layers
[docs]def normalizations_construct(name, *args, **kwargs):
norm_layers = normalizations_dict()
if name not in norm_layers:
raise Exception("Unknown normalization layer name provided")
return norm_layers[name](*args, **kwargs)