mlreco.trainval module¶
-
class
mlreco.trainval.trainval(cfg)[source]¶ Bases:
objectGroups all relevant functions for forward/backward of a network.
-
get_data_minibatched(data_iter)[source]¶ Reads data for one compute cycle of single/multi-cpu/gpu forward path INPUT
data_iter is an iterator to return a mini-batch of data (data per gpu per compute) by next(dataset)
- OUTPUT
Returns a data_blob. The data_blob is a dictionary with a value being an array of mini batch data.
-
make_input_forward(data_blob)[source]¶ Given one compute cycle amount of data (return of get_data_minibatched), forms appropriate format to be used with torch DataParallel (i.e. multi-GPU training) INPUT
data_blob is a dictionary with a unique key-value where value is an array of length == # c/gpu to be used
- OUTPUT
Returns an input_blob and loss_blob. The input_blob and loss_blob are array of array of array such as … len(input_blob) = number of compute cycles = batch_size / (minibatch_size * len(GPUs))
-
train_step(data_iter, iteration=None, log_time=True)[source]¶ data_blob is the output of the function get_data_minibatched. It is a dictionary where data_blob[key] = list of length BATCH_SIZE / (MINIBATCH_SIZE * len(GPUS))
-
forward(data_iter, iteration=None)[source]¶ Run forward for flags.BATCH_SIZE / (flags.MINIBATCH_SIZE * len(flags.GPUS)) times
-
_forward(train_blob, loss_blob, iteration=None)[source]¶ data/label/weight are lists of size minibatch size. For sparse uresnet: data[0]: shape=(N, 5) where N = total nb points in all events of the minibatch For dense uresnet: data[0]: shape=(minibatch size, channel, spatial size, spatial size, spatial size)
-
__dict__= mappingproxy({'__module__': 'mlreco.trainval', '__doc__': '\n Groups all relevant functions for forward/backward of a network.\n ', '__init__': <function trainval.__init__>, 'backward': <function trainval.backward>, 'save_state': <function trainval.save_state>, 'get_data_minibatched': <function trainval.get_data_minibatched>, 'make_input_forward': <function trainval.make_input_forward>, 'train_step': <function trainval.train_step>, 'forward': <function trainval.forward>, '_forward': <function trainval._forward>, 'initialize_calibrator': <function trainval.initialize_calibrator>, 'freeze_weights': <function trainval.freeze_weights>, 'load_weights': <function trainval.load_weights>, 'initialize': <function trainval.initialize>, '__dict__': <attribute '__dict__' of 'trainval' objects>, '__weakref__': <attribute '__weakref__' of 'trainval' objects>, '__annotations__': {}})¶
-
__module__= 'mlreco.trainval'¶
-
__weakref__¶ list of weak references to the object (if defined)
-