Source code for mlreco.post_processing.metrics.vertex_metrics

import numpy as np
from scipy.spatial.distance import cdist
from sklearn.cluster import DBSCAN

from mlreco.post_processing import post_processing
from mlreco.utils.metrics import *
from mlreco.utils.gnn.network import get_fragment_edges
from mlreco.utils.gnn.cluster import get_cluster_label, get_momenta_label
from mlreco.utils.vertex import predict_vertex, get_vertex


[docs]@post_processing(['vertex-candidates', 'vertex-distances', 'vertex-distances-others', 'vertex-distances-primaries'], ['input_data', 'seg_label', 'clust_data', 'particles', 'kinematics'], ['node_pred_vtx', 'inter_particles', 'inter_group_pred', 'particles_seg']) def vertex_metrics(cfg, module_cfg, data_blob, res, logdir, iteration, seg_label=None, clust_data=None, particles=None, kinematics=None, node_pred_vtx=None, inter_particles=None, inter_group_pred=None, data_idx=None, input_data=None, particles_seg=None, **kwargs): spatial_size = module_cfg.get('spatial_size', 768) vtx_col = module_cfg.get('vtx_col', 9) vtx_positives_col = module_cfg.get('vtx_positives_col', 12) primary_label = module_cfg.get('primary_label', 1) # endpoint_label = module_cfg.get('endpoint_label', 0) track_label = module_cfg.get('track_label', 1) shower_label = module_cfg.get('shower_label', 0) nu_col = module_cfg.get('nu_col', 8) coords_col = module_cfg.get('coords_col', (1, 4)) attaching_threshold = module_cfg.get('attaching_threshold', 10) inter_threshold = module_cfg.get('inter_threshold', 20) other_primaries_threshold = module_cfg.get('other_primaries_threshold', 10) other_primaries_gamma_threshold = module_cfg.get('other_primaries_gamma_threshold', 100) # fraction_bad_primaries = module_cfg.get('fraction_bad_primaries', 0.6) min_overlap_count = module_cfg.get('min_overlap_count', 10) pca_radius = module_cfg.get('pca_radius', 28) min_track_count = module_cfg.get('min_track_count', 2) min_voxel_count = module_cfg.get('min_voxel_count', 10) node_pred_vtx = node_pred_vtx[data_idx] original_node_pred_vtx = node_pred_vtx clusts = inter_particles[data_idx] # print(np.unique(data_blob['cluster_label'][data_idx][data_blob['cluster_label'][data_idx][:, 6] == 3, 9])) node_x_vtx = get_cluster_label(kinematics[data_idx], clusts, column=vtx_col) node_y_vtx = get_cluster_label(kinematics[data_idx], clusts, column=vtx_col+1) node_z_vtx = get_cluster_label(kinematics[data_idx], clusts, column=vtx_col+2) node_assn_vtx = np.stack([node_x_vtx, node_y_vtx, node_z_vtx], axis=1) node_assn_vtx = node_assn_vtx/spatial_size good_index = np.all(np.abs(node_assn_vtx) <= 1., axis=1) positives = [] for c in clusts: positives.append(kinematics[data_idx][c, vtx_positives_col].max().item()) original_positives = np.array(positives) positives = np.array(positives) n_clusts_vtx = (good_index).sum() n_clusts_vtx_positives = (good_index & positives.astype(bool)).sum() node_pred_vtx = node_pred_vtx[good_index] node_assn_vtx = node_assn_vtx[good_index] positives = original_positives[good_index] pred_positives = np.argmax(node_pred_vtx[:, 3:], axis=1) accuracy_positives = (pred_positives == positives).sum() / len(positives) # SMAPE metric accuracy_position = np.sum(1. - np.abs(node_pred_vtx[positives.astype(bool), :3]-node_assn_vtx[positives.astype(bool)])/(np.abs(node_assn_vtx[positives.astype(bool)]) + np.abs(node_pred_vtx[positives.astype(bool), :3])))/3. # Look at each interaction predicted vtx_candidates = [] mask_ghost = np.argmax(res['ghost'][data_idx], axis=1) == 0 masking = lambda ar: ar[mask_ghost] if 'ghost' in res else ar vtx_resolution = 0. row_candidates_names, row_candidates_values = [], [] row_distances_names, row_distances_values = [], [] row_distances_others_names, row_distances_others_values = [], [] row_distances_primaries_names, row_distances_primaries_values = [], [] for inter_idx in np.unique(inter_group_pred[data_idx]): inter_mask = inter_group_pred[data_idx] == inter_idx interaction = inter_particles[data_idx][inter_mask] ppn_candidates, c_candidates, vtx_candidate, vtx_std, ppn_candidates_old, distances, distances_others, distances_primaries = predict_vertex(inter_idx, data_idx, input_data, res, coords_col=coords_col, primary_label=primary_label, shower_label=shower_label, track_label=track_label, #endpoint_label=endpoint_label, attaching_threshold=attaching_threshold, inter_threshold=inter_threshold, return_distances=True, other_primaries_threshold=other_primaries_threshold, other_primaries_gamma_threshold=other_primaries_gamma_threshold, #fraction_bad_primaries=fraction_bad_primaries, pca_radius=pca_radius, min_voxel_count=min_voxel_count) inter_mask = inter_group_pred[data_idx] == inter_idx interaction = inter_particles[data_idx][inter_mask] primary_particles = np.argmax(original_node_pred_vtx[inter_mask][:, 3:], axis=1) == primary_label if len(distances): for x in np.hstack(distances): row_distances_names.append(("d",)) row_distances_values.append((x,)) if len(distances_others): distances_others = np.hstack(distances_others) row_distances_others_names.extend([("d",) for _ in distances_others]) row_distances_others_values.extend([(x,) for x in distances_others]) else: distances_others = np.empty((1,)) if len(distances_primaries): #print('distance to primaries', distances_primaries) row_distances_primaries_names.extend([('d',) for _ in distances_primaries]) row_distances_primaries_values.extend([(x,) for x in distances_primaries]) else: distances_primaries = np.empty((1,)) # No primary particle in interaction if len(clusts[inter_mask][primary_particles]) == 0: continue # Match to a true interaction is_nu = get_cluster_label(clust_data[data_idx], clusts[inter_mask][primary_particles], column=nu_col) #print(is_nu, len(clusts[inter_mask][primary_particles])) #print(clust_data[data_idx][clusts[inter_mask][0], nu_col]) is_nu = is_nu[is_nu > -1] is_nu = np.argmax(np.bincount(is_nu.astype(int))) if len(is_nu) else -1 print('Predicted interaction ', inter_idx, len(np.hstack(interaction)), is_nu, [len(x) for x in interaction]) #print(is_nu) #print(inter_idx, len(clusts[inter_mask])) vtx_resolution = -1 clust_assn_vtx = [-1, -1, -1] max_overlap = 0 matched_inter_idx = -1 if len(ppn_candidates): # Now evaluate vertex candidates for this interaction clust_assn_vtx = node_assn_vtx[positives.astype(bool) & (inter_mask[good_index])] * spatial_size clust_assn_vtx = clust_assn_vtx.mean(axis=0) print('Clust assn vtx', clust_assn_vtx) for true_inter_idx in np.unique(clust_data[data_idx][:, 7]): if true_inter_idx < 0: continue true_inter_mask = clust_data[data_idx][:, 7] == true_inter_idx intersection = np.intersect1d(np.hstack(interaction), np.where(true_inter_mask)[0]) print('check ', inter_idx, true_inter_idx, intersection) if len(intersection) > max_overlap and len(intersection) > min_overlap_count: max_overlap = len(intersection) matched_inter_idx = true_inter_idx print('We have candidates and overlap is ', max_overlap) if matched_inter_idx > -1: vtx = get_vertex(kinematics, clust_data, data_idx, matched_inter_idx, vtx_col=vtx_col) print("Vertex candidate and true = ", vtx_candidate, clust_assn_vtx, vtx) else: print("Vertex candidate and true = ", vtx_candidate, clust_assn_vtx, None) vtx_resolution = np.linalg.norm(clust_assn_vtx-vtx_candidate) print("resolution = ", vtx_resolution) vtx_candidates.append(vtx_candidate) row_candidates_names.append(("inter_id", "num_ppn_candidates", "num_ppn_associations", "vtx_resolution", "vtx_candidate_x", "vtx_candidate_y", "vtx_candidate_z", "vtx_true_x", "vtx_true_y", "vtx_true_z", "vtx_std_x", "vtx_std_y", "vtx_std_z", "num_primaries", "is_nu", "num_pix", "overlap", "distance_others_min", "distance_others_max", "distance_primaries_min", "distance_primaries_max", "matched_inter_id")) row_candidates_values.append((inter_idx, len(ppn_candidates), len(ppn_candidates_old), vtx_resolution, vtx_candidate[0], vtx_candidate[1], vtx_candidate[2], clust_assn_vtx[0], clust_assn_vtx[1], clust_assn_vtx[2], vtx_std[0], vtx_std[1], vtx_std[2], np.sum(primary_particles), is_nu, np.sum([len(c) for c in clusts[inter_mask][primary_particles]]), max_overlap, np.amin(distances_others), np.amax(distances_others), np.amin(distances_primaries), np.amax(distances_primaries), matched_inter_idx)) return [(row_candidates_names, row_candidates_values), (row_distances_names, row_distances_values), (row_distances_others_names, row_distances_others_values), (row_distances_primaries_names, row_distances_primaries_values)]