Source code for mlreco.models.layers.common.sparse_generator

import torch
import torch.nn as nn

import MinkowskiEngine as ME

from mlreco.models.layers.common.activation_normalization_factories import activations_dict, activations_construct, normalizations_construct
from mlreco.models.layers.common.configuration import setup_cnn_configuration
from mlreco.models.layers.common.blocks import ResNetBlock, ConvolutionBlock
from scipy.special import logit

# from torch_geometric.nn import BatchNorm, LayerNorm, MessageNorm

[docs]class SparseGenerator(torch.nn.Module):
[docs] def __init__(self, cfg, name='sparse_generator'): super(SparseGenerator, self).__init__() setup_cnn_configuration(self, cfg, name) self.model_config = cfg[name] self.reps = self.model_config.get('reps', 2) self.depth = self.model_config.get('depth', 7) self.num_filters = self.model_config.get('num_filters', 16) # self.nPlanes = [(2**i) * self.num_filters for i in range(self.depth)] self.nPlanes = [16, 16, 32, 32, 64, 64, 128, 128, 256] print(self.nPlanes) assert len(self.nPlanes) == self.depth self.latent_size = self.model_config.get('latent_size', 512) final_tensor_shape = self.spatial_size // (2**(self.depth-1)) self.coordConv = self.model_config.get('coordConv', False) print("Final Tensor Shape = ", final_tensor_shape) self.resolution = self.model_config.get('resolution', 1024) self.threshold = logit(self.model_config.get('threshold', 0.0)) print(self.threshold) self.layer_limit = self.model_config.get('layer_limit', -1) if self.layer_limit < 0: self.layer_limit = len(self.nPlanes) + 1 self.linear = nn.Sequential( normalizations_construct(self.norm, self.latent_size, **self.norm_args), activations_construct( self.activation_name, **self.activation_args), ME.MinkowskiConvolutionTranspose( in_channels=self.latent_size, out_channels=self.nPlanes[-1], kernel_size=final_tensor_shape, stride=final_tensor_shape, dimension=self.D, bias=self.allow_bias, generate_new_coords=True) ) # Initialize Decoder self.decoding_block = [] self.decoding_conv = [] self.layer_cls = [] for i in range(self.depth-2, -1, -1): m = [] m.append(normalizations_construct(self.norm, self.nPlanes[i+1], **self.norm_args)) m.append(activations_construct( self.activation_name, **self.activation_args)) m.append(ME.MinkowskiConvolutionTranspose( in_channels=self.nPlanes[i+1], out_channels=self.nPlanes[i], kernel_size=2, stride=2, dimension=self.D, bias=self.allow_bias, generate_new_coords=True)) m = nn.Sequential(*m) self.decoding_conv.append(m) m = [] for j in range(self.reps): m.append(ResNetBlock(self.nPlanes[i], self.nPlanes[i], dimension=self.D, activation=self.activation_name, activation_args=self.activation_args, normalization=self.norm, normalization_args=self.norm_args, has_bias=self.allow_bias)) m = nn.Sequential(*m) self.decoding_block.append(m) self.layer_cls.append( ME.MinkowskiLinear(self.nPlanes[i], 1, bias=self.allow_bias) ) self.decoding_block = nn.Sequential(*self.decoding_block) self.decoding_conv = nn.Sequential(*self.decoding_conv) self.layer_cls = nn.Sequential(*self.layer_cls) # pruning self.pruning = ME.MinkowskiPruning()
[docs] def get_batch_indices(self, out): return out.coords_man.get_row_indices_per_batch(out.coords_key)
[docs] def get_target(self, out, target_key, kernel_size=1): with torch.no_grad(): target = torch.zeros(len(out), dtype=torch.bool) cm = out.coords_man strided_target_key = cm.stride( target_key, out.tensor_stride[0], force_creation=True) ins, outs = cm.get_kernel_map( out.coords_key, strided_target_key, kernel_size=kernel_size, region_type=1) for curr_in in ins: target[curr_in] = 1 return target
[docs] def forward(self, latent, target_key): out_cls, targets = [], [] latent.set_tensor_stride(self.resolution) x = self.linear(latent) layer_count = 0 for i, layer in enumerate(self.decoding_conv): print(layer_count) if layer_count >= self.layer_limit: break x = layer(x) x = self.decoding_block[i](x) x_cls = self.layer_cls[i](x) target = self.get_target(x, target_key) targets.append(target) out_cls.append(x_cls) layer_count += 1 keep = (x_cls.F > self.threshold).cpu().squeeze() if self.training: keep += target if keep.sum() > 0: x = self.pruning(x, keep.cpu()) else: break return { 'reconstruction': x, 'out_cls': out_cls, 'targets': targets}
[docs]class SparseGeneratorSimple(torch.nn.Module):
[docs] def __init__(self, cfg, name='sparse_generator'): super(SparseGeneratorSimple, self).__init__() setup_cnn_configuration(self, cfg, name) self.model_config = cfg[name] self.reps = self.model_config.get('reps', 2) self.depth = self.model_config.get('depth', 7) self.num_filters = self.model_config.get('num_filters', 16) # self.nPlanes = [(2**i) * self.num_filters for i in range(self.depth)] self.nPlanes = [16, 16, 32, 32, 64, 64, 128, 128, 256] print(self.nPlanes) assert len(self.nPlanes) == self.depth self.latent_size = self.model_config.get('latent_size', 512) final_tensor_shape = self.spatial_size // (2**(self.depth-1)) self.coordConv = self.model_config.get('coordConv', False) print("Final Tensor Shape = ", final_tensor_shape) self.resolution = self.model_config.get('resolution', 1024) self.threshold = logit(self.model_config.get('threshold', 0.0)) print(self.threshold) self.layer_limit = self.model_config.get('layer_limit', -1) if self.layer_limit < 0: self.layer_limit = len(self.nPlanes) + 1 self.linear = nn.Sequential( normalizations_construct(self.norm, self.latent_size, **self.norm_args), activations_construct( self.activation_name, **self.activation_args), ME.MinkowskiConvolutionTranspose( in_channels=self.latent_size, out_channels=self.nPlanes[-1], kernel_size=final_tensor_shape, stride=final_tensor_shape, dimension=self.D, bias=self.allow_bias, generate_new_coords=True) ) # Initialize Decoder self.decoding_block = [] self.decoding_conv = [] self.layer_cls = [] for i in range(self.depth-2, -1, -1): m = [] m.append(normalizations_construct(self.norm, self.nPlanes[i+1], **self.norm_args)) m.append(activations_construct( self.activation_name, **self.activation_args)) m.append(ME.MinkowskiConvolutionTranspose( in_channels=self.nPlanes[i+1], out_channels=self.nPlanes[i], kernel_size=2, stride=2, dimension=self.D, bias=self.allow_bias, generate_new_coords=True)) m = nn.Sequential(*m) self.decoding_conv.append(m) m = [] for j in range(self.reps): m.append(ConvolutionBlock(self.nPlanes[i], self.nPlanes[i], dimension=self.D, activation=self.activation_name, activation_args=self.activation_args, normalization=self.norm, normalization_args=self.norm_args, has_bias=self.allow_bias)) m = nn.Sequential(*m) self.decoding_block.append(m) self.layer_cls.append( ME.MinkowskiLinear(self.nPlanes[i], 1, bias=self.allow_bias) ) self.decoding_block = nn.Sequential(*self.decoding_block) self.decoding_conv = nn.Sequential(*self.decoding_conv) self.layer_cls = nn.Sequential(*self.layer_cls) # pruning self.pruning = ME.MinkowskiPruning()
[docs] def get_batch_indices(self, out): return out.coords_man.get_row_indices_per_batch(out.coords_key)
[docs] def get_target(self, out, target_key, kernel_size=1): with torch.no_grad(): target = torch.zeros(len(out), dtype=torch.bool) cm = out.coords_man strided_target_key = cm.stride( target_key, out.tensor_stride[0], force_creation=True) ins, outs = cm.get_kernel_map( out.coords_key, strided_target_key, kernel_size=kernel_size, region_type=1) for curr_in in ins: target[curr_in] = 1 return target
[docs] def forward(self, latent, target_key): out_cls, targets = [], [] latent.set_tensor_stride(self.resolution) x = self.linear(latent) layer_count = 0 for i, layer in enumerate(self.decoding_conv): print(layer_count) if layer_count >= self.layer_limit: break x = layer(x) x = self.decoding_block[i](x) x_cls = self.layer_cls[i](x) target = self.get_target(x, target_key) targets.append(target) out_cls.append(x_cls) layer_count += 1 keep = (x_cls.F > self.threshold).cpu().squeeze() if self.training: keep += target if keep.sum() > 0: x = self.pruning(x, keep.cpu()) else: break return { 'reconstruction': x, 'out_cls': out_cls, 'targets': targets}