Source code for mlreco.models.layers.cluster_cnn.graph_spice_embedder

import torch
import torch.nn as nn
import MinkowskiEngine as ME

from mlreco.models.layers.common.uresnet_layers import UResNet
# from mlreco.models.layers.cluster_cnn.losses.gs_embeddings import *

from pprint import pprint


[docs]class GraphSPICEEmbedder(UResNet): MODULES = ['network_base', 'uresnet', 'graph_spice_embedder']
[docs] def __init__(self, cfg, name='graph_spice_embedder'): super(GraphSPICEEmbedder, self).__init__(cfg) self.model_config = cfg.get(name, {}) self.feature_embedding_dim = self.model_config.get( 'feature_embedding_dim', 8) self.spatial_embedding_dim = self.model_config.get( 'spatial_embedding_dim', 3) self.num_classes = self.model_config.get('num_classes', 5) self.coordConv = self.model_config.get('coordConv', True) self.segmentationLayer = self.model_config.get('segmentationLayer', False) self.covariance_mode = self.model_config.get('covariance_mode', 'exp') if self.covariance_mode == 'exp': self.cov_func = torch.exp elif self.covariance_mode == 'softplus': self.cov_func = nn.Softplus() else: self.cov_func = nn.Sigmoid() self.occupancy_mode = self.model_config.get('occupancy_mode', 'exp') if self.occupancy_mode == 'exp': self.occ_func = torch.exp elif self.occupancy_mode == 'softplus': self.occ_func = nn.Softplus() else: self.occ_func = torch.exp # Define outputlayers self.outputSpatialEmbeddings = nn.Linear(self.num_filters, self.spatial_embedding_dim) self.outputFeatureEmbeddings = nn.Linear(self.num_filters, self.feature_embedding_dim) if self.segmentationLayer: self.outputSegmentation = nn.Linear(self.num_filters, self.num_classes) self.outputCovariance = nn.Linear(self.num_filters, 2) self.outputOccupancy = nn.Linear(self.num_filters, 1) self.hyper_dimension = self.spatial_embedding_dim \ + self.feature_embedding_dim + 3 # Pytorch Activations self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid()
# print('Total Number of Trainable Parameters (graph_spice_embedder)= {}'.format( # sum(p.numel() for p in self.parameters() if p.requires_grad))) # print([name for name, param in self.named_parameters()])
[docs] def get_embeddings(self, input): ''' point_cloud is a list of length minibatch size (assumes mbs = 1) point_cloud[0] has 3 spatial coordinates + 1 batch coordinate + 1 feature label has shape (point_cloud.shape[0] + 5*num_labels, 1) label contains segmentation labels for each point + coords of gt points RETURNS: - feature_enc: encoder features at each spatial resolution. - feature_dec: decoder features at each spatial resolution. ''' point_cloud, = input # print("Point Cloud: ", point_cloud) coords = point_cloud[:, 0:self.D+1].int() features = point_cloud[:, self.D+1:].float() normalized_coords = (coords[:, 1:self.D+1] - float(self.spatial_size) / 2) \ / (float(self.spatial_size) / 2) if self.coordConv: features = torch.cat([normalized_coords, features], dim=1) x = ME.SparseTensor(features, coordinates=coords) encoder_res = self.encoder(x) encoderTensors = encoder_res['encoderTensors'] finalTensor = encoder_res['finalTensor'] decoderTensors = self.decoder(finalTensor, encoderTensors) output_features = decoderTensors[-1].F # Spatial Embeddings out = self.outputSpatialEmbeddings(output_features) spatial_embeddings = self.tanh(out) # Covariance out = self.outputCovariance(output_features) covariance = self.cov_func(out) # Feature Embeddings feature_embeddings = self.outputFeatureEmbeddings(output_features) # Occupancy out = self.outputOccupancy(output_features) occupancy = self.occ_func(out) # Segmentation if self.segmentationLayer: segmentation = self.outputSegmentation(output_features) hypergraph_features = torch.cat([ spatial_embeddings, feature_embeddings, covariance, occupancy], dim=1) res = { "spatial_embeddings": [spatial_embeddings + normalized_coords], "covariance": [covariance], "feature_embeddings": [feature_embeddings], "occupancy": [occupancy], "features": [output_features], "hypergraph_features": [hypergraph_features], # "segmentation": [segmentation] } if self.segmentationLayer: res["segmentation"] = [segmentation] return res
[docs] def forward(self, input): ''' Train time forward ''' point_cloud, = input out = self.get_embeddings([point_cloud]) return out