import numpy as np
import torch
import torch.nn as nn
import time
import MinkowskiEngine as ME
import MinkowskiFunctional as MF
from mlreco.models.layers.common.ppnplus import PPN, PPNLonelyLoss
from mlreco.models.uresnet import SegmentationLoss
from collections import defaultdict
from mlreco.models.uresnet import UResNet_Chain
[docs]class UResNetPPN(nn.Module):
"""
A model made of UResNet backbone and PPN layers. Typical configuration:
.. code-block:: yaml
model:
name: uresnet_ppn_chain
modules:
uresnet_lonely:
# Your uresnet config here
ppn:
# Your ppn config here
Configuration
-------------
data_dim: int, default 3
num_input: int, default 1
allow_bias: bool, default False
spatial_size: int, default 512
leakiness: float, default 0.33
activation: dict
For activation function, defaults to `{'name': 'lrelu', 'args': {}}`
norm_layer: dict
For normalization function, defaults to `{'name': 'batch_norm', 'args': {}}`
depth: int, default 5
Depth of UResNet, also corresponds to how many times we down/upsample.
filters: int, default 16
Number of filters in the first convolution of UResNet.
Will increase linearly with depth.
reps: int, default 2
Convolution block repetition factor
input_kernel: int, default 3
Receptive field size for very first convolution after input layer.
num_classes: int, default 5
score_threshold: float, default 0.5
classify_endpoints: bool, default False
Enable classification of points into start vs end points.
ppn_resolution: float, default 1.0
ghost: bool, default False
downsample_ghost: bool, default True
use_true_ghost_mask: bool, default False
mask_loss_name: str, default 'BCE'
Can be 'BCE' or 'LogDice'
particles_label_seg_col: int, default -2
Which column corresponds to particles' semantic label
track_label: int, default 1
See Also
--------
mlreco.models.uresnet.UResNet_Chain, mlreco.models.layers.common.ppnplus.PPN
"""
MODULES = ['mink_uresnet', 'mink_uresnet_ppn_chain', 'mink_ppn']
[docs] def __init__(self, cfg):
super(UResNetPPN, self).__init__()
self.model_config = cfg
self.ghost = cfg.get('uresnet_lonely', {}).get('ghost', False)
assert self.ghost == cfg.get('ppn', {}).get('ghost', False)
self.backbone = UResNet_Chain(cfg)
self.ppn = PPN(cfg)
self.num_classes = self.backbone.num_classes
self.num_filters = self.backbone.F
self.segmentation = ME.MinkowskiLinear(
self.num_filters, self.num_classes)
[docs] def forward(self, input):
labels = None
if len(input) == 1:
# PPN without true ghost mask propagation
input_tensors = [input[0]]
elif len(input) == 2:
# PPN with true ghost mask propagation
input_tensors = [input[0]]
labels = input[1]
out = defaultdict(list)
for igpu, x in enumerate(input_tensors):
# input_data = x[:, :5]
res = self.backbone([x])
out.update({'ghost': res['ghost']})
if self.ghost:
if self.ppn.use_true_ghost_mask:
res_ppn = self.ppn(res['finalTensor'][igpu],
res['decoderTensors'][igpu],
ghost=res['ghost_sptensor'][igpu],
ghost_labels=labels)
else:
res_ppn = self.ppn(res['finalTensor'][igpu],
res['decoderTensors'][igpu],
ghost=res['ghost_sptensor'][igpu])
else:
res_ppn = self.ppn(res['finalTensor'][igpu],
res['decoderTensors'][igpu])
# if self.training:
# res_ppn = self.ppn(res['finalTensor'], res['encoderTensors'], particles_label)
# else:
# res_ppn = self.ppn(res['finalTensor'], res['encoderTensors'])
segmentation = self.segmentation(res['decoderTensors'][igpu][-1])
out['segmentation'].append(segmentation.F)
out.update(res_ppn)
return out
[docs]class UResNetPPNLoss(nn.Module):
"""
See Also
--------
mlreco.models.uresnet.SegmentationLoss, mlreco.models.layers.common.ppnplus.PPNLonelyLoss
"""
[docs] def __init__(self, cfg):
super(UResNetPPNLoss, self).__init__()
self.ppn_loss = PPNLonelyLoss(cfg)
self.segmentation_loss = SegmentationLoss(cfg)
[docs] def forward(self, outputs, segment_label, particles_label, weights=None):
res_segmentation = self.segmentation_loss(
outputs, segment_label, weights=weights)
res_ppn = self.ppn_loss(
outputs, segment_label, particles_label)
res = {
'loss': res_segmentation['loss'] + res_ppn['loss'],
'accuracy': (res_segmentation['accuracy'] + res_ppn['accuracy'])/2
}
res.update({'segmentation_'+k:v for k, v in res_segmentation.items()})
res.update({'ppn_'+k:v for k, v in res_ppn.items()})
return res