Source code for mlreco.post_processing.metrics.cluster_cnn_metrics

import numpy as np
from mlreco.utils.cluster.dense_cluster import (gaussian_kernel,
                                                fit_predict_np,
                                                find_cluster_means)
from mlreco.utils.metrics import *
from mlreco.post_processing import post_processing
from mlreco.post_processing.common import extent

[docs]@post_processing('cluster-cnn-metrics', ['seg_label', 'clust_data', 'particles'], ['segmentation', 'embeddings', 'margins', 'seediness']) def cluster_cnn_metrics(cfg, module_cfg, data_blob, res, logdir, iteration, data_idx=None, seg_label=None, clust_data=None, particles=None, embeddings=None, margins=None, seediness=None, ghost_mask=None, true_ghost_mask=None, seg_label_noghost=None, clust_data_noghost=None, seg_prediction=None, **kwargs): """ Compute metrics for SPICE stage (CNN particle instance clustering). TODO assumes ghost points for now Parameters ---------- data_blob: dict The input data dictionary from iotools. res: dict The output of the network, formatted using `analysis_keys`. cfg: dict Configuration. logdir: string Path to folder where CSV logs can be stored. iteration: int Current iteration number. Notes ----- N/A. """ s_thresholds = module_cfg.get('s_threshold', [0, 0, 0, 0]) p_thresholds = module_cfg.get('p_thresholds', [0.5, 0.5, 0.5, 0.5]) spatial_size = module_cfg.get('spatial_size', 768) enable_physics_metrics = module_cfg.get('enable_physics_metrics', False) spice_min_voxels = cfg['model']['modules']['spice']['spice_fragment_manager'].get('min_voxels', 2) coords_col = module_cfg.get('coords_col', (1, 4)) coords = seg_label[data_idx][:, coords_col[0]:coords_col[1]] # Compute total momentum and energy per interaction total_momentum = {} for interaction_id in np.unique(clust_data[data_idx][:, 7]): interaction_mask = clust_data[data_idx][:, 7] == interaction_id total_px, total_py, total_pz, total_energy_init, total_energy_deposit = 0, 0, 0, 0, 0 for c, cluster_id in enumerate(np.unique(clust_data[data_idx][interaction_mask, 6])): total_px += particles[data_idx][int(c)].px() total_py += particles[data_idx][int(c)].py() total_pz += particles[data_idx][int(c)].pz() total_energy_init += particles[data_idx][int(c)].energy_init() total_energy_deposit += particles[data_idx][int(c)].energy_deposit() total_momentum[interaction_id] = (total_px, total_py, total_pz, total_energy_init, total_energy_deposit) if enable_physics_metrics: # Loop over semantic classes for c in np.unique(seg_label[data_idx][:, -1]): if int(c) >= 4: continue original_clust_data = clust_data_noghost[data_idx][clust_data_noghost[data_idx][:, -1] == c] semantic_mask = seg_prediction[data_idx] == c embedding_class = embeddings[semantic_mask] seed_class = seediness[semantic_mask] margins_class = margins[semantic_mask] if len(embedding_class) < spice_min_voxels: continue pred = fit_predict_np(embedding_class, seed_class, margins_class, gaussian_kernel, s_threshold=s_thresholds[int(c)], p_threshold=p_thresholds[int(c)]) #original_semantic_mask = data_blob['segment_label'][data_idx][:, -1] == c coords_class = coords[semantic_mask] clabels = clust_data[data_idx][semantic_mask][:, 6] original_coords_class = original_clust_data[:, coords_col[0]:coords_col[1]] original_clabels = original_clust_data[:, 6] #_, true_centroids = find_cluster_means(coords_class, clabels) #_, original_centroids = find_cluster_means(original_coords_class, original_clabels) # Loop over predicted clusters for j, true_id in enumerate(np.unique(clabels)): cluster_mask = clabels == true_id pred_id = np.bincount(pred[cluster_mask]).argmax() pred_mask = pred == pred_id original_mask = original_clabels == true_id # "Purity" + efficiency overlap_pixel_count = np.count_nonzero(pred[cluster_mask] == pred_id) true_pixel_count = np.count_nonzero(cluster_mask) pred_pixel_count = np.count_nonzero(pred_mask) original_pixel_count = np.count_nonzero(original_mask) efficiency = overlap_pixel_count / true_pixel_count purity = overlap_pixel_count / pred_pixel_count # True particle information p = particles[data_idx][int(true_id)] # print(c, true_id, p.pdg_code(), p.shape(), true_pixel_count, np.unique(clust_data[data_idx][semantic_mask][cluster_mask, -1])) # Voxel information true_voxels = clust_data[data_idx][semantic_mask][cluster_mask, :5] pred_voxels = clust_data[data_idx][semantic_mask][pred_mask, :5] original_voxels = original_clust_data[original_mask, :5] d = extent(coords_class[cluster_mask]) pred_d = extent(coords_class[pred_mask]) boundaries = np.min(np.concatenate([coords_class[cluster_mask], spatial_size - coords_class[cluster_mask]], axis=1)) if original_pixel_count: original_overlap_pixel_count = len(np.intersect1d(np.where(ghost_mask[data_idx])[0][semantic_mask][pred_mask], np.where(true_ghost_mask[data_idx])[0][seg_label_noghost[data_idx][true_ghost_mask[data_idx]][:, -1] == c][original_mask])) original_d = extent(original_coords_class[original_mask]) original_boundaries = np.min(np.concatenate([original_coords_class[original_mask], spatial_size - original_coords_class[original_mask]], axis=1)) else: original_overlap_pixel_count = -1 original_d = np.array([-1]) original_boundaries = -1 row_names = ('Class', 'true_id', 'pred_id', 'true_pixel_count', 'pred_pixel_count', 'overlap_pixel_count', 'purity', 'efficiency', 'spatial_extent', 'spatial_std', 'pred_spatial_extent', 'pred_spatial_std', 'true_voxels_sum', 'pred_voxels_sum', 'distance_to_boundary', 'pdg', 'px', 'py', 'pz', 'energy_init', 'energy_deposit', 'original_pixel_count', 'original_voxels_sum', 'original_spatial_extent', 'original_spatial_std', 'original_distance_to_boundary') row_values = (c, true_id, pred_id, true_pixel_count, pred_pixel_count, overlap_pixel_count, purity, efficiency, d.max(), d.std(), pred_d.max(), pred_d.std(), true_voxels[:, -1].sum(), pred_voxels[:, -1].sum(), boundaries, p.pdg_code(), p.px(), p.py(), p.pz(), p.energy_init(), p.energy_deposit(), original_pixel_count, original_voxels[:, -1].sum(), original_d.max(), original_d.std(), original_boundaries) else: # Loop over semantic classes for c in np.unique(seg_label[data_idx][:, -1]): if int(c) >= 4: continue semantic_mask = seg_label[data_idx][:, -1] == c embedding_class = embeddings[semantic_mask] seed_class = seediness[semantic_mask] margins_class = margins[semantic_mask] coords_class = coords[semantic_mask] clabels = clust_data[data_idx][semantic_mask][:, 6] if len(embedding_class) < spice_min_voxels: continue pred = fit_predict_np(embedding_class, seed_class, margins_class, gaussian_kernel, s_threshold=s_thresholds[int(c)], p_threshold=p_thresholds[int(c)]) purity, efficiency = purity_efficiency(pred, clabels) # purity = purity.mean() # efficiency = efficiency.mean() fscore = 2 * (purity * efficiency) / (purity + efficiency) ari = ARI(pred, clabels) sbd = SBD(pred, clabels) nclusters = len(np.unique(clabels)) #num_particles = len(particles[data_idx]) event_num_particles = len(np.unique(clust_data[data_idx][:, 6])) class_num_particles = len(np.unique(clust_data[data_idx][semantic_mask][:, 6])) event_num_pix = seg_label[data_idx].shape[0] class_num_pix = seg_label[data_idx][semantic_mask].shape[0] event_num_interactions = len(np.unique(clust_data[data_idx][:, 7])) _, true_centroids = find_cluster_means(coords_class, clabels) true_num_clusters = len(np.unique(clabels)) for j, cluster_id in enumerate(np.unique(clabels)): margin = np.mean(margins_class[clabels == cluster_id]) true_size = np.std(np.linalg.norm(coords_class[clabels == cluster_id] - true_centroids[j], axis=1)) cluster_num_pix = (clabels == cluster_id).sum() #interaction_id = np.unique(clust_data[data_idx][clust_data[data_idx][:, 5] == cluster_id, 7])[0] row_values = (c, ari, sbd, purity, efficiency, fscore, \ nclusters, true_num_clusters, margin, true_size, event_num_particles, \ cluster_num_pix, event_num_pix, class_num_particles, class_num_pix, \ event_num_interactions, particles[data_idx][j].energy_deposit(), particles[data_idx][j].energy_init(), particles[data_idx][j].pdg_code(), particles[data_idx][j].nu_interaction_type(), particles[data_idx][j].px(), particles[data_idx][j].py(), particles[data_idx][j].pz(), # total_momentum[interaction_id][0], # total_momentum[interaction_id][1], # total_momentum[interaction_id][2], # total_momentum[interaction_id][3], # total_momentum[interaction_id][4], cluster_id, # interaction_id ) #output.append(row) row_names = ('Class', 'ARI', 'SBD', 'Purity', 'Efficiency', 'FScore', 'num_clusters', 'true_num_clusters', 'margin', 'true_size', 'event_num_particles', 'cluster_num_pix', 'event_num_pix', 'class_num_particles', 'class_num_pix', 'event_num_interactions', 'particle_energy_deposit', 'particle_energy_init', 'particle_pdg_code', 'particle_interaction_type', 'particle_px', 'particle_py', 'particle_pz', #'interaction_px', #'interaction_py', 'interaction_pz', 'interaction_energy_id', #'interaction_energy_deposit', 'cluster_id') return row_names, row_values