Source code for mlreco.post_processing.metrics.duq_metrics

import numpy as np
import os
from mlreco.utils import CSVData
from scipy.stats import entropy


[docs]def duq_metrics(cfg, processor_cfg, data_blob, result, logdir, iteration): import umap labels = data_blob['label'][0][:, 0] index = data_blob['index'] score = result['score'][0] pred = np.argmax(score, axis=1) probability = (score + 1e-6) / np.sum(score + 1e-6, axis=1, keepdims=True) embedding = result['embedding'][0] centroids = result['centroids'][0] uncertainty = np.linalg.norm(centroids.reshape(1, -1, 5) - embedding, axis=1) uncertainty = uncertainty[np.arange(pred.shape[0]), pred] np.save(os.path.join(logdir, 'centroids'), centroids) print(centroids) pred_entropy = entropy(probability, axis=1) latent = np.zeros((embedding.shape[0], 2, embedding.shape[2])) for c in range(embedding.shape[2]): reduced = umap.UMAP(n_components=2).fit_transform(embedding[:, :, c]) latent[:, :, c] = reduced latent = latent[np.arange(embedding.shape[0]), :, pred] if iteration: append = True else: append = False fout = CSVData( os.path.join(logdir, 'duq-singlep-metrics.csv'), append=append) for batch_id, event_id in enumerate(index): latent_batch = latent[batch_id] labels_batch = labels[batch_id] p = probability[batch_id] unc = uncertainty[batch_id] ent = pred_entropy[batch_id] fout.record(('Index', 'Truth', 'Prediction', 'p0', 'p1', 'p2', 'p3', 'p4', 'uncertainty', 'entropy', 'x', 'y'), (int(event_id), int(labels_batch), int(pred[batch_id]), p[0], p[1], p[2], p[3], p[4], unc, ent, latent_batch[0], latent_batch[1])) fout.write() fout.close()