Source code for mlreco.models.layers.common.dbscan

import torch
import numpy as np
import sklearn
from larcv import larcv
from mlreco.utils.track_clustering import track_clustering


[docs]class DBSCANFragmenter(torch.nn.Module): """ DBSCAN Layer that uses sklearn's DBSCAN implementation to fragment each of the particle classes into dense instances. Runs DBSCAN on each requested class separately, in one of three ways: - Run pure DBSCAN on all the voxels in that class - Runs DBSCAN on PPN point-masked voxels, associates leftovers based on proximity - Use a graph-based method to cluster tracks based on PPN vertices (track only) Args: data ([np.array]): (N,5) [x, y, z, batchid, sem_type] output (dict) : Dictionary that contains the UResNet+PPN output Returns: (torch.tensor): [(C_0^0, C_0^1, ..., C_0^N_0), ...] List of list of clusters (one per class) """
[docs] def __init__(self, cfg, name='dbscan_frag', batch_col=0, coords_col=(1, 4)): super(DBSCANFragmenter, self).__init__() model_cfg = cfg[name] self.batch_col = batch_col self.coords_col = coords_col # Global DBSCAN clustering parameters self.dim = model_cfg.get('dim', 3) self.eps = model_cfg.get('eps', 1.999) self.metric = model_cfg.get('metric', 'euclidean') self.min_samples = model_cfg.get('min_samples', 1) self.min_size = model_cfg.get('min_size', 3) self.num_classes = model_cfg.get('num_classes', 4) self.cluster_classes = model_cfg.get('cluster_classes', list(np.arange(self.num_classes))) # Instance breaking parameters self.break_classes = model_cfg.get('break_classes', [1]) self.track_include_delta = model_cfg.get('track_include_delta', False) self.track_clustering_method = model_cfg.get('track_clustering_method', 'masked_dbscan') self.ppn_score_threshold = model_cfg.get('ppn_score_threshold', 0.5) self.ppn_type_threshold = model_cfg.get('ppn_type_threshold', 1.999) self.ppn_type_score_threshold = model_cfg.get('ppn_type_score_threshold', 0.5) self.ppn_mask_radius = model_cfg.get('ppn_mask_radius', 5) # Assert consistency between parameter sizes if 'break_tracks' in model_cfg: # Deprecated, only kept for backward compatibility assert 'break_classes' not in model_cfg, 'break_tracks is deprecated, only specify break_classes' self.break_classes = model_cfg['break_tracks']*[1] if not isinstance(self.cluster_classes, list): self.cluster_classes = [self.cluster_classes] if not isinstance(self.eps, list): self.eps = [self.eps for _ in self.cluster_classes] if not isinstance(self.min_samples, list): self.min_samples = [self.min_samples for _ in self.cluster_classes] if not isinstance(self.min_size, list): self.min_size = [self.min_size for _ in self.cluster_classes] if not isinstance(self.break_classes, list): self.break_classes = [self.break_classes] assert len(self.eps) == len(self.min_samples) == len(self.min_size) == len(self.cluster_classes)
[docs] def get_clusts(self, data, bids, segmentation, break_points=None): # Loop over batch and semantic classes clusts = [] for bid in bids: # Batch mask batch_mask = data[:, self.batch_col] == bid for k, s in enumerate(self.cluster_classes): # Batch and segmentation mask mask = batch_mask & (segmentation == s) if self.track_include_delta and s == larcv.kShapeTrack and s in self.break_classes: mask = batch_mask & ((segmentation == s) | (segmentation == larcv.kShapeDelta)) selection = np.where(mask)[0] if not len(selection): continue # Restrict voxel set, run clustering voxels = data[selection, self.coords_col[0]:self.coords_col[1]] if s in self.break_classes: assert break_points is not None points_mask = break_points[:, self.batch_col] == bid breaking_method = self.track_clustering_method if s==larcv.kShapeTrack else 'masked_dbscan' labels = track_clustering(voxels = voxels, points = break_points[points_mask, self.coords_col[0]:self.coords_col[1]], method = breaking_method, eps = self.eps[k], min_samples = self.min_samples[k], metric = self.metric, mask_radius = self.ppn_mask_radius) else: labels = sklearn.cluster.DBSCAN(eps=self.eps[k], min_samples=self.min_samples[k], metric=self.metric).fit(voxels).labels_ # Build clusters for this class if self.track_include_delta and s == larcv.kShapeTrack and s in self.break_classes: labels[segmentation[selection] == larcv.kShapeDelta] = -1 cls_idx = [selection[np.where(labels == i)[0]] \ for i in np.unique(labels) \ if (i > -1 and np.sum(labels == i) >= self.min_size[k])] clusts.extend(cls_idx) same_length = np.all([len(c) == len(clusts[0]) for c in clusts]) clusts = np.array(clusts, dtype=object if not same_length else np.int64) return clusts
[docs] def forward(self, data, output=None, points=None): # If instances are to be broken up, either provide a set of points or get them from the PPN output break_points = None if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() if points is not None and isinstance(points, torch.Tensor): points = points.detach().cpu().numpy() if len(self.break_classes): assert output is not None or points is not None if points is None: from mlreco.utils.ppn import uresnet_ppn_type_point_selector numpy_output = {'segmentation': [output['segmentation'][0].detach().cpu().numpy()], 'points' : [output['points'][0].detach().cpu().numpy()], 'mask_ppn' : [x.detach().cpu().numpy() for x in output['mask_ppn'][0]], 'ppn_coords' : [x.detach().cpu().numpy() for x in output['ppn_coords'][0]]} points = uresnet_ppn_type_point_selector(data, numpy_output, score_threshold = self.ppn_score_threshold, type_threshold = self.ppn_type_threshold, type_score_threshold = self.ppn_type_score_threshold) point_labels = points[:, 12] else: point_labels = points[:, -1] break_points = points[point_labels != larcv.kShapeDelta, :self.dim+1] # Do not include delta points # Break down the input data to its components bids = np.unique(data[:, self.batch_col].astype(int)) segmentation = data[:,-1] data = data[:,:-1] clusts = self.get_clusts(data, bids, segmentation, break_points) return clusts