mlreco.models.layers.gnn.losses.node_type module¶
-
class
mlreco.models.layers.gnn.losses.node_type.NodeTypeLoss(loss_config, batch_col=0, coords_col=(1, 4))[source]¶ Bases:
torch.nn.modules.module.ModuleTakes the c-channel node output of the GNN and optimizes node-wise scores such that the score corresponding to the correct class is maximized.
For use in config: model:
name: cluster_gnn modules:
- grappa_loss:
- node_loss:
name: : type batch_col : <column in the label data that specifies the batch ids of each voxel (default 3)> target_col : <column in the label data that specifies the target node class of each voxel (default 7)> loss : <loss function: ‘CE’ or ‘MM’ (default ‘CE’)> reduction : <loss reduction method: ‘mean’ or ‘sum’ (default ‘sum’)> balance_classes : <balance loss per class: True or False (default False)>
-
__init__(loss_config, batch_col=0, coords_col=(1, 4))[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
__module__= 'mlreco.models.layers.gnn.losses.node_type'¶
-
training: bool¶
-
forward(out, types)[source]¶ Applies the requested loss on the node prediction.
- Parameters
out (dict) – ‘node_pred’ (torch.tensor): (C,2) Two-channel node predictions ‘clusts’ ([np.ndarray]) : [(N_0), (N_1), …, (N_C)] Cluster ids
types ([torch.tensor]) – (N,8) [x, y, z, batchid, value, id, groupid, pdg]
- Returns
loss, accuracy, clustering metrics
- Return type
double