Source code for mlreco.models.layers.cluster_cnn.factories

[docs]def backbone_dict(): """ returns dictionary of clustering models """ from mlreco.models.layers.common import uresnet, fpn models = { # Encoder-Decoder Type Backbone Architecture. "uresnet": uresnet.UResNet, "fpn": fpn.FPN } return models
[docs]def gs_kernel_dict(): ''' Returns dictionary of kernel function models. ''' from . import gs_kernels kernels = { 'default': gs_kernels.DefaultKernel, 'bilinear': gs_kernels.BilinearKernel, 'bilinear_mlp': gs_kernels.BilinearNNKernel, 'mixed': gs_kernels.MixedKernel } return kernels
[docs]def gs_kernel_construct(cfg): models = gs_kernel_dict() name = cfg.get('name', 'default') num_features = cfg.get('num_features', 1) args = cfg.get('args', {}) if not name in models: raise Exception("Unknown kernel function name provided") return models[name](num_features=num_features, **args)
[docs]def cluster_model_dict(): ''' Returns dictionary of implemented clustering layers. ''' # from mlreco.models.scn.cluster_cnn import spatial_embeddings # from mlreco.models.scn.cluster_cnn import graph_spice from mlreco.models.layers.cluster_cnn.embeddings import SPICE as MinkSPICE models = { # "spice_cnn": spatial_embeddings.SpatialEmbeddings, "spice_cnn_me": MinkSPICE, # "graph_spice_embedder": graph_spice.GraphSPICEEmbedder, # "graph_spice_geo_embedder": graph_spice.GraphSPICEGeoEmbedder # "graphgnn_spice": graphgnn_spice.SparseOccuSegGNN } return models
[docs]def spice_loss_dict(): ''' Returns dictionary of various clustering losses with enhancements. ''' from . import losses # from .graphgnn_spice import SparseOccuSegGNNLoss loss = { # Hyperspace Clustering Losses 'single': losses.single_layers.DiscriminativeLoss, #'multi': losses.multi_layers.MultiScaleLoss, #'multi-weighted': losses.multi_layers.DistanceEstimationLoss3, #'multi-repel': losses.multi_layers.DistanceEstimationLoss2, #'multi-distance': losses.multi_layers.DistanceEstimationLoss, # SPICE Losses 'se_bce': losses.spatial_embeddings.MaskBCELoss2, 'se_bce_ellipse': losses.spatial_embeddings.MaskBCELossBivariate, 'se_lovasz': losses.spatial_embeddings.MaskLovaszHingeLoss, 'se_lovasz_inter': losses.spatial_embeddings.MaskLovaszInterLoss, 'se_focal': losses.spatial_embeddings.MaskFocalLoss, 'se_multivariate': losses.spatial_embeddings.MultiVariateLovasz, 'se_ce_lovasz': losses.spatial_embeddings.CELovaszLoss, 'se_lovasz_inter_2': losses.spatial_embeddings.MaskLovaszInterLoss2, 'se_lovasz_inter_bc': losses.spatial_embeddings.MaskLovaszInterBC, # SPICE Losses Vectorized 'se_vectorized': losses.spatial_embeddings_fast.SPICELoss, 'se_vectorized_inter': losses.spatial_embeddings_fast.SPICEInterLoss, 'graph_spice_edge_loss': losses.gs_embeddings.NodeEdgeHybridLoss, 'graph_spice_loss': losses.gs_embeddings.GraphSPICEEmbeddingLoss # 'graphgnn_spice_loss': SparseOccuSegGNNLoss } return loss
[docs]def backbone_construct(name): models = backbone_dict() if not name in models: raise Exception("Unknown backbone architecture name provided") return models[name]
[docs]def cluster_model_construct(cfg, name): models = cluster_model_dict() if not name in models: raise Exception("Unknown clustering model name provided") return models[name](cfg)
[docs]def spice_loss_construct(name): loss_fns = spice_loss_dict() if not name in loss_fns: raise Exception("Unknown clustering loss function name provided") return loss_fns[name]