Source code for mlreco.models.grappa

import random
import torch
import numpy as np

from mlreco.models.layers.common.dbscan import DBSCANFragmenter
from mlreco.models.layers.common.momentum import DeepVertexNet, EvidentialMomentumNet, MomentumNet, VertexNet
from mlreco.models.experimental.transformers.transformer import TransformerEncoderLayer
from mlreco.models.layers.gnn import gnn_model_construct, node_encoder_construct, edge_encoder_construct, node_loss_construct, edge_loss_construct

from mlreco.utils.gnn.data import merge_batch, split_clusts, split_edge_index
from mlreco.utils.gnn.cluster import form_clusters, get_cluster_batch, get_cluster_label, get_cluster_primary_label, get_cluster_points_label, get_cluster_directions, get_cluster_dedxs
from mlreco.utils.gnn.network import complete_graph, delaunay_graph, mst_graph, bipartite_graph, inter_cluster_distance, knn_graph, restrict_graph

[docs]class GNN(torch.nn.Module): """ Driver class for cluster node+edge prediction, assumed to be a GNN model. This class mostly acts as a wrapper that will hand the graph data to another model. If DBSCAN is used, use the semantic label tensor as an input. Typical configuration can look like this: .. code-block:: yaml model: name: grappa modules: grappa: your config goes here Configuration ------------- base: dict Configuration of base Grappa : .. code-block:: yaml base: source_col : <column in the input data that specifies the source node ids of each voxel (default 5)> target_col : <column in the input data that specifies the target instance ids of each voxel (default 6)> node_type : <semantic class to aggregate (all classes if -1, default -1)> node_min_size : <minimum number of voxels inside a cluster to be included in the aggregation (default -1, i.e. no threshold)> add_points : <add label point(s) to the node features: False (none) or True (both) (default False)> add_local_dirs : <add reconstructed local direction(s) to the node features: False (none), True (both) or 'start' (default False)> dir_max_dist : <maximium distance between start point and cluster voxels to be used to estimate direction: support value or 'optimize' (default 5 voxels)> add_local_dedxs : <add reconstructed local dedx(s) to the node features: False (none), True (both) or 'start' (default False)> dedx_max_dist : <maximium distance between start point and cluster voxels to be used to estimate dedx (default 5 voxels)> network : <type of network: 'complete', 'delaunay', 'mst', 'knn' or 'bipartite' (default 'complete')> edge_max_dist : <maximal edge Euclidean length (default -1)> edge_dist_method: <edge length evaluation method: 'centroid' or 'voxel' (default 'voxel')> merge_batch : <flag for whether to merge batches (default False)> merge_batch_mode: <mode of batch merging, 'const' or 'fluc'; 'const' use a fixed size of batch for merging, 'fluc' takes the input size a mean and sample based on it (default 'const')> merge_batch_size: <size of batch merging (default 2)> shuffle_clusters: <randomize cluster order (default False)> dbscan: dict dictionary of dbscan parameters node_encoder: dict .. code-block:: yaml node_encoder: name: <name of the node encoder> <dictionary of arguments to pass to the encoder> model_path : <path to the encoder weights> edge_encoder: dict .. code-block:: yaml edge_encoder: name: <name of the edge encoder> <dictionary of arguments to pass to the encoder> model_path : <path to the encoder weights> gnn_model: dict .. code-block:: yaml gnn_model: name: <name of the node model> <dictionary of arguments to pass to the model> model_path : <path to the model weights> kinematics_mlp: bool, default False Whether to enable MLP-like layers after the GNN to predict momentum, particle type, etc. kinematics_type: bool Whether to add PID MLP to each node. kinematics_momentum: bool Whether to add momentum MLP to each node. type_net: dict Configuration for the PID MLP (if enabled). Can partial load weights here too. momentum_net: dict Configuration for the Momentum MLP (if enabled). Can partial load weights here too. vertex_mlp: bool, default False Whether to add vertex prediction MLP to each node. Includes primary particle + vertex coordinates predictions. vertex_net: dict Configuration for the Vertex MLP (if enabled). Can partial load weights here too. Outputs ------- input_node_features: input_edge_features: clusts: edge_index: node_pred: edge_pred: node_pred_p: node_pred_type: node_pred_vtx: See Also -------- GNNLoss """ MODULES = [('grappa', ['base', 'dbscan', 'node_encoder', 'edge_encoder', 'gnn_model']), 'grappa_loss']
[docs] def __init__(self, cfg, name='grappa', batch_col=0, coords_col=(1, 4)): super(GNN, self).__init__() # Get the chain input parameters base_config = cfg[name].get('base', {}) self.name = name self.batch_index = batch_col self.coords_index = coords_col # Choose what type of node to use self.source_col = base_config.get('source_col', 5) self.target_col = base_config.get('target_col', 6) self.node_type = base_config.get('node_type', -1) self.node_min_size = base_config.get('node_min_size', -1) self.add_points = base_config.get('add_points', False) self.add_local_dirs = base_config.get('add_local_dirs', False) self.dir_max_dist = base_config.get('dir_max_dist', 5) self.opt_dir_max_dist = self.dir_max_dist == 'optimize' self.add_local_dedxs = base_config.get('add_local_dedxs', False) self.dedx_max_dist = base_config.get('dedx_max_dist', 5) self.break_clusters = base_config.get('break_clusters', False) self.shuffle_clusters = base_config.get('shuffle_clusters', False) # *Deprecated* but kept for backward compatibility: if 'add_start_point' in base_config: self.add_points = base_config['add_start_point'] if 'add_start_dir' in base_config: self.add_local_dirs = 'start' if base_config['add_start_dir'] else False if 'add_start_dedx' in base_config: self.add_local_dedxs = 'start' if base_config['add_start_dedx'] else False if 'start_dir_max_dist' in base_config: self.dir_max_dist = self.dedx_max_dist = base_config['start_dir_max_dist'] if 'start_dir_opt' in base_config: self.opt_dir_max_dist = base_config['start_dir_opt'] # Interpret node type as list of classes to cluster, -1 means all classes if isinstance(self.node_type, int): self.node_type = [self.node_type] # Choose what type of network to use self.network = base_config.get('network', 'complete') self.edge_max_dist = base_config.get('edge_max_dist', -1) self.edge_dist_metric = base_config.get('edge_dist_metric', 'voxel') self.edge_knn_k = base_config.get('edge_knn_k', 5) self.edge_max_count = base_config.get('edge_max_count', 2e6) # Turn the edge_max_dist value into a matrix if not isinstance(self.edge_max_dist, list): self.edge_max_dist = [self.edge_max_dist] mat_size = int((np.sqrt(8*len(self.edge_max_dist)+1)-1)/2) max_dist_mat = np.zeros((mat_size, mat_size), dtype=float) max_dist_mat[np.triu_indices(mat_size)] = self.edge_max_dist max_dist_mat += max_dist_mat.T - np.diag(np.diag(max_dist_mat)) self.edge_max_dist = max_dist_mat # If requested, merge images together within the batch self.merge_batch = base_config.get('merge_batch', False) self.merge_batch_mode = base_config.get('merge_batch_mode', 'const') self.merge_batch_size = base_config.get('merge_batch_size', 2) if self.merge_batch_mode not in ['const', 'fluc']: raise ValueError('Batch merging mode not supported, must be one of const or fluc') self.merge_batch_fluc = self.merge_batch_mode == 'fluc' # If requested, use DBSCAN to form clusters from semantics if 'dbscan' in cfg[name]: cfg[name]['dbscan']['cluster_classes'] = self.node_type if self.node_type[0] > -1 else [0,1,2,3] cfg[name]['dbscan']['min_size'] = self.node_min_size self.dbscan = DBSCANFragmenter(cfg[name], name='dbscan', batch_col=self.batch_index, coords_col=self.coords_index) # If requested, initialize two MLPs for kinematics predictions self.kinematics_mlp = base_config.get('kinematics_mlp', False) self.kinematics_type = base_config.get('kinematics_type', False) self.kinematics_momentum = base_config.get('kinematics_momentum', False) if self.kinematics_mlp: node_output_feats = cfg[name]['gnn_model'].get('node_output_feats', 64) self.kinematics_type = base_config.get('kinematics_type', False) self.kinematics_momentum = base_config.get('kinematics_momentum', False) if self.kinematics_type: type_config = cfg[name].get('type_net', {}) type_net_mode = type_config.get('mode', 'standard') if type_net_mode == 'linear': self.type_net = torch.nn.Linear(node_output_feats, 5) elif type_net_mode == 'standard': self.type_net = MomentumNet(node_output_feats, num_output=5, num_hidden=type_config.get('num_hidden', 128), positive_outputs=type_config.get('positive_outputs', False)) elif type_net_mode == 'edl': self.type_net = MomentumNet(node_output_feats, num_output=5, num_hidden=type_config.get('num_hidden', 128), positive_outputs=type_config.get('positive_outputs', True)) else: raise ValueError('Unrecognized Particle ID Type Net Mode: ', type_net_mode) if self.kinematics_momentum: momentum_config = cfg[name].get('momentum_net', {}) softplus_and_shift = momentum_config.get('eps', 0.0) logspace = momentum_config.get('logspace', False) if momentum_config.get('mode', 'standard') == 'edl': self.momentum_net = EvidentialMomentumNet(node_output_feats, num_output=4, num_hidden=momentum_config.get('num_hidden', 128), eps=softplus_and_shift, logspace=logspace) else: self.momentum_net = MomentumNet(node_output_feats, num_output=1, num_hidden=momentum_config.get('num_hidden', 128)) self.vertex_mlp = base_config.get('vertex_mlp', False) if self.vertex_mlp: node_feats = cfg[name]['gnn_model'].get('node_feats') node_output_feats = cfg[name]['gnn_model'].get('node_output_feats') vertex_config = cfg[name].get('vertex_net', {'name': 'momentum_net'}) self.pred_vtx_positions = vertex_config.get('pred_vtx_positions', True) self.use_vtx_input_features = vertex_config.get('use_vtx_input_features', False) self.add_vtx_input_features = vertex_config.get('add_vtx_input_features', False) num_input = node_output_feats + node_feats * self.add_vtx_input_features num_output = 2 + 3 * self.pred_vtx_positions vertex_net_name = vertex_config.get('name', 'momentum_net') if vertex_net_name == 'linear': self.vertex_net = torch.nn.Linear(num_input, num_output) elif vertex_net_name == 'momentum_net': self.vertex_net = VertexNet(num_input, num_output, num_hidden=vertex_config.get('num_hidden', 64), positive_outputs=vertex_config.get('positive_outputs',False)) elif vertex_net_name == 'attention_net': self.vertex_net = TransformerEncoderLayer(num_input, num_output, **vertex_config) elif vertex_net_name == 'deep_vertex_net': self.vertex_net = DeepVertexNet(num_input, num_output, num_hidden=vertex_config.get('num_hidden', 64), num_layers=vertex_config.get('num_layers', 5), positive_outputs=vertex_config.get('positive_outputs',False)) else: raise ValueError('Vertex MLP {} not recognized!'.format(vertex_config['name'])) # Initialize encoders self.node_encoder = node_encoder_construct(cfg[name], batch_col=self.batch_index, coords_col=self.coords_index) self.edge_encoder = edge_encoder_construct(cfg[name], batch_col=self.batch_index, coords_col=self.coords_index) # Construct the GNN self.gnn_model = gnn_model_construct(cfg[name])
[docs] def forward(self, data, clusts=None, groups=None, points=None, extra_feats=None, batch_size=None): """ Prepares particle clusters and feed them to the GNN model. Args: array: data[0] ([torch.tensor]): (N,5-10) [x, y, z, batch_id(, value), part_id(, group_id, int_id, nu_id, sem_type)] or (N,5) [x, y, z, batch_id, sem_type] (with DBSCAN) data[1] ([torch.tensor]): (N,8) [first_x, first_y, first_z, batch_id, last_x, last_y, last_z, first_step_t] (optional) clusts: [(N_0), (N_1), ..., (N_C)] Cluster ids (optional) groups: (C) vectors of groups IDs (one per cluster) to enforce connections only within each group points: (N,3/6) tensor of start (and end) points of clusters extra_feats: (N,F) tensor of features to add to the encoded features Returns: dict: 'node_pred' (torch.tensor): (N,2) Two-channel node predictions (split batch-wise) 'edge_pred' (torch.tensor): (E,2) Two-channel edge predictions (split batch-wise) 'clusts' ([np.ndarray]) : [(N_0), (N_1), ..., (N_C)] Cluster ids (split batch-wise) 'edge_index' (np.ndarray) : (E,2) Incidence matrix (split batch-wise) """ cluster_data = data[0] if len(data) > 1: particles = data[1] result = {} # Form list of list of voxel indices, one list per cluster in the requested class if clusts is None: if hasattr(self, 'dbscan'): clusts = self.dbscan(cluster_data, points=particles if len(data) > 1 else None) else: clusts = form_clusters(cluster_data.detach().cpu().numpy(), self.node_min_size, self.source_col, cluster_classes=self.node_type) if self.break_clusters: from sklearn.cluster import DBSCAN dbscan = DBSCAN(eps=1.1, min_samples=1, metric='chebyshev') broken_clusts = [] for c in clusts: labels = dbscan.fit(cluster_data[c, self.coords_index[0]:self.coords_index[1]].detach().cpu().numpy()).labels_ for l in np.unique(labels): broken_clusts.append(c[labels==l]) clusts = broken_clusts # If requested, shuffle the order in which the clusters are listed (used for debugging) if self.shuffle_clusters: random.shuffle(clusts) # If requested, merge images together within the batch if self.merge_batch: cluster_data, particles, batch_list = merge_batch(cluster_data, particles, self.merge_batch_size, self.merge_batch_fluc, self.batch_index) batch_counts = np.unique(batch_list, return_counts=True)[1] result['batch_counts'] = [batch_counts] # Update result with a list of clusters for each batch id batches, bcounts = np.unique(cluster_data[:,self.batch_index].detach().cpu().numpy(), return_counts=True) if not len(clusts): return {**result, 'clusts': [[np.array([]) for _ in batches]]} # If an event is missing from the input data - e.g., deghosting # erased everything (extreme case but possible if very few voxels) # then we might be miscounting batches. Ensure that batches is the # same length as batch_size if specified. if batch_size is not None: new_bcounts = np.zeros(batch_size, dtype=np.int64) new_bcounts[batches.astype(np.int64)] = bcounts bcounts = new_bcounts batches = np.arange(batch_size) batch_ids = get_cluster_batch(cluster_data, clusts, batch_index=self.batch_index) clusts_split, cbids = split_clusts(clusts, batch_ids, batches, bcounts) result['clusts'] = [clusts_split] if self.edge_max_count > -1: _, cnts = np.unique(batch_ids, return_counts=True) if np.sum([c*(c-1) for c in cnts]) > 2*self.edge_max_count: return result # If necessary, compute the cluster distance matrix dist_mat = None if np.any(self.edge_max_dist > -1) or self.network == 'mst' or self.network == 'knn': dist_mat = inter_cluster_distance(cluster_data[:,self.coords_index[0]:self.coords_index[1]].float(), clusts, batch_ids, self.edge_dist_metric) # Form the requested network if len(clusts) == 1: edge_index = np.empty((2,0), dtype=np.int64) elif self.network == 'complete': edge_index = complete_graph(batch_ids) elif self.network == 'delaunay': import numba as nb edge_index = delaunay_graph(cluster_data.cpu().numpy(), nb.typed.List(clusts), batch_ids, self.batch_index, self.coords_index) elif self.network == 'mst': edge_index = mst_graph(batch_ids, dist_mat) elif self.network == 'knn': edge_index = knn_graph(batch_ids, self.edge_knn_k, dist_mat) elif self.network == 'bipartite': clust_ids = get_cluster_label(cluster_data, clusts, self.source_col) group_ids = get_cluster_label(cluster_data, clusts, self.target_col) edge_index = bipartite_graph(batch_ids, clust_ids==group_ids, dist_mat) else: raise ValueError('Network type not recognized: '+self.network) # If groups is sepecified, only keep edges that belong to the same group (cluster graph) if groups is not None: mask = groups[edge_index[0]] == groups[edge_index[1]] edge_index = edge_index[:,mask] # Restrict the input graph based on edge distance, if requested if np.any(self.edge_max_dist > -1): if self.edge_max_dist.shape[0] == 1: edge_index = restrict_graph(edge_index, dist_mat, self.edge_max_dist) else: # Here get_cluster_primary_label is used to ensure that Michel/Delta showers are given the appropriate semantic label if self.source_col == 5: classes = extra_feats[:,-1].cpu().numpy().astype(int) if extra_feats is not None else get_cluster_label(cluster_data, clusts, -1).astype(int) if self.source_col == 6: classes = extra_feats[:,-1].cpu().numpy().astype(int) if extra_feats is not None else get_cluster_primary_label(cluster_data, clusts, -1).astype(int) edge_index = restrict_graph(edge_index, dist_mat, self.edge_max_dist, classes) # Update result with a list of edges for each batch id edge_index_split, ebids = split_edge_index(edge_index, batch_ids, batches) result['edge_index'] = [edge_index_split] if edge_index.shape[1] > self.edge_max_count: return result # Obtain node and edge features x = self.node_encoder(cluster_data, clusts) e = self.edge_encoder(cluster_data, clusts, edge_index) # If extra features are provided separately, add them if extra_feats is not None: x = torch.cat([x, extra_feats.float()], dim=1) # Add end points and/or local directions to node features, if requested if self.add_points or points is not None: if points is None: points = get_cluster_points_label(cluster_data, particles, clusts, coords_index=self.coords_index) x = torch.cat([x, points.float()], dim=1) if self.add_local_dirs: dirs_start = get_cluster_directions(cluster_data[:, self.coords_index[0]:self.coords_index[1]], points[:,:3], clusts, self.dir_max_dist, self.opt_dir_max_dist) if self.add_local_dirs != 'start': dirs_end = get_cluster_directions(cluster_data[:, self.coords_index[0]:self.coords_index[1]], points[:,3:6], clusts, self.dir_max_dist, self.opt_dir_max_dist) x = torch.cat([x, dirs_start.float(), dirs_end.float()], dim=1) else: x = torch.cat([x, dirs_start.float()], dim=1) if self.add_local_dedxs: dedxs_start = get_cluster_dedxs(cluster_data[:, self.coords_index[0]:self.coords_index[1]], cluster_data[:,4], points[:,:3], clusts, self.dedx_max_dist) if self.add_local_dedxs != 'start': dedxs_end = get_cluster_dedxs(cluster_data[:, self.coords_index[0]:self.coords_index[1]], cluster_data[:,4], points[:,3:6], clusts, self.dedx_max_dist) x = torch.cat([x, dedxs_start.reshape(-1,1).float(), dedxs_end.reshape(-1,1).float()], dim=1) else: x = torch.cat([x, dedxs_start.reshape(-1,1).float()], dim=1) # Bring edge_index and batch_ids to device index = torch.tensor(edge_index, device=cluster_data.device, dtype=torch.long) xbatch = torch.tensor(batch_ids, device=cluster_data.device) result['input_node_features'] = [[x[b] for b in cbids]] result['input_edge_features'] = [[e[b] for b in ebids]] # Pass through the model, update results out = self.gnn_model(x, index, e, xbatch) result['node_pred'] = [[out['node_pred'][0][b] for b in cbids]] result['edge_pred'] = [[out['edge_pred'][0][b] for b in ebids]] # If requested, pass the node features through two MLPs for kinematics predictions if self.kinematics_mlp: if self.kinematics_type: node_pred_type = self.type_net(out['node_features'][0]) result['node_pred_type'] = [[node_pred_type[b] for b in cbids]] if self.kinematics_momentum: node_pred_p = self.momentum_net(out['node_features'][0]) if isinstance(self.momentum_net, EvidentialMomentumNet): result['node_pred_p'] = [[node_pred_p[b] for b in cbids]] aleatoric = node_pred_p[:, 3] / (node_pred_p[:, 2] - 1.0 + 0.001) epistemic = node_pred_p[:, 3] / (node_pred_p[:, 1] * (node_pred_p[:, 2] - 1.0 + 0.001)) result['node_pred_p_aleatoric'] = [[aleatoric[b] for b in cbids]] result['node_pred_p_epistemic'] = [[epistemic[b] for b in cbids]] else: result['node_pred_p'] = [[node_pred_p[b] for b in cbids]] else: # If final post-gnn MLP is not given, set type features to node_pred. result['node_pred_type'] = result['node_pred'] if self.vertex_mlp: if self.use_vtx_input_features: node_pred_vtx = self.vertex_net(x) elif self.add_vtx_input_features: node_pred_vtx = self.vertex_net(torch.cat([x, out['node_features'][0]], dim=1)) else: node_pred_vtx = self.vertex_net(out['node_features'][0]) result['node_pred_vtx'] = [[node_pred_vtx[b] for b in cbids]] return result
[docs]class GNNLoss(torch.nn.modules.loss._Loss): """ Takes the output of the GNN and computes the total loss. For use in config: .. code-block:: yaml model: name: grappa modules: grappa_loss: node_loss: name: <name of the node loss> <dictionary of arguments to pass to the loss> edge_loss: name: <name of the edge loss> <dictionary of arguments to pass to the loss> """
[docs] def __init__(self, cfg, name='grappa_loss', batch_col=0, coords_col=(1, 4)): super(GNNLoss, self).__init__() self.batch_index = batch_col self.coords_index = coords_col # Initialize the node and edge losses, if requested self.apply_node_loss, self.apply_edge_loss = False, False if 'node_loss' in cfg[name]: self.apply_node_loss = True self.node_loss = node_loss_construct(cfg[name], batch_col=batch_col, coords_col=coords_col) if 'edge_loss' in cfg[name]: self.apply_edge_loss = True self.edge_loss = edge_loss_construct(cfg[name], batch_col=batch_col, coords_col=coords_col)
[docs] def forward(self, result, clust_label, graph=None, node_label=None, iteration=None): # Apply edge and node losses, if instantiated loss = {} if self.apply_node_loss: if node_label is None: node_label = clust_label if iteration is not None: node_loss = self.node_loss(result, node_label, iteration=iteration) else: node_loss = self.node_loss(result, node_label) loss.update(node_loss) loss['node_loss'] = node_loss['loss'] loss['node_accuracy'] = node_loss['accuracy'] if self.apply_edge_loss: edge_loss = self.edge_loss(result, clust_label, graph) loss.update(edge_loss) loss['edge_loss'] = edge_loss['loss'] loss['edge_accuracy'] = edge_loss['accuracy'] if self.apply_node_loss and self.apply_edge_loss: loss['loss'] = loss['node_loss'] + loss['edge_loss'] loss['accuracy'] = (loss['node_accuracy'] + loss['edge_accuracy'])/2 return loss