mlreco.models.bayes_uresnet module

class mlreco.models.bayes_uresnet.BayesianUResNet(cfg, name='mcdropout_uresnet')[source]

Bases: torch.nn.modules.module.Module

UResNet with Uncertainty Quantification

The backbone model consists of UResNet Encoder-Decoder format with standard residual layers for the shallow half and dropout residual layers for the deep half of the network.

Configuration
  • mode (str) – string indicator for slight changes in network behavior/architecture. Supports three options:

    • standard: standard dropout segmentation network. This also

    includes MCDropout segnet, since training behavior is identical for both standard and mcdropout networks. - evd: Changes network into evidential segmentation network

  • num_samples (int) – if used as MCDropout Segnet, the number of stochastic forward samples to be taken.

  • num_classes (int) – number of segmentation classes (default: 5)

MODULES = []
__init__(cfg, name='mcdropout_uresnet')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

mc_forward(input, num_samples=None)[source]

Forwarding operation for MC Dropout segmentation network.

Parameters

num_samples – number of stochastic forward samples to be taken

evidential_forward(input)[source]

Forwarding operation for evidential segmentation network.

standard_forward(input)[source]

Forwarding operation for standard dropout segmentation network.

forward(input)[source]
__module__ = 'mlreco.models.bayes_uresnet'
training: bool
class mlreco.models.bayes_uresnet.DUQUResNet(cfg, name='duq_uresnet')[source]

Bases: torch.nn.modules.module.Module

Single Pass Deep Uncertainty Quantification Network Original Paper: https://arxiv.org/abs/2003.02037

Implementation adapted from the DUQ main Github Repository: https://github.com/y0ast/deterministic-uncertainty-quantification

Author: Joost van Amersfoort

MODULES = []
__init__(cfg, name='duq_uresnet')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

embed(x)[source]
bilinear(z)[source]
forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

update_buffers()[source]
__module__ = 'mlreco.models.bayes_uresnet'
training: bool
class mlreco.models.bayes_uresnet.SegmentationLoss(cfg, name='mcdropout_uresnet')[source]

Bases: torch.nn.modules.module.Module

__init__(cfg, name='mcdropout_uresnet')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(outputs, label, iteration=0, weight=None)[source]

segmentation[0], label and weight are lists of size #gpus = batch_size. segmentation has as many elements as UResNet returns. label[0] has shape (N, dim + batch_id + 1) where N is #pts across minibatch_size events.

__module__ = 'mlreco.models.bayes_uresnet'
training: bool
class mlreco.models.bayes_uresnet.DUQSegmentationLoss(cfg, name='duq_uresnet')[source]

Bases: torch.nn.modules.module.Module

__init__(cfg, name='duq_uresnet')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

__module__ = 'mlreco.models.bayes_uresnet'
static calc_gradient_penalty(x, y_pred)[source]

Code From the DUQ main Github Repository: https://github.com/y0ast/deterministic-uncertainty-quantification

Author: Joost van Amersfoort

training: bool
forward(out, type_labels)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.