import torch
import torch.nn as nn
from .misc import *
from collections import defaultdict
from pprint import pprint
[docs]class SPICELoss(nn.Module):
'''
Loss function for Sparse Spatial Embeddings Model, with fixed
centroids and symmetric gaussian kernels.
'''
[docs] def __init__(self, cfg, name='spice_loss'):
super(SPICELoss, self).__init__()
self.loss_config = cfg[name]
self.seediness_weight = self.loss_config.get('seediness_weight', 1.0)
self.embedding_weight = self.loss_config.get('embedding_weight', 1.0)
self.smoothing_weight = self.loss_config.get('smoothing_weight', 1.0)
self.spatial_size = self.loss_config.get('spatial_size', 768)
self.mask_loss_fn = self.loss_config.get('mask_loss_fn', 'BCE')
self.seed_loss_fn = self.loss_config.get('seed_loss_fn', 'L1')
self.batch_loc = self.loss_config.get('batch_loc', 0)
if self.mask_loss_fn == 'BCE':
self.mask_loss = nn.BCEWithLogitsLoss(reduction='none')
elif self.mask_loss_fn == 'lovasz_hinge':
self.mask_loss = LovaszHingeLoss()
elif self.mask_loss_fn == 'focal':
raise NotImplementedError
else:
raise ValueError(
'Invalid loss scheme: {}'.format(self.loss_scheme))
# L2 Loss for Seediness
if self.seed_loss_fn == 'L1':
self.seed_loss = torch.nn.L1Loss(reduction='mean')
elif self.seed_loss_fn == 'L1':
self.seed_loss = torch.nn.MSELoss(reduction='mean')
elif self.seed_loss_fn == 'huber':
self.seed_loss = torch.nn.SmoothL1Loss(reduction='mean')
else:
raise ValueError(
'Invalid loss scheme: {}'.format(self.loss_scheme))
[docs] def find_cluster_means(self, features, labels):
'''
For a given image, compute the centroids mu_c for each
cluster label in the embedding space.
Inputs:
features (torch.Tensor): the pixel embeddings, shape=(N, d) where
N is the number of pixels and d is the embedding space dimension.
labels (torch.Tensor): ground-truth group labels, shape=(N, )
Returns:
cluster_means (torch.Tensor): (n_c, d) tensor where n_c is the number of
distinct instances. Each row is a (1,d) vector corresponding to
the coordinates of the i-th centroid.
'''
cluster_means = find_cluster_means(features, labels)
return cluster_means
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, eps=1e-6):
'''
Computes binary foreground/background loss.
'''
device = embeddings.device
n = labels.shape[0]
centroids = self.find_cluster_means(embeddings, labels)
sigma = scatter_mean(margins.squeeze(), labels)
num_clusters = labels.unique().shape[0]
# Compute spatial term
em = embeddings[:, None, :]
centroids = centroids[None, :, :]
sqdists = ((em - centroids)**2).sum(-1)
p = sqdists / (2.0 * sigma.view(1, -1)**2)
p = torch.clamp(torch.exp(-p), min=eps, max=1-eps)
logits = logit_fn(p, eps=eps)
eye = torch.eye(len(labels.unique()), dtype=torch.float32, device=device)
targets = eye[labels]
loss_tensor = self.mask_loss(logits, targets)
loss = loss_tensor.mean(dim=0).mean()
with torch.no_grad():
acc = iou_batch(logits > 0, targets.bool())
smoothing_loss = margin_smoothing_loss(margins.squeeze(), sigma.detach(), labels, margin=0)
p = torch.gather(p, 1, labels.view(-1, 1))
return loss, smoothing_loss, p.squeeze(), acc
[docs] def combine_multiclass(self, embeddings, margins, seediness, slabels, clabels, coords):
'''
Wrapper function for combining different components of the loss,
in particular when clustering must be done PER SEMANTIC CLASS.
NOTE: When there are multiple semantic classes, we compute the DLoss
by first masking out by each semantic segmentation (ground-truth/prediction)
and then compute the clustering loss over each masked point cloud.
INPUTS:
features (torch.Tensor): pixel embeddings
slabels (torch.Tensor): semantic labels
clabels (torch.Tensor): group/instance/cluster labels
OUTPUT:
loss_segs (list): list of computed loss values for each semantic class.
loss[i] = computed DLoss for semantic class <i>.
acc_segs (list): list of computed clustering accuracy for each semantic class.
'''
loss = defaultdict(list)
accuracy = defaultdict(float)
semantic_classes = slabels.unique()
for sc in semantic_classes:
if int(sc) == 4: # Skip low energy deposits
continue
index = (slabels == sc)
clabels_unique, _ = unique_label_torch(clabels[index])
mask_loss, smoothing_loss, probs, acc = self.get_per_class_probabilities(
embeddings[index], margins[index], clabels_unique)
prob_truth = probs.detach()
seed_loss = self.seed_loss(prob_truth, seediness[index].squeeze(1))
total_loss = self.embedding_weight * mask_loss \
+ self.seediness_weight * seed_loss \
+ self.smoothing_weight * smoothing_loss
loss['loss'].append(total_loss)
loss['mask_loss'].append(float(self.embedding_weight * mask_loss))
loss['seed_loss'].append(float(self.seediness_weight * seed_loss))
loss['smoothing_loss'].append(float(self.smoothing_weight * smoothing_loss))
loss['mask_loss_{}'.format(int(sc))].append(float(mask_loss))
loss['seed_loss_{}'.format(int(sc))].append(float(seed_loss))
accuracy['accuracy_{}'.format(int(sc))] = acc
return loss, accuracy
[docs] def forward(self, out, segment_label, group_label):
num_gpus = len(segment_label)
loss = defaultdict(list)
accuracy = defaultdict(list)
for i in range(num_gpus):
slabels = segment_label[i][:, -1]
#coords = segment_label[i][:, :3].float()
#if torch.cuda.is_available():
# coords = coords.cuda()
slabels = slabels.int()
clabels = group_label[i][:, -1]
batch_idx = segment_label[i][:, self.batch_loc]
embedding = out['embeddings'][i]
seediness = out['seediness'][i]
margins = out['margins'][i]
nbatch = batch_idx.unique().shape[0]
for bidx in batch_idx.unique(sorted=True):
embedding_batch = embedding[batch_idx == bidx]
slabels_batch = slabels[batch_idx == bidx]
clabels_batch = clabels[batch_idx == bidx]
seed_batch = seediness[batch_idx == bidx]
margins_batch = margins[batch_idx == bidx]
loss_class, acc_class = self.combine_multiclass(
embedding_batch, margins_batch,
seed_batch, slabels_batch, clabels_batch)
if len(acc_class.values()):
for key, val in loss_class.items():
loss[key].append(sum(val) / len(val))
for s, acc in acc_class.items():
accuracy[s].append(acc)
acc = sum(acc_class.values()) / len(acc_class.values())
accuracy['accuracy'].append(acc)
loss_avg = {}
acc_avg = defaultdict(float)
for key, val in loss.items():
loss_avg[key] = sum(val) / len(val)
for key, val in accuracy.items():
acc_avg[key] = sum(val) / len(val)
res = {}
res.update(loss_avg)
res.update(acc_avg)
return res
[docs]class SPICEInterLoss(SPICELoss):
[docs] def __init__(self, cfg, name='spice_loss'):
super(SPICEInterLoss, self).__init__(cfg, name)
self.inter_weight = self.loss_config.get('inter_weight', 1.0)
self.inter_margin = self.loss_config.get('inter_margin', 0.2)
self.norm = 2
self._min_voxels = self.loss_config.get('min_voxels', 2)
[docs] def regularization(self, cluster_means):
'''
Implementation of regularization loss in Discriminative Loss
Inputs:
cluster_means (torch.Tensor): output from find_cluster_means
Returns:
reg_loss (float): computed regularization loss (see paper).
'''
reg_loss = regularization_loss(cluster_means)
return reg_loss
[docs] def inter_cluster_loss(self, cluster_means, margin=0.2):
'''
Implementation of distance loss in Discriminative Loss.
Inputs:
cluster_means (torch.Tensor): output from find_cluster_means
margin (float/int): the magnitude of the margin delta_d in the paper.
Think of it as the distance between each separate clusters in
embedding space.
Returns:
inter_loss (float): computed cross-centroid distance loss (see paper).
Factor of 2 is included for proper normalization.
'''
inter_loss = inter_cluster_loss(cluster_means, margin=margin)
return inter_loss
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, eps=1e-6):
'''
Computes binary foreground/background loss.
'''
device = embeddings.device
n = labels.shape[0]
centroids = self.find_cluster_means(embeddings, labels)
sigma = self.find_cluster_means(margins, labels).view(-1, 1)
smoothing_loss = margin_smoothing_loss(margins.view(-1), sigma.view(-1).detach(), labels, margin=0)
num_clusters = labels.unique().shape[0]
inter_loss = self.inter_cluster_loss(centroids, margin=self.inter_margin)
# Compute spatial term
em = embeddings[:, None, :]
centroids = centroids[None, :, :]
cov = torch.clamp(sigma[:, 0][None, :], min=eps)
sqdists = ((em - centroids)**2).sum(-1) / (2.0 * cov**2)
pvec = torch.exp(-sqdists)
logits = logit_fn(pvec, eps=eps)
# print(logits)
eye = torch.eye(len(labels.unique()), dtype=torch.float32, device=device)
targets = eye[labels]
loss = self.mask_loss(logits, targets).mean()
# loss = loss_tensor
with torch.no_grad():
acc = iou_batch(logits > 0, targets.bool())
p = torch.gather(pvec, 1, labels.view(-1, 1))
loss += inter_loss
return loss, smoothing_loss, float(inter_loss), p.squeeze(), acc
[docs] def combine_multiclass(self, embeddings, margins, seediness, slabels, clabels):
'''
Wrapper function for combining different components of the loss,
in particular when clustering must be done PER SEMANTIC CLASS.
NOTE: When there are multiple semantic classes, we compute the DLoss
by first masking out by each semantic segmentation (ground-truth/prediction)
and then compute the clustering loss over each masked point cloud.
INPUTS:
features (torch.Tensor): pixel embeddings
slabels (torch.Tensor): semantic labels
clabels (torch.Tensor): group/instance/cluster labels
OUTPUT:
loss_segs (list): list of computed loss values for each semantic class.
loss[i] = computed DLoss for semantic class <i>.
acc_segs (list): list of computed clustering accuracy for each semantic class.
'''
loss = defaultdict(list)
accuracy = defaultdict(float)
semantic_classes = slabels.unique()
for sc in semantic_classes:
if int(sc) == 4:
continue
index = (slabels == sc)
if len(embeddings[index]) < self._min_voxels:
continue
clabels_unique, _ = unique_label_torch(clabels[index])
mask_loss, smoothing_loss, inter_loss, probs, acc = self.get_per_class_probabilities(
embeddings[index], margins[index], clabels_unique)
prob_truth = probs.detach()
seed_loss = self.seed_loss(prob_truth, seediness[index].squeeze(1))
total_loss = self.embedding_weight * mask_loss \
+ self.seediness_weight * seed_loss \
+ self.smoothing_weight * smoothing_loss
loss['loss'].append(total_loss)
loss['mask_loss'].append(
float(self.embedding_weight * mask_loss))
loss['seed_loss'].append(
float(self.seediness_weight * seed_loss))
loss['smoothing_loss'].append(
float(self.smoothing_weight * smoothing_loss))
loss['inter_loss'].append(
float(self.inter_weight * inter_loss))
loss['mask_loss_{}'.format(int(sc))].append(float(mask_loss))
loss['seed_loss_{}'.format(int(sc))].append(float(seed_loss))
accuracy['accuracy_{}'.format(int(sc))] = acc
return loss, accuracy