Source code for mlreco.post_processing.metrics.single_particle

import numpy as np
import pandas as pd
import sys, os, re

from mlreco.post_processing import post_processing
from mlreco.utils import CSVData

from scipy.special import softmax
from scipy.stats import entropy

import torch

[docs]def single_particle(cfg, processor_cfg, data_blob, result, logdir, iteration): output = pd.DataFrame(columns=['p0', 'p1', 'p2', 'p3', 'p4', 'prediction', 'truth', 'index', 'entropy']) labels = data_blob['label'][0][:, 0] index = data_blob['index'] logits = result['logits'][0] pred = np.argmax(logits, axis=1) index = np.asarray(index) if iteration: append = True else: append = False fout = CSVData( os.path.join(logdir, 'single-particle-metrics.csv'), append=append) for batch_id, event_id in enumerate(index): logit_batch = logits[batch_id] pred = np.argmax(logit_batch) label_batch = labels[batch_id] probs = softmax(logit_batch) ent = entropy(probs) fout.record(('Index', 'Truth', 'Prediction', 'p0', 'p1', 'p2', 'p3', 'p4', 'entropy'), (int(event_id), int(label_batch), int(pred), probs[0], probs[1], probs[2], probs[3], probs[4], ent)) fout.write() fout.close()