Source code for mlreco.models.layers.common.extract_feature_map

import torch
import numpy as np
import MinkowskiEngine as ME

[docs]class MinkGhostMask(torch.nn.Module): ''' Ghost mask downsampler for MinkowskiEngine Backend '''
[docs] def __init__(self, data_dim): from mlreco.models.layers.common.ppnplus import ExpandAs super(MinkGhostMask, self).__init__() self._data_dim = data_dim self.downsample = ME.MinkowskiMaxPooling(2, 2, dimension=3) self.eval()
[docs] def forward(self, ghost_mask, premask_tensor): """ Downsamples the ghost mask and prunes premask_tensor with current ghost mask to obtain nonghost tensor and new ghost mask. Inputs: - ghost_mask (ME.SparseTensor): current resolution ghost mask - premask_tensor (ME.SparseTensor): current resolution feature map to be pruned Returns: - downsampled_mask (ME.SparseTensor): 2x2 downsampled ghost mask - downsampled_tensor (ME.SparseTensor): 2x2 downsampled feature map """ # assert ghost_mask.shape[0] == premask_tensor.shape[0] with torch.no_grad(): factor = premask_tensor.tensor_stride[0] gm = ghost_mask for i in range(np.log2(factor).astype(int)): gm = self.downsample(gm) return gm