Source code for mlreco.models.layers.cluster_cnn.losses.misc

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import fps, knn

from .lovasz import StableBCELoss, lovasz_hinge, lovasz_softmax_flat

# Collection of Miscellaneous Loss Functions not yet implemented in Pytorch.

[docs]def logit_fn(input, eps=1e-6): x = torch.clamp(input, min=eps, max=1-eps) return torch.log(x / (1 - x))
[docs]def unique_label_torch(label): _, label2, cts = torch.unique(label, return_inverse=True, return_counts=True) return label2, cts
[docs]def iou_batch(pred: torch.BoolTensor, labels: torch.BoolTensor, eps=0.0): ''' pred: N x C labels: N x C (one-hot) ''' intersection = (pred & labels).float().sum(0) union = (pred | labels).float().sum(0) iou = (intersection + eps) / (union + eps) # We smooth our devision to avoid 0/0 return iou.mean()
[docs]class VectorEstimationLoss(nn.Module):
[docs] def __init__(self, cfg, name='vector_estimation_loss'): super(VectorEstimationLoss, self).__init__() self.loss_config = cfg[name] self.fps_ratio = self.loss_config.get('fps_ratio', 0.1) self.k = self.loss_config.get('k', 20) self.D = 3 self.cos = nn.CosineSimilarity(dim=1) self.eps = self.loss_config.get('eps', 1e-6)
[docs] def compute_loss_single_graph(self, vec_pred, pos): ''' INPUTS: - x (N x 4) : sin/cos predictions for phi and theta. - pos (N x 3 Tensor): spatial coordinates (batch, semantic_id) ''' with torch.no_grad(): anchors = fps(pos, ratio=self.fps_ratio) anchors_pos = pos[anchors] index = knn(pos, anchors_pos, k=self.k) nbhds = pos[index[1, :]].view(-1, self.k, self.D) U, S, V = torch.pca_lowrank(nbhds) vecs = V[:, :, 0] abs_cos = torch.abs(self.cos(vec_pred[index[0, :]], vecs[index[0, :]])) vec_loss = -torch.log(abs_cos + self.eps).mean() return { 'vec_loss': vec_loss, 'abs_cos': abs_cos, }
[docs] def forward(self, graph): loss = [] num_graphs = torch.unique(graph.batch).shape[0] for i in range(num_graphs): print(i) subgraph = graph.get_example(i) res = self.compute_loss_single_graph( subgraph.x[:, 3:6], subgraph.pos ) loss.append(res['vec_loss']) # result['abs_cos'].append(res['abs_cos']) loss = sum(loss) / len(loss) return loss
[docs]class BinaryLogDiceLoss(torch.nn.Module):
[docs] def __init__(self, gamma=1): super(BinaryLogDiceLoss, self).__init__()
[docs] def forward(self, logits, targets, eps=1e-6): p = torch.sigmoid(logits) p = (logits < 0).float() num = 2.0 * p[targets].sum() denom = p.sum() + targets.sum() dice = torch.clamp((num + eps) / (denom + eps), min=eps, max=1-eps) return -torch.log(dice)
[docs]class IoUScore(torch.nn.Module):
[docs] def __init__(self): super(IoUScore, self).__init__()
[docs] def forward(self, y_pred, y_true): iou = 0 intersection = (y_pred.long() == 1) & (y_true.long() == 1) union = (y_pred.long() == 1) | (y_true.long() == 1) if not union.any(): iou = 0 else: iou = float(intersection.sum()) / float(union.sum()) return iou
[docs]class BinaryCELogDiceLoss(torch.nn.Module):
[docs] def __init__(self, gamma=0.3, w_ce=0.2, w_dice=0.8): super(BinaryCELogDiceLoss, self).__init__() self.ce = F.binary_cross_entropy_with_logits self.gamma = gamma self.w_ce = w_ce self.w_dice = w_dice
[docs] def forward(self, logits, targets, weight=None, eps=0.001, reduction='none'): bceloss = self.ce(logits, targets, weight=weight, reduction=reduction) bce = bceloss.mean() # if weight is not None: # bce = (weight * torch.pow(bceloss, self.gamma)).mean() # else: # bce = torch.pow(bceloss, self.gamma).mean() p = torch.sigmoid(logits) num = 2.0 * p[targets > 0.5].sum() denom = (p**2).sum() + (targets**2).sum() dice = torch.clamp((num + eps) / (denom + eps), min=eps, max=1-eps) dice_loss = -torch.log(dice) # print("CE = {}, Dice = {} ({})".format(bce, dice_loss, dice)) return self.w_ce * bce + self.w_dice * dice_loss
[docs]class MincutLoss(BinaryCELogDiceLoss):
[docs] def __init__(self, mincut_weight=1.0, **kwargs): super(MincutLoss, self).__init__(**kwargs) self.w_mc = mincut_weight
[docs] def forward(self, logits, targets, weight=None, eps=0.001, reduction='none'): bceloss = self.ce(logits, targets, weight=weight, reduction=reduction) bce = bceloss.mean() p = torch.sigmoid(logits) num = 2.0 * p[targets > 0.5].sum() denom = (p**2).sum() + (targets**2).sum() dice = torch.clamp((num + eps) / (denom + eps), min=eps, max=1-eps) dice_loss = -torch.log(dice) # MinCut mincut_loss = (1.0 - p[targets > 0.5]).sum() print("CE = {:.5f}, Dice = {:.5f}, Mincut = {:.5f}".format(bce, dice_loss, mincut_loss)) return self.w_ce * bce + self.w_dice * dice_loss + self.w_mc * mincut_loss
[docs]class LovaszHingeLoss(torch.nn.modules.loss._Loss):
[docs] def __init__(self, reduction='none'): super(LovaszHingeLoss, self).__init__(reduction=reduction)
[docs] def forward(self, logits, targets): num_clusters = targets.shape[1] return lovasz_hinge(logits.T.view(num_clusters, 1, -1), targets.T.view(num_clusters, 1, -1))
[docs]class LovaszSoftmaxWithLogitsLoss(torch.nn.modules.loss._Loss):
[docs] def __init__(self, reduction='none'): super(LovaszSoftmaxWithLogitsLoss, self).__init__(reduction=reduction) self.softmax = nn.Softmax(dim=1)
[docs] def forward(self, logits, targets): probs = self.softmax(logits) return lovasz_softmax_flat(probs, targets)
[docs]def find_cluster_means(features, labels): ''' For a given image, compute the centroids mu_c for each cluster label in the embedding space. Inputs: features (torch.Tensor): the pixel embeddings, shape=(N, d) where N is the number of pixels and d is the embedding space dimension. labels (torch.Tensor): ground-truth group labels, shape=(N, ) Returns: cluster_means (torch.Tensor): (n_c, d) tensor where n_c is the number of distinct instances. Each row is a (1,d) vector corresponding to the coordinates of the i-th centroid. ''' device = features.device bincount = torch.bincount(labels) zero_bins = bincount > 0 bincount[bincount == 0] = 1.0 numerator = torch.zeros(bincount.shape[0], features.shape[1]).to(device) numerator = numerator.index_add(0, labels, features) centroids = numerator / bincount.view(-1, 1) centroids = centroids[zero_bins] return centroids
[docs]def intra_cluster_loss(features, cluster_means, labels, margin=1.0): from torch_scatter import scatter_mean x = features[:, None, :] mu = cluster_means[None, :, :] l = torch.clamp(torch.norm(x - mu, dim=-1) - margin, min=0)**2 l = torch.gather(l, 1, labels.view(-1, 1)).squeeze() if len(l.size()) and len(labels.size()): intra_loss = torch.mean(scatter_mean(l, labels)) return intra_loss else: # print('intra_cluster_loss', l.size(), labels.size()) return 0.0
[docs]def inter_cluster_loss(cluster_means, margin=0.2): inter_loss = 0.0 n_clusters = len(cluster_means) if n_clusters < 2: # Inter-cluster loss is zero if there only one instance exists for # a semantic label. return 0.0 else: indices = torch.triu_indices(cluster_means.shape[0], cluster_means.shape[0], 1) dist = squared_distances(cluster_means, cluster_means) return torch.pow(torch.clamp(2.0 * margin - dist[indices[0, :], \ indices[1, :]], min=0), 2).mean()
[docs]def regularization_loss(cluster_means): return torch.mean(torch.norm(cluster_means, dim=1))
[docs]def margin_smoothing_loss(sigma, sigma_means, labels, margin=0): from torch_scatter import scatter_mean x = sigma[:, None] mu = sigma_means[None, :] l = torch.sqrt(torch.clamp(torch.abs(x-mu) - margin, min=0)**2 + 1e-6) l = torch.gather(l, 1, labels.view(-1, 1)).view(-1) loss = torch.mean(scatter_mean(l, labels)) return loss
[docs]def get_probs(embeddings, margins, labels, eps=1e-6): from torch_scatter import scatter_mean device = embeddings.device n = labels.shape[0] centroids = find_cluster_means(embeddings, labels) sigma = scatter_mean(margins.squeeze(), labels) num_clusters = labels.unique().shape[0] # Compute spatial term em = embeddings[:, None, :] centroids = centroids[None, :, :] sqdists = ((em - centroids)**2).sum(-1) p = sqdists / (2.0 * sigma.view(1, -1)**2) p = torch.clamp(torch.exp(-p), min=eps, max=1-eps) logits = logit_fn(p, eps=eps) eye = torch.eye(len(labels.unique()), dtype=torch.float32, device=device) targets = eye[labels] loss_tensor = nn.BCEWithLogitsLoss(reduction='none')(logits, targets) loss = loss_tensor.mean(dim=0).mean() with torch.no_grad(): acc = iou_batch(logits > 0, targets.bool()) smoothing_loss = margin_smoothing_loss(margins.squeeze(), sigma.detach(), labels, margin=0) p = torch.gather(p, 1, labels.view(-1, 1)) return loss, smoothing_loss, p.squeeze(), acc
[docs]def multivariate_kernel(centroid, log_sigma, Lprime, eps=1e-8): def f(x): N = x.shape[0] L = torch.zeros(3, 3) tril_indices = torch.tril_indices(row=3, col=3, offset=-1) L[tril_indices[0], tril_indices[1]] = Lprime sigma = torch.exp(log_sigma) + eps L += torch.diag(sigma) cov = torch.matmul(L, L.T) dist = torch.matmul((x - centroid), torch.inverse(cov)) dist = torch.bmm(dist.view(N, 1, -1), (x-centroid).view(N, -1, 1)).squeeze() probs = torch.exp(-dist) return probs return f
[docs]def bhattacharyya_distance_matrix(v1, v2, eps=1e-8): x1, s1 = v1[:, :3], v1[:, 3].view(-1) x2, s2 = v2[:, :3], v1[:, 3].view(-1) g1 = torch.ger(s1**2, 1.0 / (s2**2 + eps)) g2 = g1.t() dist = squared_distances(x1.contiguous(), x2.contiguous()) denom = 1.0 / (eps + s1.unsqueeze(1)**2 + s2**2) out = 0.25 * torch.log(0.25 * (g1 + g2 + 2)) + 0.25 * dist / denom return out
[docs]def squared_distances(v1, v2): v1_2 = v1.unsqueeze(1).expand(v1.size(0), v2.size(0), v1.size(1)).double() v2_2 = v2.unsqueeze(0).expand(v1.size(0), v2.size(0), v1.size(1)).double() return torch.pow(v2_2 - v1_2, 2).sum(2)
[docs]def bhattacharyya_coeff_matrix(v1, v2, eps=1e-6): x1, s1 = v1[:, :3], v1[:, 3].view(-1) x2, s2 = v2[:, :3], v1[:, 3].view(-1) g1 = torch.ger(s1**2, 1.0 / (s2**2 + eps)) g2 = g1.t() dist = torch.cidst(x1.contiguous(), x2.contiguous()) denom = 1.0 / (eps + s1.unsqueeze(1)**2 + s2**2) out = 0.25 * torch.log(0.25 * (g1 + g2 + 2)) + 0.25 * dist / denom out = torch.exp(-out) return out
[docs]def get_graphspice_logits(sp_emb, ft_emb, cov, groups, sp_centroids, ft_centroids, eps=0.001, compute_accuracy=True): device = sp_emb.device cov_means = find_cluster_means(cov, groups) # Compute spatial term sp_emb_tmp = sp_emb[:, None, :] sp_centroids_tmp = sp_centroids[None, :, :] sp_cov = torch.clamp(cov_means[:, 0][None, :], min=eps) sp_sqdists = ((sp_emb_tmp - sp_centroids_tmp)**2).sum(-1) / (sp_cov**2) # Compute feature term ft_emb_tmp = ft_emb[:, None, :] ft_centroids_tmp = ft_centroids[None, :, :] ft_cov = torch.clamp(cov_means[:, 0][None, :], min=eps) ft_sqdists = ((ft_emb_tmp - ft_centroids_tmp)**2).sum(-1) / (ft_cov**2) # Compute joint kernel score pvec = torch.exp(-sp_sqdists - ft_sqdists) # probs = (1-pvec).index_put((torch.arange(groups.shape[0]), groups), # torch.gather(pvec, 1, groups.view(-1, 1)).squeeze()) logits = logit_fn(pvec) acc = None eye = torch.eye(len(groups.unique()), dtype=torch.float32, device=device) targets = eye[groups] if compute_accuracy: acc = iou_batch(logits > 0, targets.bool()) return logits, acc, targets
[docs]class FocalLoss(nn.Module): ''' Original Paper: https://arxiv.org/abs/1708.02002 Implementation: https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65938 '''
[docs] def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.logits = logits self.reduce = reduce self.stable_bce = StableBCELoss()
[docs] def forward(self, inputs, targets): if self.logits: BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') else: BCE_loss = self.stable_bce(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss if self.reduce: return torch.mean(F_loss) else: return F_loss
[docs]class WeightedFocalLoss(FocalLoss):
[docs] def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): super(WeightedFocalLoss, self).__init__(alpha=alpha, gamma=gamma, logits=logits, reduce=reduce)
[docs] def forward(self, inputs, targets): with torch.no_grad(): pos_weight = torch.sum(targets == 0) / (1.0 + torch.sum(targets == 1)) weight = torch.ones(inputs.shape[0]).cuda() weight[targets == 1] = pos_weight if self.logits: BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') else: BCE_loss = self.stable_bce(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss F_loss = torch.mul(F_loss, weight) if self.reduce: return torch.mean(F_loss) else: return F_loss