Source code for mlreco.iotools.datasets

import os, glob, inspect
import numpy as np
from torch.utils.data import Dataset
import mlreco.iotools.parsers

[docs]class LArCVDataset(Dataset): """ A generic interface for LArCV data files. This Dataset is designed to produce a batch of arbitrary number of data chunks (e.g. input data matrix, segmentation label, point proposal target, clustering labels, etc.). Each data chunk is processed by parser functions defined in the iotools.parsers module. LArCVDataset object can be configured with arbitrary number of parser functions where each function can take arbitrary number of LArCV event data objects. The assumption is that each data chunk respects the LArCV event boundary. """
[docs] def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples=0, event_list=None, skip_event_list=None): """ Instantiates the LArCVDataset. Parameters ---------- data_schema : dict A dictionary of (string, dictionary) pairs. The key is a unique name of a data chunk in a batch and the associated dictionary must include: - parser: name of the parser - args: (key, value) pairs that correspond to parser argument names and their values The nested dictionaries can replaced be lists, in which case they will be considered as parser argument values, in order. data_keys : list a list of strings that is required to be present in the file paths limit_num_files : int an integer limiting number of files to be taken per data directory limit_num_samples : int an integer limiting number of samples to be taken per data event_list : list a list of integers to specify which event (ttree index) to process skip_event_list : list a list of integers to specify which events (ttree index) to skip """ # Create file list self._files = [] for key in data_keys: fs = glob.glob(key) for f in fs: self._files.append(f) if len(self._files) >= limit_num_files: break if len(self._files) >= limit_num_files: break if len(self._files)<1: raise FileNotFoundError elif len(self._files)>10: print(len(self._files),'files loaded') else: for f in self._files: print('Loading file:',f) # Instantiate parsers self._data_keys = [] self._data_parsers = [] self._trees = {} for key, value in data_schema.items(): # Check that the schema is a dictionary if not isinstance(value, dict): raise ValueError('A data schema must be expressed as a dictionary') # Identify the parser and its parameter names assert 'parser' in value, 'A parser needs to be specified for %s' % key if not hasattr(mlreco.iotools.parsers, value['parser']): print('The specified parser name %s does not exist!' % value['parser']) assert 'args' in value, 'Parser arguments must be provided for %s' % key fn = getattr(mlreco.iotools.parsers, value['parser']) keys = list(inspect.signature(fn).parameters.keys()) assert isinstance(value['args'], dict), 'Parser arguments must be a list or dictionary for %s' % key for k in value['args'].keys(): assert k in keys, 'Argument %s does not exist in parser %s' % (k, value['parser']) # Append data key and parsers self._data_keys.append(key) self._data_parsers.append((getattr(mlreco.iotools.parsers, value['parser']), value['args'])) for arg_name, data_key in value['args'].items(): if 'event' not in arg_name: continue if 'event_list' not in arg_name: data_key = [data_key] for k in data_key: if k not in self._trees: self._trees[k] = None self._data_keys.append('index') # Prepare TTrees and load files from ROOT import TChain self._entries = None for data_key in self._trees.keys(): # Check data TTree exists, and entries are identical across >1 trees. # However do NOT register these TTrees in self._trees yet in order to support >1 workers by DataLoader print('Loading tree',data_key) chain = TChain(data_key + "_tree") for f in self._files: chain.AddFile(f) if self._entries is not None: assert(self._entries == chain.GetEntries()) else: self._entries = chain.GetEntries() # If event list is provided, register if event_list is None: self._event_list = np.arange(0, self._entries) elif isinstance(event_list, tuple): event_list = np.arange(event_list[0], event_list[1]) self._event_list = event_list self._entries = len(self._event_list) else: if isinstance(event_list,list): event_list = np.array(event_list).astype(np.int32) assert(len(event_list.shape)==1) where = np.where(event_list >= self._entries) removed = event_list[where] if len(removed): print('WARNING: ignoring some of specified events in event_list as they do not exist in the sample.') print(removed) self._event_list = event_list[np.where(event_list < self._entries)] self._entries = len(self._event_list) if skip_event_list is not None: self._event_list = self._event_list[~np.isin(self._event_list, skip_event_list)] self._entries = len(self._event_list) # Set total sample size if limit_num_samples > 0 and self._entries > limit_num_samples: self._entries = limit_num_samples print('Found %d events in file(s)' % len(self._event_list)) # Flag to identify if Trees are initialized or not self._trees_ready=False
[docs] @staticmethod def list_data(f): from ROOT import TFile f=TFile.Open(f,"READ") data={'sparse3d':[],'cluster3d':[],'particle':[]} for k in f.GetListOfKeys(): name = k.GetName() if not name.endswith('_tree'): continue if not len(name.split('_')) < 3: continue key = name.split('_')[0] if not key in data.keys(): continue data[key] = name[:name.rfind('_')] return data
[docs] @staticmethod def get_event_list(cfg, key): event_list = None if key in cfg: if os.path.isfile(cfg[key]): event_list = [int(val) for val in open(cfg[key],'r').read().replace(',',' ').split() if val.isdigit()] else: try: import ast event_list = ast.literal_eval(cfg[key]) except SyntaxError: print('iotool.dataset.%s has invalid representation:' % key,event_list) raise ValueError return event_list
[docs] @staticmethod def create(cfg): data_schema = cfg['schema'] data_keys = cfg['data_keys'] lnf = 0 if not 'limit_num_files' in cfg else int(cfg['limit_num_files']) lns = 0 if not 'limit_num_samples' in cfg else int(cfg['limit_num_samples']) event_list = LArCVDataset.get_event_list(cfg, 'event_list') skip_event_list = LArCVDataset.get_event_list(cfg, 'skip_event_list') return LArCVDataset(data_schema=data_schema, data_keys=data_keys, limit_num_files=lnf, event_list=event_list, skip_event_list=skip_event_list)
[docs] def data_keys(self): return self._data_keys
[docs] def __len__(self): return self._entries
[docs] def __getitem__(self,idx): # convert to actual index: by default, it is idx, but not if event_list provided event_idx = self._event_list[idx] # If this is the first data loading, instantiate chains if not self._trees_ready: from ROOT import TChain for key in self._trees.keys(): chain = TChain(key + '_tree') for f in self._files: chain.AddFile(f) self._trees[key] = chain self._trees_ready=True # Move the event pointer for tree in self._trees.values(): tree.GetEntry(event_idx) # Create data chunks result = {} for index, (parser, args) in enumerate(self._data_parsers): kwargs = {} for k, v in args.items(): if 'event_list' in k: kwargs[k] = [getattr(self._trees[vi], vi+'_branch') for vi in v] elif 'event' in k: kwargs[k] = getattr(self._trees[v], v+'_branch') else: kwargs[k] = v name = self._data_keys[index] result[name] = parser(**kwargs) result['index'] = event_idx return result