Source code for mlreco.post_processing.metrics.bayes_segnet_mcdropout

import numpy as np
import pandas as pd
import os

from mlreco.utils import CSVData
from mlreco.utils import CSVData, ChunkCSVData

from scipy.special import softmax as softmax_func
from scipy.stats import entropy

[docs]def bayes_segnet_mcdropout(cfg, processor_cfg, data_blob, result, logdir, iteration): labels = data_blob['segment_label'][0] index = data_blob['index'] # logits = result['logits'][0] if processor_cfg['mode'] != 'mc_dropout': softmax = softmax_func(result['segmentation'][0], axis=1) else: softmax = result['softmax'][0] segmentation = result['segmentation'][0] pred = np.argmax(result['segmentation'][0], axis=1) index = np.asarray(index) batch_index = data_blob['input_data'][0][:, 0].astype(int) if iteration: append = True else: append = False min_samples = processor_cfg['min_samples'] fout = CSVData( os.path.join(logdir, 'bayes-segnet-metrics.csv'), append=append) fout_voxel = ChunkCSVData( os.path.join(logdir, 'bayes-segnet-metrics-voxels.csv'), append=append) for batch_id, event_id in enumerate(index): batch_mask = batch_index == batch_id input_batch = data_blob['input_data'][0][batch_mask] label_mask = labels[:, 0].astype(int) == batch_id label_batch = labels[label_mask][:, -1].astype(int) pred_batch = pred[batch_mask].squeeze() softmax_batch = softmax[batch_mask] entropy_batch = entropy(softmax_batch, axis=1) df = np.concatenate([np.ones((label_batch.shape[0], 1)) * event_id, label_batch.reshape(-1, 1), pred_batch.reshape(-1, 1), softmax_batch, entropy_batch.reshape(-1, 1)], axis=1) columns = ['Index', 'Truth', 'Prediction', 'p0', 'p1', 'p2', 'p3', 'p4', 'Entropy'] df = pd.DataFrame(df, columns=columns) avg_entropy = df['Entropy'].mean() median_entropy = df['Entropy'].median() accuracy = np.sum(df['Truth'] == df['Prediction']) / float(df.shape[0]) fout.record(('Index', 'Mean Entropy', 'Median Entropy', 'Accuracy'), (int(event_id), avg_entropy, median_entropy, accuracy)) fout.write() for c in np.unique(label_batch.astype(int)): df_slice = df.query('Truth == {}'.format(c)) num_total_voxels = df_slice.shape[0] num_samples = np.ceil(num_total_voxels * 0.05).astype(int) if min_samples > num_total_voxels: samples = df_slice else: samples = df_slice.sample(num_samples) fout_voxel.record(samples) fout.close()