import numpy as np
import torch
import torch.nn as nn
import MinkowskiEngine as ME
import MinkowskiFunctional as MF
from mlreco.utils import local_cdist
from mlreco.models.layers.common.blocks import ResNetBlock, SPP, ASPP
from mlreco.models.layers.common.activation_normalization_factories import activations_construct
from mlreco.models.layers.common.configuration import setup_cnn_configuration
from mlreco.models.layers.common.extract_feature_map import MinkGhostMask
from collections import Counter
from mlreco.models.layers.cluster_cnn.losses.misc import BinaryCELogDiceLoss
[docs]class AttentionMask(torch.nn.Module):
'''
Returns a masked tensor of x according to mask, where the number of
coordinates between x and mask differ
'''
[docs] def __init__(self, score_threshold=0.5):
super(AttentionMask, self).__init__()
self.prune = ME.MinkowskiPruning()
self.score_threshold=score_threshold
[docs] def forward(self, x, mask):
assert x.tensor_stride == mask.tensor_stride
device = x.F.device
# Create a mask sparse tensor in x-coordinates
x0 = ME.SparseTensor(
coordinates=x.C,
features=torch.zeros(x.F.shape[0], mask.F.shape[1]).to(device),
coordinate_manager=x.coordinate_manager,
tensor_stride=x.tensor_stride)
mask_in_xcoords = x0 + mask
x_expanded = ME.SparseTensor(
coordinates=mask_in_xcoords.C,
features=torch.zeros(mask_in_xcoords.F.shape[0],
x.F.shape[1]).to(device),
coordinate_manager=x.coordinate_manager,
tensor_stride=x.tensor_stride)
x_expanded = x_expanded + x
target = mask_in_xcoords.F.int().bool().squeeze()
x_pruned = self.prune(x_expanded, target)
return x_pruned
[docs]class MergeConcat(torch.nn.Module):
[docs] def __init__(self):
super(MergeConcat, self).__init__()
[docs] def forward(self, input, other):
assert input.tensor_stride == other.tensor_stride
device = input.F.device
# Create a placeholder tensor with input.C coordinates
x0 = ME.SparseTensor(
coordinates=input.C,
features=torch.zeros(input.F.shape[0], other.F.shape[1]).to(device),
coordinate_manager=input.coordinate_manager,
tensor_stride=input.tensor_stride)
# Set placeholder values with other.F features by performing
# sparse tensor addition.
x1 = x0 + other
# Same procedure, but with other
x_expanded = ME.SparseTensor(
coordinates=x1.C,
features=torch.zeros(x1.F.shape[0],
input.F.shape[1]).to(device),
coordinate_manager=input.coordinate_manager,
tensor_stride=input.tensor_stride)
x2 = x_expanded + input
# Now input and other share the same coordinates and shape
concated = ME.cat(x1, x2)
return concated
[docs]class ExpandAs(nn.Module):
[docs] def __init__(self):
super(ExpandAs, self).__init__()
[docs] def forward(self, x, shape, labels=None):
'''
x: feature tensor of input sparse tensor (N x F)
labels: N x 0 tensor of labels
'''
device = x.F.device
features = x.F
features[labels] = 1.0
features = features.expand(*shape)
# if labels is not None:
# features_expand = features.expand(*shape).clone()
# features_expand[labels] = 1.0
# else:
# features_expand = features.expand(*shape)
output = ME.SparseTensor(
features=features,
coordinate_map_key=x.coordinate_map_key,
coordinate_manager=x.coordinate_manager)
return output
[docs]class PPN(torch.nn.Module):
'''
Point Proposal Network (PPN) implementation using MinkowskiEngine
It requires a UResNet network as a backbone.
Configuration
-------------
data_dim: int, default 3
num_input: int, default 1
allow_bias: bool, default False
spatial_size: int, default 512
leakiness: float, default 0.33
activation: dict
For activation function, defaults to `{'name': 'lrelu', 'args': {}}`
norm_layer: dict
For normalization function, defaults to `{'name': 'batch_norm', 'args': {}}`
depth: int, default 5
Depth of UResNet, also corresponds to how many times we down/upsample.
filters: int, default 16
Number of filters in the first convolution of UResNet.
Will increase linearly with depth.
reps: int, default 2
Convolution block repetition factor
input_kernel: int, default 3
Receptive field size for very first convolution after input layer.
num_classes: int, default 5
score_threshold: float, default 0.5
classify_endpoints: bool, default False
Enable classification of points into start vs end points.
ppn_resolution: float, default 1.0
ghost: bool, default False
downsample_ghost: bool, default True
use_true_ghost_mask: bool, default False
mask_loss_name: str, default 'BCE'
Can be 'BCE' or 'LogDice'
particles_label_seg_col: int, default -2
Which column corresponds to particles' semantic label
track_label: int, default 1
Output
------
points: torch.Tensor
Contains X, Y, Z predictions, semantic class prediction logits, and prob score
mask_ppn: list of torch.Tensor
Binary mask at various spatial scales of PPN predictions (voxel-wise score > some threshold)
ppn_coords: list of torch.Tensor
List of XYZ coordinates at various spatial scales.
ppn_layers: list of torch.Tensor
List of score features at various spatial scales.
ppn_output_coordinates: torch.Tensor
XYZ coordinates tensor at the very last layer of PPN (initial spatial scale)
classify_endpoints: torch.Tensor
Logits for end/start point classification.
See Also
--------
PPNLonelyLoss, mlreco.models.uresnet_ppn_chain
'''
[docs] def __init__(self, cfg, name='ppn'):
super(PPN, self).__init__()
setup_cnn_configuration(self, cfg, name)
self.model_cfg = cfg.get(name, {})
# UResNet Configurations
self.reps = self.model_cfg.get('reps', 2)
self.depth = self.model_cfg.get('depth', 5)
self.num_classes = self.model_cfg.get('num_classes', 5)
self.num_filters = self.model_cfg.get('filters', 16)
self.nPlanes = [i * self.num_filters for i in range(1, self.depth+1)]
self.ppn_score_threshold = self.model_cfg.get('score_threshold', 0.5)
self.input_kernel = self.model_cfg.get('input_kernel', 3)
self._classify_endpoints = self.model_cfg.get('classify_endpoints', False)
# Initialize Decoder
self.decoding_block = []
self.decoding_conv = []
self.ppn_pred = nn.ModuleList()
for i in range(self.depth-2, -1, -1):
m = []
m.append(ME.MinkowskiBatchNorm(self.nPlanes[i+1]))
m.append(activations_construct(
self.activation_name, **self.activation_args))
m.append(ME.MinkowskiConvolutionTranspose(
in_channels=self.nPlanes[i+1],
out_channels=self.nPlanes[i],
kernel_size=2,
stride=2,
dimension=self.D))
m = nn.Sequential(*m)
self.decoding_conv.append(m)
m = []
for j in range(self.reps):
m.append(ResNetBlock(self.nPlanes[i] * (2 if j == 0 else 1),
self.nPlanes[i],
dimension=self.D,
activation=self.activation_name,
activation_args=self.activation_args))
m = nn.Sequential(*m)
self.decoding_block.append(m)
self.ppn_pred.append(ME.MinkowskiLinear(self.nPlanes[i], 1))
self.decoding_block = nn.Sequential(*self.decoding_block)
self.decoding_conv = nn.Sequential(*self.decoding_conv)
self.sigmoid = ME.MinkowskiSigmoid()
self.expand_as = ExpandAs()
self.final_block = ResNetBlock(self.nPlanes[0],
self.nPlanes[0],
dimension=self.D,
activation=self.activation_name,
activation_args=self.activation_args)
self.ppn_pixel_pred = ME.MinkowskiConvolution(self.nPlanes[0],
self.D,
kernel_size=3,
stride=1,
dimension=self.D)
self.ppn_type = ME.MinkowskiConvolution(self.nPlanes[0],
self.num_classes,
kernel_size=3,
stride=1,
dimension=self.D)
self.ppn_final_score = ME.MinkowskiConvolution(self.nPlanes[0],
2,
kernel_size=3,
stride=1,
dimension=self.D)
if self._classify_endpoints:
self.ppn_endpoint = ME.MinkowskiConvolution(self.nPlanes[0],
2,
kernel_size=3,
stride=1,
dimension=self.D)
self.resolution = self.model_cfg.get('ppn_resolution', 1.0)
# Ghost point removal options
self.ghost = self.model_cfg.get('ghost', False)
self.masker = AttentionMask()
self.merge_concat = MergeConcat()
if self.ghost:
#print("Ghost Masking is enabled for MinkPPN.")
self.ghost_mask = MinkGhostMask(self.D)
self.use_true_ghost_mask = self.model_cfg.get(
'use_true_ghost_mask', False)
self.downsample_ghost = self.model_cfg.get('downsample_ghost', True)
# print('Total Number of Trainable Parameters (mink_ppnplus)= {}'.format(
# sum(p.numel() for p in self.parameters() if p.requires_grad)))
[docs] def forward(self, final, decoderTensors, ghost=None, ghost_labels=None):
ppn_layers, ppn_coords = [], []
tmp = []
mask_ppn = []
device = final.device
# We need to make labels on-the-fly to include true points in the
# propagated masks during training
decoder_feature_maps = []
if self.ghost:
# Downsample stride 1 ghost mask to all intermediate decoder layers
with torch.no_grad():
if self.use_true_ghost_mask:
assert ghost_labels is not None
# TODO: Not sure what's going on here
ghost_mask_tensor = ghost_labels[:, -1] < self.num_classes
ghost_coords = ghost_labels[:, :4]
else:
ghost_mask_tensor = 1.0 - torch.argmax(ghost.F,
dim=1,
keepdim=True)
ghost_coords = ghost.C
ghost_coords_man = final.coordinate_manager
ghost_tensor_stride = ghost.tensor_stride
ghost_mask = ME.SparseTensor(
features=ghost_mask_tensor,
coordinates=ghost_coords,
coordinate_manager=ghost_coords_man,
tensor_stride=ghost_tensor_stride)
for t in decoderTensors[::-1]:
scaled_ghost_mask = self.ghost_mask(ghost_mask, t)
nonghost_tensor = self.masker(t, scaled_ghost_mask)
decoder_feature_maps.append(nonghost_tensor)
decoder_feature_maps = decoder_feature_maps[::-1]
else:
decoder_feature_maps = decoderTensors
x = final
for i, layer in enumerate(self.decoding_conv):
decTensor = decoder_feature_maps[i]
x = layer(x)
if self.ghost:
x = self.merge_concat(decTensor, x)
else:
x = ME.cat(decTensor, x)
x = self.decoding_block[i](x)
scores = self.ppn_pred[i](x)
tmp.append(scores.F)
ppn_coords.append(scores.C)
scores = self.sigmoid(scores)
s_expanded = self.expand_as(scores, x.F.shape)
mask_ppn.append((scores.F > self.ppn_score_threshold))
x = x * s_expanded.detach()
# Note that we skipped ghost masking for the final sparse tensor,
# namely the tensor with the same resolution as the input to uresnet.
# This is done at the full chain cnn stage, for consistency with SCN
device = x.F.device
ppn_output_coordinates = x.C
for p in tmp:
a = p.to(dtype=torch.float32, device=device)
ppn_layers.append(a)
x = self.final_block(x)
pixel_pred = self.ppn_pixel_pred(x)
ppn_type = self.ppn_type(x)
ppn_final_score = self.ppn_final_score(x)
if self._classify_endpoints:
ppn_endpoint = self.ppn_endpoint(x)
# X, Y, Z, logits, and prob score
points = torch.cat([pixel_pred.F, ppn_type.F, ppn_final_score.F], dim=1)
res = {
'points': [points],
'mask_ppn': [mask_ppn],
'ppn_layers': [ppn_layers],
'ppn_coords': [ppn_coords],
'ppn_output_coordinates': [ppn_output_coordinates],
}
if self._classify_endpoints:
res['classify_endpoints'] = [ppn_endpoint.F]
return res
[docs]class PPNLonelyLoss(torch.nn.modules.loss._Loss):
"""
Loss function for PPN.
Output
------
reg_loss: float
Distance loss
mask_loss: float
Binary voxel-wise prediction (is there an object of interest or not)
type_loss: float
Semantic prediction loss.
classify_endpoints_loss: float
classify_endpoints_acc: float
See Also
--------
PPN, mlreco.models.uresnet_ppn_chain
"""
[docs] def __init__(self, cfg, name='ppn'):
super(PPNLonelyLoss, self).__init__()
self.loss_config = cfg.get(name, {})
# pprint(self.loss_config)
self.mask_loss_name = self.loss_config.get('mask_loss_name', 'BCE')
if self.mask_loss_name == "BCE":
self.lossfn = torch.nn.functional.binary_cross_entropy_with_logits
elif self.mask_loss_name == "LogDice":
self.lossfn = BinaryCELogDiceLoss()
else:
NotImplementedError
self.resolution = self.loss_config.get('ppn_resolution', 1.0)
self.regloss = torch.nn.MSELoss()
self.segloss = torch.nn.functional.cross_entropy
self.particles_label_seg_col = self.loss_config.get(
'particles_label_seg_col', -2)
# Endpoint classification (optional)
self._classify_endpoints = self.loss_config.get('classify_endpoints', False)
self._track_label = self.loss_config.get('track_label', 1)
# Restrict the label points to specific classes (pass a list if needed)
self._point_classes = self.loss_config.get('point_classes', [])
[docs] def forward(self, result, segment_label, particles_label):
# TODO Add weighting
assert len(particles_label) == len(segment_label)
ppn_output_coordinates = result['ppn_output_coordinates']
# print("PPN Output Coordinates = ", ppn_output_coordinates[0].shape)
# assert False
# print(result['ppn_coords'][0][-1])
batch_ids = [result['ppn_coords'][0][-1][:, 0]]
num_batches = len(batch_ids[0].unique())
total_loss = 0
total_acc = 0
device = segment_label[0].device
res = {
'reg_loss': 0.,
'mask_loss': 0.,
'type_loss': 0.,
'classify_endpoints_loss': 0.,
'classify_endpoints_accuracy': 0.
}
# Semantic Segmentation Loss
for igpu in range(len(segment_label)):
particles = particles_label[igpu]
if len(self._point_classes) > 0:
classes = particles[:, self.particles_label_seg_col]
class_mask = torch.zeros(len(particles), dtype=torch.bool, device=particles.device)
for c in self._point_classes:
class_mask |= classes == c
particles = particles[class_mask]
ppn_layers = result['ppn_layers'][igpu]
ppn_coords = result['ppn_coords'][igpu]
points = result['points'][igpu]
loss_gpu, acc_gpu = 0.0, 0.0
for layer in range(len(ppn_layers)):
# print("Layer = ", layer)
ppn_score_layer = ppn_layers[layer]
coords_layer = ppn_coords[layer]
loss_layer = 0.0
for b in batch_ids[igpu].int().unique():
batch_index_layer = coords_layer[:, 0].int() == b
batch_particle_index = batch_ids[igpu].int() == b
points_label = particles[particles[:, 0].int() == b][:, 1:4]
scores_event = ppn_score_layer[batch_index_layer].squeeze()
points_event = coords_layer[batch_index_layer]
if len(scores_event.shape) == 0:
continue
d_true = local_cdist(
points_label,
points_event[:, 1:4].float().to(device))
d_positives = (d_true < self.resolution * \
2**(len(ppn_layers) - layer)).any(dim=0)
num_positives = d_positives.sum()
num_negatives = d_positives.nelement() - num_positives
w = num_positives.float() / \
(num_positives + num_negatives).float()
weight_ppn = torch.zeros(d_positives.shape[0]).to(device)
weight_ppn[d_positives] = 1 - w
weight_ppn[~d_positives] = w
loss_batch = self.lossfn(scores_event,
d_positives.float(),
weight=weight_ppn,
reduction='mean')
loss_layer += loss_batch
if layer == len(ppn_layers)-1:
# Get Final Layers
anchors = coords_layer[batch_particle_index][:, 1:4].float().to(device) + 0.5
pixel_score = points[batch_particle_index][:, -1]
pixel_logits = points[batch_particle_index][:, 3:8]
pixel_pred = points[batch_particle_index][:, :3] + anchors
d = local_cdist(points_label, pixel_pred)
positives = (d < self.resolution).any(dim=0)
if (torch.sum(positives) < 1):
continue
acc = (positives == (pixel_score > 0)).sum().float() / float(pixel_score.shape[0])
total_acc += acc
# Mask Loss
mask_loss_final = self.lossfn(pixel_score,
positives.float(),
weight=weight_ppn,
reduction='mean')
# Type Segmentation Loss
# d = local_cdist(points_label, pixel_pred)
# positives = (d < self.resolution).any(dim=0)
distance_positives = d[:, positives]
event_types_label = particles[particles[:, 0] == b]\
[:, self.particles_label_seg_col]
counter = Counter({0:0, 1:0, 2:0, 3:0})
counter.update(list(event_types_label.int().cpu().numpy()))
w = torch.Tensor([counter[0],
counter[1],
counter[2],
counter[3], 0]).float()
w = float(sum(counter.values())) / (w + 1.0)
positive_labels = event_types_label[torch.argmin(distance_positives, dim=0)]
type_loss = self.segloss(pixel_logits[positives],
positive_labels.long(),
weight=w.to(device))
# --- Endpoint classification loss
if self._classify_endpoints:
tracks = positive_labels == self._track_label
loss_point_class, acc_point_class, point_class_count = 0., 0., 0.
loss_classify_endpoints, acc_classify_endpoints = 0., 1.
if tracks.sum().item() > 0:
# Start and end points separately in case of overlap
for point_class in range(2):
point_class_mask = particles[particles[:, 0].int() == b][:, -1] == point_class
#true = event_particles[event_particles[:, -4] == b][torch.argmin(distances_positives, dim=0), -1]
point_class_positives = (d_true[point_class_mask, :] < self.resolution).any(dim=0)
point_class_index = d[point_class_mask, :][:, point_class_positives]
if point_class_index.nelement():
point_class_index = torch.argmin(point_class_index, dim=0)
true = particles[particles[:, 0].int() == b][point_class_mask][point_class_index, -1]
#pred = result['classify_endpoints'][i][batch_index][event_mask][positives]
pred = result['classify_endpoints'][igpu][batch_index_layer][point_class_positives]
tracks = event_types_label[point_class_index] == self._track_label
if tracks.sum().item():
loss_point_class += torch.mean(self.segloss(pred[tracks], true[tracks].long()))
acc_point_class += (torch.argmax(pred[tracks], dim=-1) == true[tracks]).sum().item() / float(true[tracks].nelement())
point_class_count += 1
if point_class_count:
loss_classify_endpoints = loss_point_class / point_class_count
acc_classify_endpoints = acc_point_class / point_class_count
#total_loss += loss_classify_endpoints.float()
res['classify_endpoints_loss'] += float(loss_classify_endpoints) / num_batches
res['classify_endpoints_accuracy'] += float(acc_classify_endpoints) / num_batches
# --- end of Endpoint classification
# Distance Loss
d2, _ = torch.min(distance_positives, dim=0)
reg_loss = d2.mean()
res['reg_loss'] += float(reg_loss) / num_batches if num_batches else float(reg_loss)
res['type_loss'] += float(type_loss) / num_batches if num_batches else float(type_loss)
res['mask_loss'] += float(mask_loss_final) / num_batches if num_batches else float(mask_loss_final)
total_loss += (reg_loss + type_loss + mask_loss_final) / num_batches if num_batches else reg_loss + type_loss + mask_loss_final
if self._classify_endpoints:
total_loss += loss_classify_endpoints / num_batches if num_batches else loss_classify_endpoints
loss_layer /= max(1, num_batches)
loss_gpu += loss_layer
loss_gpu /= len(ppn_layers)
total_loss += loss_gpu
total_acc = total_acc / num_batches if num_batches else 1.
res['loss'] = total_loss
res['accuracy'] = float(total_acc)
return res