Source code for mlreco.models.layers.gnn.normalizations

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d, LeakyReLU
import torch.nn.functional as F

from torch.nn import BatchNorm1d
from torch.nn.modules.instancenorm import _InstanceNorm


[docs]class BatchNorm(BatchNorm1d): r""" © Copyright 2020, Matthias Fey Revision 18da46c2. Applies batch normalization over a batch of node features as described in the `"Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" <https://arxiv.org/abs/1502.03167>`_ paper .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`True`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default: :obj:`True`) """
[docs] def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(BatchNorm, self).__init__(in_channels, eps, momentum, affine, track_running_stats)
[docs] def forward(self, x): """""" return super(BatchNorm, self).forward(x)
[docs] def __repr__(self): return ('{}({}, eps={}, momentum={}, affine={}, ' 'track_running_stats={})').format(self.__class__.__name__, self.num_features, self.eps, self.momentum, self.affine, self.track_running_stats)
[docs]class InstanceNorm(_InstanceNorm): r""" © Copyright 2020, Matthias Fey Revision 18da46c2. Applies instance normalization over each individual example in a batch of node features as described in the `"Instance Normalization: The Missing Ingredient for Fast Stylization" <https://arxiv.org/abs/1607.08022>`_ paper .. math:: \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} \odot \gamma + \beta Args: in_channels (int): Size of each input sample. eps (float, optional): A value added to the denominator for numerical stability. (default: :obj:`1e-5`) momentum (float, optional): The value used for the running mean and running variance computation. (default: :obj:`0.1`) affine (bool, optional): If set to :obj:`True`, this module has learnable affine parameters :math:`\gamma` and :math:`\beta`. (default: :obj:`False`) track_running_stats (bool, optional): If set to :obj:`True`, this module tracks the running mean and variance, and when set to :obj:`False`, this module does not track such statistics and always uses instance statistics in both training and eval modes. (default: :obj:`False`) """
[docs] def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False): super(InstanceNorm, self).__init__(in_channels, eps, momentum, affine, track_running_stats)
[docs] def forward(self, x, batch=None): """""" from torch_scatter import scatter_add from torch_geometric.utils import degree if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) batch_size = batch.max().item() + 1 if self.training or not self.track_running_stats: count = degree(batch, batch_size, dtype=x.dtype).view(-1, 1) tmp = scatter_add(x, batch, dim=0, dim_size=batch_size) mean = tmp / count.clamp(min=1) tmp = (x - mean[batch]) tmp = scatter_add(tmp * tmp, batch, dim=0, dim_size=batch_size) var = tmp / count.clamp(min=1) unbiased_var = tmp / (count - 1).clamp(min=1) if self.training and self.track_running_stats: momentum = self.momentum self.running_mean = ( 1 - momentum) * self.running_mean + momentum * mean.mean(dim=0) self.running_var = ( 1 - momentum ) * self.running_var + momentum * unbiased_var.mean(dim=0) if not self.training and self.track_running_stats: mean = self.running_mean.view(1, -1).expand(batch_size, -1) var = self.running_var.view(1, -1).expand(batch_size, -1) out = (x - mean[batch]) / torch.sqrt(var[batch] + self.eps) if self.affine: out = out * self.weight.view(1, -1) + self.bias.view(1, -1) return out