Source code for mlreco.models.layers.cluster_cnn.losses.spatial_embeddings

import torch
import torch.nn as nn

from .lovasz import lovasz_hinge_flat, StableBCELoss, iou_binary
from .misc import *
from collections import defaultdict

[docs]class MaskBCELoss(nn.Module): ''' Loss function for Sparse Spatial Embeddings Model, with fixed centroids and symmetric gaussian kernels. '''
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskBCELoss, self).__init__() self.loss_config = cfg.get(name, {}) self.seediness_weight = self.loss_config.get('seediness_weight', 0.0) self.embedding_weight = self.loss_config.get('embedding_weight', 1.0) self.smoothing_weight = self.loss_config.get('smoothing_weight', 1.0) self.spatial_size = self.loss_config.get('spatial_size', 512) self.loss_scheme = self.loss_config.get('loss_scheme', 'BCE') if self.loss_scheme == 'BCE': self.mask_loss = StableBCELoss() elif self.loss_scheme == 'lovasz_hinge': self.mask_loss = lovasz_hinge_flat elif self.loss_scheme == 'focal': raise NotImplementedError else: raise ValueError( 'Invalid loss scheme: {}'.format(self.loss_scheme)) # BCELoss for Embedding Loss self.bceloss = StableBCELoss() # L2 Loss for Seediness and Smoothing self.l2loss = torch.nn.L1Loss(reduction='mean')
[docs] def find_cluster_means(self, features, labels): ''' For a given image, compute the centroids mu_c for each cluster label in the embedding space. Inputs: features (torch.Tensor): the pixel embeddings, shape=(N, d) where N is the number of pixels and d is the embedding space dimension. labels (torch.Tensor): ground-truth group labels, shape=(N, ) Returns: cluster_means (torch.Tensor): (n_c, d) tensor where n_c is the number of distinct instances. Each row is a (1,d) vector corresponding to the coordinates of the i-th centroid. ''' clabels = labels.unique(sorted=True) cluster_means = [] for c in clabels: index = (labels == c) mu_c = features[index].mean(0) cluster_means.append(mu_c) cluster_means = torch.stack(cluster_means) return cluster_means
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, coords): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(coords, labels) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) acc = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1.0 mask[~index] = 0.0 sigma = torch.mean(margins[index], dim=0) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) p = torch.clamp(torch.exp(-dists / (2 * torch.pow(sigma, 2))), min=0, max=1) probs[index] = p[index] loss += self.bceloss(p, mask) acc += iou_binary(p > 0.5, mask) sigma_detach = sigma.detach() smoothing_loss += torch.sum(torch.pow(margins[index] - sigma_detach, 2)) loss /= n_clusters smoothing_loss /= n_clusters acc /= n_clusters return loss, smoothing_loss, probs, acc
[docs] def combine_multiclass(self, embeddings, margins, seediness, slabels, clabels, coords): ''' Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS. NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud. INPUTS: features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels OUTPUT: loss_segs (list): list of computed loss values for each semantic class. loss[i] = computed DLoss for semantic class <i>. acc_segs (list): list of computed clustering accuracy for each semantic class. ''' loss = defaultdict(list) accuracy = defaultdict(float) semantic_classes = slabels.unique() #print(semantic_classes) for sc in semantic_classes: if int(sc) == 4: continue index = (slabels == sc) mask_loss, smoothing_loss, probs, acc = self.get_per_class_probabilities( embeddings[index], margins[index], clabels[index], coords[index]) prob_truth = probs.detach() seed_loss = self.l2loss(prob_truth, seediness[index].squeeze(1)) total_loss = self.embedding_weight * mask_loss \ + self.seediness_weight * seed_loss \ + self.smoothing_weight * smoothing_loss loss['loss'].append(total_loss) loss['mask_loss'].append(float(self.embedding_weight * mask_loss)) loss['seed_loss'].append(float(self.seediness_weight * seed_loss)) loss['smoothing_loss'].append(float(self.smoothing_weight * smoothing_loss)) loss['mask_loss_{}'.format(int(sc))].append(float(mask_loss)) loss['seed_loss_{}'.format(int(sc))].append(float(seed_loss)) accuracy['accuracy_{}'.format(int(sc))] = acc return loss, accuracy
[docs] def forward(self, out, segment_label, group_label): num_gpus = len(segment_label) loss = defaultdict(list) accuracy = defaultdict(list) for i in range(num_gpus): slabels = segment_label[i][:, -1] #coords = segment_label[i][:, :3].float() #if torch.cuda.is_available(): # coords = coords.cuda() slabels = slabels.int() clabels = group_label[i][:, -1] batch_idx = segment_label[i][:, 3] embedding = out['embeddings'][i] seediness = out['seediness'][i] margins = out['margins'][i] nbatch = batch_idx.unique().shape[0] for bidx in batch_idx.unique(sorted=True): embedding_batch = embedding[batch_idx == bidx] slabels_batch = slabels[batch_idx == bidx] clabels_batch = clabels[batch_idx == bidx] seed_batch = seediness[batch_idx == bidx] margins_batch = margins[batch_idx == bidx] loss_class, acc_class = self.combine_multiclass( embedding_batch, margins_batch, seed_batch, slabels_batch, clabels_batch) for key, val in loss_class.items(): loss[key].append(sum(val) / len(val)) for s, acc in acc_class.items(): accuracy[s].append(acc) acc = sum(acc_class.values()) / len(acc_class.values()) accuracy['accuracy'].append(acc) loss_avg = {} acc_avg = defaultdict(float) for key, val in loss.items(): loss_avg[key] = sum(val) / len(val) for key, val in accuracy.items(): acc_avg[key] = sum(val) / len(val) res = {} res.update(loss_avg) res.update(acc_avg) return res
[docs]class MaskBCELoss2(MaskBCELoss): ''' Spatial Embeddings Loss with trainable center of attention. '''
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskBCELoss2, self).__init__(cfg, name)
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, coords): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) acc = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1.0 mask[~index] = 0.0 sigma = torch.mean(margins[index], dim=0) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) p = torch.clamp(torch.exp(-dists / (2 * torch.pow(sigma, 2))), min=0, max=1) probs[index] = p[index] loss += self.bceloss(p, mask) acc += iou_binary(p > 0.5, mask, per_image=False) sigma_detach = sigma.detach() smoothing_loss += torch.sum(torch.pow(margins[index] - sigma_detach, 2)) loss /= n_clusters smoothing_loss /= n_clusters acc /= n_clusters return loss, smoothing_loss, probs, acc
[docs]class MaskBCELossBivariate(MaskBCELoss): ''' Spatial Embeddings Loss with trainable center of attraction and bivariate gaussian probability kernels. '''
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskBCELossBivariate, self).__init__(cfg, name)
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, coords): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) acc = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1.0 mask[~index] = 0.0 sigma = torch.mean(margins[index], dim=0) dists = torch.pow(embeddings - centroids[i], 2) dists = dists / (2 * torch.pow(sigma, 2)) p = torch.clamp(torch.exp(-torch.sum(dists, dim=1)), min=0, max=1) probs[index] = p[index] loss += self.bceloss(p, mask) sigma_detach = sigma.detach() smoothing_loss += torch.sum(torch.pow(margins[index] - sigma_detach, 2)) loss /= n_clusters smoothing_loss /= n_clusters acc /= n_clusters return loss, smoothing_loss, probs, acc
[docs]class MaskLovaszHingeLoss(MaskBCELoss2): ''' Spatial Embeddings Loss using Lovasz Hinge for foreground/background segmentation and trainable center of attention. '''
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskLovaszHingeLoss, self).__init__(cfg, name)
[docs] def get_per_class_probabilities(self, embeddings, margins, labels): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) acc = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1 mask[~index] = 0 sigma = torch.mean(margins[index], dim=0) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) p = torch.exp(-dists / (2 * torch.pow(sigma, 2) + 1e-6) ) probs[index] = p[index] loss += lovasz_hinge_flat(2 * p - 1, mask) sigma_detach = sigma.detach() smoothing_loss += torch.sum(torch.pow(margins[index] - sigma_detach, 2)) loss /= n_clusters smoothing_loss /= n_clusters acc /= n_clusters return loss, smoothing_loss, probs, acc
[docs]class CELovaszLoss(MaskBCELoss2):
[docs] def __init__(self, cfg, name='spice_loss'): super(CELovaszLoss, self).__init__(cfg, name)
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, coords): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) acc = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1.0 mask[~index] = 0.0 sigma = torch.mean(margins[index], dim=0) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) p = torch.clamp(torch.exp(-dists / (2 * torch.pow(sigma, 2) + 1e-8)), min=0, max=1) probs[index] = p[index] loss += (self.bceloss(p, mask) + lovasz_hinge_flat(2.0 * p - 1, mask)) / 2 acc += iou_binary(p > 0.5, mask, per_image=False) sigma_detach = sigma.detach() smoothing_loss += torch.mean(torch.norm(margins[index] - sigma_detach, dim=1)) loss /= n_clusters smoothing_loss /= n_clusters acc /= n_clusters return loss, smoothing_loss, probs, acc
[docs]class MaskLovaszInterLoss(MaskLovaszHingeLoss):
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskLovaszInterLoss, self).__init__(cfg, name) self.inter_weight = self.loss_config.get('inter_weight', 1.0) self.norm = 2
[docs] def regularization(self, cluster_means): ''' Implementation of regularization loss in Discriminative Loss Inputs: cluster_means (torch.Tensor): output from find_cluster_means Returns: reg_loss (float): computed regularization loss (see paper). ''' reg_loss = 0.0 n_clusters, _ = cluster_means.shape for i in range(n_clusters): reg_loss += torch.norm(cluster_means[i, :] + 1e-8, p=self.norm) reg_loss /= float(n_clusters) return reg_loss
[docs] def inter_cluster_loss(self, cluster_means, margin=0.2): ''' Implementation of distance loss in Discriminative Loss. Inputs: cluster_means (torch.Tensor): output from find_cluster_means margin (float/int): the magnitude of the margin delta_d in the paper. Think of it as the distance between each separate clusters in embedding space. Returns: inter_loss (float): computed cross-centroid distance loss (see paper). Factor of 2 is included for proper normalization. ''' inter_loss = 0.0 n_clusters = len(cluster_means) if n_clusters < 2: # Inter-cluster loss is zero if there only one instance exists for # a semantic label. return 0.0 else: for i, c1 in enumerate(cluster_means): for j, c2 in enumerate(cluster_means): if i != j: dist = torch.norm(c1 - c2 + 1e-8, p=self.norm) hinge = torch.clamp(2.0 * margin - dist, min=0) inter_loss += torch.pow(hinge, 2) inter_loss /= float((n_clusters - 1) * n_clusters) return inter_loss
[docs] def get_per_class_probabilities(self, embeddings, margins, labels): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) inter_loss = self.inter_cluster_loss(centroids) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) accuracy = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1.0 mask[~index] = 0.0 sigma = torch.mean(margins[index], dim=0) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) p = torch.clamp(torch.exp(-dists / (2 * torch.pow(sigma, 2))), min=0, max=1) logits = logit_fn(p, eps=1e-6) # print(logits.shape) probs[index] = p[index] loss += lovasz_hinge_flat(logits, mask).mean() accuracy += float(iou_binary(p > 0.5, mask, per_image=False)) sigma_detach = sigma.detach() smoothing_loss += torch.mean(torch.norm(margins[index].view(-1, 1) - sigma_detach, dim=1)) loss /= n_clusters smoothing_loss /= n_clusters accuracy /= n_clusters loss += inter_loss return loss, smoothing_loss, float(inter_loss), probs, accuracy
[docs] def combine_multiclass(self, embeddings, margins, seediness, slabels, clabels): ''' Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS. NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud. INPUTS: features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels OUTPUT: loss_segs (list): list of computed loss values for each semantic class. loss[i] = computed DLoss for semantic class <i>. acc_segs (list): list of computed clustering accuracy for each semantic class. ''' loss = defaultdict(list) accuracy = defaultdict(float) semantic_classes = slabels.unique() for sc in semantic_classes: if int(sc) == 4: continue index = (slabels == sc) clabels_unique, _ = unique_label_torch(clabels[index]) mask_loss, smoothing_loss, inter_loss, probs, acc = \ self.get_per_class_probabilities( embeddings[index], margins[index], clabels_unique) prob_truth = probs.detach() seed_loss = self.l2loss(prob_truth, seediness[index].squeeze(1)) total_loss = self.embedding_weight * mask_loss \ + self.seediness_weight * seed_loss \ + self.smoothing_weight * smoothing_loss loss['loss'].append(total_loss) loss['mask_loss'].append( float(self.embedding_weight * mask_loss)) loss['seed_loss'].append( float(self.seediness_weight * seed_loss)) loss['smoothing_loss'].append( float(self.smoothing_weight * smoothing_loss)) loss['inter_loss'].append( float(self.inter_weight * inter_loss)) loss['mask_loss_{}'.format(int(sc))].append(float(mask_loss)) loss['seed_loss_{}'.format(int(sc))].append(float(seed_loss)) accuracy['accuracy_{}'.format(int(sc))] = acc return loss, accuracy
[docs]class MaskLovaszInterLoss2(MaskLovaszInterLoss):
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskLovaszInterLoss2, self).__init__(cfg, name) self.inter_weight = self.loss_config.get('inter_weight', 1.0) self.norm = 2 self.seed_loss = torch.nn.BCEWithLogitsLoss() self.seed_threshold = self.loss_config.get('seed_threshold', 0.98)
[docs] def get_per_class_probabilities(self, embeddings, margins, labels): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 sigma_reg_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) # inter_loss = self.inter_cluster_loss(centroids) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) accuracy = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1 mask[~index] = 0 sigma = torch.mean(margins[index]) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) p = torch.exp(-dists / (2 * torch.pow(sigma, 2) + 1e-8)) probs[index] = p[index] loss += lovasz_hinge_flat(2 * p - 1, mask) accuracy += float(iou_binary(p > 0.5, mask, per_image=False)) sigma_detach = sigma.detach() smoothing_loss += torch.mean(torch.norm(margins[index] - sigma_detach, dim=1)) sigma_reg_loss += torch.log(1.0 + sigma) loss /= n_clusters smoothing_loss /= n_clusters accuracy /= n_clusters sigma_reg_loss /= n_clusters loss += sigma_reg_loss return loss, smoothing_loss, probs, accuracy
[docs] def combine_multiclass(self, embeddings, margins, seediness, slabels, clabels): ''' Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS. NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud. INPUTS: features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels OUTPUT: loss_segs (list): list of computed loss values for each semantic class. loss[i] = computed DLoss for semantic class <i>. acc_segs (list): list of computed clustering accuracy for each semantic class. ''' loss = defaultdict(list) accuracy = defaultdict(float) semantic_classes = slabels.unique() for sc in semantic_classes: if int(sc) == 4: continue index = (slabels == sc) mask_loss, smoothing_loss, probs, acc = \ self.get_per_class_probabilities( embeddings[index], margins[index], clabels[index]) prob_truth = probs.detach() target = (probs > self.seed_threshold).float() seed_loss = self.seed_loss(seediness[index].squeeze(1), target) total_loss = self.embedding_weight * mask_loss \ + self.seediness_weight * seed_loss \ + self.smoothing_weight * smoothing_loss loss['loss'].append(total_loss) loss['mask_loss'].append( float(self.embedding_weight * mask_loss)) loss['seed_loss'].append( float(self.seediness_weight * seed_loss)) loss['smoothing_loss'].append( float(self.smoothing_weight * smoothing_loss)) # loss['inter_loss'].append( # float(self.inter_weight * inter_loss)) loss['mask_loss_{}'.format(int(sc))].append(float(mask_loss)) loss['seed_loss_{}'.format(int(sc))].append(float(seed_loss)) accuracy['accuracy_{}'.format(int(sc))] = acc return loss, accuracy
[docs]class MaskLovaszInterBC(MaskLovaszInterLoss):
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskLovaszInterBC, self).__init__(cfg, name) self.inter_weight = self.loss_config.get('inter_weight', 1.0) self.norm = 2 self.seed_loss = torch.nn.BCEWithLogitsLoss() self.seed_threshold = self.loss_config.get('seed_threshold', 0.98)
[docs] def inter_cluster_loss(self, embeddings, margins, labels): vecs = torch.cat([embeddings, margins], dim=1) v = self.find_cluster_means(vecs, labels) mat = bhattacharyya_coeff_matrix(v, v) ind = torch.triu_indices(v.shape[0], v.shape[0], offset=1) inter = mat[ind[0], ind[1]] if inter.shape[0] == 0: return 0 else: out = torch.mean(inter) return out
[docs] def get_per_class_probabilities(self, embeddings, margins, labels): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 sigma_reg_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) inter_loss = self.inter_cluster_loss(embeddings, margins, labels) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) accuracy = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1 mask[~index] = 0 sigma = torch.mean(margins[index]) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) p = torch.exp(-dists / (2 * torch.pow(sigma, 2) + 1e-8)) probs[index] = p[index] loss += lovasz_hinge_flat(2 * p - 1, mask) accuracy += float(iou_binary(p > 0.5, mask, per_image=False)) sigma_detach = sigma.detach() smoothing_loss += torch.mean(torch.norm(margins[index] - sigma_detach, dim=1)) sigma_reg_loss += torch.log(1.0 + sigma) loss /= n_clusters smoothing_loss /= n_clusters accuracy /= n_clusters sigma_reg_loss /= n_clusters loss += sigma_reg_loss loss += inter_loss return loss, smoothing_loss, float(inter_loss), probs, accuracy
[docs] def combine_multiclass(self, embeddings, margins, seediness, slabels, clabels): ''' Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS. NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud. INPUTS: features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels OUTPUT: loss_segs (list): list of computed loss values for each semantic class. loss[i] = computed DLoss for semantic class <i>. acc_segs (list): list of computed clustering accuracy for each semantic class. ''' loss = defaultdict(list) accuracy = defaultdict(float) semantic_classes = slabels.unique() for sc in semantic_classes: if int(sc) == 4: continue index = (slabels == sc) mask_loss, smoothing_loss, inter_loss, probs, acc = \ self.get_per_class_probabilities( embeddings[index], margins[index], clabels[index]) prob_truth = probs.detach() seed_loss = self.l2loss(prob_truth, seediness[index].squeeze(1)) total_loss = self.embedding_weight * mask_loss \ + self.seediness_weight * seed_loss \ + self.smoothing_weight * smoothing_loss loss['loss'].append(total_loss) loss['mask_loss'].append( float(self.embedding_weight * mask_loss)) loss['seed_loss'].append( float(self.seediness_weight * seed_loss)) loss['smoothing_loss'].append( float(self.smoothing_weight * smoothing_loss)) loss['inter_loss'].append( float(self.inter_weight * inter_loss)) loss['mask_loss_{}'.format(int(sc))].append(float(mask_loss)) loss['seed_loss_{}'.format(int(sc))].append(float(seed_loss)) accuracy['accuracy_{}'.format(int(sc))] = acc return loss, accuracy
[docs]class MaskFocalLoss(MaskBCELoss2): ''' Spatial Embeddings Loss with trainable center of attention. '''
[docs] def __init__(self, cfg, name='spice_loss'): super(MaskFocalLoss, self).__init__(cfg, name) self.bceloss = FocalLoss(logits=False)
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, coords): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) acc = 0.0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1.0 mask[~index] = 0.0 sigma = torch.mean(margins[index], dim=0) dists = torch.sum(torch.pow(embeddings - centroids[i], 2), dim=1) logits = dists / (2 * torch.pow(sigma, 2)) p = torch.clamp(torch.exp(-dists / (2 * torch.pow(sigma, 2))), min=0, max=1) probs[index] = p[index] loss += self.bceloss(p, mask) acc += iou_binary(p > 0.5, mask, per_image=False) sigma_detach = sigma.detach() smoothing_loss += torch.sum(torch.pow(margins[index] - sigma_detach, 2)) loss /= n_clusters smoothing_loss /= n_clusters acc /= n_clusters return loss, smoothing_loss, probs, acc
[docs]class MultiVariateLovasz(MaskLovaszInterLoss):
[docs] def __init__(self, cfg, name='spice_loss'): super(MultiVariateLovasz, self).__init__(cfg, name)
[docs] def get_per_class_probabilities(self, embeddings, margins, labels, coords): ''' Computes binary foreground/background loss. ''' device = embeddings.device loss = 0.0 smoothing_loss = 0.0 centroids = self.find_cluster_means(embeddings, labels) inter_loss = self.inter_cluster_loss(centroids) reg_loss = self.regularization(centroids) n_clusters = len(centroids) cluster_labels = labels.unique(sorted=True) probs = torch.zeros(embeddings.shape[0]).float().to(device) accuracy = 0.0 if embeddings.shape[0] < 2: return 0, 0, 0, 0, 0 for i, c in enumerate(cluster_labels): index = (labels == c) mask = torch.zeros(embeddings.shape[0]).to(device) mask[index] = 1 mask[~index] = 0 sigma = torch.mean(margins[index], dim=0) f = multivariate_kernel(centroids[i], sigma) p = f(embeddings) probs[index] = p[index] loss += lovasz_hinge_flat(2 * p - 1, mask) accuracy += iou_binary(p > 0.5, mask, per_image=False) sigma_detach = sigma.detach() smoothing_loss += torch.sum(torch.pow(margins[index] - sigma_detach, 2)) loss /= n_clusters smoothing_loss /= n_clusters accuracy /= n_clusters loss += inter_loss loss += reg_loss / n_clusters return loss, smoothing_loss, inter_loss, probs, accuracy