Source code for mlreco.models.layers.cluster_cnn.losses.gs_embeddings

import torch
import torch.nn as nn
import torch.nn.functional as F

from .misc import *
from collections import defaultdict
from torch_scatter import scatter_mean

[docs]class WeightedEdgeLoss(nn.Module):
[docs] def __init__(self, loss_type='BCE', reduction='mean', invert=False): super(WeightedEdgeLoss, self).__init__() self.reduction = reduction if loss_type == 'BCE': self.loss_fn = F.binary_cross_entropy_with_logits elif loss_type == 'LogDice': self.loss_fn = BinaryCELogDiceLoss() elif loss_type == 'MinCut': self.loss_fn = MincutLoss() else: self.loss_fn = F.binary_cross_entropy_with_logits self.invert = invert
[docs] def forward(self, logits, targets): if self.invert: y = (targets < 0.5).float() else: y = targets #(targets > 0.5).float() device = logits.device weight = torch.ones(targets.shape[0]).to(device) # The crucial error are the false positives, as these will # lead to overclustering. negatives_index = (targets < 0.5) negatives = float(torch.sum(negatives_index)) positives = float(torch.sum(targets > 0.5)) # w = positives / negatives weight[negatives_index] = 1.0 # with torch.no_grad(): # num_pos = torch.sum(y).item() # num_edges = y.shape[0] # w = 1.0 / (1.0 - float(num_pos) / num_edges) # weight[~y.bool()] = w # print(logits.shape, logits.max()) loss = self.loss_fn(logits, y.float(), weight=weight) return loss
[docs]def compute_edge_weight(sp_emb: torch.Tensor, ft_emb: torch.Tensor, cov: torch.Tensor, edge_indices: torch.Tensor, occ=None, eps=0.001): device = sp_emb.device ui, vi = edge_indices[0, :], edge_indices[1, :] sp_cov_i = cov[:, 0][ui] sp_cov_j = cov[:, 0][vi] sp_i = ((sp_emb[ui] - sp_emb[vi])**2).sum(dim=1) / (sp_cov_i**2 + eps) sp_j = ((sp_emb[ui] - sp_emb[vi])**2).sum(dim=1) / (sp_cov_j**2 + eps) ft_cov_i = cov[:, 1][ui] ft_cov_j = cov[:, 1][vi] ft_i = ((ft_emb[ui] - ft_emb[vi])**2).sum(dim=1) / (ft_cov_i**2 + eps) ft_j = ((ft_emb[ui] - ft_emb[vi])**2).sum(dim=1) / (ft_cov_j**2 + eps) p_ij = torch.exp(-sp_i-ft_i) p_ji = torch.exp(-sp_j-ft_j) pvec = torch.clamp(p_ij + p_ji - p_ij * p_ji, min=0, max=1) # pvec = torch.exp(- sp - ft) if occ is not None: r1 = occ[edge_indices[0, :]] r2 = occ[edge_indices[1, :]] r = torch.max((r2 + eps) / (r1 + eps), (r1 + eps) / (r2 + eps)) pvec = pvec / r return pvec
[docs]class GraphSPICEEmbeddingLoss(nn.Module): ''' Loss function for Sparse Spatial Embeddings Model, with fixed centroids and symmetric gaussian kernels. '''
[docs] def __init__(self, cfg, name='graph_spice_loss'): super(GraphSPICEEmbeddingLoss, self).__init__() self.loss_config = cfg #[name] self.batch_column = self.loss_config.get('batch_column', 0) self.ft_interloss = self.loss_config.get('ft_interloss_margin', 1.5) self.sp_interloss = self.loss_config.get('sp_interloss_margin', 0.2) self.ft_intraloss = self.loss_config.get('ft_intraloss_margin', 1.0) self.sp_intraloss = self.loss_config.get('sp_intraloss_margin', 0.0) self.eps = self.loss_config.get('eps', 0.001) self.ft_loss_params = self.loss_config.get( 'ft_loss_params', dict(inter=1.0, intra=1.0, reg=0.1)) self.sp_loss_params = self.loss_config.get( 'sp_loss_params', dict(inter=1.0, intra=1.0)) self.kernel_lossfn_name = self.loss_config.get('kernel_lossfn', 'BCE') if self.kernel_lossfn_name == 'BCE': self.kernel_lossfn = nn.BCEWithLogitsLoss(reduction='mean') elif self.kernel_lossfn_name == 'lovasz_hinge': self.kernel_lossfn = LovaszHingeLoss(reduction='none') else: self.kernel_lossfn = nn.BCEWithLogitsLoss(reduction='none') self.seg_lossfn_name = self.loss_config.get('seg_lossfn', 'CE') if self.seg_lossfn_name == 'CE': self.seg_loss_fn = nn.CrossEntropyLoss(reduction='mean') elif self.seg_lossfn_name == 'lovasz_softmax': self.seg_loss_fn = LovaszSoftmaxWithLogitsLoss(reduction='mean') else: self.seg_loss_fn = nn.CrossEntropyLoss(reduction='mean')
[docs] def feature_embedding_loss(self, ft_emb, groups, ft_centroids): ''' Compute discriminative feature embedding loss. INPUTS: - ft_emb (N x F) - groups (N) - ft_centroids (N_c X F) ''' intercluster_loss = inter_cluster_loss(ft_centroids, margin=self.ft_interloss) intracluster_loss = intra_cluster_loss(ft_emb, ft_centroids, groups, margin=self.ft_intraloss) reg_loss = torch.mean(torch.norm(ft_centroids, dim=1)) out = {} out['intercluster_loss'] = float(intercluster_loss) out['intracluster_loss'] = float(intracluster_loss) out['regularization_loss'] = float(reg_loss) out['loss'] = self.ft_loss_params['inter'] * intercluster_loss + \ self.ft_loss_params['intra'] * intracluster_loss + \ self.ft_loss_params['reg'] * reg_loss return out
[docs] def spatial_embedding_loss(self, sp_emb, groups, sp_centroids): ''' Compute spatial centroid regression loss. INPUTS: - sp_emb (N x D) - groups (N) - ft_centroids (N_c X F) ''' out = {} intercluster_loss = inter_cluster_loss(sp_centroids, margin=self.sp_interloss) intracluster_loss = intra_cluster_loss(sp_emb, sp_centroids, groups, margin=self.sp_intraloss) out['intercluster_loss'] = float(intercluster_loss) out['intracluster_loss'] = float(intracluster_loss) out['loss'] = self.sp_loss_params['inter'] * intercluster_loss + \ self.sp_loss_params['intra'] * intracluster_loss return out
[docs] def covariance_loss(self, sp_emb, ft_emb, cov, groups, sp_centroids, ft_centroids, eps=0.001): logits, acc, targets = get_graphspice_logits(sp_emb, ft_emb, cov, groups, sp_centroids, ft_centroids, eps) # Compute kernel score loss cov_loss = self.kernel_lossfn(logits, targets) return cov_loss, acc
[docs] def occupancy_loss(self, occ, groups): ''' INPUTS: - occ (N x 1) - groups (N) ''' bincounts = torch.bincount(groups).float() bincounts[bincounts == 0] = 1 occ_truth = torch.log(bincounts) occ_loss = torch.abs(torch.gather( occ - occ_truth[None, :], 1, groups.view(-1, 1))) if len(occ_loss.squeeze().size()) and len(groups.size()): occ_loss = scatter_mean(occ_loss.squeeze(), groups) # occ_loss = occ_loss[occ_loss > 0] return occ_loss.mean() else: #print(occ_loss.squeeze().size(), groups.size()) return 0.0
[docs] def combine_multiclass(self, sp_embeddings, ft_embeddings, covariance, occupancy, slabels, clabels): ''' Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS. NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud. INPUTS: features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels OUTPUT: loss_segs (list): list of computed loss values for each semantic class. loss[i] = computed DLoss for semantic class <i>. acc_segs (list): list of computed clustering accuracy for each semantic class. ''' loss = defaultdict(list) loss['loss'] = [] accuracy = defaultdict(float) accuracy['accuracy'] = 0. semantic_classes = slabels.unique() counts = 0 for sc in semantic_classes: if int(sc) == 4: continue index = (slabels == sc) sp_emb = sp_embeddings[index] ft_emb = ft_embeddings[index] cov = covariance[index] occ = occupancy[index] groups = clabels[index] # Remove groups labeled -1 mask = groups != -1 sp_emb = sp_emb[mask] ft_emb = ft_emb[mask] cov = cov[mask] occ = occ[mask] groups = groups[mask] groups_unique, _ = unique_label_torch(groups) # print(torch.unique(groups_unique, return_counts=True)) if groups_unique.shape[0] < 2: continue sp_centroids = find_cluster_means(sp_emb, groups_unique) ft_centroids = find_cluster_means(ft_emb, groups_unique) # Get different loss components ft_out = self.feature_embedding_loss( ft_emb, groups_unique, ft_centroids) sp_out = self.spatial_embedding_loss( sp_emb, groups_unique, sp_centroids) cov_loss, acc = self.covariance_loss( sp_emb, ft_emb, cov, groups_unique, sp_centroids, ft_centroids, eps=self.eps) occ_loss = self.occupancy_loss(occ, groups_unique) # TODO: Combine loss with weighting, keep track for logging loss['ft_intra'].append(ft_out['intracluster_loss']) loss['ft_inter'].append(ft_out['intercluster_loss']) loss['ft_reg'].append(ft_out['regularization_loss']) loss['sp_intra'].append(sp_out['intracluster_loss']) loss['sp_inter'].append(sp_out['intercluster_loss']) loss['cov_loss'].append(float(cov_loss)) loss['occ_loss'].append(float(occ_loss)) loss['loss'].append( ft_out['loss'] + sp_out['loss'] + cov_loss + occ_loss) # TODO: Implement train-time accuracy estimation accuracy['acc_{}'.format(int(sc))] = acc accuracy['accuracy'] += acc counts += 1 # for key, val in ft_out.items(): # if val != val: # print('ft_loss: {} = NaN ({})'.format(key, val)) # # for key, val in sp_out.items(): # if val != val: # print('sp_loss: {} = NaN ({})'.format(key, val)) # # if cov_loss != cov_loss: # print('cov_loss: {}'.format(cov_loss)) # # if occ_loss != occ_loss: # print('occ_loss: {}'.format(occ_loss)) if counts > 0: accuracy['accuracy'] /= counts return loss, accuracy
[docs] def forward(self, out, segment_label, cluster_label): num_gpus = len(segment_label) loss = defaultdict(list) accuracy = defaultdict(list) segmentationLayer = 'segmentation' in out for i in range(num_gpus): slabels = segment_label[i][:, -1] #coords = segment_label[i][:, :3].float() #if torch.cuda.is_available(): # coords = coords.cuda() slabels = slabels.long() clabels = cluster_label[i][:, -1].long() # print(clabels) batch_idx = segment_label[i][:, self.batch_column] sp_embedding = out['spatial_embeddings'][i] ft_embedding = out['feature_embeddings'][i] covariance = out['covariance'][i] occupancy = out['occupancy'][i] if segmentationLayer: segmentation = out['segmentation'][i] # nbatch = batch_idx.unique().shape[0] for bidx in batch_idx.unique(sorted=True): batch_mask = batch_idx == bidx sp_embedding_batch = sp_embedding[batch_mask] ft_embedding_batch = ft_embedding[batch_mask] if segmentationLayer: segmentation_batch = segmentation[batch_mask] slabels_batch = slabels[batch_mask] clabels_batch = clabels[batch_mask] covariance_batch = covariance[batch_mask] occupancy_batch = occupancy[batch_mask] if segmentationLayer: loss_seg = self.seg_loss_fn(segmentation_batch, slabels_batch) acc_seg = float(torch.sum(torch.argmax( segmentation_batch, dim=1) == slabels_batch)) \ / float(segmentation_batch.shape[0]) loss_class, acc_class = self.combine_multiclass( sp_embedding_batch, ft_embedding_batch, covariance_batch, occupancy_batch, slabels_batch, clabels_batch) for key, val in loss_class.items(): loss[key].append(sum(val) / len(val) if len(val) else 0.) for s, acc in acc_class.items(): accuracy[s].append(acc) if segmentationLayer: loss['gs_loss_seg'].append(loss_seg) accuracy['gs_acc_seg'].append(acc_seg) loss_avg = {} acc_avg = defaultdict(float) for key, val in loss.items(): loss_avg[key] = sum(val) / len(val) if segmentationLayer: loss_avg['loss'] += loss_avg['gs_loss_seg'] for key, val in accuracy.items(): acc_avg[key] = sum(val) / len(val) res = {} res.update(loss_avg) res.update(acc_avg) return res
[docs]class NodeEdgeHybridLoss(torch.nn.modules.loss._Loss): ''' Combined Node + Edge Loss '''
[docs] def __init__(self, cfg, name='graph_spice_loss'): super(NodeEdgeHybridLoss, self).__init__() # print("CFG + ", cfg) self.loss_config = cfg #[name] self.loss_fn = GraphSPICEEmbeddingLoss(cfg) self.edge_loss_cfg = self.loss_config.get('edge_loss_cfg', {}) self.invert = cfg.get('invert', True) self.edge_loss = WeightedEdgeLoss(invert=self.invert, **self.edge_loss_cfg) # self.is_eval = cfg['eval'] self.acc_fn = IoUScore() self.use_cluster_labels = cfg.get('use_cluster_labels', True)
[docs] def forward(self, result, segment_label, cluster_label): group_label = [cluster_label[0][:, [0, 1, 2, 3, 5]]] res = self.loss_fn(result, segment_label, group_label) # print(result) edge_score = result['edge_score'][0].squeeze() x = edge_score.squeeze() pred = x >= 0 iou, edge_loss = 0, 0 if self.use_cluster_labels: edge_truth = result['edge_truth'][0] edge_loss = self.edge_loss(edge_score.squeeze(), edge_truth.float()) edge_loss = edge_loss.mean() if self.invert: edge_truth = torch.logical_not(edge_truth.long()) iou = self.acc_fn(pred, edge_truth) # iou2 = self.acc_fn(~pred, ~edge_truth) res['edge_accuracy'] = iou if 'loss' in res: res['loss'] += edge_loss else: res['loss'] = edge_loss res['edge_loss'] = float(edge_loss) return res