Source code for mlreco.models.layers.gnn.losses.node_grouping

import torch
import torch.nn as nn

from mlreco.models.layers.cluster_cnn.losses.lovasz import mean, lovasz_hinge_flat, StableBCELoss, iou_binary


[docs]def bc_distance(gauss1, gauss2, eps=1e-6, debug=False): ''' Computes the Bhattacharya similarity measure for two spherical gaussians. ''' g1 = gauss1.expand_as(gauss2) mu1, sigma1 = g1[:, :3], g1[:, 3:] mu2, sigma2 = gauss2[:, :3], gauss2[:, 3:] variance_term = (sigma1**2 + sigma2**2 + 2 * eps) beta = torch.pow(0.5 * ((sigma1 + eps) / (sigma2 + eps) + (sigma2 + eps) / (sigma1 + eps)), -1.5) dist = 0.25 * torch.pow(torch.norm(mu1 - mu2, dim=1), 2) / variance_term.squeeze() # if debug: # print('variance_term = ', variance_term) # print('beta = ', beta) # print('dist = ', dist) return torch.clamp(beta.squeeze() * torch.exp(-dist), min=1e-6, max=1-1e-6)
[docs]class GNNGroupingLoss(nn.Module):
[docs] def __init__(self, cfg, name='gnn_grouping_loss', batch_col=0, coords_col=(1, 4)): super(GNNGroupingLoss, self).__init__() self.loss_config = cfg[name] self.kernel = bc_distance self.bceloss = StableBCELoss() self.batch_col= batch_col self.coords_col = coords_col
[docs] def forward(self, nodes, node_batch_labels, node_group_labels): loss, accuracy = [], [] for bidx in node_batch_labels.unique(): batch_mask = node_batch_labels == bidx groups = node_group_labels[batch_mask] nodes_batch = nodes[batch_mask] for g in groups.unique(): group_mask = groups == g grouped_nodes = nodes_batch[group_mask] others = nodes_batch[~group_mask] intra_dist = 0 # print('------------------------------') if grouped_nodes.shape[0] > 1: gauss1 = grouped_nodes.mean(dim=0) else: gauss1 = grouped_nodes p = bc_distance(gauss1, nodes_batch) kernel_loss = self.bceloss(p, group_mask.float()) loss.append(kernel_loss) with torch.no_grad(): acc = iou_binary(p > 0.5, group_mask, per_image=False) accuracy.append(float(acc)) # print(loss) loss = sum(loss) / len(loss) accuracy = sum(accuracy) / len(accuracy) # print(loss, accuracy) return {'loss': loss, 'accuracy': accuracy}