import torch
import torch.nn as nn
from mlreco.models.layers.cluster_cnn.embeddings import SPICE
from mlreco.models.layers.cluster_cnn import spice_loss_construct
[docs]class MinkSPICE(SPICE):
MODULES = ['network_base', 'uresnet_encoder', 'embedding_decoder', 'seediness_decoder']
[docs] def __init__(self, cfg):
super(MinkSPICE, self).__init__(cfg)
#print('Total Number of Trainable Parameters = {}'.format(
# sum(p.numel() for p in self.parameters() if p.requires_grad)))
#print(self)
[docs]class SPICELoss(nn.Module):
'''
Loss function for Proposal-Free Mask Generators.
'''
[docs] def __init__(self, cfg, name='spice_loss'):
super(SPICELoss, self).__init__()
self.model_config = cfg.get('spice', {})
self.skip_classes = self.model_config.get('skip_classes', [2, 3, 4])
self.loss_config = cfg.get(name, {})
self.loss_func_name = self.loss_config.get('name', 'se_lovasz_inter')
self.loss_func = spice_loss_construct(self.loss_func_name)
self.loss_func = self.loss_func(cfg)
#print(self.loss_func)
[docs] def class_mask(self, cluster_label):
'''
Filter classes according to segmentation label.
'''
mask = torch.ones(len(cluster_label), dtype=bool, device=cluster_label.device)
for c in self.skip_classes:
mask &= cluster_label[:,-1] != c
return mask
[docs] def forward(self, result, cluster_label):
mask = self.class_mask(cluster_label[0])
segment_label = [cluster_label[0][mask][:, [0, 1, 2, 3, -1]]]
group_label = [cluster_label[0][mask][:, [0, 1, 2, 3, 5]]]
return self.loss_func(result, segment_label, group_label)