mlreco.models.layers.cluster_cnn.losses.gs_embeddings module¶
-
class
mlreco.models.layers.cluster_cnn.losses.gs_embeddings.WeightedEdgeLoss(loss_type='BCE', reduction='mean', invert=False)[source]¶ Bases:
torch.nn.modules.module.Module-
__init__(loss_type='BCE', reduction='mean', invert=False)[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward(logits, targets)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
__module__= 'mlreco.models.layers.cluster_cnn.losses.gs_embeddings'¶
-
training: bool¶
-
-
mlreco.models.layers.cluster_cnn.losses.gs_embeddings.compute_edge_weight(sp_emb: torch.Tensor, ft_emb: torch.Tensor, cov: torch.Tensor, edge_indices: torch.Tensor, occ=None, eps=0.001)[source]¶
-
class
mlreco.models.layers.cluster_cnn.losses.gs_embeddings.GraphSPICEEmbeddingLoss(cfg, name='graph_spice_loss')[source]¶ Bases:
torch.nn.modules.module.ModuleLoss function for Sparse Spatial Embeddings Model, with fixed centroids and symmetric gaussian kernels.
-
__init__(cfg, name='graph_spice_loss')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
feature_embedding_loss(ft_emb, groups, ft_centroids)[source]¶ Compute discriminative feature embedding loss.
- INPUTS:
ft_emb (N x F)
groups (N)
ft_centroids (N_c X F)
-
spatial_embedding_loss(sp_emb, groups, sp_centroids)[source]¶ Compute spatial centroid regression loss.
- INPUTS:
sp_emb (N x D)
groups (N)
ft_centroids (N_c X F)
-
combine_multiclass(sp_embeddings, ft_embeddings, covariance, occupancy, slabels, clabels)[source]¶ Wrapper function for combining different components of the loss, in particular when clustering must be done PER SEMANTIC CLASS.
NOTE: When there are multiple semantic classes, we compute the DLoss by first masking out by each semantic segmentation (ground-truth/prediction) and then compute the clustering loss over each masked point cloud.
- INPUTS:
features (torch.Tensor): pixel embeddings slabels (torch.Tensor): semantic labels clabels (torch.Tensor): group/instance/cluster labels
- OUTPUT
loss_segs (list) – list of computed loss values for each semantic class.
loss[i] = computed DLoss for semantic class <i>.
acc_segs (list) – list of computed clustering accuracy for each semantic class.
-
forward(out, segment_label, cluster_label)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
__module__= 'mlreco.models.layers.cluster_cnn.losses.gs_embeddings'¶
-
training: bool¶
-
-
class
mlreco.models.layers.cluster_cnn.losses.gs_embeddings.NodeEdgeHybridLoss(cfg, name='graph_spice_loss')[source]¶ Bases:
torch.nn.modules.loss._LossCombined Node + Edge Loss
-
__init__(cfg, name='graph_spice_loss')[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
__module__= 'mlreco.models.layers.cluster_cnn.losses.gs_embeddings'¶
-
reduction: str¶
-
forward(result, segment_label, cluster_label)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-