Source code for mlreco.post_processing.metrics.cluster_gnn_metrics

# GNN clustering prediction
import numpy as np

from mlreco.post_processing import post_processing
from mlreco.utils.gnn.evaluation import (node_purity_mask,
                                         edge_purity_mask,
                                         edge_assignment_score,
                                         edge_assignment,
                                         node_assignment,
                                         node_assignment_score,
                                         node_assignment_bipartite,
                                         clustering_metrics,
                                         primary_assignment)
from mlreco.utils.gnn.cluster import form_clusters
from mlreco.post_processing.common import extent


[docs]@post_processing('cluster-gnn-metrics', ['clust_data', 'particles'], ['edge_pred', 'clusts', 'node_pred', 'edge_index']) def cluster_gnn_metrics(cfg, module_cfg, data_blob, res, logdir, iteration, edge_pred=None, clusts=None, node_pred=None, edge_index=None, clust_data=None, particles=None, data_idx=None, clust_data_noghost=None, **kwargs): """ Compute metrics for GRAPPA stage (GNN clustering). `enable_physics_metrics` = compute detailed cluster-wise metrics `integrated_metrics` = compute voxel-wise ARI/Purity/Efficiency metrics (only if `enable_physics_metrics: False`) Parameters ---------- data_blob: dict The input data dictionary from iotools. res: dict The output of the network, formatted using `analysis_keys`. cfg: dict Configuration. logdir: string Path to folder where CSV logs can be stored. iteration: int Current iteration number. Notes ----- N/A. """ # If there is no prediction, proceed edge_pred_label = module_cfg.get('edge_pred', 'edge_pred') # Get the post processor parameters coords_col = module_cfg.get('coords_col', (1, 4)) column = module_cfg.get('target_col', 6) column_source = module_cfg.get('source_col', 5) column_primary = module_cfg.get('primary_col', 10) chain = module_cfg.get('chain', 'chain') enable_physics_metrics = module_cfg.get('enable_physics_metrics', False) spatial_size = module_cfg.get('spatial_size', 768) #if not edge_pred_label in res: continue bipartite = cfg['model']['modules'][chain]['base'].get('network', 'complete') == 'bipartite' group_pred_alg = cfg['model']['modules'][chain + '_loss'].get('node_loss', {}).get('group_pred_alg', 'score') high_purity = cfg['model']['modules'][chain + '_loss'].get('edge_loss', {}).get('high_purity', False) node_predictions = node_pred original_clust_data = clust_data_noghost if clust_data_noghost is not None else clust_data if not len(clusts) or not len(clust_data): return (), () # If there is no node, append default if not len(clusts[data_idx]) or not len(clust_data[data_idx]): # fout.record(['ite', 'idx', 'ari', 'ami', 'sbd', 'pur', 'eff', 'num_clusts', 'num_pix'], # [iteration, tree_idx, -1, -1, -1, -1, -1, -1, -1]) return (), () # Use group id to make node labels group_ids, cluster_ids, primary_ids = [], [], [] for c in clusts[data_idx]: v, cts = np.unique(clust_data[data_idx][c,column], return_counts=True) group_ids.append(int(v[cts.argmax()])) v, cts = np.unique(clust_data[data_idx][c,column_source], return_counts=True) cluster_ids.append(int(v[cts.argmax()])) v, cts = np.unique(clust_data[data_idx][c,column_primary], return_counts=True) primary_ids.append(int(v[cts.argmax()])) group_ids = np.array(group_ids, dtype=np.int64) cluster_ids = np.array(cluster_ids, dtype=np.int64) primary_ids = np.array(primary_ids, dtype=np.int64) #print('clusts', [len(c) for c in clusts[data_idx]], len(group_ids), edge_index[data_idx].shape) #edge_assn = edge_assignment(edge_index[data_idx], group_ids) #purity_mask = edge_purity_mask(edge_index[data_idx], cluster_ids, group_ids) #edge_predictions, _, _ = edge_assignment_score(edge_index[data_idx], edge_pred[data_idx], n) #print('metrics' , edge_pred[data_idx].shape, edge_assn.shape, edge_pred[data_idx][:10], np.unique(edge_assn, return_counts=True)) #print(np.sum(edge_assn == np.argmax(edge_pred[data_idx], axis=1))/len(edge_assn)) #print(np.sum(edge_assn[purity_mask] == np.argmax(edge_pred[data_idx][purity_mask], axis=1))/len(edge_assn[purity_mask])) # Assign predicted group ids n = len(clusts[data_idx]) num_pix = np.sum([len(c) for c in clusts[data_idx]]) if not bipartite: # Determine the predicted group IDs by using union find edge_assn = np.argmax(edge_pred[data_idx], axis=1) if group_pred_alg == 'threshold': node_pred = node_assignment(edge_index[data_idx], edge_assn, n) elif group_pred_alg == 'score': node_pred = node_assignment_score(edge_index[data_idx], edge_pred[data_idx], n) else: raise ValueError('Group prediction algorithm not recognized: ' + group_pred_alg) else: # Determine the predicted group by chosing the most likely primary for each secondary primary_ids = np.unique(edge_index[data_idx][:,0]) node_pred = node_assignment_bipartite(edge_index[data_idx], edge_pred[data_idx][:,1], primary_ids, n) node_pred = np.array(node_pred, dtype=np.int64) # primary prediction node_pred_primary = None if node_predictions is not None: node_pred_primary = primary_assignment(node_predictions[data_idx], group_ids=node_pred) node_true_primary = np.equal(cluster_ids, group_ids) if enable_physics_metrics: # Loop over true clusters for true_id in np.unique(group_ids): true_cluster = clusts[data_idx][group_ids == true_id] pred_id = np.bincount(node_pred[group_ids == true_id]).argmax() pred_cluster = clusts[data_idx][node_pred == pred_id] overlap_cluster = clusts[data_idx][(group_ids == true_id) & (node_pred == pred_id)] original_indices = np.where(original_clust_data[data_idx][:, column] == true_id)[0] original_cluster = [np.where(original_clust_data[data_idx][original_indices][:, column_source] == x)[0] for x in np.unique(original_clust_data[data_idx][original_indices][:, column_source])] #original_cluster = form_clusters(original_clust_data[data_idx][original_indices], column=column_source) original_cluster = [original_indices[c] for c in original_cluster] # Purity + efficiency true_voxel_count = np.sum([len(c) for c in true_cluster]) pred_voxel_count = np.sum([len(c) for c in pred_cluster]) original_voxel_count = np.sum([len(c) for c in original_cluster]) overlap_voxel_count = np.sum([len(c) for c in overlap_cluster]) efficiency = overlap_voxel_count / true_voxel_count purity = overlap_voxel_count / pred_voxel_count # Primary identification pred_primaries_accuracy = -1 if node_pred_primary is not None: pred_primaries = node_true_primary[node_pred == pred_id] & node_pred_primary[node_pred == pred_id] pred_primaries_accuracy = pred_primaries.sum() # True particle information true_particles_idx = np.unique(clust_data[data_idx][np.hstack(true_cluster), 6]) # Remove -1 true_particles_idx = true_particles_idx[true_particles_idx >= 0] energy_deposit = 0. energy_init = 0. pdg, px, py, pz = [], [], [], [] for j in true_particles_idx: p = particles[data_idx][int(j)] energy_deposit += p.energy_deposit() energy_init += p.energy_init() pdg.append(p.pdg_code()) px.append(p.px()) py.append(p.py()) pz.append(p.pz()) if len(pdg) == 0: pdg = [-1] # True interaction information true_interaction_idx = np.unique(clust_data[data_idx][clust_data[data_idx][:, column] == true_id, 7]) # Remove -1 true_interaction_idx = true_interaction_idx[true_interaction_idx >= 0] nu_id = [] for j in true_interaction_idx: nu_idx = np.unique(clust_data[data_idx][(clust_data[data_idx][:, 7] == j) & (clust_data[data_idx][:, column] == true_id), 8]) nu_id.append(nu_idx[0]) if len(nu_id) == 0: nu_id = [-2] # Voxels information true_voxels = clust_data[data_idx][np.hstack(true_cluster), :5] pred_voxels = clust_data[data_idx][np.hstack(pred_cluster), :5] original_voxels = original_clust_data[data_idx][np.hstack(original_cluster), :5] true_d = extent(true_voxels) pred_d = extent(pred_voxels) original_d = extent(original_voxels) boundaries = np.min(np.concatenate([true_voxels[:, coords_col[0]:coords_col[1]], spatial_size - true_voxels[:, coords_col[0]:coords_col[1]]], axis=1)) true_fragments_count = len(true_cluster) pred_fragments_count = len(pred_cluster) overlap_fragments_count = len(overlap_cluster) original_fragments_count = len(original_cluster) row_names = ('true_id', 'pred_id', 'true_voxel_count', 'pred_voxel_count', 'overlap_voxel_count', 'original_voxel_count', 'purity', 'efficiency', 'true_voxels_sum', 'pred_voxels_sum', 'original_voxels_sum', 'true_fragments_count', 'pred_fragments_count', 'overlap_fragments_count', 'original_fragments_count', 'true_spatial_extent', 'true_spatial_std', 'distance_to_boundary', 'pred_spatial_extent', 'pred_spatial_std', 'particle_count', 'original_spatial_extent', 'original_spatial_std', 'true_energy_deposit', 'true_energy_init', 'true_pdg', 'true_px', 'true_py', 'true_pz', 'nu_idx', 'pred_primaries_accuracy') row_values = (true_id, pred_id, true_voxel_count, pred_voxel_count, overlap_voxel_count, original_voxel_count, purity, efficiency, true_voxels[:, -1].sum(), pred_voxels[:, -1].sum(), original_voxels[:, -1].sum(), true_fragments_count, pred_fragments_count, overlap_fragments_count, original_fragments_count, true_d.max(), true_d.std(), boundaries, pred_d.max(), pred_d.std(), len(true_particles_idx), original_d.max(), original_d.std(), energy_deposit, energy_init, pdg[0], np.sum(px), np.sum(py), np.sum(pz), nu_id[0], pred_primaries_accuracy) else: integrated_metrics = module_cfg.get('integrated_metrics', False) # Evaluate clustering metrics pixel-wise if integrated_metrics: pixel_group_ids = np.hstack([[g] * len(clusts[data_idx][c_idx]) for c_idx, g in enumerate(group_ids)]) pixel_cluster_ids = np.hstack([[g] * len(clusts[data_idx][c_idx]) for c_idx, g in enumerate(cluster_ids)]) pixel_clusts = np.hstack(clusts[data_idx])[:, None] pixel_node_pred = np.hstack([[g] * len(clusts[data_idx][c_idx]) for c_idx, g in enumerate(node_pred)]) if high_purity: purity_mask = node_purity_mask(cluster_ids, group_ids, primary_ids) if not purity_mask.any(): return (), () pixel_group_ids = np.hstack([[g] * len(clusts[data_idx][purity_mask][c_idx]) for c_idx, g in enumerate(group_ids[purity_mask])]) pixel_clusts = np.hstack(clusts[data_idx][purity_mask])[:, None] pixel_node_pred = np.hstack([[g] * len(clusts[data_idx][purity_mask][c_idx]) for c_idx, g in enumerate(node_pred[purity_mask])]) ari, ami, sbd, pur, eff = clustering_metrics(pixel_clusts, pixel_group_ids, pixel_node_pred) else: if not high_purity: ari, ami, sbd, pur, eff = clustering_metrics(clusts[data_idx], group_ids, node_pred) else: purity_mask = node_purity_mask(cluster_ids, group_ids, primary_ids) if not purity_mask.any(): return (), () ari, ami, sbd, pur, eff = clustering_metrics(clusts[data_idx][purity_mask], group_ids[purity_mask], node_pred[purity_mask]) #print(ari, pur, eff) primary_accuracy = -1. high_purity_value = -1 if high_purity and node_pred_primary is not None: primary_accuracy = np.count_nonzero(node_pred_primary == node_true_primary) / len(node_pred_primary) high_purity_value = purity_mask.any() #print(data_idx, "primary accuracy", primary_accuracy, high_purity) # Store row_names = ('ari', 'ami', 'sbd', 'pur', 'eff', 'num_fragments', 'num_pix', 'num_true_clusts', 'num_pred_clusts', 'primary_accuracy', 'high_purity') row_values = (ari, ami, sbd, pur, eff, n, num_pix, len(np.unique(group_ids)), len(np.unique(node_pred)), primary_accuracy, high_purity_value) return row_names, row_values