mlreco.models.grappa module

class mlreco.models.grappa.GNN(cfg, name='grappa', batch_col=0, coords_col=(1, 4))[source]

Bases: torch.nn.modules.module.Module

Driver class for cluster node+edge prediction, assumed to be a GNN model.

This class mostly acts as a wrapper that will hand the graph data to another model. If DBSCAN is used, use the semantic label tensor as an input.

Typical configuration can look like this:

model:
  name: grappa
  modules:
    grappa:
      your config goes here
Configuration
  • base (dict) – Configuration of base Grappa :

    base:
      source_col      : <column in the input data that specifies the source node ids of each voxel (default 5)>
      target_col      : <column in the input data that specifies the target instance ids of each voxel (default 6)>
      node_type       : <semantic class to aggregate (all classes if -1, default -1)>
      node_min_size   : <minimum number of voxels inside a cluster to be included in the aggregation (default -1, i.e. no threshold)>
      add_points      : <add label point(s) to the node features: False (none) or True (both) (default False)>
      add_local_dirs  : <add reconstructed local direction(s) to the node features: False (none), True (both) or 'start' (default False)>
      dir_max_dist    : <maximium distance between start point and cluster voxels to be used to estimate direction: support value or 'optimize' (default 5 voxels)>
      add_local_dedxs : <add reconstructed local dedx(s) to the node features: False (none), True (both) or 'start' (default False)>
      dedx_max_dist   : <maximium distance between start point and cluster voxels to be used to estimate dedx (default 5 voxels)>
      network         : <type of network: 'complete', 'delaunay', 'mst', 'knn' or 'bipartite' (default 'complete')>
      edge_max_dist   : <maximal edge Euclidean length (default -1)>
      edge_dist_method: <edge length evaluation method: 'centroid' or 'voxel' (default 'voxel')>
      merge_batch     : <flag for whether to merge batches (default False)>
      merge_batch_mode: <mode of batch merging, 'const' or 'fluc'; 'const' use a fixed size of batch for merging, 'fluc' takes the input size a mean and sample based on it (default 'const')>
      merge_batch_size: <size of batch merging (default 2)>
      shuffle_clusters: <randomize cluster order (default False)>
    
  • dbscan (dict) – dictionary of dbscan parameters

  • node_encoder (dict) –

    node_encoder:
      name: <name of the node encoder>
      <dictionary of arguments to pass to the encoder>
      model_path      : <path to the encoder weights>
    
  • edge_encoder (dict) –

    edge_encoder:
      name: <name of the edge encoder>
      <dictionary of arguments to pass to the encoder>
      model_path      : <path to the encoder weights>
    
  • gnn_model (dict) –

    gnn_model:
      name: <name of the node model>
      <dictionary of arguments to pass to the model>
      model_path      : <path to the model weights>
    
  • kinematics_mlp (bool, default False) – Whether to enable MLP-like layers after the GNN to predict momentum, particle type, etc.

  • kinematics_type (bool) – Whether to add PID MLP to each node.

  • kinematics_momentum (bool) – Whether to add momentum MLP to each node.

  • type_net (dict) – Configuration for the PID MLP (if enabled). Can partial load weights here too.

  • momentum_net (dict) – Configuration for the Momentum MLP (if enabled). Can partial load weights here too.

  • vertex_mlp (bool, default False) – Whether to add vertex prediction MLP to each node. Includes primary particle + vertex coordinates predictions.

  • vertex_net (dict) – Configuration for the Vertex MLP (if enabled). Can partial load weights here too.

  • Outputs

  • ——-

  • input_node_features

  • input_edge_features

  • clusts

  • edge_index

  • node_pred

  • edge_pred

  • node_pred_p

  • node_pred_type

  • node_pred_vtx

See also

GNNLoss

MODULES = [('grappa', ['base', 'dbscan', 'node_encoder', 'edge_encoder', 'gnn_model']), 'grappa_loss']
__init__(cfg, name='grappa', batch_col=0, coords_col=(1, 4))[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

__module__ = 'mlreco.models.grappa'
training: bool
forward(data, clusts=None, groups=None, points=None, extra_feats=None, batch_size=None)[source]

Prepares particle clusters and feed them to the GNN model.

Parameters
  • array

    data[0] ([torch.tensor]): (N,5-10) [x, y, z, batch_id(, value), part_id(, group_id, int_id, nu_id, sem_type)]

    or (N,5) [x, y, z, batch_id, sem_type] (with DBSCAN)

    data[1] ([torch.tensor]): (N,8) [first_x, first_y, first_z, batch_id, last_x, last_y, last_z, first_step_t] (optional)

  • clusts – [(N_0), (N_1), …, (N_C)] Cluster ids (optional)

  • groups

    1. vectors of groups IDs (one per cluster) to enforce connections only within each group

  • points – (N,3/6) tensor of start (and end) points of clusters

  • extra_feats – (N,F) tensor of features to add to the encoded features

Returns

‘node_pred’ (torch.tensor): (N,2) Two-channel node predictions (split batch-wise) ‘edge_pred’ (torch.tensor): (E,2) Two-channel edge predictions (split batch-wise) ‘clusts’ ([np.ndarray]) : [(N_0), (N_1), …, (N_C)] Cluster ids (split batch-wise) ‘edge_index’ (np.ndarray) : (E,2) Incidence matrix (split batch-wise)

Return type

dict

class mlreco.models.grappa.GNNLoss(cfg, name='grappa_loss', batch_col=0, coords_col=(1, 4))[source]

Bases: torch.nn.modules.loss._Loss

Takes the output of the GNN and computes the total loss.

For use in config:

model:
  name: grappa
  modules:
    grappa_loss:
      node_loss:
        name: <name of the node loss>
        <dictionary of arguments to pass to the loss>
      edge_loss:
        name: <name of the edge loss>
        <dictionary of arguments to pass to the loss>
__module__ = 'mlreco.models.grappa'
reduction: str
__init__(cfg, name='grappa_loss', batch_col=0, coords_col=(1, 4))[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(result, clust_label, graph=None, node_label=None, iteration=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.