Source code for mlreco.models.layers.common.cnn_encoder

import numpy as np
import torch
import torch.nn as nn

import MinkowskiEngine as ME
import MinkowskiFunctional as MF

from collections import defaultdict
from mlreco.models.layers.common.activation_normalization_factories import activations_dict, activations_construct, normalizations_construct
from mlreco.models.layers.common.configuration import setup_cnn_configuration
from mlreco.models.layers.common.blocks import ResNetBlock, ConvolutionBlock
from mlreco.models.layers.common.uresnet_layers import UResNetEncoder


[docs]class SparseResidualEncoder(UResNetEncoder): ''' Minkowski Net Autoencoder for sparse tensor reconstruction. '''
[docs] def __init__(self, cfg, name='res_encoder'): #print("RESENCODER = ", cfg) super(SparseResidualEncoder, self).__init__(cfg, name=name) self.model_config = cfg.get(name, {}) self.latent_size = self.model_config.get('latent_size', 512) final_tensor_shape = self.spatial_size // (2**(self.depth-1)) self.coordConv = self.model_config.get('coordConv', False) #print("Final Tensor Shape = ", final_tensor_shape) self.pool_mode = self.model_config.get('pool_mode', 'avg') # Initialize Input Layer if self.coordConv: self.input_layer = ME.MinkowskiConvolution( in_channels=self.num_input + self.D, out_channels=self.num_filters, kernel_size=self.input_kernel, stride=1, dimension=self.D) else: self.input_layer = ME.MinkowskiConvolution( in_channels=self.num_input, out_channels=self.num_filters, kernel_size=self.input_kernel, stride=1, dimension=self.D) if self.pool_mode == 'avg': self.pool = ME.MinkowskiGlobalPooling() elif self.pool_mode == 'conv': self.pool = nn.Sequential( ME.MinkowskiConvolution( in_channels=self.nPlanes[-1], out_channels=self.nPlanes[-1], kernel_size=final_tensor_shape, stride=final_tensor_shape, dimension=self.D), ME.MinkowskiGlobalPooling()) elif self.pool_mode == 'max': self.pool = nn.Sequential( ME.MinkowskiMaxPooling(final_tensor_shape, stride=final_tensor_shape, dimension=self.D), ME.MinkowskiGlobalPooling()) else: raise NotImplementedError self.linear1 = ME.MinkowskiLinear(self.nPlanes[-1], self.latent_size)
[docs] def forward(self, input_tensor): # print(input_tensor) features = input_tensor[:, -1].view(-1, 1) if self.coordConv: normalized_coords = (input_tensor[:, 1:4] - float(self.spatial_size) / 2) \ / (float(self.spatial_size) / 2) features = torch.cat([normalized_coords, features], dim=1) x = ME.SparseTensor(coordinates=input_tensor[:, :4].int(), features=features.float()) # Encoder encoderOutput = self.encoder(x) encoderTensors = encoderOutput['encoderTensors'] finalTensor = encoderOutput['finalTensor'] z = self.pool(finalTensor) latent = self.linear1(z) return latent.F