Source code for mlreco.models.layers.cluster_cnn.losses.single_layers

import torch
from collections import defaultdict


[docs]class DiscriminativeLoss(torch.nn.Module): ''' Implementation of the Discriminative Loss Function in Pytorch. https://arxiv.org/pdf/1708.02551.pdf Note that there are many other implementations in Github, yet here we tailor it for use in conjuction with Sparse UResNet. '''
[docs] def __init__(self, cfg, reduction='sum'): super(DiscriminativeLoss, self).__init__() self.loss_config = cfg['spice_loss'] self.num_classes = self.loss_config.get('num_classes', 5) self.depth = self.loss_config.get('stride', 5) # Clustering Loss Parameters self.loss_hyperparams = {} self.loss_hyperparams['intra_weight'] = self.loss_config.get('intra_weight', 1.0) self.loss_hyperparams['inter_weight'] = self.loss_config.get('inter_weight', 1.0) self.loss_hyperparams['reg_weight'] = self.loss_config.get('reg_weight', 0.001) self.loss_hyperparams['intra_margin'] = self.loss_config.get('intracluster_margin', 0.5) self.loss_hyperparams['inter_margin'] = self.loss_config.get('intercluster_margin', 1.5) self.dimension = self.loss_config.get('data_dim', 3) self.norm = self.loss_config.get('norm', 2) self.use_segmentation = self.loss_config.get('use_segmentation', True)
[docs] def find_cluster_means(self, features, labels): ''' For a given image, compute the centroids mu_c for each cluster label in the embedding space. Inputs: features (torch.Tensor): the pixel embeddings, shape=(N, d) where N is the number of pixels and d is the embedding space dimension. labels (torch.Tensor): ground-truth group labels, shape=(N, ) Returns: cluster_means (torch.Tensor): (n_c, d) tensor where n_c is the number of distinct instances. Each row is a (1,d) vector corresponding to the coordinates of the i-th centroid. ''' clabels = labels.unique(sorted=True) cluster_means = [] for c in clabels: index = (labels == c) mu_c = features[index].mean(0) cluster_means.append(mu_c) cluster_means = torch.stack(cluster_means) return cluster_means
[docs] def intra_cluster_loss(self, features, labels, cluster_means, margin=0.5): ''' Implementation of variance loss in Discriminative Loss. Inputs: features (torch.Tensor): pixel embedding, same as in find_cluster_means. labels (torch.Tensor): ground truth instance labels cluster_means (torch.Tensor): output from find_cluster_means margin (float/int): constant used to specify delta_v in paper. Think of it as the size of each clusters in embedding space. Returns: intra_loss: (float) variance loss (see paper). ''' intra_loss = 0.0 n_clusters = len(cluster_means) cluster_labels = labels.unique(sorted=True) for i, c in enumerate(cluster_labels): index = (labels == c) dists = torch.norm(features[index] - cluster_means[i] + 1e-8, p=self.norm, dim=1) hinge = torch.clamp(dists - margin, min=0) l = torch.mean(torch.pow(hinge, 2)) intra_loss += l intra_loss /= n_clusters return intra_loss
[docs] def inter_cluster_loss(self, cluster_means, margin=1.5): ''' Implementation of distance loss in Discriminative Loss. Inputs: cluster_means (torch.Tensor): output from find_cluster_means margin (float/int): the magnitude of the margin delta_d in the paper. Think of it as the distance between each separate clusters in embedding space. Returns: inter_loss (float): computed cross-centroid distance loss (see paper). Factor of 2 is included for proper normalization. ''' inter_loss = 0.0 n_clusters = len(cluster_means) if n_clusters < 2: # Inter-cluster loss is zero if there only one instance exists for # a semantic label. return 0.0 else: for i, c1 in enumerate(cluster_means): for j, c2 in enumerate(cluster_means): if i != j: dist = torch.norm(c1 - c2 + 1e-8, p=self.norm) hinge = torch.clamp(2.0 * margin - dist, min=0) inter_loss += torch.pow(hinge, 2) inter_loss /= float((n_clusters - 1) * n_clusters) return inter_loss
[docs] def regularization(self, cluster_means): ''' Implementation of regularization loss in Discriminative Loss Inputs: cluster_means (torch.Tensor): output from find_cluster_means Returns: reg_loss (float): computed regularization loss (see paper). ''' reg_loss = 0.0 n_clusters, _ = cluster_means.shape for i in range(n_clusters): reg_loss += torch.norm(cluster_means[i, :] + 1e-8, p=self.norm) reg_loss /= float(n_clusters) return reg_loss
[docs] def compute_heuristic_accuracy(self, embedding, truth): ''' Compute Adjusted Rand Index Score for given embedding coordinates, where predicted cluster labels are obtained from distance to closest centroid (computes heuristic accuracy). Inputs: embedding (torch.Tensor): (N, d) Tensor where 'd' is the embedding dimension. truth (torch.Tensor): (N, ) Tensor for the ground truth clustering labels. Returns: score (float): Computed ARI Score clustering (array): the predicted cluster labels. ''' from sklearn.metrics import adjusted_rand_score nearest = [] with torch.no_grad(): cmeans = self.find_cluster_means(embedding, truth) for centroid in cmeans: dists = torch.sum((embedding - centroid)**2, dim=1) dists = dists.view(-1, 1) nearest.append(dists) nearest = torch.cat(nearest, dim=1) nearest = torch.argmin(nearest, dim=1) pred = nearest.cpu().numpy() grd = truth.cpu().numpy() score = adjusted_rand_score(pred, grd) return score
[docs] def combine(self, features, labels, **kwargs): ''' Wrapper function for combining different components of the loss function. Inputs: features (torch.Tensor): pixel embeddings labels (torch.Tensor): ground-truth instance labels Returns: loss: combined loss, in most cases over a given semantic class. ''' # Clustering Loss Hyperparameters # We allow changing the parameters at each computation in order # to alter the margins at each spatial resolution in multi-scale losses. intra_margin = kwargs.get('intra_margin', 0.5) inter_margin = kwargs.get('inter_margin', 1.5) intra_weight = kwargs.get('intra_weight', 1.0) inter_weight = kwargs.get('inter_weight', 1.0) reg_weight = kwargs.get('reg_weight', 0.001) c_means = self.find_cluster_means(features, labels) inter_loss = self.inter_cluster_loss(c_means, margin=inter_margin) intra_loss = self.intra_cluster_loss(features, labels, c_means, margin=intra_margin) reg_loss = self.regularization(c_means) loss = intra_weight * intra_loss + inter_weight \ * inter_loss + reg_weight * reg_loss return { 'loss': loss, 'intra_loss': intra_weight * float(intra_loss), 'inter_loss': inter_weight * float(inter_loss), 'reg_loss': reg_weight * float(reg_loss) }
[docs] def combine_multiclass(self, features, slabels, clabels, **kwargs): ''' Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS. NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud. INPUTS: features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels OUTPUT: loss_segs (list): list of computed loss values for each semantic class. loss[i] = computed DLoss for semantic class <i>. acc_segs (list): list of computed clustering accuracy for each semantic class. ''' loss, acc_segs = defaultdict(list), defaultdict(float) semantic_classes = slabels.unique() for sc in semantic_classes: index = (slabels == sc) num_clusters = len(clabels[index].unique()) loss_blob = self.combine(features[index], clabels[index], **kwargs) for key, val in loss_blob.items(): loss[key].append(val) # loss['loss'].append(loss_blob['loss']) # loss['intra_loss'].append(loss_blob['intra_loss']) # loss['inter_loss'].append(loss_blob['inter_loss']) # loss['reg_loss'].append(loss_blob['reg_loss']) acc = self.compute_heuristic_accuracy(features[index], clabels[index]) acc_segs['accuracy_{}'.format(sc.item())] = acc return loss, acc_segs
[docs] def forward(self, out, semantic_labels, group_labels): ''' Forward function for the Discriminative Loss Module. Inputs: out: output of UResNet; embedding-space coordinates. semantic_labels: ground-truth semantic labels group_labels: ground-truth instance labels Returns: (dict): A dictionary containing key-value pairs for loss, accuracy, etc. ''' num_gpus = len(semantic_labels) loss = defaultdict(list) accuracy = defaultdict(list) for i in range(num_gpus): slabels = semantic_labels[i][:, -1] slabels = slabels.int() clabels = group_labels[i][:, -2] batch_idx = semantic_labels[i][:, 3] embedding = out['cluster_feature'][i] nbatch = batch_idx.unique().shape[0] for bidx in batch_idx.unique(sorted=True): embedding_batch = embedding[batch_idx == bidx] slabels_batch = slabels[batch_idx == bidx] clabels_batch = clabels[batch_idx == bidx] if self.use_segmentation: loss_dict, acc_segs = self.combine_multiclass( embedding_batch, slabels_batch, clabels_batch, **self.loss_hyperparams) for key, val in loss_dict.items(): loss[key].append(sum(val) / len(val)) for s, acc in acc_segs.items(): accuracy[s].append(acc) acc = sum(acc_segs.values()) / len(acc_segs.values()) accuracy['accuracy'].append(acc) else: loss["loss"].append(self.combine(embedding_batch, clabels_batch, **self.loss_hyperparams)) acc, _ = self.compute_heuristic_accuracy(embedding_batch, clabels_batch) accuracy['accuracy'].append(acc) loss_avg = {} acc_avg = defaultdict(float) for key, val in loss.items(): loss_avg[key] = sum(val) / len(val) for key, val in accuracy.items(): acc_avg[key] = sum(val) / len(val) res = {} res.update(loss_avg) res.update(acc_avg) return res