mlreco.models.layers.cluster_cnn.losses.single_layers module

class mlreco.models.layers.cluster_cnn.losses.single_layers.DiscriminativeLoss(cfg, reduction='sum')[source]

Bases: torch.nn.modules.module.Module

Implementation of the Discriminative Loss Function in Pytorch. https://arxiv.org/pdf/1708.02551.pdf Note that there are many other implementations in Github, yet here we tailor it for use in conjuction with Sparse UResNet.

__init__(cfg, reduction='sum')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

find_cluster_means(features, labels)[source]

For a given image, compute the centroids mu_c for each cluster label in the embedding space. Inputs:

features (torch.Tensor): the pixel embeddings, shape=(N, d) where N is the number of pixels and d is the embedding space dimension. labels (torch.Tensor): ground-truth group labels, shape=(N, )

Returns

(n_c, d) tensor where n_c is the number of distinct instances. Each row is a (1,d) vector corresponding to the coordinates of the i-th centroid.

Return type

cluster_means (torch.Tensor)

intra_cluster_loss(features, labels, cluster_means, margin=0.5)[source]

Implementation of variance loss in Discriminative Loss. Inputs:

features (torch.Tensor): pixel embedding, same as in find_cluster_means. labels (torch.Tensor): ground truth instance labels cluster_means (torch.Tensor): output from find_cluster_means margin (float/int): constant used to specify delta_v in paper. Think of it as the size of each clusters in embedding space.

Returns

(float) variance loss (see paper).

Return type

intra_loss

inter_cluster_loss(cluster_means, margin=1.5)[source]

Implementation of distance loss in Discriminative Loss. Inputs:

cluster_means (torch.Tensor): output from find_cluster_means margin (float/int): the magnitude of the margin delta_d in the paper. Think of it as the distance between each separate clusters in embedding space.

Returns

computed cross-centroid distance loss (see paper). Factor of 2 is included for proper normalization.

Return type

inter_loss (float)

regularization(cluster_means)[source]

Implementation of regularization loss in Discriminative Loss Inputs:

cluster_means (torch.Tensor): output from find_cluster_means

Returns

computed regularization loss (see paper).

Return type

reg_loss (float)

compute_heuristic_accuracy(embedding, truth)[source]

Compute Adjusted Rand Index Score for given embedding coordinates, where predicted cluster labels are obtained from distance to closest centroid (computes heuristic accuracy).

Inputs:

embedding (torch.Tensor): (N, d) Tensor where ‘d’ is the embedding dimension. truth (torch.Tensor): (N, ) Tensor for the ground truth clustering labels.

Returns

Computed ARI Score clustering (array): the predicted cluster labels.

Return type

score (float)

combine(features, labels, **kwargs)[source]

Wrapper function for combining different components of the loss function. Inputs:

features (torch.Tensor): pixel embeddings labels (torch.Tensor): ground-truth instance labels

Returns

combined loss, in most cases over a given semantic class.

Return type

loss

combine_multiclass(features, slabels, clabels, **kwargs)[source]

Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS.

NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud.

INPUTS:

features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels

OUTPUT
  • loss_segs (list) – list of computed loss values for each semantic class.

  • loss[i] = computed DLoss for semantic class <i>.

  • acc_segs (list) – list of computed clustering accuracy for each semantic class.

forward(out, semantic_labels, group_labels)[source]

Forward function for the Discriminative Loss Module.

Inputs:

out: output of UResNet; embedding-space coordinates. semantic_labels: ground-truth semantic labels group_labels: ground-truth instance labels

Returns

A dictionary containing key-value pairs for loss, accuracy, etc.

Return type

(dict)

__module__ = 'mlreco.models.layers.cluster_cnn.losses.single_layers'
training: bool