Source code for mlreco.post_processing.metrics.evidential_gnn

import numpy as np
import os
from mlreco.utils import CSVData
from scipy.special import softmax as softmax_func
from scipy.stats import entropy

from mlreco.utils.gnn.cluster import get_cluster_label, get_momenta_label


[docs]def default_gnn_metrics(cfg, processsor_cfg, data_blob, result, logdir, iteration): clust_label = data_blob['cluster_label'] clusts = result['clusts'] index = data_blob['index'] num_batches = len(clusts) assert num_batches == len(result['node_pred']) if iteration: append = True else: append = False fout = CSVData( os.path.join(logdir, 'default-gnn-metrics.csv'), append=append) for batch_id, logits in enumerate(result['node_pred_type']): labels_batch = clust_label[batch_id] event_particle_labels = get_cluster_label(labels_batch, clusts[batch_id], column=9) valid = np.nonzero(event_particle_labels > -1)[0] num_valid = valid.shape[0] probas = softmax_func(logits, axis=1) entropy_event = entropy(probas, axis=1) proba_valid = probas[valid] truth_valid = event_particle_labels[valid] pred_valid = np.argmax(logits[valid], axis=1) entropy_valid = entropy_event[valid] for i in range(num_valid): fout.record(('Index', 'Truth', 'Prediction', 'Entropy', 'p0', 'p1', 'p2', 'p3', 'p4'), (int(index[batch_id]), int(truth_valid[i]), int(pred_valid[i]), entropy_valid[i], proba_valid[i][0], proba_valid[i][1], proba_valid[i][2], proba_valid[i][3], proba_valid[i][4])) fout.write() fout.close()
[docs]def evidential_gnn_metrics(cfg, processor_cfg, data_blob, result, logdir, iteration): clust_label = data_blob['cluster_label'] clusts = result['clusts'] index = data_blob['index'] num_batches = len(clusts) assert num_batches == len(result['node_pred']) == len(clust_label) if iteration: append = True else: append = False fout = CSVData( os.path.join(logdir, 'evidential-gnn-metrics.csv'), append=append) for batch_id, evidence in enumerate(result['node_pred_type']): labels_batch = clust_label[batch_id] event_particle_labels = get_cluster_label(labels_batch, clusts[batch_id], column=9) concentration = evidence + 1.0 S = np.sum(concentration, axis=1) uncertainty = evidence.shape[1] / S p = concentration / S.reshape(-1, 1) valid = np.nonzero(event_particle_labels > -1)[0] num_valid = valid.shape[0] if num_valid < 1: continue p_valid = p[valid] truth_valid = event_particle_labels[valid] pred_valid = np.argmax(p_valid, axis=1) entropy_event = entropy(p_valid, axis=1) uncertainty_event = uncertainty[valid] for i in range(num_valid): fout.record(('Index', 'loss', 'Truth', 'Prediction', 'Entropy', 'Uncertainty', 'Strength', 'p0', 'p1', 'p2', 'p3', 'p4'), (int(index[batch_id]), int(truth_valid[i]), int(pred_valid[i]), entropy_event[i], uncertainty_event[i], S[valid][i], p_valid[i][0], p_valid[i][1], p_valid[i][2], p_valid[i][3], p_valid[i][4])) fout.write() fout.close()