Source code for mlreco.post_processing.store.store_uresnet

import numpy as np
import scipy
import os
from mlreco.utils import CSVData

[docs]def store_uresnet(cfg, data_blob, res, logdir, iteration): # UResNet prediction if not 'segmentation' in res: return method_cfg = cfg['post_processing']['store_uresnet'] index = data_blob['index'] segment_data = res['segmentation'] input_data = data_blob.get('input_data' if method_cfg is None else method_cfg.get('input_data', 'input_data'), None) store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method',None) is not None: assert(method_cfg['store_method'] in ['per-iteration','per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout=None if store_per_iteration: fout=CSVData(os.path.join(logdir, 'uresnet-segmentation-iter-%07d.csv' % iteration)) for data_idx, tree_idx in enumerate(index): if not store_per_iteration: fout=CSVData(os.path.join(logdir, 'uresnet-segmentation-event-%07d.csv' % tree_idx)) predictions = np.argmax(segment[data_idx],axis=1) for row in predictions: event = input_data[i] fout.record(('idx','x', 'y', 'z', 'type', 'value'), (idx,event[0], event[1], event[2], 4, row)) fout.write() if not store_per_iteration: fout.close() if store_per_iteration: fout.close()