mlreco.models.layers.cluster_cnn.losses.spatial_embeddings_fast module

class mlreco.models.layers.cluster_cnn.losses.spatial_embeddings_fast.SPICELoss(cfg, name='spice_loss')[source]

Bases: torch.nn.modules.module.Module

Loss function for Sparse Spatial Embeddings Model, with fixed centroids and symmetric gaussian kernels.

__init__(cfg, name='spice_loss')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

find_cluster_means(features, labels)[source]

For a given image, compute the centroids mu_c for each cluster label in the embedding space. Inputs:

features (torch.Tensor): the pixel embeddings, shape=(N, d) where N is the number of pixels and d is the embedding space dimension. labels (torch.Tensor): ground-truth group labels, shape=(N, )

Returns

(n_c, d) tensor where n_c is the number of distinct instances. Each row is a (1,d) vector corresponding to the coordinates of the i-th centroid.

Return type

cluster_means (torch.Tensor)

get_per_class_probabilities(embeddings, margins, labels, eps=1e-06)[source]

Computes binary foreground/background loss.

combine_multiclass(embeddings, margins, seediness, slabels, clabels, coords)[source]

Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS.

NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud.

INPUTS:

features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels

OUTPUT
  • loss_segs (list) – list of computed loss values for each semantic class.

  • loss[i] = computed DLoss for semantic class <i>.

  • acc_segs (list) – list of computed clustering accuracy for each semantic class.

forward(out, segment_label, group_label)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

__module__ = 'mlreco.models.layers.cluster_cnn.losses.spatial_embeddings_fast'
training: bool
class mlreco.models.layers.cluster_cnn.losses.spatial_embeddings_fast.SPICEInterLoss(cfg, name='spice_loss')[source]

Bases: mlreco.models.layers.cluster_cnn.losses.spatial_embeddings_fast.SPICELoss

__init__(cfg, name='spice_loss')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

regularization(cluster_means)[source]

Implementation of regularization loss in Discriminative Loss Inputs:

cluster_means (torch.Tensor): output from find_cluster_means

Returns

computed regularization loss (see paper).

Return type

reg_loss (float)

__module__ = 'mlreco.models.layers.cluster_cnn.losses.spatial_embeddings_fast'
inter_cluster_loss(cluster_means, margin=0.2)[source]

Implementation of distance loss in Discriminative Loss. Inputs:

cluster_means (torch.Tensor): output from find_cluster_means margin (float/int): the magnitude of the margin delta_d in the paper. Think of it as the distance between each separate clusters in embedding space.

Returns

computed cross-centroid distance loss (see paper). Factor of 2 is included for proper normalization.

Return type

inter_loss (float)

training: bool
get_per_class_probabilities(embeddings, margins, labels, eps=1e-06)[source]

Computes binary foreground/background loss.

combine_multiclass(embeddings, margins, seediness, slabels, clabels)[source]

Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS.

NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud.

INPUTS:

features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels

OUTPUT
  • loss_segs (list) – list of computed loss values for each semantic class.

  • loss[i] = computed DLoss for semantic class <i>.

  • acc_segs (list) – list of computed clustering accuracy for each semantic class.