Source code for mlreco.post_processing.store.store_input
import os
from mlreco.utils import CSVData
[docs]def get_coords(row, data_dim, tree_index):
if data_dim == 2:
coords_labels = ('idx', 'x', 'y')
coords = (tree_index, row[0], row[1])
elif data_dim == 3:
coords_labels = ('idx', 'x', 'y', 'z')
coords = (tree_index, row[0], row[1], row[2])
else:
raise Exception("data_dim must be 2 or 3, got %d" % data_dim)
return coords_labels, coords
[docs]def store_input(cfg, data_blob, res, logdir, iteration):
"""
Store input data blob.
Parameters
-------------
threshold: float, optional
Default: 0.
input_data: str, optional
particles_label: str, optional
segment_label: str, optional
clusters_label: str, optional
cluster3d_mcst_true: str, optional
store_method: str, optional
Can be `per-iteration` or `per-event`
"""
method_cfg = cfg['post_processing']['store_input']
if (method_cfg is not None and not method_cfg.get('input_data', 'input_data') in data_blob) or (method_cfg is None and 'input_data' not in data_blob): return
threshold = 0. if method_cfg is None else method_cfg.get('threshold',0.)
data_dim = 3 if method_cfg is None else method_cfg.get('data_dim', 3)
index = data_blob.get('index', None)
input_dat = data_blob.get('input_data' if method_cfg is None else method_cfg.get('input_data', 'input_data'), None)
label_ppn = data_blob.get('particles_label' if method_cfg is None else method_cfg.get('particles_label', 'particles_label'), None)
label_seg = data_blob.get('segment_label' if method_cfg is None else method_cfg.get('segment_label', 'segment_label'), None)
label_cls = data_blob.get('clusters_label' if method_cfg is None else method_cfg.get('clusters_label', 'clusters_label'), None)
label_mcst = data_blob.get('cluster3d_mcst_true' if method_cfg is None else method_cfg.get('cluster3d_mcst_true', 'cluster3d_mcst_true'), None)
store_per_iteration = True
if method_cfg is not None and method_cfg.get('store_method',None) is not None:
assert(method_cfg['store_method'] in ['per-iteration','per-event'])
store_per_iteration = method_cfg['store_method'] == 'per-iteration'
fout=None
if store_per_iteration:
fout=CSVData(os.path.join(logdir, 'input-iter-%07d.csv' % iteration))
if input_dat is None: return
for data_index,tree_index in enumerate(index):
if not store_per_iteration:
fout=CSVData(os.path.join(logdir, 'input-event-%07d.csv' % tree_index))
mask = input_dat[data_index][:,-1] > threshold
# type 0 = input data
for row in input_dat[data_index][mask]:
coords_labels, coords = get_coords(row, data_dim, tree_index)
fout.record(coords_labels + ('type','value'), coords + (0,row[data_dim+1]))
fout.write()
# type 1 = Labels for PPN
if label_ppn is not None:
for row in label_ppn[data_index]:
fout.record(('idx','x','y','z','type','value'),(tree_index,row[0],row[1],row[2],1,row[4]))
fout.write()
# 2 = UResNet labels
if label_seg is not None:
for row in label_seg[data_index][mask]:
coords_labels, coords = get_coords(row, data_dim, tree_index)
fout.record(coords_labels + ('type','value'),coords + (2,row[data_dim+1]))
fout.write()
# type 15 = group id, 16 = semantic labels, 17 = energy
if label_cls is not None:
for row in label_cls[data_index]:
fout.record(('idx','x','y','z','type','value'),(tree_index,row[0],row[1],row[2],15,row[5]))
fout.write()
for row in label_cls[data_index]:
fout.record(('idx','x','y','z','type','value'),(tree_index,row[0],row[1],row[2],16,row[6]))
fout.write()
for row in label_cls[data_index]:
fout.record(('idx','x','y','z','type','value'),(tree_index,row[0],row[1],row[2],17,row[4]))
fout.write()
# type 18 = cluster3d_mcst_true
if label_mcst is not None:
for row in label_mcst[data_index]:
fout.record(('idx','x','y','z','type','value'),(tree_index,row[0],row[1],row[2],19,row[4]))
fout.write()
if not store_per_iteration: fout.close()
if store_per_iteration: fout.close()