Source code for mlreco.models.layers.common.momentum

import torch
import torch.nn as nn

[docs]class MomentumNet(nn.Module): ''' Small MLP for extracting input edge features from two node features. USAGE: net = EdgeFeatureNet(16, 16) node_x = torch.randn(16, 5) node_y = torch.randn(16, 5) edge_feature_x2y = net(node_x, node_y) # (16, 5) '''
[docs] def __init__(self, num_input, num_output=1, num_hidden=128, positive_outputs=False): super(MomentumNet, self).__init__() self.linear1 = nn.Linear(num_input, num_hidden) self.norm1 = nn.BatchNorm1d(num_input) self.linear2 = nn.Linear(num_hidden, num_hidden) self.norm2 = nn.BatchNorm1d(num_hidden) self.linear3 = nn.Linear(num_hidden, num_output) self.lrelu = nn.LeakyReLU(negative_slope=0.33) if positive_outputs: self.final = nn.Softplus() else: self.final = nn.Identity()
[docs] def forward(self, x): if x.shape[0] > 1: x = self.norm1(x) x = self.linear1(x) x = self.lrelu(x) if x.shape[0] > 1: x = self.norm2(x) x = self.linear2(x) x = self.lrelu(x) x = self.linear3(x) out = self.final(x) return out
[docs]class VertexNet(MomentumNet): ''' Small MLP for handling vertex regression and particle primary prediction. '''
[docs] def __init__(self, num_input, num_output=1, num_hidden=128, positive_outputs=False, batch_norm=False): super(VertexNet, self).__init__(num_input, num_output, num_hidden, positive_outputs) self.num_output = num_output self.batch_norm = batch_norm
[docs] def forward(self, x): if self.batch_norm and x.shape[0] > 1: x = self.norm1(x) x = self.linear1(x) x = self.lrelu(x) if self.batch_norm and x.shape[0] > 1: x = self.norm2(x) x = self.linear2(x) x = self.lrelu(x) x = self.linear3(x) if self.num_output == 5: vtx_pred = self.final(x[:, :3]) out = torch.cat([vtx_pred, x[:, 3:]], dim=1) return out else: return x
[docs]class DeepVertexNet(nn.Module): ''' Small MLP for extracting input edge features from two node features. USAGE: net = EdgeFeatureNet(16, 16) node_x = torch.randn(16, 5) node_y = torch.randn(16, 5) edge_feature_x2y = net(node_x, node_y) # (16, 5) '''
[docs] def __init__(self, num_input, num_output=1, num_hidden=512, num_layers=5, positive_outputs=False): super(DeepVertexNet, self).__init__() self.num_output = num_output self.linear = nn.ModuleList() self.norm = nn.ModuleList() self.num_layers = num_layers for i in range(num_layers): self.norm.append(nn.BatchNorm1d(num_input)) self.linear.append(nn.Linear(num_input, num_hidden)) num_input = num_hidden self.final = nn.Linear(num_hidden, num_output) self.lrelu = nn.LeakyReLU(negative_slope=0.33) if positive_outputs: self.final = nn.Softplus() else: self.final = nn.Identity()
[docs] def forward(self, x): for i in range(self.num_layers): x = self.norm[i](x) x = self.lrelu(x) x = self.linear[i](x) if self.num_output == 5: vtx_pred = self.final(x[:, :3]) out = torch.cat([vtx_pred, x[:, 3:]], dim=1) return out else: return x
[docs]class EvidentialMomentumNet(nn.Module):
[docs] def __init__(self, num_input, num_output=4, num_hidden=128, eps=0.0, logspace=False): super(EvidentialMomentumNet, self).__init__() self.linear1 = nn.Linear(num_input, num_hidden) self.norm1 = nn.BatchNorm1d(num_input) self.linear2 = nn.Linear(num_hidden, num_hidden) self.norm2 = nn.BatchNorm1d(num_hidden) self.linear3 = nn.Linear(num_hidden, num_output) self.elu = nn.LeakyReLU(negative_slope=0.33) self.softplus = nn.Softplus() self.logspace = logspace if logspace: self.gamma = nn.Identity() else: self.gamma = nn.Sigmoid() self.eps = eps
[docs] def forward(self, x): if x.shape[0] > 1: self.norm1(x) x = self.linear1(x) x = self.elu(x) if x.shape[0] > 1: x = self.norm2(x) x = self.linear2(x) x = self.elu(x) x = self.linear3(x) vab = self.softplus(x[:, :3]) + self.eps alpha = torch.clamp(vab[:, 1] + 1.0, min=1.0).view(-1, 1) gamma = 2.0 * self.gamma(x[:, 3]).view(-1, 1) out = torch.cat([gamma, vab[:, 0].view(-1, 1), alpha, vab[:, 2].view(-1, 1)], dim=1) if not self.logspace: evidence = torch.clamp(out, min=self.eps) else: evidence = out return evidence