mlreco.models.graph_spice module

class mlreco.models.graph_spice.MinkGraphSPICE(cfg, name='graph_spice')[source]

Bases: torch.nn.modules.module.Module

Neighbor-graph embedding based particle clustering.

GraphSPICE has two components:

1. Voxel Embedder: UNet-type CNN architecture used for feature extraction and feature embeddings.

2. Edge Probability Kernel function: A kernel function (any callable that takes two node attribute vectors to give a edge proability score).

Prediction is done in two steps:

1. A neighbor graph (ex. KNN, Radius) is constructed to compute edge probabilities between neighboring edges.

  1. Edges with low probability scores are dropped.

  2. The voxels are clustered by counting connected components.

Configuration
  • skip_classes (list, default [2, 3, 4]) – semantic labels for which to skip voxel clustering (ex. Michel, Delta, and Low Es rarely require neural network clustering)

  • dimension (int, default 3) – Spatial dimension (2 or 3).

  • min_points (int, default 0) – If a value > 0 is specified, this will enable the orphans assignment for any predicted cluster with voxel count < min_points.

    Warning

    min_points is set to 0 at training time.

  • node_dim (int)

  • use_raw_features (bool)

  • constructor_cfg (dict) – Configuration for ClusterGraphConstructor instance. A typical configuration:

    constructor_cfg:
      mode: 'knn'
      seg_col: -1
      cluster_col: 5
      edge_mode: 'attributes'
      hyper_dimension: 22
      edge_cut_threshold: 0.1
    

    Warning

    edge_cut_threshold is set to 0. at training time. At inference time you want to set it to a value > 0. As a rule of thumb, 0.1 is a good place to start. Its exact value can be optimized.

  • embedder_cfg (dict) – A typical configuration would look like:

    embedder_cfg:
      graph_spice_embedder:
        segmentationLayer: False
        feature_embedding_dim: 16
        spatial_embedding_dim: 3
        num_classes: 5
        occupancy_mode: 'softplus'
        covariance_mode: 'softplus'
      uresnet:
        filters: 32
        input_kernel: 5
        depth: 5
        reps: 2
        spatial_size: 768
        num_input: 4 # 1 feature + 3 normalized coords
        allow_bias: False
        activation:
          name: lrelu
          args:
            negative_slope: 0.33
        norm_layer:
          name: batch_norm
          args:
            eps: 0.0001
            momentum: 0.01
    
  • kernel_cfg (dict) – A typical configuration:

    kernel_cfg:
      name: 'bilinear'
      num_features: 32
    
  • .. warning:: – Train time and test time configurations are slightly different for GraphSpice.

Output
  • graph

  • graph_info

  • coordinates

  • batch_indices

  • hypergraph_features

See also

GraphSPICELoss

MODULES = ['constructor_cfg', 'embedder_cfg', 'kernel_cfg', 'gspice_fragment_manager']
__init__(cfg, name='graph_spice')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

weight_initialization()[source]
filter_class(input)[source]

Filter classes according to segmentation label.

forward(input)[source]
__module__ = 'mlreco.models.graph_spice'
training: bool
class mlreco.models.graph_spice.GraphSPICELoss(cfg, name='graph_spice_loss')[source]

Bases: torch.nn.modules.module.Module

Loss function for GraphSpice.

Configuration
  • name (str, default ‘se_lovasz_inter’) – Loss function to use.

  • invert (bool, default True) – You want to leave this to True for statistical weighting purpose.

  • kernel_lossfn (str)

  • edge_loss_cfg (dict) – For example

    edge_loss_cfg:
      loss_type: 'LogDice'
    
  • eval (bool, default False) – Whether we are in inference mode or not.

    Warning

    Currently you need to manually switch eval to True when you want to run the inference, as there is no way (?) to know from within the loss function whether we are training or not.

Output

To be completed.

See also

MinkGraphSPICE

__init__(cfg, name='graph_spice_loss')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

filter_class(segment_label, cluster_label)[source]

Filter classes according to segmentation label.

forward(result, segment_label, cluster_label)[source]
__module__ = 'mlreco.models.graph_spice'
training: bool