Source code for mlreco.models.layers.gnn.encoders.geometric

# Geometric feature extractor for Cluster GNN
import torch
import numpy as np
from torch_scatter import scatter_min

from mlreco.utils import local_cdist
from mlreco.utils.gnn.data import cluster_features, cluster_edge_features

[docs]class ClustGeoNodeEncoder(torch.nn.Module): """ Produces geometric cluster node features. The first 19 features are composed of: - Center (3) - Covariance matrix (9) - Principal axis (3) - Voxel count (1) - Mean energy (1) - RMS energy (1) - Semantic type (1), i.e. most represented type in cluster 6 features for the end points (duplicated for shower, randomly ordered for tracks) 3 features for direction estimate (mean direction w.r.t. start point) Total of 28 hand-engineered features """
[docs] def __init__(self, model_config, batch_col=0, coords_col=(1, 4)): super(ClustGeoNodeEncoder, self).__init__() # Initialize the encoder parameters self.use_numpy = model_config.get('use_numpy', True) self.more_feats = model_config.get('more_feats', False) self.batch_col = batch_col self.coords_col = coords_col
[docs] def forward(self, data, clusts): # If numpy is to be used, bring data to CPU, pass through Numba function if self.use_numpy: return cluster_features(data, clusts, extra=self.more_feats, batch_col=self.batch_col, coords_col=self.coords_col) # Get the voxel set voxels = data[:, self.coords_col[0]:self.coords_col[1]].float() # Get the value & semantic types values = data[:, 4].float() sem_types = data[:, -1].float() # Below is a torch-based implementation of cluster_features feats = [] for c in clusts: # Get list of voxels in the cluster x = voxels[c] size = torch.tensor([len(c)], dtype=voxels.dtype, device=voxels.device) # Do not waste time with computations with size 1 clusters, default to zeros if len(c) < 2: if not self.more_feats: feats.append(torch.cat((x.flatten(), torch.zeros(12, dtype=voxels.dtype, device=voxels.device), size))) else: extra_feats = torch.tensor([values[c[0]], 0., sem_types[c[0]]], dtype=voxels.dtype, device=voxels.device) feats.append(torch.cat((x.flatten(), torch.zeros(12, dtype=voxels.dtype, device=voxels.device), size, extra_feats))) continue # Center data center = x.mean(dim=0) x = x - center # Get orientation matrix A = x.t().mm(x) # Get eigenvectors, normalize orientation matrix and eigenvalues to largest # This step assumes points are not superimposed, i.e. that largest eigenvalue != 0 #w, v = torch.symeig(A, eigenvectors=True) w, v = torch.linalg.eigh(A, UPLO='U') dirwt = 1.0 - w[1] / w[2] B = A / w[2] # Get the principal direction, identify the direction of the spread v0 = v[:,2] # Projection all points, x, along the principal axis x0 = x.mv(v0) # Evaluate the distance from the points to the principal axis xp0 = x - torch.ger(x0, v0) np0 = torch.norm(xp0, dim=1) # Flip the principal direction if it is not pointing towards the maximum spread sc = torch.dot(x0, np0) if sc < 0: v0 = -v0 # Weight direction v0 = dirwt * v0 # Append (center, B.flatten(), v0, size) if not self.more_feats: feats.append(torch.cat((center, B.flatten(), v0, size))) else: extra_feats = torch.tensor([values[c].mean(), values[c].std(), sem_types[c].mode()[0]], dtype=voxels.dtype, device=voxels.device) feats.append(torch.cat((center, B.flatten(), v0, size, extra_feats))) return torch.stack(feats, dim=0)
[docs]class ClustGeoEdgeEncoder(torch.nn.Module): """ Produces geometric cluster edge features. """
[docs] def __init__(self, model_config, batch_col=0, coords_col=(1, 4)): super(ClustGeoEdgeEncoder, self).__init__() # Initialize the chain parameters self.use_numpy = model_config.get('use_numpy', True) self.batch_col = batch_col self.coords_col = coords_col
[docs] def forward(self, data, clusts, edge_index): # Check if the graph is undirected, select the relevant part of the edge index half_idx = int(edge_index.shape[1] / 2) undirected = not edge_index.shape[1] or (not edge_index.shape[1] % 2 and [edge_index[1, 0], edge_index[0, 0]] == edge_index[:, half_idx].tolist()) if undirected: edge_index = edge_index[:, :half_idx] # If numpy is to be used, bring data to cpu, pass through Numba function # Otherwise use torch-based implementation of cluster_edge_features if self.use_numpy: feats = cluster_edge_features(data, clusts, edge_index.T, batch_col=self.batch_col, coords_col=self.coords_col) else: # Get the voxel set voxels = data[:, self.coords_col[0]:self.coords_col[1]].float() # Here is a torch-based implementation of cluster_edge_features feats = [] for e in edge_index.T: # Get the voxels in the clusters connected by the edge x1 = voxels[clusts[e[0]]] x2 = voxels[clusts[e[1]]] # Find the closest set point in each cluster d12 = local_cdist(x1,x2) imin = torch.argmin(d12) i1, i2 = imin//len(x2), imin%len(x2) v1 = x1[i1,:] # closest point in c1 v2 = x2[i2,:] # closest point in c2 # Displacement disp = v1 - v2 # Distance lend = torch.norm(disp) if lend > 0: disp = disp / lend # Outer product B = torch.ger(disp, disp).flatten() feats.append(torch.cat([v1, v2, disp, lend.reshape(1), B])) feats = torch.stack(feats, dim=0) # If the graph is undirected, infer reciprocal features if undirected: feats_flip = feats.clone() feats_flip[:,:3] = feats[:,3:6] feats_flip[:,3:6] = feats[:,:3] feats_flip[:,6:9] = -feats[:,6:9] feats = torch.cat([feats,feats_flip]) return feats