mlreco.models.uresnet module¶
-
class
mlreco.models.uresnet.UResNet_Chain(cfg, name='uresnet_lonely')[source]¶ Bases:
torch.nn.modules.module.ModuleUResNet implementation. Typical configuration should look like:
model: name: uresnet modules: uresnet_lonely: # Your config here
- Configuration
data_dim (int, default 3)
num_input (int, default 1)
allow_bias (bool, default False)
spatial_size (int, default 512)
leakiness (float, default 0.33)
activation (dict) – For activation function, defaults to {‘name’: ‘lrelu’, ‘args’: {}}
norm_layer (dict) – For normalization function, defaults to {‘name’: ‘batch_norm’, ‘args’: {}}
depth (int, default 5) – Depth of UResNet, also corresponds to how many times we down/upsample.
filters (int, default 16) – Number of filters in the first convolution of UResNet. Will increase linearly with depth.
reps (int, default 2) – Convolution block repetition factor
input_kernel (int, default 3) – Receptive field size for very first convolution after input layer.
num_classes (int, default 5)
ghost (bool, default False)
ghost_label (int, default -1)
weight_loss (bool, default False) – Whether to weight the loss using class counts.
alpha (float, default 1.0) – Weight for UResNet semantic segmentation loss.
beta (float, default 1.0) – Weight for ghost/non-ghost segmentation loss.
- Output
segmentation (torch.Tensor)
finalTensor (torch.Tensor)
encoderTensors (list of torch.Tensor)
decoderTensors (list of torch.Tensor)
ghost (torch.Tensor)
ghost_sptensor (torch.Tensor)
-
INPUT_SCHEMA= [['parse_sparse3d_scn', (<class 'float'>,), (3, 1)]]¶
-
MODULES= ['uresnet_lonely']¶
-
__init__(cfg, name='uresnet_lonely')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward(input)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
__module__= 'mlreco.models.uresnet'¶
-
training: bool¶
-
class
mlreco.models.uresnet.SegmentationLoss(cfg, reduction='sum', batch_col=0)[source]¶ Bases:
torch.nn.modules.loss._LossLoss definition for UResNet.
For a regular flavor UResNet, it is a cross-entropy loss. For deghosting, it depends on a configuration parameter ghost:
If ghost=True, we first compute the cross-entropy loss on the ghost point classification (weighted on the fly with sample statistics). Then we compute a mask = all non-ghost points (based on true information in label) and within this mask, compute a cross-entropy loss for the rest of classes.
If ghost=False, we compute a N+1-classes cross-entropy loss, where N is the number of classes, not counting the ghost point class.
See also
-
INPUT_SCHEMA= [['parse_sparse3d_scn', (<class 'int'>,), (3, 1)]]¶
-
__init__(cfg, reduction='sum', batch_col=0)[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
__module__= 'mlreco.models.uresnet'¶
-
forward(result, label, weights=None)[source]¶ result[0], label and weight are lists of size #gpus = batch_size. segmentation has as many elements as UResNet returns. label[0] has shape (N, 1) where N is #pts across minibatch_size events.
The ghost label is the last one among the classes numbering. If ghost = True, then num_classes should not count the ghost class. If ghost_label > -1, then we perform only ghost segmentation.
-
reduction: str¶