mlreco.models.bayes_uresnet module¶
-
class
mlreco.models.bayes_uresnet.BayesianUResNet(cfg, name='mcdropout_uresnet')[source]¶ Bases:
torch.nn.modules.module.ModuleUResNet with Uncertainty Quantification
The backbone model consists of UResNet Encoder-Decoder format with standard residual layers for the shallow half and dropout residual layers for the deep half of the network.
- Configuration
mode (str) – string indicator for slight changes in network behavior/architecture. Supports three options:
standard: standard dropout segmentation network. This also
includes MCDropout segnet, since training behavior is identical for both standard and mcdropout networks. - evd: Changes network into evidential segmentation network
num_samples (int) – if used as MCDropout Segnet, the number of stochastic forward samples to be taken.
num_classes (int) – number of segmentation classes (default: 5)
-
MODULES= []¶
-
__init__(cfg, name='mcdropout_uresnet')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
mc_forward(input, num_samples=None)[source]¶ Forwarding operation for MC Dropout segmentation network.
- Parameters
num_samples – number of stochastic forward samples to be taken
-
__module__= 'mlreco.models.bayes_uresnet'¶
-
training: bool¶
-
class
mlreco.models.bayes_uresnet.DUQUResNet(cfg, name='duq_uresnet')[source]¶ Bases:
torch.nn.modules.module.ModuleSingle Pass Deep Uncertainty Quantification Network Original Paper: https://arxiv.org/abs/2003.02037
Implementation adapted from the DUQ main Github Repository: https://github.com/y0ast/deterministic-uncertainty-quantification
Author: Joost van Amersfoort
-
MODULES= []¶
-
__init__(cfg, name='duq_uresnet')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward(input)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
__module__= 'mlreco.models.bayes_uresnet'¶
-
training: bool¶
-
-
class
mlreco.models.bayes_uresnet.SegmentationLoss(cfg, name='mcdropout_uresnet')[source]¶ Bases:
torch.nn.modules.module.Module-
__init__(cfg, name='mcdropout_uresnet')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward(outputs, label, iteration=0, weight=None)[source]¶ segmentation[0], label and weight are lists of size #gpus = batch_size. segmentation has as many elements as UResNet returns. label[0] has shape (N, dim + batch_id + 1) where N is #pts across minibatch_size events.
-
__module__= 'mlreco.models.bayes_uresnet'¶
-
training: bool¶
-
-
class
mlreco.models.bayes_uresnet.DUQSegmentationLoss(cfg, name='duq_uresnet')[source]¶ Bases:
torch.nn.modules.module.Module-
__init__(cfg, name='duq_uresnet')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
__module__= 'mlreco.models.bayes_uresnet'¶
-
static
calc_gradient_penalty(x, y_pred)[source]¶ Code From the DUQ main Github Repository: https://github.com/y0ast/deterministic-uncertainty-quantification
Author: Joost van Amersfoort
-
training: bool¶
-
forward(out, type_labels)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-