mlreco.utils.data_parallel module

class mlreco.utils.data_parallel.DataParallel(module, device_ids=None, output_device=None, dim=0)[source]

Bases: torch.nn.parallel.data_parallel.DataParallel

Scatters and gathers data for multi-gpu training.

This is a layer over torch.nn.parallel.DataParallel because we have custom inputs/outputs:

1. we want to have dict input to our networks and it is not handled by PyTorch DataParallel, 2. we want to return several outputs from the network.

Note

Reason 2. might be obsolete as it seems PyTorch DataParallel now supports dict returns. Assumptions =========== Network has a single input.

__init__(module, device_ids=None, output_device=None, dim=0)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

scatter(inputs, kwargs, device_ids)[source]

Scatters the inputs and kwargs to several GPUs (device_ids).

len(inputs) = how many inputs the network takes len(inputs[0]) = #GPUs * mbs

gather(outputs, output_device)[source]

Gathers outputs of the network from all GPUs to output_device.

len(outputs) = number of gpus len(outputs[0]) = number of outputs of the network len(outputs[0][0]) = 1 (each output is enclosed in a [])

Returns

  • len(results) = number of outputs returned by network

  • len(results[0]) = number of gpus

__module__ = 'mlreco.utils.data_parallel'
training: bool