Source code for mlreco.post_processing.analysis.michel_reconstruction

import numpy as np
import os
from sklearn.cluster import DBSCAN
from scipy.spatial.distance import cdist
from mlreco.utils import CSVData
from mlreco.utils.deghosting import adapt_labels_numpy as adapt_labels


[docs]def michel_reconstruction(cfg, module_cfg, data_blob, res, logdir, iteration): """ Very simple algorithm to reconstruct Michel clusters from UResNet semantic segmentation output. Notes ----- Assumes 3D Configuration ------------- ``` michel_reconstruction: store_method: per-iteration dbscan: False particles: particles_asis ``` Output ------ Writes 2 CSV files: - `michel_reconstruction-reco-*` - `michel_reconstruction-true-*` """ coords_col = module_cfg.get('coords_col', (1, 4)) particles_label = module_cfg.get('particles', 'particles_label') cluster_label = module_cfg.get('cluster_label', 'cluster_label') segment_label = module_cfg.get('segment_label', 'segment_label') # should_adapt_labels = module_cfg.get('adapt_labels', False) # Create output CSV store_per_iteration = True if module_cfg is not None and module_cfg.get('store_method',None) is not None: assert(module_cfg['store_method'] in ['per-iteration','per-event']) store_per_iteration = module_cfg['store_method'] == 'per-iteration' fout_reco,fout_true=None,None if store_per_iteration: fout_reco=CSVData(os.path.join(logdir, 'michel-reconstruction-reco-iter-%07d.csv' % iteration)) fout_true=CSVData(os.path.join(logdir, 'michel-reconstruction-true-iter-%07d.csv' % iteration)) fout_matched=CSVData(os.path.join(logdir, 'michel-reconstruction-matched-iter-%07d.csv' % iteration)) # Loop over events for batch_id,data in enumerate(data_blob['input_data']): event_idx = data_blob['index' ][batch_id] if not store_per_iteration: fout_reco=CSVData(os.path.join(logdir, 'michel-reconstruction-reco-event-%07d.csv' % event_idx)) fout_true=CSVData(os.path.join(logdir, 'michel-reconstruction-true-event-%07d.csv' % event_idx)) fout_matched=CSVData(os.path.join(logdir, 'michel-reconstruction-matched-event-%07d.csv' % event_idx)) # from input/labels label = data_blob[segment_label][batch_id][:,-1] #label_raw = data_blob['sparse3d_pcluster_semantics'][batch_id] clusters = data_blob[cluster_label][batch_id] # if should_adapt_labels: # clusters = adapt_labels(res, data_blob[segment_label], data_blob[cluster_label])[batch_id] particles = data_blob[particles_label][batch_id] true_ghost_mask = label < 5 data_masked = data[true_ghost_mask] label_masked = label[true_ghost_mask] one_pixel = 5#2.8284271247461903 # Retrieve semantic labels corresponding to clusters clusters_semantics = clusters[:, -1] # from network output segmentation = res['segmentation'][batch_id] predictions = np.argmax(segmentation,axis=1) ghost_mask = (np.argmax(res['ghost'][batch_id],axis=1) == 0) data_pred = data[ghost_mask] # coords label_pred = label[ghost_mask] # labels # FIXME this is temporary because of 2-class deghosting only, assumes we do perfect 5-class segmentation # after deghosting #predictions = label_pred predictions = (np.argmax(segmentation,axis=1))[ghost_mask] segmentation = segmentation[ghost_mask] Michel_label = 2 MIP_label = 1 # 0. Retrieve coordinates of true and predicted Michels # MIP_coords = data[(label == 1).reshape((-1,)), ...][:, :3] # Michel_coords = data[(label == 4).reshape((-1,)), ...][:, :3] # Michel_particles = particles[particles[:, 4] == Michel_label] MIP_coords = data[label == MIP_label][:, coords_col[0]:coords_col[1]] # Michel_coords = data[label == Michel_label][:, :3] Michel_all = clusters[clusters_semantics == Michel_label] if Michel_all.shape[0] == 0: # FIXME continue #print(Michel_all.shape) if module_cfg.get('dbscan', False): shower_label = 0 lowE_label = 4 #Michel_coords = clusters[(clusters_semantics == Michel_label) | (clusters_semantics == shower_label) | (clusters_semantics == lowE_label)][:, :3] shower_lowE = clusters[(clusters_semantics == shower_label) | (clusters_semantics == lowE_label)] if shower_lowE.shape[0] > 0: shower_lowE_clusters = DBSCAN(eps=one_pixel, min_samples=1).fit(shower_lowE[:, coords_col[0]:coords_col[1]]).labels_ d = cdist(Michel_all[:, coords_col[0]:coords_col[1]], shower_lowE[:, coords_col[0]:coords_col[1]]) select_shower = (d.min(axis=0) < module_cfg.get('threshold', 2.)) select_shower_clusters = np.unique(shower_lowE_clusters[(shower_lowE_clusters>-1) & select_shower]) fragments = [] for shower_id in select_shower_clusters: fragment = shower_lowE[shower_lowE_clusters == shower_id] print(batch_id, "merging", fragment.shape[0]) #print(fragment[:, -3][:20]) temp = fragment[:, -3] print(temp, Michel_all[d[:, shower_lowE_clusters == shower_id].min(axis=1).argmin(), -3]) fragment[:, -3] = Michel_all[d[:, shower_lowE_clusters == shower_id].min(axis=1).argmin(), -3] print(np.count_nonzero(temp != fragment[:, -3])) fragments.append(fragment) Michel_all = np.concatenate([Michel_all] + fragments, axis=0) print(Michel_all.shape) # Michel_true_clusters = DBSCAN(eps=2., min_samples=5).fit(Michel_coords).labels_ # Michel_coords = Michel_coords[Michel_true_clusters>-1] # Michel_true_clusters = Michel_true_clusters[Michel_true_clusters>-1] # Michel_all = Michel_all[Michel_true_clusters>-1] Michel_coords = Michel_all[:, coords_col[0]:coords_col[1]] MIP_coords_pred = data_pred[(predictions == MIP_label).reshape((-1,)), ...][:, coords_col[0]:coords_col[1]] Michel_coords_pred = data_pred[(predictions == Michel_label).reshape((-1,)), ...][:, coords_col[0]:coords_col[1]] # 1. Find true particle information matching the true Michel cluster # Michel_true_clusters = DBSCAN(eps=one_pixel, min_samples=5).fit(Michel_coords).labels_ # Michel_true_clusters = [Michel_coords[Michel_coords[:, -2] == gid] for gid in np.unique(Michel_coords[:, -2])] #print(clusters.shape, label.shape) Michel_true_clusters = Michel_all[:, 6].astype(np.int64) # Michel_start = Michel_particles[:, :3] for cluster in np.unique(Michel_true_clusters): # print("True", np.count_nonzero(Michel_true_clusters == cluster)) # TODO sum_pix fout_true.record(('batch_id', 'iteration', 'event_idx', 'num_pix', 'sum_pix'), (batch_id, iteration, event_idx, np.count_nonzero(Michel_true_clusters == cluster), Michel_all[Michel_true_clusters == cluster][:, -4].sum())) fout_true.write() # TODO how do we count events where there are no predictions but true? if MIP_coords_pred.shape[0] == 0 or Michel_coords_pred.shape[0] == 0: continue # print("Also predicted!") # 2. Compute true and predicted clusters MIP_clusters = DBSCAN(eps=one_pixel, min_samples=5).fit(MIP_coords_pred).labels_ if np.count_nonzero(MIP_clusters>-1) == 0: continue Michel_pred_clusters = DBSCAN(eps=2*one_pixel, min_samples=5).fit(Michel_coords_pred).labels_ Michel_pred_clusters_id = np.unique(Michel_pred_clusters[Michel_pred_clusters>-1]) Michel_true_clusters_id = np.unique(Michel_true_clusters[Michel_true_clusters > -1]) # Loop over predicted Michel clusters Michel_is_attached, Michel_is_edge = [], [] for Michel_id in Michel_pred_clusters_id: current_index = Michel_pred_clusters == Michel_id # 3. Check whether predicted Michel is attached to a predicted MIP # and at the edge of the predicted MIP distances = cdist(Michel_coords_pred[current_index], MIP_coords_pred[MIP_clusters>-1]) # is_attached = np.min(distances) < 2.8284271247461903 is_attached = np.min(distances) < 5 is_edge = False # default # print("Min distance:", np.min(distances)) if is_attached: Michel_min, MIP_min = np.unravel_index(np.argmin(distances), distances.shape) MIP_id = MIP_clusters[MIP_clusters>-1][MIP_min] MIP_min_coords = MIP_coords_pred[MIP_clusters>-1][MIP_min] MIP_cluster_coords = MIP_coords_pred[MIP_clusters==MIP_id] ablated_cluster = MIP_cluster_coords[np.linalg.norm(MIP_cluster_coords-MIP_min_coords, axis=1)>15.0] if ablated_cluster.shape[0] > 0: new_cluster = DBSCAN(eps=one_pixel, min_samples=5).fit(ablated_cluster).labels_ is_edge = len(np.unique(new_cluster[new_cluster>-1])) == 1 else: is_edge = True # print(is_attached, is_edge) Michel_is_attached.append(is_attached) Michel_is_edge.append(is_edge) Michel_is_attached = np.array(Michel_is_attached, dtype=np.bool) Michel_is_edge = np.array(Michel_is_edge, dtype=np.bool) candidates = np.isin(Michel_pred_clusters, Michel_pred_clusters_id[Michel_is_edge & Michel_is_attached]) # Record all predicted Michel cluster for idx, Michel_id in enumerate(Michel_pred_clusters_id): current_index = Michel_pred_clusters == Michel_id fout_reco.record(('batch_id', 'iteration', 'event_idx', 'is_attached', 'is_edge', 'num_pix'), (batch_id, iteration, event_idx, Michel_is_attached[idx], Michel_is_edge[idx], np.count_nonzero(current_index))) fout_reco.write() if np.count_nonzero(candidates) == 0: continue for Michel_id in Michel_true_clusters_id: current_index = Michel_true_clusters == Michel_id distances = cdist(Michel_coords_pred[candidates], Michel_coords[current_index]) closest_clusters = Michel_pred_clusters[candidates][np.argmin(distances, axis=0)] closest_clusters_matching = closest_clusters[(closest_clusters > -1) & (np.min(distances, axis=0)<one_pixel)] if len(closest_clusters_matching) == 0: continue # Index of Michel predicted cluster that overlaps the most closest_pred_id = np.bincount(closest_clusters_matching).argmax() if closest_pred_id < 0: continue closest_Michel_pred = Michel_coords_pred[Michel_pred_clusters == closest_pred_id] # Intersection michel_pred_num_pix_true = 0 michel_pred_sum_pix_true = 0. for v in data_masked[clusters_semantics == Michel_label][current_index]: count = int(np.any(np.all(v[coords_col[0]:coords_col[1]] == closest_Michel_pred, axis=1))) michel_pred_num_pix_true += count if count > 0: michel_pred_sum_pix_true += v[-1] michel_true_num_pix = particles[Michel_id].num_voxels() #np.count_nonzero(current_index) michel_true_sum_pix = data_masked[clusters_semantics == Michel_label][current_index, 4].sum() michel_pred_num_pix = np.count_nonzero(Michel_pred_clusters == closest_pred_id) michel_pred_sum_pix = data_pred[predictions == Michel_label][Michel_pred_clusters == closest_pred_id, 4].sum() michel_true_energy = particles[Michel_id].energy_init() michel_true_num_pix_cluster = np.count_nonzero(current_index) # for v in Michel_coords[Michel_true_clusters == closest_true_id]: # if (v == data[label == Michel_label, coords_col[0]:coords_col[1]]).all(axis=1).any(): # michel_true_num_pix_cluster += 1 # Record every predicted Michel cluster in CSV fout_matched.record(('batch_id', 'iteration', 'event_idx', 'pred_num_pix', 'pred_sum_pix', 'pred_num_pix_true', 'pred_sum_pix_true', 'true_num_pix', 'true_sum_pix', 'is_attached', 'is_edge', 'michel_true_energy', 'true_num_pix_cluster'), (batch_id, iteration, event_idx, michel_pred_num_pix, michel_pred_sum_pix, michel_pred_num_pix_true, michel_pred_sum_pix_true, michel_true_num_pix, michel_true_sum_pix, is_attached, is_edge, michel_true_energy, michel_true_num_pix_cluster)) fout_matched.write() if not store_per_iteration: fout_reco.close() fout_true.close() fout_matched.close() if store_per_iteration: fout_reco.close() fout_true.close() fout_matched.close()