"""
These factories instantiate `torch.utils.data.DataLoader`
based on the YAML configuration that was provided.
"""
from torch.utils.data import DataLoader
[docs]def dataset_factory(cfg,event_list=None):
"""
Instantiates dataset based on type specified in configuration under
`iotool.dataset.name`. The name must match the name of a class under
mlreco.iotools.datasets.
Note
----
Currently the choice is limited to `LArCVDataset` only.
"""
import mlreco.iotools.datasets
params = cfg['iotool']['dataset']
if event_list is not None:
params['event_list'] = str(list(event_list))
return getattr(mlreco.iotools.datasets, params['name']).create(params)
[docs]def loader_factory(cfg,event_list=None):
"""
Instantiates a DataLoader based on configuration.
Dataset comes from `dataset_factory`.
Parameters
----------
cfg : dict
Configuration dictionary. Expects a field `iotool`.
event_list: list, optional
List of tree idx.
Returns
-------
loader : torch.utils.data.DataLoader
"""
params = cfg['iotool']
minibatch_size = int(params['minibatch_size'])
shuffle = True if not 'shuffle' in params else bool(params['shuffle' ])
num_workers = 1 if not 'num_workers' in params else int (params['num_workers'])
collate_fn = None if not 'collate_fn' in params else str (params['collate_fn' ])
collate_kwargs = {}
if collate_fn is None:
collate_params = params.get('collate', {})
collate_fn = None if not 'collate_fn' in collate_params else str(collate_params['collate_fn'])
collate_params.pop('collate_fn', None)
collate_kwargs = collate_params
if not int(params['batch_size']) % int(params['minibatch_size']) == 0:
print('iotools.batch_size (',params['batch_size'],'must be divisble by iotools.minibatch_size',params['minibatch_size'])
raise ValueError
import mlreco.iotools.collates
import mlreco.iotools.samplers
from functools import partial
ds = dataset_factory(cfg,event_list)
sampler = None
if 'sampler' in cfg['iotool']:
sam_cfg = cfg['iotool']['sampler']
sam_cfg['minibatch_size']=cfg['iotool']['minibatch_size']
sampler = getattr(mlreco.iotools.samplers,sam_cfg['name']).create(ds,sam_cfg)
if collate_fn is not None:
collate_fn = partial(getattr(mlreco.iotools.collates,collate_fn), **collate_kwargs)
loader = DataLoader(ds,
batch_size = minibatch_size,
shuffle = shuffle,
sampler = sampler,
num_workers = num_workers,
collate_fn = collate_fn)
else:
loader = DataLoader(ds,
batch_size = minibatch_size,
shuffle = shuffle,
sampler = sampler,
num_workers = num_workers)
return loader