mlreco.models.layers.gnn.losses.edge_channel module¶
-
class
mlreco.models.layers.gnn.losses.edge_channel.EdgeChannelLoss(loss_config, batch_col=0, coords_col=(1, 4))[source]¶ Bases:
torch.nn.modules.module.ModuleTakes the two-channel edge output of the GNN and optimizes edge-wise scores such that edges that connect nodes that belong to common instance are given a high score.
For use in config: model:
name: cluster_gnn modules:
- grappa_loss:
- edge_loss:
name: : channel source_col : <column in the label data that specifies the source node ids of each voxel (default 5)> target_col : <column in the label data that specifies the target group ids of each voxel (default 6)> batch_col : <column in the label data that specifies the batch ids of each voxel (default 3)> loss : <loss function: ‘CE’ or ‘MM’ (default ‘CE’)> reduction : <loss reduction method: ‘mean’ or ‘sum’ (default ‘sum’)> balance_classes : <balance loss per class: True or False (default False)> target : <type of target adjacency matrix: ‘group’, ‘forest’, ‘particle_forest’ (default ‘group’)> high_purity : <only penalize loss on groups with a primary (default False)>
-
__init__(loss_config, batch_col=0, coords_col=(1, 4))[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
__module__= 'mlreco.models.layers.gnn.losses.edge_channel'¶
-
training: bool¶
-
forward(out, clusters, graph=None)[source]¶ Applies the requested loss on the edge prediction.
- Parameters
out (dict) – ‘edge_pred’ (torch.tensor): (E,2) Two-channel edge predictions ‘clusts’ ([np.ndarray]) : [(N_0), (N_1), …, (N_C)] Cluster ids ‘edge_index’ (np.ndarray) : (E,2) Incidence matrix
clusters ([torch.tensor]) – (N,8) [x, y, z, batchid, value, id, groupid, shape]
(graph ([torch.tensor]) – (N,3) True edges, optional)
- Returns
loss, accuracy, edge count
- Return type
double