import numpy as np
import torch
import torch.nn as nn
import MinkowskiEngine as ME
from mlreco.models.layers.common.uresnet_layers import UResNet
from collections import defaultdict
from mlreco.models.layers.common.activation_normalization_factories import activations_construct, normalizations_construct
[docs]class UResNet_Chain(nn.Module):
"""
UResNet implementation. Typical configuration should look like:
.. code-block:: yaml
model:
name: uresnet
modules:
uresnet_lonely:
# Your config here
Configuration
-------------
data_dim: int, default 3
num_input: int, default 1
allow_bias: bool, default False
spatial_size: int, default 512
leakiness: float, default 0.33
activation: dict
For activation function, defaults to `{'name': 'lrelu', 'args': {}}`
norm_layer: dict
For normalization function, defaults to `{'name': 'batch_norm', 'args': {}}`
depth : int, default 5
Depth of UResNet, also corresponds to how many times we down/upsample.
filters : int, default 16
Number of filters in the first convolution of UResNet.
Will increase linearly with depth.
reps : int, default 2
Convolution block repetition factor
input_kernel : int, default 3
Receptive field size for very first convolution after input layer.
num_classes: int, default 5
ghost: bool, default False
ghost_label: int, default -1
weight_loss: bool, default False
Whether to weight the loss using class counts.
alpha: float, default 1.0
Weight for UResNet semantic segmentation loss.
beta: float, default 1.0
Weight for ghost/non-ghost segmentation loss.
Output
------
segmentation: torch.Tensor
finalTensor: torch.Tensor
encoderTensors: list of torch.Tensor
decoderTensors: list of torch.Tensor
ghost: torch.Tensor
ghost_sptensor: torch.Tensor
See Also
--------
SegmentationLoss, mlreco.models.layers.common.uresnet_layers
"""
INPUT_SCHEMA = [
["parse_sparse3d_scn", (float,), (3, 1)]
]
MODULES = ['uresnet_lonely']
[docs] def __init__(self, cfg, name='uresnet_lonely'):
super(UResNet_Chain, self).__init__()
self.model_config = cfg.get(name, {})
self.num_classes = self.model_config.get('num_classes', 5)\
# Parameters for Deghosting
self.ghost = self.model_config.get('ghost', False)
self.ghost_label = self.model_config.get('ghost_label', -1)
self.net = UResNet(cfg, name=name)
self.F = self.net.num_filters
self.D = self.net.D
self.output = [
normalizations_construct(self.net.norm, self.F, **self.net.norm_args),
#activations_construct(self.net.activation_name, **self.net.activation_args),
activations_construct(self.net.activation_name, negative_slope=0.33),
]
self.output = nn.Sequential(*self.output)
self.linear_segmentation = ME.MinkowskiLinear(self.F, self.num_classes)
if self.ghost:
print("Ghost Masking is enabled for UResNet Segmentation")
self.linear_ghost = ME.MinkowskiLinear(self.F, 2)
# print('Total Number of Trainable Parameters (mink_uresnet)= {}'.format(
# sum(p.numel() for p in self.parameters() if p.requires_grad)))
# print(self)
[docs] def forward(self, input):
out = defaultdict(list)
for igpu, x in enumerate(input):
res = self.net(x)
feats = res['decoderTensors'][-1]
feats = self.output(feats)
seg = self.linear_segmentation(feats)
out['segmentation'].append(seg.F)
out['finalTensor'].append(res['finalTensor'])
out['encoderTensors'].append(res['encoderTensors'])
out['decoderTensors'].append(res['decoderTensors'])
if self.ghost:
ghost = self.linear_ghost(feats)
out['ghost'].append(ghost.F)
out['ghost_sptensor'].append(ghost)
return out
[docs]class SegmentationLoss(torch.nn.modules.loss._Loss):
"""
Loss definition for UResNet.
For a regular flavor UResNet, it is a cross-entropy loss.
For deghosting, it depends on a configuration parameter `ghost`:
- If `ghost=True`, we first compute the cross-entropy loss on the ghost
point classification (weighted on the fly with sample statistics). Then we
compute a mask = all non-ghost points (based on true information in label)
and within this mask, compute a cross-entropy loss for the rest of classes.
- If `ghost=False`, we compute a N+1-classes cross-entropy loss, where N is
the number of classes, not counting the ghost point class.
See Also
--------
UResNet_Chain
"""
INPUT_SCHEMA = [
["parse_sparse3d_scn", (int,), (3, 1)]
]
[docs] def __init__(self, cfg, reduction='sum', batch_col=0):
super(SegmentationLoss, self).__init__(reduction=reduction)
self._cfg = cfg.get('uresnet_lonely', {})
self._ghost = self._cfg.get('ghost', False)
self._ghost_label = self._cfg.get('ghost_label', -1)
self._num_classes = self._cfg.get('num_classes', 5)
self._alpha = self._cfg.get('alpha', 1.0)
self._beta = self._cfg.get('beta', 1.0)
self._weight_loss = self._cfg.get('weight_loss', False)
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='none')
self._batch_col = batch_col
[docs] def forward(self, result, label, weights=None):
"""
result[0], label and weight are lists of size #gpus = batch_size.
segmentation has as many elements as UResNet returns.
label[0] has shape (N, 1) where N is #pts across minibatch_size events.
Assumptions
===========
The ghost label is the last one among the classes numbering.
If ghost = True, then num_classes should not count the ghost class.
If ghost_label > -1, then we perform only ghost segmentation.
"""
assert len(result['segmentation']) == len(label)
batch_ids = [d[:, self._batch_col] for d in label]
# print("batch ids", batch_ids)
uresnet_loss, uresnet_acc = 0., 0.
uresnet_acc_class = [0.] * self._num_classes
count_class = [0.] * self._num_classes
mask_loss, mask_acc = 0., 0.
ghost2ghost, nonghost2nonghost = 0., 0.
count = 0
for i in range(len(label)):
for b in batch_ids[i].unique():
batch_index = batch_ids[i] == b
event_segmentation = result['segmentation'][i][batch_index] # (N, num_classes)
event_label = label[i][batch_index][:, -1][:, None] # (N, 1)
event_label = torch.squeeze(event_label, dim=-1).long()
if self._ghost_label > -1:
event_label = (event_label == self._ghost_label).long()
elif self._ghost:
# check and warn about invalid labels
unique_label,unique_count = torch.unique(event_label,return_counts=True)
if (unique_label > self._num_classes).long().sum():
print('Invalid semantic label found (will be ignored)')
print('Semantic label values:',unique_label)
print('Label counts:',unique_count)
event_ghost = result['ghost'][i][batch_index] # (N, 2)
# 0 = not a ghost point, 1 = ghost point
mask_label = (event_label == self._num_classes).long()
num_ghost_points = (mask_label == 1).sum().float()
num_nonghost_points = (mask_label == 0).sum().float()
fraction = num_ghost_points \
/ (num_ghost_points + num_nonghost_points)
weight = torch.stack([fraction, 1. - fraction]).float()
loss_mask = torch.nn.functional.cross_entropy(event_ghost,
mask_label,
weight=weight)
mask_loss += loss_mask
# mask_loss += torch.mean(loss_mask)
# Accuracy of ghost mask: fraction of correcly predicted
# points, whether ghost or nonghost
with torch.no_grad():
predicted_mask = torch.argmax(event_ghost, dim=-1)
# Accuracy ghost2ghost = fraction of correcly predicted
# ghost points as ghost points
if float(num_ghost_points.item()) > 0:
ghost2ghost += (predicted_mask[event_label == self._num_classes] == 1).sum().item() \
/ float(num_ghost_points.item())
# Accuracy noghost2noghost = fraction of correctly predicted
# non ghost points as non ghost points
if float(num_nonghost_points.item()) > 0:
nonghost2nonghost += (predicted_mask[event_label < self._num_classes] == 0).sum().item() \
/ float(num_nonghost_points.item())
# Global ghost predictions accuracy
acc_mask = predicted_mask.eq_(mask_label).sum().item() \
/ float(predicted_mask.nelement())
mask_acc += acc_mask
# Now mask to compute the rest of UResNet loss
mask = event_label < self._num_classes
event_segmentation = event_segmentation[mask]
event_label = event_label[mask]
else:
# check and warn about invalid labels
unique_label,unique_count = torch.unique(event_label,return_counts=True)
if (unique_label >= self._num_classes).long().sum():
print('Invalid semantic label found (will be ignored)')
print('Semantic label values:',unique_label)
print('Label counts:',unique_count)
# Now mask to compute the rest of UResNet loss
mask = event_label < self._num_classes
event_segmentation = event_segmentation[mask]
event_label = event_label[mask]
if event_label.shape[0] > 0: # FIXME how to handle empty mask?
# Loss for semantic segmentation
if self._weight_loss:
class_count = [(event_label == c).sum().float() for c in range(self._num_classes)]
sum_class_count = len(event_label)
w = torch.Tensor([sum_class_count / c if c.item() > 0 else 0. for c in class_count]).float()
w = w.to(event_label.device)
#print(class_count, w, class_count[0].item() > 0)
loss_seg = torch.nn.functional.cross_entropy(event_segmentation, event_label, weight=w)
else:
loss_seg = self.cross_entropy(event_segmentation, event_label)
if weights is not None:
loss_seg *= weights[i][batch_index][:, -1].float()
if weights is not None:
uresnet_loss += torch.sum(loss_seg)/torch.sum(weights[i][batch_index][:,-1].float())
else:
uresnet_loss += torch.mean(loss_seg)
# Accuracy for semantic segmentation
with torch.no_grad():
predicted_labels = torch.argmax(event_segmentation, dim=-1)
acc = predicted_labels.eq_(event_label).sum().item() / float(predicted_labels.nelement())
uresnet_acc += acc
# Class accuracy
for c in range(self._num_classes):
class_mask = event_label == c
class_count = class_mask.sum().item()
if class_count > 0:
uresnet_acc_class[c] += predicted_labels[class_mask].sum().item() / float(class_count)
count_class[c] += 1
count += 1
if self._ghost:
results = {
'accuracy': uresnet_acc/count if count else 1.,
'loss': (self._alpha * uresnet_loss + self._beta * mask_loss)/count if count else self._alpha * uresnet_loss + self._beta * mask_loss,
'ghost_mask_accuracy': mask_acc / count if count else 1.,
'ghost_mask_loss': self._beta * mask_loss / count if count else self._beta * mask_loss,
'uresnet_accuracy': uresnet_acc / count if count else 1.,
'uresnet_loss': self._alpha * uresnet_loss / count if count else self._alpha * uresnet_loss,
'ghost2ghost': ghost2ghost / count if count else 1.,
'nonghost2nonghost': nonghost2nonghost / count if count else 1.
}
else:
results = {
'accuracy': uresnet_acc/count if count else 1.,
'loss': uresnet_loss/count if count else uresnet_loss
}
for c in range(self._num_classes):
if count_class[c] > 0:
results['accuracy_class_%d' % c] = uresnet_acc_class[c]/count_class[c]
else:
results['accuracy_class_%d' % c] = 1.
return results