Source code for mlreco.models.layers.gnn.losses.edge_channel

import torch
import numpy as np
from mlreco.utils.gnn.cluster import get_cluster_label
from mlreco.utils.gnn.network import get_fragment_edges
from mlreco.utils.gnn.evaluation import edge_assignment, edge_assignment_from_graph, edge_purity_mask
from mlreco.models.experimental.bayes.evidential import EVDLoss

[docs]class EdgeChannelLoss(torch.nn.Module): """ Takes the two-channel edge output of the GNN and optimizes edge-wise scores such that edges that connect nodes that belong to common instance are given a high score. For use in config: model: name: cluster_gnn modules: grappa_loss: edge_loss: name: : channel source_col : <column in the label data that specifies the source node ids of each voxel (default 5)> target_col : <column in the label data that specifies the target group ids of each voxel (default 6)> batch_col : <column in the label data that specifies the batch ids of each voxel (default 3)> loss : <loss function: 'CE' or 'MM' (default 'CE')> reduction : <loss reduction method: 'mean' or 'sum' (default 'sum')> balance_classes : <balance loss per class: True or False (default False)> target : <type of target adjacency matrix: 'group', 'forest', 'particle_forest' (default 'group')> high_purity : <only penalize loss on groups with a primary (default False)> """
[docs] def __init__(self, loss_config, batch_col=0, coords_col=(1, 4)): super(EdgeChannelLoss, self).__init__() # Set the source and target for the loss self.batch_col = batch_col self.coords_col = coords_col self.source_col = loss_config.get('source_col', 5) self.target_col = loss_config.get('target_col', 6) self.primary_col = loss_config.get('primary_col', 10) # Set the loss self.loss = loss_config.get('loss', 'CE') self.reduction = loss_config.get('reduction', 'sum') self.balance_classes = loss_config.get('balance_classes', False) self.target = loss_config.get('target', 'group') self.high_purity = loss_config.get('high_purity', False) if self.loss == 'CE': self.lossfn = torch.nn.CrossEntropyLoss(reduction=self.reduction) elif self.loss == 'MM': p = loss_config.get('p', 1) margin = loss_config.get('margin', 1.0) self.lossfn = torch.nn.MultiMarginLoss(p=p, margin=margin, reduction=self.reduction) elif self.loss == 'EVD': evd_loss_name = loss_config.get('evd_loss_name', 'evd_nll') T = loss_config.get('T', 50000) self.lossfn = EVDLoss(evd_loss_name, reduction=self.reduction,T=T, num_classes=2, mode='evidence') else: raise ValueError('Loss not recognized: ' + self.loss)
[docs] def forward(self, out, clusters, graph=None): """ Applies the requested loss on the edge prediction. Args: out (dict): 'edge_pred' (torch.tensor): (E,2) Two-channel edge predictions 'clusts' ([np.ndarray]) : [(N_0), (N_1), ..., (N_C)] Cluster ids 'edge_index' (np.ndarray) : (E,2) Incidence matrix clusters ([torch.tensor]) : (N,8) [x, y, z, batchid, value, id, groupid, shape] (graph ([torch.tensor]) : (N,3) True edges, optional) Returns: double: loss, accuracy, edge count """ total_loss, total_acc = 0., 0. n_edges = 0 for i in range(len(clusters)): # If this batch did not have any node, proceed if 'edge_pred' not in out: continue # Get the list of batch ids, loop over individual batches batches = clusters[i][:, self.batch_col] nbatches = len(batches.unique()) for j in range(nbatches): # Narrow down the tensor to the rows in the batch labels = clusters[i][batches == j] if not labels.shape[0]: continue # Get the output of the forward function edge_pred = out['edge_pred'][i][j] if not edge_pred.shape[0]: continue edge_index = out['edge_index'][i][j] clusts = out['clusts'][i][j] clust_ids = get_cluster_label(labels, clusts, self.source_col) group_ids = get_cluster_label(labels, clusts, self.target_col) # If a cluster target is -1, none of its edges contribute to the loss valid_clust_mask = group_ids > -1 valid_mask = np.all(valid_clust_mask[edge_index], axis = -1) # If high purity is requested, remove edges in groups without a single primary if self.high_purity: primary_ids = get_cluster_label(labels, clusts, self.primary_col) valid_mask &= edge_purity_mask(edge_index, clust_ids, group_ids, primary_ids) # Apply valid mask to edges and their predictions if not valid_mask.any(): continue edge_index = edge_index[valid_mask] edge_pred = edge_pred[np.where(valid_mask)[0]] # Use group information or particle tree to determine the true edge assigment if self.target == 'group': edge_assn = edge_assignment(edge_index, group_ids) elif self.target == 'forest': # For each group, find the most likely spanning tree, label the edges in the # tree as 1. For all other edges, apply loss only if in separate group. # If undirected, also assign symmetric path to 1. from scipy.sparse.csgraph import minimum_spanning_tree edge_assn = edge_assignment(edge_index, group_ids) off_scores = torch.softmax(edge_pred, dim=1)[:,0].detach().cpu().numpy() score_mat = np.full((len(clusts), len(clusts)), 2.0) score_mat[tuple(edge_index.T)] = off_scores new_edges = np.empty((0,2)) for g in np.unique(group_ids): clust_ids = np.where(group_ids == g)[0] if len(clust_ids) < 2: continue mst_mat = minimum_spanning_tree(score_mat[np.ix_(clust_ids,clust_ids)]+1e-6).toarray().astype(float) inds = np.where(mst_mat.flatten() > 0.)[0] ind_pairs = np.array(np.unravel_index(inds, mst_mat.shape)).T edges = np.array([[clust_ids[i], clust_ids[j]] for i, j in ind_pairs]) edges = np.concatenate((edges, np.flip(edges, axis=1))) # reciprocal connections new_edges = np.concatenate((new_edges, edges)) edge_assn_max = np.zeros(len(edge_assn)) for e in new_edges: edge_id = np.where([(e == ei).all() for ei in edge_index])[0] edge_assn_max[edge_id] = 1. max_mask = edge_assn == edge_assn_max edge_assn = edge_assn_max[max_mask] edge_pred = edge_pred[np.where(max_mask)[0]] if not len(edge_pred): continue elif 'particle_forest' in self.target: clust_ids = get_cluster_label(labels, clusts, self.source_col) subgraph = graph[i][graph[i][:, self.batch_col] == j, self.coords_col[0]:self.coords_col[0]+2] true_edge_index = get_fragment_edges(subgraph, clust_ids) edge_assn = edge_assignment_from_graph(edge_index, true_edge_index) else: raise ValueError('Prediction target not recognized:', self.target) edge_assn = torch.tensor(edge_assn, device=edge_pred.device, dtype=torch.long, requires_grad=False).view(-1) # Increment the loss, balance classes if requested if self.balance_classes: vals, counts = torch.unique(edge_assn, return_counts=True) weights = np.array([float(counts[k])/len(edge_assn) for k in range(len(vals))]) for k, v in enumerate(vals): total_loss += (1./weights[k])*self.lossfn(edge_pred[edge_assn==v], edge_assn[edge_assn==v]) else: total_loss += self.lossfn(edge_pred, edge_assn) # Compute accuracy of assignment (fraction of correctly assigned edges) total_acc += torch.sum(torch.argmax(edge_pred, dim=1) == edge_assn).float() # Increment the number of edges n_edges += len(edge_pred) return { 'accuracy': total_acc/n_edges if n_edges else 1., 'loss': total_loss/n_edges if n_edges else torch.tensor(0., requires_grad=True, device=clusters[0].device), 'n_edges': n_edges }