import torch
from collections import defaultdict
[docs]class DiscriminativeLoss(torch.nn.Module):
'''
Implementation of the Discriminative Loss Function in Pytorch.
https://arxiv.org/pdf/1708.02551.pdf
Note that there are many other implementations in Github, yet here
we tailor it for use in conjuction with Sparse UResNet.
'''
[docs] def __init__(self, cfg, reduction='sum'):
super(DiscriminativeLoss, self).__init__()
self.loss_config = cfg['spice_loss']
self.num_classes = self.loss_config.get('num_classes', 5)
self.depth = self.loss_config.get('stride', 5)
# Clustering Loss Parameters
self.loss_hyperparams = {}
self.loss_hyperparams['intra_weight'] = self.loss_config.get('intra_weight', 1.0)
self.loss_hyperparams['inter_weight'] = self.loss_config.get('inter_weight', 1.0)
self.loss_hyperparams['reg_weight'] = self.loss_config.get('reg_weight', 0.001)
self.loss_hyperparams['intra_margin'] = self.loss_config.get('intracluster_margin', 0.5)
self.loss_hyperparams['inter_margin'] = self.loss_config.get('intercluster_margin', 1.5)
self.dimension = self.loss_config.get('data_dim', 3)
self.norm = self.loss_config.get('norm', 2)
self.use_segmentation = self.loss_config.get('use_segmentation', True)
[docs] def find_cluster_means(self, features, labels):
'''
For a given image, compute the centroids mu_c for each
cluster label in the embedding space.
Inputs:
features (torch.Tensor): the pixel embeddings, shape=(N, d) where
N is the number of pixels and d is the embedding space dimension.
labels (torch.Tensor): ground-truth group labels, shape=(N, )
Returns:
cluster_means (torch.Tensor): (n_c, d) tensor where n_c is the number of
distinct instances. Each row is a (1,d) vector corresponding to
the coordinates of the i-th centroid.
'''
clabels = labels.unique(sorted=True)
cluster_means = []
for c in clabels:
index = (labels == c)
mu_c = features[index].mean(0)
cluster_means.append(mu_c)
cluster_means = torch.stack(cluster_means)
return cluster_means
[docs] def intra_cluster_loss(self, features, labels, cluster_means, margin=0.5):
'''
Implementation of variance loss in Discriminative Loss.
Inputs:
features (torch.Tensor): pixel embedding, same as in find_cluster_means.
labels (torch.Tensor): ground truth instance labels
cluster_means (torch.Tensor): output from find_cluster_means
margin (float/int): constant used to specify delta_v in paper. Think of it
as the size of each clusters in embedding space.
Returns:
intra_loss: (float) variance loss (see paper).
'''
intra_loss = 0.0
n_clusters = len(cluster_means)
cluster_labels = labels.unique(sorted=True)
for i, c in enumerate(cluster_labels):
index = (labels == c)
dists = torch.norm(features[index] - cluster_means[i] + 1e-8,
p=self.norm,
dim=1)
hinge = torch.clamp(dists - margin, min=0)
l = torch.mean(torch.pow(hinge, 2))
intra_loss += l
intra_loss /= n_clusters
return intra_loss
[docs] def inter_cluster_loss(self, cluster_means, margin=1.5):
'''
Implementation of distance loss in Discriminative Loss.
Inputs:
cluster_means (torch.Tensor): output from find_cluster_means
margin (float/int): the magnitude of the margin delta_d in the paper.
Think of it as the distance between each separate clusters in
embedding space.
Returns:
inter_loss (float): computed cross-centroid distance loss (see paper).
Factor of 2 is included for proper normalization.
'''
inter_loss = 0.0
n_clusters = len(cluster_means)
if n_clusters < 2:
# Inter-cluster loss is zero if there only one instance exists for
# a semantic label.
return 0.0
else:
for i, c1 in enumerate(cluster_means):
for j, c2 in enumerate(cluster_means):
if i != j:
dist = torch.norm(c1 - c2 + 1e-8, p=self.norm)
hinge = torch.clamp(2.0 * margin - dist, min=0)
inter_loss += torch.pow(hinge, 2)
inter_loss /= float((n_clusters - 1) * n_clusters)
return inter_loss
[docs] def regularization(self, cluster_means):
'''
Implementation of regularization loss in Discriminative Loss
Inputs:
cluster_means (torch.Tensor): output from find_cluster_means
Returns:
reg_loss (float): computed regularization loss (see paper).
'''
reg_loss = 0.0
n_clusters, _ = cluster_means.shape
for i in range(n_clusters):
reg_loss += torch.norm(cluster_means[i, :] + 1e-8, p=self.norm)
reg_loss /= float(n_clusters)
return reg_loss
[docs] def compute_heuristic_accuracy(self, embedding, truth):
'''
Compute Adjusted Rand Index Score for given embedding coordinates,
where predicted cluster labels are obtained from distance to closest
centroid (computes heuristic accuracy).
Inputs:
embedding (torch.Tensor): (N, d) Tensor where 'd' is the embedding dimension.
truth (torch.Tensor): (N, ) Tensor for the ground truth clustering labels.
Returns:
score (float): Computed ARI Score
clustering (array): the predicted cluster labels.
'''
from sklearn.metrics import adjusted_rand_score
nearest = []
with torch.no_grad():
cmeans = self.find_cluster_means(embedding, truth)
for centroid in cmeans:
dists = torch.sum((embedding - centroid)**2, dim=1)
dists = dists.view(-1, 1)
nearest.append(dists)
nearest = torch.cat(nearest, dim=1)
nearest = torch.argmin(nearest, dim=1)
pred = nearest.cpu().numpy()
grd = truth.cpu().numpy()
score = adjusted_rand_score(pred, grd)
return score
[docs] def combine(self, features, labels, **kwargs):
'''
Wrapper function for combining different components of the loss function.
Inputs:
features (torch.Tensor): pixel embeddings
labels (torch.Tensor): ground-truth instance labels
Returns:
loss: combined loss, in most cases over a given semantic class.
'''
# Clustering Loss Hyperparameters
# We allow changing the parameters at each computation in order
# to alter the margins at each spatial resolution in multi-scale losses.
intra_margin = kwargs.get('intra_margin', 0.5)
inter_margin = kwargs.get('inter_margin', 1.5)
intra_weight = kwargs.get('intra_weight', 1.0)
inter_weight = kwargs.get('inter_weight', 1.0)
reg_weight = kwargs.get('reg_weight', 0.001)
c_means = self.find_cluster_means(features, labels)
inter_loss = self.inter_cluster_loss(c_means, margin=inter_margin)
intra_loss = self.intra_cluster_loss(features,
labels,
c_means,
margin=intra_margin)
reg_loss = self.regularization(c_means)
loss = intra_weight * intra_loss + inter_weight \
* inter_loss + reg_weight * reg_loss
return {
'loss': loss,
'intra_loss': intra_weight * float(intra_loss),
'inter_loss': inter_weight * float(inter_loss),
'reg_loss': reg_weight * float(reg_loss)
}
[docs] def combine_multiclass(self, features, slabels, clabels, **kwargs):
'''
Wrapper function for combining different components of the loss,
in particular when clustering must be done PER SEMANTIC CLASS.
NOTE: When there are multiple semantic classes, we compute the DLoss
by first masking out by each semantic segmentation (ground-truth/prediction)
and then compute the clustering loss over each masked point cloud.
INPUTS:
features (torch.Tensor): pixel embeddings
slabels (torch.Tensor): semantic labels
clabels (torch.Tensor): group/instance/cluster labels
OUTPUT:
loss_segs (list): list of computed loss values for each semantic class.
loss[i] = computed DLoss for semantic class <i>.
acc_segs (list): list of computed clustering accuracy for each semantic class.
'''
loss, acc_segs = defaultdict(list), defaultdict(float)
semantic_classes = slabels.unique()
for sc in semantic_classes:
index = (slabels == sc)
num_clusters = len(clabels[index].unique())
loss_blob = self.combine(features[index], clabels[index], **kwargs)
for key, val in loss_blob.items():
loss[key].append(val)
# loss['loss'].append(loss_blob['loss'])
# loss['intra_loss'].append(loss_blob['intra_loss'])
# loss['inter_loss'].append(loss_blob['inter_loss'])
# loss['reg_loss'].append(loss_blob['reg_loss'])
acc = self.compute_heuristic_accuracy(features[index], clabels[index])
acc_segs['accuracy_{}'.format(sc.item())] = acc
return loss, acc_segs
[docs] def forward(self, out, semantic_labels, group_labels):
'''
Forward function for the Discriminative Loss Module.
Inputs:
out: output of UResNet; embedding-space coordinates.
semantic_labels: ground-truth semantic labels
group_labels: ground-truth instance labels
Returns:
(dict): A dictionary containing key-value pairs for
loss, accuracy, etc.
'''
num_gpus = len(semantic_labels)
loss = defaultdict(list)
accuracy = defaultdict(list)
for i in range(num_gpus):
slabels = semantic_labels[i][:, -1]
slabels = slabels.int()
clabels = group_labels[i][:, -2]
batch_idx = semantic_labels[i][:, 3]
embedding = out['cluster_feature'][i]
nbatch = batch_idx.unique().shape[0]
for bidx in batch_idx.unique(sorted=True):
embedding_batch = embedding[batch_idx == bidx]
slabels_batch = slabels[batch_idx == bidx]
clabels_batch = clabels[batch_idx == bidx]
if self.use_segmentation:
loss_dict, acc_segs = self.combine_multiclass(
embedding_batch, slabels_batch, clabels_batch, **self.loss_hyperparams)
for key, val in loss_dict.items():
loss[key].append(sum(val) / len(val))
for s, acc in acc_segs.items():
accuracy[s].append(acc)
acc = sum(acc_segs.values()) / len(acc_segs.values())
accuracy['accuracy'].append(acc)
else:
loss["loss"].append(self.combine(embedding_batch, clabels_batch, **self.loss_hyperparams))
acc, _ = self.compute_heuristic_accuracy(embedding_batch, clabels_batch)
accuracy['accuracy'].append(acc)
loss_avg = {}
acc_avg = defaultdict(float)
for key, val in loss.items():
loss_avg[key] = sum(val) / len(val)
for key, val in accuracy.items():
acc_avg[key] = sum(val) / len(val)
res = {}
res.update(loss_avg)
res.update(acc_avg)
return res