import torch
import torch.nn as nn
import MinkowskiEngine as ME
import torch.nn.functional as F
from collections import defaultdict
from mlreco.models.layers.common.activation_normalization_factories import (activations_dict,
activations_construct,
normalizations_construct)
from mlreco.models.layers.common.configuration import setup_cnn_configuration
from mlreco.models.layers.common.uresnet_layers import UResNet
from mlreco.models.experimental.bayes.encoder import MCDropoutEncoder
from mlreco.models.experimental.bayes.decoder import MCDropoutDecoder
from mlreco.models.experimental.bayes.evidential import EVDLoss
[docs]class BayesianUResNet(torch.nn.Module):
"""
UResNet with Uncertainty Quantification
The backbone model consists of UResNet Encoder-Decoder format with
standard residual layers for the shallow half and dropout residual layers
for the deep half of the network.
Configuration
-------------
mode: str
string indicator for slight changes in network
behavior/architecture. Supports three options:
- standard: standard dropout segmentation network. This also
includes MCDropout segnet, since training behavior is identical
for both standard and mcdropout networks.
- evd: Changes network into evidential segmentation network
num_samples: int
if used as MCDropout Segnet, the number of stochastic
forward samples to be taken.
num_classes: int
number of segmentation classes (default: 5)
"""
MODULES = []
[docs] def __init__(self, cfg, name='mcdropout_uresnet'):
super(BayesianUResNet, self).__init__()
setup_cnn_configuration(self, cfg, 'network_base')
self.model_config = cfg.get(name, {})
self.num_classes = self.model_config.get('num_classes', 5)
self.num_samples = self.model_config.get('num_samples', 20)
self.encoder = MCDropoutEncoder(cfg)
self.decoder = MCDropoutDecoder(cfg)
self.mode = self.model_config.get('mode', 'standard')
if 'edl' in self.model_config.get('loss_fn', 'cross_entropy'):
self.classifier = nn.Sequential(
ME.MinkowskiLinear(self.encoder.num_filters, self.num_classes),
ME.MinkowskiSoftplus()
)
else:
self.classifier = ME.MinkowskiLinear(self.encoder.num_filters,
self.num_classes)
[docs] def mc_forward(self, input, num_samples=None):
"""
Forwarding operation for MC Dropout segmentation network.
Args:
num_samples: number of stochastic forward samples to be taken
"""
res = defaultdict(list)
if num_samples is None:
num_samples = self.num_samples
for m in self.modules():
if m.__class__.__name__ == 'Dropout':
m.train()
for igpu, x in enumerate(input):
num_voxels = x.shape[0]
device = x.device
x_sparse = ME.SparseTensor(coordinates=x[:, :4].int(),
features=x[:, -1].view(-1, 1).float())
pvec = torch.zeros((num_voxels, self.num_classes)).to(device)
logits = torch.zeros((num_voxels, self.num_classes)).to(device)
for i in range(num_samples):
res_encoder = self.encoder.encoder(x_sparse)
decoderTensors = self.decoder(
res_encoder['finalTensor'], res_encoder['encoderTensors'])
feats = decoderTensors[-1]
out = self.classifier(feats)
logits += out.F
pvec += F.softmax(out.F, dim=1)
logits /= num_samples
softmax_probs = pvec / num_samples
res['softmax'].append(softmax_probs)
res['segmentation'].append(logits)
return res
[docs] def evidential_forward(self, input):
"""
Forwarding operation for evidential segmentation network.
"""
out = defaultdict(list)
for igpu, x in enumerate(input):
x = ME.SparseTensor(coordinates=x[:, :4].int(),
features=x[:, -1].view(-1, 1).float())
res_encoder = self.encoder.encoder(x)
print([t.F.shape for t in res_encoder['encoderTensors']])
decoderTensors = self.decoder(res_encoder['finalTensor'],
res_encoder['encoderTensors'])
feats = decoderTensors[-1]
# For evidential models, logits correspond to collected evidence.
logits = self.classifier(feats)
ev = logits.F
concentration = ev + 1.0
S = torch.sum(concentration, dim=1, keepdim=True)
uncertainty = self.num_classes / (S + 0.000001)
out['segmentation'].append(ev)
out['evidence'].append(ev)
out['uncertainty'].append(uncertainty)
out['concentration'].append(concentration)
out['expected_probability'].append(concentration / S)
return out
[docs] def standard_forward(self, input):
"""
Forwarding operation for standard dropout segmentation network.
"""
out = defaultdict(list)
for igpu, x in enumerate(input):
x = ME.SparseTensor(coordinates=x[:, :4].int(),
features=x[:, -1].view(-1, 1).float())
res_encoder = self.encoder.encoder(x)
print([t.F.shape for t in res_encoder['encoderTensors']])
decoderTensors = self.decoder(res_encoder['finalTensor'],
res_encoder['encoderTensors'])
feats = decoderTensors[-1]
# For evidential models, logits correspond to collected evidence.
logits = self.classifier(feats)
out['segmentation'].append(logits.F)
return out
[docs] def forward(self, input):
"""
"""
if self.mode == 'mc_dropout':
return self.mc_forward(input)
elif self.mode == 'evidential':
return self.evidential_forward(input)
else:
return self.standard_forward(input)
[docs]class DUQUResNet(torch.nn.Module):
"""
Single Pass Deep Uncertainty Quantification Network
Original Paper: https://arxiv.org/abs/2003.02037
Implementation adapted from the DUQ main Github Repository:
https://github.com/y0ast/deterministic-uncertainty-quantification
Author: Joost van Amersfoort
"""
MODULES = []
[docs] def __init__(self, cfg, name='duq_uresnet'):
super(DUQUResNet, self).__init__()
setup_cnn_configuration(self, cfg, name)
self.model_config = cfg.get(name, {})
self.num_classes = self.model_config.get('num_classes', 5)
self.num_samples = self.model_config.get('num_samples', 20)
self.net = UResNet(cfg)
self.gamma = self.model_config.get('gamma', 0.999)
self.sigma = self.model_config.get('sigma', 0.3)
self.embedding_dim = self.model_config.get('embedding_dim', 10)
self.latent_size = self.model_config.get('latent_size', 32)
self.w = nn.Parameter(torch.zeros(self.embedding_dim,
self.num_classes,
self.latent_size))
nn.init.kaiming_normal_(self.w, nonlinearity='relu')
self.register_buffer('N', torch.ones(self.num_classes) * 20)
self.register_buffer('m', torch.normal(
torch.zeros(self.embedding_dim, self.num_classes), 0.05))
self.m = self.m * self.N.unsqueeze(0)
[docs] def embed(self, x):
res = self.net(x)
feats = res['decoderTensors'][-1]
print(feats.F)
out = torch.einsum('ij,mnj->imn', feats.F, self.w)
return out
[docs] def bilinear(self, z):
embeddings = self.m / self.N.unsqueeze(0)
diff = z - embeddings.unsqueeze(0)
y_pred = (- diff**2).mean(1).div(2 * self.sigma**2).exp()
return y_pred
[docs] def forward(self, input):
point_cloud, = input
if self.training:
point_cloud.requires_grad_(True)
z = self.embed(point_cloud)
y_pred = self.bilinear(z)
res = {
'score': [y_pred],
'embedding': [z],
'input': [point_cloud],
'centroids' : [self.m.detach().cpu().numpy() \
/ self.N.detach().cpu().numpy()]
}
self.z = z
self.y_pred = y_pred
return res
[docs] def update_buffers(self):
with torch.no_grad():
# normalizing value per class, assumes y is one_hot encoded
self.N = torch.max(self.gamma * self.N + (1 - self.gamma) \
* self.y_pred.sum(0), torch.ones_like(self.N))
# compute sum of embeddings on class by class basis
features_sum = torch.einsum('ijk,ik->jk', self.z, self.y_pred)
self.m = self.gamma * self.m + (1 - self.gamma) * features_sum
[docs]class SegmentationLoss(nn.Module):
[docs] def __init__(self, cfg, name='mcdropout_uresnet'):
super(SegmentationLoss, self).__init__()
self.loss_config = cfg.get(name, {})
self.loss_fn_name = self.loss_config.get('loss_fn', 'edl_sumsq')
self.loss_fn_args = self.loss_config.get('loss_fn_args', {})
if 'edl' in self.loss_fn_name:
self.loss_fn = EVDLoss(self.loss_fn_name, **self.loss_fn_args)
elif self.loss_fn_name == 'cross_entropy':
self.loss_fn = torch.nn.functional.cross_entropy
else:
raise ValueError('Loss function {} not recognized'.format(self.loss_fn_name))
self.one_hot = self.loss_config.get('one_hot', False)
self.num_classes = self.loss_config.get('num_classes', 5)
[docs] def forward(self, outputs, label, iteration=0, weight=None):
'''
segmentation[0], label and weight are lists of size #gpus = batch_size.
segmentation has as many elements as UResNet returns.
label[0] has shape (N, dim + batch_id + 1)
where N is #pts across minibatch_size events.
'''
# TODO Add weighting
logits = outputs['segmentation']
if 'edl' in self.loss_fn_name:
segmentation = [logits[0] + 1.0] # convert evidence to alpha concentration params.
else:
segmentation = logits
device = segmentation[0].device
assert len(segmentation) == len(label)
# if weight is not None:
# assert len(data) == len(weight)
batch_ids = [d[:, 0] for d in label]
total_loss = 0
total_acc = 0
count = 0
# Loop over GPUS
for i in range(len(segmentation)):
for b in batch_ids[i].unique():
batch_index = batch_ids[i] == b
event_segmentation = segmentation[i][batch_index]
event_label = label[i][:, -1][batch_index]
event_label = torch.squeeze(event_label, dim=-1).long()
loss_label = event_label
if self.one_hot:
loss_label = torch.eye(self.num_classes, device=device)[event_label]
loss_seg = self.loss_fn(event_segmentation, loss_label,
t=iteration)
else:
loss_seg = self.loss_fn(event_segmentation, loss_label)
if weight is not None:
event_weight = weight[i][batch_index]
event_weight = torch.squeeze(event_weight, dim=-1).float()
total_loss += torch.mean(loss_seg * event_weight)
else:
total_loss += torch.mean(loss_seg)
# Accuracy
predicted_labels = torch.argmax(event_segmentation, dim=-1)
acc = (predicted_labels == event_label).sum().item() \
/ float(predicted_labels.nelement())
total_acc += acc
count += 1
return {
'accuracy': total_acc/count,
'loss': total_loss/count
}
[docs]class DUQSegmentationLoss(nn.Module):
[docs] def __init__(self, cfg, name='duq_uresnet'):
super(DUQSegmentationLoss, self).__init__()
self.xentropy = nn.BCELoss(reduction='none')
self.num_classes = 5
self.grad_w = cfg.get(name, {}).get('grad_w', 0.0)
self.grad_penalty = cfg.get(name, {}).get('grad_penalty', True)
[docs] @staticmethod
def calc_gradient_penalty(x, y_pred):
'''
Code From the DUQ main Github Repository:
https://github.com/y0ast/deterministic-uncertainty-quantification
Author: Joost van Amersfoort
'''
gradients = torch.autograd.grad(
outputs=y_pred,
inputs=x,
grad_outputs=torch.ones_like(y_pred),
create_graph=True,
)[0]
gradients = gradients.flatten(start_dim=1)
# L2 norm
grad_norm = gradients.norm(2, dim=1)
# Two sided penalty
gradient_penalty = ((grad_norm - 1) ** 2).mean()
# One sided penalty - down
# gradient_penalty = F.relu(grad_norm - 1).mean()
return gradient_penalty
[docs] def forward(self, out, type_labels):
# print(type_labels)
probas = out['score'][0]
device = probas.device
labels = type_labels[0][:, -1].long()
labels_one_hot = torch.eye(self.num_classes)[labels].to(device=device)
loss1 = self.xentropy(probas, labels_one_hot)
pred = torch.argmax(probas, dim=1)
# Comptue gradient penalty
loss2 = 0
if self.grad_penalty:
loss2 = self.calc_gradient_penalty(out['input'][0], probas)
loss1 = loss1.sum(dim=1).mean()
loss = loss1 + self.grad_w * loss2
accuracy = float(torch.sum(pred == labels)) / float(labels.shape[0])
res = {
'loss': loss,
'loss_embedding': float(loss1),
'loss_grad_penalty': float(loss2),
'accuracy': accuracy
}
print(res)
acc_types = {}
for c in labels.unique():
mask = labels == c
acc_types['accuracy_{}'.format(int(c))] = \
float(torch.sum(pred[mask] == labels[mask])) / float(torch.sum(mask))
return res