import torch
import torch.nn as nn
import numpy as np
# For MinkowskiEngine
import MinkowskiEngine as ME
from .activation_normalization_factories import *
from typing import Union
# Custom Network Units/Blocks
[docs]class Identity(nn.Module):
[docs] def __init__(self):
super(Identity, self).__init__()
[docs] def forward(self, input):
return input
[docs]def dense_coordinates(shape: Union[list, torch.Size]):
"""
coordinates = dense_coordinates(tensor.shape)
"""
r"""
Assume the input to have BxCxD1xD2x....xDN format.
If the shape of the tensor do not change, use
"""
spatial_dim = len(shape) - 2
assert (
spatial_dim > 0
), "Invalid shape. Shape must be batch x channel x spatial dimensions."
# Generate coordinates
size = [i for i in shape]
B = size[0]
coordinates = torch.from_numpy(
np.stack(
[
s.reshape(-1)
for s in np.meshgrid(
np.linspace(0, B - 1, B),
*(np.linspace(0, s - 1, s) for s in size[2:]),
indexing="ij"
)
],
1,
)
).int()
return coordinates
[docs]def to_sparse(dense_tensor: torch.Tensor,
resolution: int,
coordinates: torch.Tensor = None,
coords_key = None,
coords_man = None):
r"""Converts a (differentiable) dense tensor to a sparse tensor.
Assume the input to have BxCxD1xD2x....xDN format.
If the shape of the tensor do not change,
use `dense_coordinates` to cache the coordinates.
Please refer to tests/python/dense.py for usage
Example::
>>> dense_tensor = torch.rand(3, 4, 5, 6, 7, 8) # BxCxD1xD2xD3xD4
>>> dense_tensor.requires_grad = True
>>> stensor = to_sparse(dense_tensor)
"""
spatial_dim = dense_tensor.ndim - 2
assert (
spatial_dim > 0
), "Invalid shape. Shape must be batch x channel x spatial dimensions."
if coordinates is None:
coordinates = dense_coordinates(dense_tensor.shape)
coordinates[:, 1:] *= resolution
feat_tensor = dense_tensor.permute(
0, *(2 + i for i in range(spatial_dim)), 1)
return ME.SparseTensor(
features=feat_tensor.reshape(-1, dense_tensor.size(1)),
coords=coordinates,
force_creation=True,
# coords_key=coords_key,
coords_manager=coords_man,
tensor_stride=resolution
)
[docs]class SparseToDense(nn.Module):
[docs] def __init__(self):
super(SparseToDense, self).__init__()
[docs] def forward(self, x: ME.SparseTensor):
x_dense, _, _ = x.dense()
return x_dense
[docs]class DenseResBlock(nn.Module):
[docs] def __init__(self,
in_channels,
out_channels):
super(DenseResBlock, self).__init__()
self.norm_1 = nn.BatchNorm3d(in_channels)
self.act_1 = nn.LeakyReLU(0.2)
self.conv_1 = nn.Conv3d(in_channels, out_channels, 3, 1, 1)
self.norm_2 = nn.BatchNorm3d(out_channels)
self.act_2 = nn.LeakyReLU(0.2)
self.conv_2 = nn.Conv3d(out_channels, out_channels, 3, 1, 1)
[docs] def forward(self, x):
y = self.conv_1(self.act_1(self.norm_1(x)))
y = self.conv_2(self.act_2(self.norm_2(y)))
return y
[docs]def normalize_coords(coords, spatial_size=512):
'''
Utility Method for attaching normalized coordinates to
sparse tensor features.
INPUTS:
- input (scn.SparseConvNetTensor): sparse tensor to
attach normalized coordinates with range (-1, 1)
RETURNS:
- output (scn.SparseConvNetTensor): sparse tensor with
normalized coordinate concatenated to first three dimensions.
'''
with torch.no_grad():
coords = coords.float()
normalized_coords = (coords[:, :3] - spatial_size / 2) \
/ (spatial_size / 2)
if torch.cuda.is_available():
normalized_coords = normalized_coords.cuda()
return normalized_coords
[docs]class ConvolutionBlock(ME.MinkowskiNetwork):
[docs] def __init__(self,
in_features,
out_features,
stride=1,
dilation=1,
dimension=3,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={},
has_bias=False):
super(ConvolutionBlock, self).__init__(dimension)
assert dimension > 0
self.act_fn1 = activations_construct(activation, **activation_args)
self.act_fn2 = activations_construct(activation, **activation_args)
self.conv1 = ME.MinkowskiConvolution(
in_features, out_features, kernel_size=3,
stride=1, dilation=dilation, dimension=dimension, bias=has_bias)
self.norm1 = normalizations_construct(
normalization, out_features, **normalization_args)
self.conv2 = ME.MinkowskiConvolution(
out_features, out_features, kernel_size=3,
stride=1, dilation=dilation, dimension=dimension, bias=has_bias)
self.norm2 = normalizations_construct(
normalization, out_features, **normalization_args)
[docs] def forward(self, x):
out = self.conv1(x)
out = self.norm1(out)
out = self.act_fn1(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.act_fn2(out)
return out
[docs]class DropoutBlock(ME.MinkowskiNetwork):
[docs] def __init__(self,
in_features,
out_features,
stride=1,
dilation=1,
dimension=3,
p=0.5,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={},
bias=False):
super(DropoutBlock, self).__init__(dimension)
assert dimension > 0
self.act_fn1 = activations_construct(activation, **activation_args)
self.act_fn2 = activations_construct(activation, **activation_args)
self.conv1 = ME.MinkowskiConvolution(
in_features, out_features, kernel_size=3,
stride=1, dilation=dilation, dimension=dimension, bias=bias)
self.dropout1 = ME.MinkowskiDropout(p=p)
self.norm1 = normalizations_construct(
normalization, out_features, **normalization_args)
self.conv2 = ME.MinkowskiConvolution(
out_features, out_features, kernel_size=3,
stride=1, dilation=dilation, dimension=dimension, bias=bias)
self.dropout2 = ME.MinkowskiDropout(p=p)
self.norm2 = normalizations_construct(
normalization, out_features, **normalization_args)
[docs] def forward(self, x):
out = self.conv1(x)
out = self.dropout1(out)
out = self.norm1(out)
out = self.act_fn1(out)
out = self.conv2(out)
out = self.dropout2(out)
out = self.norm2(out)
out = self.act_fn2(out)
return out
[docs]class ResNetBlock(ME.MinkowskiNetwork):
'''
ResNet Block with Leaky ReLU nonlinearities.
'''
expansion = 1
[docs] def __init__(self,
in_features,
out_features,
stride=1,
dilation=1,
dimension=3,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={},
bias=False):
super(ResNetBlock, self).__init__(dimension)
assert dimension > 0
self.act_fn1 = activations_construct(activation, **activation_args)
self.act_fn2 = activations_construct(activation, **activation_args)
if in_features != out_features:
self.residual = ME.MinkowskiLinear(in_features, out_features,
bias=bias)
else:
self.residual = Identity()
self.conv1 = ME.MinkowskiConvolution(
in_features, out_features, kernel_size=3,
stride=stride, dilation=dilation, dimension=dimension, bias=bias)
self.norm1 = normalizations_construct(
normalization, in_features, **normalization_args)
self.conv2 = ME.MinkowskiConvolution(
out_features, out_features, kernel_size=3,
stride=stride, dilation=dilation, dimension=dimension, bias=bias)
self.norm2 = normalizations_construct(
normalization, out_features, **normalization_args)
[docs] def forward(self, x):
residual = self.residual(x)
out = self.conv1(self.act_fn1(self.norm1(x)))
out = self.conv2(self.act_fn2(self.norm2(out)))
out += residual
return out
[docs]class AtrousIIBlock(ME.MinkowskiNetwork):
'''
ResNet-type block with Atrous Convolutions, as developed in ACNN paper:
<ACNN: a Full Resolution DCNN for Medical Image Segmentation>
Original Paper: https://arxiv.org/pdf/1901.09203.pdf
'''
[docs] def __init__(self, in_features, out_features,
dimension=3,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={}):
super(AtrousIIBlock, self).__init__(dimension)
assert dimension > 0
self.D = dimension
self.act_fn1 = activations_construct(activation, **activation_args)
self.act_fn2 = activations_construct(activation, **activation_args)
if in_features != out_features:
self.residual = ME.MinkowskiLinear(in_features, out_features)
else:
self.residual = Identity()
self.conv1 = ME.MinkowskiConvolution(
in_features, out_features,
kernel_size=3, stride=1, dilation=1, dimension=self.D)
self.norm1 = normalizations_construct(
normalization, out_features, **normalization_args)
self.conv2 = ME.MinkowskiConvolution(
out_features, out_features,
kernel_size=3, stride=1, dilation=3, dimension=self.D)
self.norm2 = normalizations_construct(
normalization, out_features, **normalization_args)
[docs] def forward(self, x):
residual = self.residual(x)
out = self.conv1(x)
out = self.norm1(out)
out = self.act_fn1(out)
out = self.conv2(out)
out = self.norm2(out)
out += residual
out = self.act_fn2(out)
return out
[docs]class ResNeXtBlock(ME.MinkowskiNetwork):
'''
ResNeXt block with leaky relu nonlinearities and atrous convs.
CONFIGURATIONS:
-------------------------------------------------------
- in_features (int): total number of input features
- out_features (int): total number of output features
NOTE: if in_features != out_features, then the identity skip
connection is replaced with a 1x1 conv layer.
- dimension (int): dimension of dataset.
- leakiness (float): leakiness for LeakyReLUs.
- cardinality (int): number of different paths, see ResNeXt paper.
- depth (int): number of convolutions + BN + LeakyReLU layers inside
each cardinal path.
- dilations (int or list of ints): dilation rates for atrous
convolutions.
- kernel_sizes (int or list of ints): kernel sizes for each conv layers
inside cardinal paths.
- strides (int or list of ints): strides for each conv layers inside
cardinal paths.
-------------------------------------------------------
NOTE: For vanilla resnext blocks, set dilation=1 and others to default.
'''
[docs] def __init__(self, in_features, out_features,
dimension=3,
cardinality=4,
depth=1,
dilations=None,
kernel_sizes=3,
strides=1,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={}):
super(ResNeXtBlock, self).__init__(dimension)
assert dimension > 0
assert cardinality > 0
assert (in_features % cardinality == 0 and
out_features % cardinality == 0)
self.D = dimension
nIn = in_features // cardinality
nOut = out_features // cardinality
self.dilations = []
if dilations is None:
# Default
self.dilations = [3**i for i in range(cardinality)]
elif isinstance(dilations, int):
self.dilations = [dilations for _ in range(cardinality)]
elif isinstance(dilations, list):
assert len(dilations) == cardinality
self.dilations = dilations
else:
raise ValueError(
'Invalid type for input strides, must be int or list!')
self.kernels = []
if isinstance(kernel_sizes, int):
self.kernels = [kernel_sizes for _ in range(cardinality)]
elif isinstance(kernel_sizes, list):
assert len(kernel_sizes) == cardinality
self.kernels = kernel_sizes
else:
raise ValueError(
'Invalid type for input strides, must be int or list!')
self.strides = []
if isinstance(strides, int):
self.strides = [strides for _ in range(cardinality)]
elif isinstance(strides, list):
assert len(strides) == cardinality
self.strides = strides
else:
raise ValueError(
'Invalid type for input strides, must be int or list!')
# For each path, generate sequentials
self.paths = []
for i in range(cardinality):
m = []
m.append(ME.MinkowskiLinear(in_features, nIn))
for j in range(depth):
in_C = (nIn if j == 0 else nOut)
m.append(ME.MinkowskiConvolution(
in_channels=in_C, out_channels=nOut,
kernel_size=self.kernels[i], stride=self.strides[i],
dilation=self.dilations[i], dimension=self.D))
m.append(normalizations_construct(
normalization, nOut, **normalization_args))
m.append(activations_construct(activation, **activation_args))
m = nn.Sequential(*m)
self.paths.append(m)
self.paths = nn.Sequential(*self.paths)
self.linear = ME.MinkowskiLinear(out_features, out_features)
# Skip Connection
if in_features != out_features:
self.residual = ME.MinkowskiLinear(in_features, out_features)
else:
self.residual = Identity()
[docs] def forward(self, x):
residual = self.residual(x)
cat = tuple([layer(x) for layer in self.paths])
out = ME.cat(cat)
out = self.linear(out)
out += residual
return out
[docs]class SPP(ME.MinkowskiNetwork):
'''
Spatial Pyramid Pooling Module.
PSPNet (Pyramid Scene Parsing Network) uses vanilla SPPs, while
DeeplabV3 and DeeplabV3+ uses ASPP (atrous versions).
Default parameters will construct a global average pooling + unpooling
layer which is done in ParseNet.
CONFIGURATIONS:
-------------------------------------------------------
- in_features (int): number of input features
- out_features (int): number of output features
- D (int): dimension of dataset.
- mode (str): pooling mode. In MinkowskiEngine, currently
'avg', 'max', and 'sum' are supported.
- dilations (int or list of ints): dilation rates for atrous
convolutions.
- kernel_sizes (int or list of ints): kernel sizes for each
pooling operation. Note that kernel_size == stride for the SPP layer.
-------------------------------------------------------
'''
[docs] def __init__(self, in_features, out_features,
kernel_sizes=None, dilations=None, mode='avg', D=3):
super(SPP, self).__init__(D)
if mode == 'avg':
self.pool_fn = ME.MinkowskiAvgPooling
elif mode == 'max':
self.pool_fn = ME.MinkowskiMaxPooling
elif mode == 'sum':
self.pool_fn = ME.MinkowskiSumPooling
else:
raise ValueError("Invalid pooling mode, must be one of \
'sum', 'max' or 'average'")
self.unpool_fn = ME.MinkowskiPoolingTranspose
# Include global pooling as first modules.
self.pool = [ME.MinkowskiGlobalPooling(dimension=D)]
self.unpool = [ME.MinkowskiBroadcast(dimension=D)]
multiplier = 1
# Define subregion poolings
self.spp = []
if kernel_sizes is not None:
if isinstance(dilations, int):
dilations = [dilations for _ in range(len(kernel_sizes))]
elif isinstance(dilations, list):
assert len(kernel_sizes) == len(dilations)
else:
raise ValueError("Invalid input to dilations, must be either \
int or list of ints")
multiplier = len(kernel_sizes) + 1 # Additional 1 for globalPool
for k, d in zip(kernel_sizes, dilations):
pooling_layer = self.pool_fn(
kernel_size=k, dilation=d, stride=k, dimension=D)
unpooling_layer = self.unpool_fn(
kernel_size=k, dilation=d, stride=k, dimension=D)
self.pool.append(pooling_layer)
self.unpool.append(unpooling_layer)
self.pool = nn.Sequential(*self.pool)
self.unpool = nn.Sequential(*self.unpool)
self.linear = ME.MinkowskiLinear(in_features * multiplier, out_features)
[docs] def forward(self, input):
cat = []
for i, pool in enumerate(self.pool):
x = pool(input)
# First item is Global Pooling
if i == 0:
x = self.unpool[i](input, x)
else:
x = self.unpool[i](x)
cat.append(x)
out = ME.cat(cat)
out = self.linear(out)
return out
[docs]class CascadeDilationBlock(ME.MinkowskiNetwork):
'''
Cascaded Atrous Convolution Block
'''
[docs] def __init__(self, in_features, out_features,
dimension=3,
depth=6,
dilations=[1, 2, 4, 8, 16, 32],
activation='relu',
activation_args={}):
super(CascadeDilationBlock, self).__init__(dimension)
self.D = dimension
F = out_features
m = []
self.input_layer = ME.MinkowskiLinear(in_features, F)
for i in range(depth):
m.append(ResNetBlock(F, F, dilation=dilations[i],
activation=activation, activation_args=activation_args))
self.net = nn.Sequential(*m)
[docs] def forward(self, x):
x = self.input_layer(x)
sumTensor = x
for i, layer in enumerate(self.net):
x = layer(x)
sumTensor += x
return sumTensor
[docs]class ASPP(ME.MinkowskiNetwork):
'''
Atrous Spatial Pyramid Pooling Module
'''
[docs] def __init__(self, in_features, out_features,
dimension=3,
width=5,
dilations=[2, 4, 6, 8, 12]):
super(ASPP, self).__init__(dimension)
assert len(dilations) == width
m = []
m.append(ME.MinkowskiLinear(in_features, out_features))
for d in dilations:
m.append(ME.MinkowskiConvolution(in_features, out_features,
kernel_size=3, dilation=d, dimension=self.D))
self.net = nn.Sequential(*m)
self.pool = ME.MinkowskiGlobalPooling(dimension=self.D)
self.unpool = ME.MinkowskiBroadcast(dimension=self.D)
self.out = nn.Sequential(
ME.MinkowskiConvolution(in_features * (2 + width), out_features,
kernel_size=3, dilation=1, dimension=self.D),
ME.MinkowskiBatchNorm(out_features),
ME.MinkowskiReLU()
)
[docs] def forward(self, x):
cat = []
for i, layer in enumerate(self.net):
x_i = layer(x)
cat.append(x_i)
x_global = self.pool(x)
x_global = self.unpool(x, x_global)
cat.append(x_global)
out = ME.cat(cat)
return self.out(out)
[docs]class MBConv(ME.MinkowskiNetwork):
[docs] def __init__(self, in_features, out_features,
expand_ratio=2,
dimension=3,
dilation=1,
kernel_size=3,
stride=1,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={},
has_bias=False):
super(MBConv, self).__init__(dimension)
self.D = dimension
self.hidden_dim = int(expand_ratio * in_features)
if expand_ratio == 1:
self.m = nn.Sequential(
normalizations_construct(
normalization, in_features, **normalization_args),
activations_construct(activation, **activation_args),
ME.MinkowskiConvolution(
in_features, out_features,
kernel_size=3, stride=1, dilation=1,
dimension=self.D, bias=has_bias))
else:
self.m = nn.Sequential(
normalizations_construct(
normalization, in_features, **normalization_args),
activations_construct(activation, **activation_args),
ME.MinkowskiLinear(in_features, self.hidden_dim),
normalizations_construct(
normalization, self.hidden_dim, **normalization_args),
activations_construct(activation, **activation_args),
ME.MinkowskiChannelwiseConvolution(self.hidden_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
bias=has_bias,
dimension=self.D),
normalizations_construct(
normalization, self.hidden_dim, **normalization_args),
activations_construct(activation, **activation_args),
ME.MinkowskiLinear(self.hidden_dim, out_features)
)
[docs] def forward(self, x):
out = self.m(x)
return out
[docs]class MBResConv(ME.MinkowskiNetwork):
[docs] def __init__(self, in_features, out_features,
expand_ratio=2,
dimension=3,
dilation=1,
kernel_size=3,
stride=1,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={},
has_bias=False):
super(MBResConv, self).__init__(dimension)
self.D = dimension
self.m1 = MBConv(in_features, out_features,
expand_ratio=expand_ratio,
dimension=dimension,
dilation=dilation,
kernel_size=kernel_size,
stride=stride,
activation=activation,
activation_args=activation_args,
normalization=normalization,
normalization_args=normalization_args,
has_bias=has_bias)
self.m2 = MBConv(out_features, out_features,
expand_ratio=expand_ratio,
dimension=dimension,
dilation=dilation,
kernel_size=kernel_size,
stride=stride,
activation=activation,
activation_args=activation_args,
normalization=normalization,
normalization_args=normalization_args,
has_bias=has_bias)
if in_features == out_features:
self.connection = Identity()
else:
self.connection = nn.Sequential(
normalizations_construct(
normalization, in_features, **normalization_args),
activations_construct(activation, **activation_args),
ME.MinkowskiLinear(in_features, out_features))
[docs] def forward(self, x):
x_add = self.connection(x)
x = self.m1(x)
x = self.m2(x)
out = x_add + x
return out
[docs]class SEBlock(ME.MinkowskiNetwork):
'''
Squeeze and Excitation Blocks
'''
[docs] def __init__(self, channels, ratio=8, dimension=3):
super(SEBlock, self).__init__(dimension)
assert channels // ratio > 0
assert channels % ratio == 0
self.linear1 = ME.MinkowskiLinear(channels, channels // ratio)
self.relu = ME.MinkowskiReLU()
self.linear2 = ME.MinkowskiLinear(channels // ratio, channels)
self.sigmoid = ME.MinkowskiSigmoid()
self.pool = ME.MinkowskiGlobalPooling()
self.bcst = ME.MinkowskiBroadcastMultiplication()
[docs] def forward(self, x):
g = self.pool(x)
g = self.linear1(g)
g = self.relu(g)
g = self.linear2(g)
g = self.sigmoid(g)
out = self.bcst(x, g)
return x
[docs]class SEResNetBlock(ME.MinkowskiNetwork):
'''
Squeeze and Excitation ResNet Block with Leaky ReLU nonlinearities.
'''
expansion = 1
[docs] def __init__(self,
in_features,
out_features,
se_ratio=8,
stride=1,
dilation=1,
dimension=3,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={}):
super(SEResNetBlock, self).__init__(dimension)
assert dimension > 0
self.act_fn1 = activations_construct(activation, **activation_args)
self.act_fn2 = activations_construct(activation, **activation_args)
if in_features != out_features:
self.residual = ME.MinkowskiLinear(in_features, out_features)
else:
self.residual = Identity()
self.conv1 = ME.MinkowskiConvolution(
in_features, out_features, kernel_size=3,
stride=1, dilation=dilation, dimension=dimension)
self.norm1 = normalizations_construct(
normalization, out_features, **normalization_args)
self.conv2 = ME.MinkowskiConvolution(
out_features, out_features, kernel_size=3,
stride=1, dilation=dilation, dimension=dimension)
self.norm2 = normalizations_construct(
normalization, out_features, **normalization_args)
self.se_block = SEBlock(out_features, ratio=se_ratio, dimension=dimension)
[docs] def forward(self, x):
residual = self.residual(x)
out = self.act_fn1(self.norm1(self.conv1(x)))
out = self.norm2(self.conv2(out))
out = self.se_block(out)
out += residual
out = self.act_fn2(out)
return out
[docs]class MBResConvSE(ME.MinkowskiNetwork):
[docs] def __init__(self, in_features, out_features, se_ratio=8,
expand_ratio=2,
dimension=3,
dilation=1,
kernel_size=3,
stride=1,
activation='relu',
activation_args={},
normalization='batch_norm',
normalization_args={},
has_bias=False):
super(MBResConvSE, self).__init__(in_features, out_features,
expand_ratio=expand_ratio,
dimension=dimension,
dilation=dilation,
kernel_size=kernel_size,
stride=stride,
activation=activation,
activation_args=activation_args,
normalization=normalization,
normalization_args=normalization_args,
has_bias=has_bias)
if in_features == out_features:
self.connection = Identity()
else:
self.connection = nn.Sequential(
normalizations_construct(
normalization, in_features, **normalization_args),
activations_construct(activation, **activation_args),
ME.MinkowskiLinear(in_features, out_features))
self.se = SEBlock(out_features, ratio=se_ratio)
[docs] def forward(self, x):
res = self.m(x)
attn = self.se(res)
out = self.connection(x) + attn
return out