Source code for mlreco.post_processing.decorator

from mlreco.utils import CSVData
import os
import numpy as np
from mlreco.utils.deghosting import adapt_labels_numpy as adapt_labels

from functools import wraps


[docs]def post_processing(filename, data_capture, output_capture): """ Decorator to capture the common boilerplate between all postprocessing scripts. The corresponding config block should have the same name as the script. parameters ---------- filename: string or list of string Name that will prefix all log files. If a list of strings, several log files can be created. The order of filenames must match the order of the script return. data_capture: list of string List of data components needed. Some of them are reserved: clust_data, seg_label. The rest can be any data label from the config `iotool` section. output_capture: list of string List of output components needed. Some of them are reserved: embeddings, margins, seediness, segmentation. The rest can be anything from any network output. """ def decorator(func): # This mapping is hardcoded for now... defaultNameToIO = { 'clust_data': 'cluster_label', 'seg_label': 'segment_label', 'kinematics': 'kinematics_label', 'points_label': 'particles_label', 'particles': 'particles_asis' } @wraps(func) def wrapper(cfg, module_cfg, data_blob, res, logdir, iteration): # The config block should have the same name as the analysis function # module_cfg = cfg['post_processing'].get(func.__name__, {}) log_name = module_cfg.get('filename', filename) deghosting = module_cfg.get('ghost', False) store_method = module_cfg.get('store_method', 'per-iteration') store_per_event = store_method == 'per-event' fout = [] if not isinstance(log_name, list): log_name = [log_name] for name in log_name: if store_method == 'per-iteration': fout.append(CSVData(os.path.join(logdir, '%s-iter-%07d.csv' % (name, iteration)))) if store_method == 'single-file': append = True if iteration else False fout.append(CSVData(os.path.join(logdir, '%s.csv' % name), append=append)) kwargs = {} # Get the relevant data products - index is special, no need to specify it. kwargs['index'] = data_blob['index'] # We need true segmentation label for deghosting masks/adapting labels #if deghosting and 'seg_label' not in data_capture: if 'seg_label' not in data_capture: data_capture.append('seg_label') for key in data_capture: if module_cfg.get(key, defaultNameToIO.get(key, key)) in data_blob: kwargs[key] = data_blob[module_cfg.get(key, defaultNameToIO.get(key, key))] for key in output_capture: if key in ['embeddings', 'margins', 'seediness']: continue if not len(module_cfg.get(key, key)): continue kwargs[key] = res.get(module_cfg.get(key, key), None) if key == 'segmentation': kwargs['segmentation'] = [res['segmentation'][i] for i in range(len(res['segmentation']))] kwargs['seg_prediction'] = [res['segmentation'][i].argmax(axis=1) for i in range(len(res['segmentation']))] if deghosting: kwargs['ghost_mask'] = [res['ghost'][i].argmax(axis=1) == 0 for i in range(len(res['ghost']))] kwargs['true_ghost_mask'] = [ kwargs['seg_label'][i][:, -1] < 5 for i in range(len(kwargs['seg_label']))] if 'clust_data' in kwargs and kwargs['clust_data'] is not None: kwargs['clust_data_noghost'] = kwargs['clust_data'] # Save the clust_data before deghosting kwargs['clust_data'] = adapt_labels(res, kwargs['seg_label'], kwargs['clust_data']) if 'seg_prediction' in kwargs and kwargs['seg_prediction'] is not None: kwargs['seg_prediction'] = [kwargs['seg_prediction'][i][kwargs['ghost_mask'][i]] for i in range(len(kwargs['seg_prediction']))] if 'segmentation' in kwargs and kwargs['segmentation'] is not None: kwargs['segmentation'] = [kwargs['segmentation'][i][kwargs['ghost_mask'][i]] for i in range(len(kwargs['segmentation']))] if 'kinematics' in kwargs and kwargs['kinematics'] is not None: kwargs['kinematics'] = adapt_labels(res, kwargs['seg_label'], kwargs['kinematics']) # This needs to come last - in adapt_labels seg_label is the original one if 'seg_label' in kwargs and kwargs['seg_label'] is not None: kwargs['seg_label_noghost'] = kwargs['seg_label'] kwargs['seg_label'] = [kwargs['seg_label'][i][kwargs['ghost_mask'][i]] for i in range(len(kwargs['seg_label']))] batch_ids = [] for data_idx, _ in enumerate(kwargs['index']): if 'seg_label' in kwargs: n = kwargs['seg_label'][data_idx].shape[0] elif 'kinematics' in kwargs: n = kwargs['kinematics'][data_idx].shape[0] elif 'clust_data' in kwargs: n = kwargs['clust_data'][data_idx].shape[0] else: raise Exception('Need some labels to run postprocessing') batch_ids.append(np.ones((n,)) * data_idx) batch_ids = np.hstack(batch_ids) kwargs['batch_ids'] = batch_ids # Loop over events counter = 0 for data_idx, tree_idx in enumerate(kwargs['index']): kwargs['counter'] = counter kwargs['data_idx'] = data_idx # Initialize log if one per event if store_per_event: for name in log_name: fout.append(CSVData(os.path.join(logdir, '%s-event-%07d.csv' % (name, tree_idx)))) for key in ['embeddings', 'margins', 'seediness']: # add points? if key in output_capture: kwargs[key] = np.array(res[key])[batch_ids == data_idx] # if np.isin(output_capture, ['embeddings', 'margins', 'seediness']).any(): # kwargs['embeddings'] = np.array(res['embeddings'])[batch_ids == data_idx] # kwargs['margins'] = np.array(res['margins'])[batch_ids == data_idx] # kwargs['seediness'] = np.array(res['seediness'])[batch_ids == data_idx] out = func(cfg, module_cfg, data_blob, res, logdir, iteration, **kwargs) if isinstance(out, tuple): out = [out] assert len(out) == len(fout) for out_idx, (out_names, out_values) in enumerate(out): assert len(out_names) == len(out_values) if isinstance(out_names, tuple): assert isinstance(out_values, tuple) out_names = [out_names] out_values = [out_values] for row_names, row_values in zip(out_names, out_values): if len(row_names) and len(row_values): row_names = ('Iteration', 'Index',) + row_names row_values = (iteration, tree_idx,) + row_values fout[out_idx].record(row_names, row_values) fout[out_idx].write() counter += 1 if len(out_names) and len(out_names[0]) else 0 if store_per_event: for f in fout: f.close() if not store_per_event: for f in fout: f.close() return wrapper return decorator