Source code for mlreco.post_processing.store.store_uresnet_ppn

from mlreco.utils import CSVData
import numpy as np
import scipy
import os
from mlreco.utils.ppn import uresnet_ppn_type_point_selector


[docs]def store_uresnet_ppn(cfg, data_blob, res, logdir, iteration, nms_score_threshold=0.8, window_size=3, score_threshold=0.9, type_threshold=2, **kwargs): """ Configuration ------------- input_data: str, optional store_method: str, optional threshold, size: NMS parameters score_threshold: to filter based on score only (no NMS) """ method_cfg = cfg['post_processing']['store_uresnet_ppn'] coords_col = method_cfg.get('coords_col', (1, 4)) if (method_cfg is not None and not method_cfg.get('input_data', 'input_data') in data_blob) or (method_cfg is None and 'input_data' not in data_blob): return if not 'points' in res: return index = data_blob['index'] input_dat = data_blob.get('input_data' if method_cfg is None else method_cfg.get('input_data', 'input_data'), None) output_pts = res.get('points',None) output_seg = res.get('segmentation',None) # --- # TODO change ppn1 and ppn2 to use new output of ME version of PPN # --- output_ppn1 = res.get('ppn1',None) output_ppn2 = res.get('ppn2',None) output_mask = res.get('mask_ppn',None) output_ghost = res.get('ghost',None) ppn_score_threshold = 0.2 if method_cfg is None else method_cfg.get('ppn_score_threshold', 0.2) ppn_type_threshold = 2 if method_cfg is None else method_cfg.get('ppn_type_threshold', 2) store_per_iteration = True if method_cfg is not None and method_cfg.get('store_method',None) is not None: assert(method_cfg['store_method'] in ['per-iteration','per-event']) store_per_iteration = method_cfg['store_method'] == 'per-iteration' fout=None if store_per_iteration: fout=CSVData(os.path.join(logdir, 'uresnet-ppn-iter-%07d.csv' % iteration)) for data_idx, tree_idx in enumerate(index): if not store_per_iteration: fout=CSVData(os.path.join(logdir, 'uresnet-ppn-event-%07d.csv' % tree_idx)) if output_pts is not None: scores = scipy.special.softmax(output_pts[data_idx][:, 3:5], axis=1) # type 3 = raw PPN predictions for row_idx, row in enumerate(output_pts[data_idx]): event = input_dat[data_idx][row_idx] if len(row) > 5: # Includes prediction of point type value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[row_idx, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 3, value)) fout.write() # type 5 = PPN predictions after NMS keep = nms_numpy(output_pts[data_idx][:,:3], scores[:, 1], nms_score_threshold, window_size) events = input_dat[data_idx][keep] for row_idx, row in enumerate(output_pts[data_idx][keep]): event = events[row_idx] if len(row) > 5: value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[keep][row_idx, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 5, value)) fout.write() # 6 = PPN predictions after score thresholding keep = scores[:,1] > score_threshold events = input_dat[data_idx][keep] for row_idx,row in enumerate(output_pts[data_idx][keep]): event = events[row_idx] if len(row) > 5: value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[keep][row_idx, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 6, value)) fout.write() # type 7 = PPN predictions after masking mask = (~(output_mask[data_idx] == 0)).any(axis=1) events = input_dat[data_idx][mask] for row_idx, row in enumerate(output_pts[data_idx][mask]): event = events[row_idx] if len(row) > 5: value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[mask][i, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 7, value)) fout.write() # type 10 = masking + score threshold mask = ((~(output_mask[data_idx] == 0)).any(axis=1)) & (scores[:, 1] > score_threshold) events = input_dat[data_idx][mask] for row_idx, row in enumerate(output_pts[data_idx][mask]): event = events[row_idx] if len(row) > 5: value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[mask][i, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 10, value)) fout.write() # type 11 = masking + score threshold + NMS mask = ((~(output_mask[data_idx] == 0)).any(axis=1)) & (scores[:, 1] > score_threshold) keep = nms_numpy(output_pts[data_idx][mask][:, :3], scores[mask][:, 1], nms_score_threshold, window_size) events = input_dat[data_idx][mask][keep] for row_idx, row in enumerate(output_pts[data_idx][mask][keep]): event = events[row_idx] if len(row) > 5: value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[mask][keep][i, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 11, value)) fout.write() # Store PPN1 and PPN2 output scores_ppn1 = scipy.special.softmax(output_ppn1[data_idx][:, -2:], axis=1) scores_ppn2 = scipy.special.softmax(output_ppn2[data_idx][:, -2:], axis=1) keep_ppn1 = scores_ppn1[:, 1] > 0.5 keep_ppn2 = scores_ppn2[:, 1] > 0.5 # 8 = PPN1 for i, row in enumerate(scores_ppn1[keep_ppn1]): event = output_ppn1[data_idx][keep_ppn1][i, coords_col[0]:coords_col[1]] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5, event[1] + 0.5, event[2] + 0.5, 8, scores_ppn1[keep_ppn1][i, 1])) fout.write() # 9 = PPN2 for i, row in enumerate(scores_ppn2[keep_ppn2]): event = output_ppn2[data_idx][keep_ppn2][i, coords_col[0]:coords_col[1]] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5, event[1] + 0.5, event[2] + 0.5, 9, scores_ppn2[keep_ppn2][i, 1])) fout.write() if output_seg is not None: predictions = np.argmax(output_seg[data_idx],axis=1) for row_idx,row in enumerate(predictions): event = input_dat[data_idx][row_idx] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0], event[1], event[2], 4, row)) fout.write() if output_ghost is not None: predictions = np.argmax(output_ghost[data_idx],axis=1) for row_idx,row in enumerate(predictions): event = input_dat[data_idx][row_idx] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0], event[1], event[2], 14, row)) fout.write() if output_seg is not None and output_pts is not None and len(output_pts[0][0]) > 5: # 12 = masking + score threshold + filter PPN points of type X within N pixels of type X # 13 = masking + score threshold + filter PPN points of type X within N pixels of type X + NMS mask = ((~(output_mask[data_idx] == 0)).any(axis=1)) & (scores[:, 1] > score_threshold) uresnet_predictions = np.argmax(output_seg[data_idx][mask], axis=1) num_classes = output_seg[data_idx].shape[1] ppn_type_predictions = np.argmax(scipy.special.softmax(output_pts[data_idx][mask][:, 5:], axis=1), axis=1) for c in range(num_classes): uresnet_points = uresnet_predictions == c ppn_points = ppn_type_predictions == c if ppn_points.shape[0] > 0 and uresnet_points.shape[0] > 0: d = scipy.spatial.distance.cdist(output_pts[data_idx][mask][ppn_points][:, :3] + input_dat[data_idx][mask][ppn_points][:, coords_col[0]:coords_col[1]] + 0.5, input_dat[data_idx][mask][uresnet_points][:, coords_col[0]:coords_col[1]]) ppn_mask = (d < type_threshold).any(axis=1) for i, row in enumerate(output_pts[data_idx][mask][ppn_points][ppn_mask]): event = input_dat[data_idx][mask][ppn_points][ppn_mask][i] if len(row) > 5: value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[mask][ppn_points][ppn_mask][i, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 12, value)) fout.write() keep = nms_numpy(output_pts[data_idx][mask][ppn_points][ppn_mask][:, :3], scores[mask][ppn_points][ppn_mask][:, 1], nms_score_threshold, window_size) for i, row in enumerate(output_pts[data_idx][mask][ppn_points][ppn_mask][keep]): event = input_dat[data_idx][mask][ppn_points][ppn_mask][keep][i] if len(row) > 5: value = np.argmax(scipy.special.softmax(row[5:])) else: value = scores[mask][ppn_points][ppn_mask][keep][i, 1] fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, event[0] + 0.5 + row[0], event[1] + 0.5 + row[1], event[2] + 0.5 + row[2], 13, value)) fout.write() # 14 pts = uresnet_ppn_type_point_selector(data_blob['input_data'][data_idx], res, entry=data_idx, score_threshold=ppn_score_threshold, type_threshold=ppn_type_threshold) for i, row in enumerate(pts): fout.record(('idx', 'x', 'y', 'z', 'type', 'value'), (tree_idx, row[1], row[2], row[3], 14, row[-1])) fout.write() if not store_per_iteration: fout.close() if store_per_iteration: fout.close()
[docs]def nms_numpy(im_proposals, im_scores, threshold, size): dim = im_proposals.shape[-1] coords = [] for d in range(dim): coords.append(im_proposals[:, d] - size) for d in range(dim): coords.append(im_proposals[:, d] + size) coords = np.array(coords) areas = np.ones_like(coords[0]) areas = np.prod(coords[dim:] - coords[0:dim] + 1, axis=0) order = im_scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) xx = np.maximum(coords[:dim, i][:, np.newaxis], coords[:dim, order[1:]]) yy = np.minimum(coords[dim:, i][:, np.newaxis], coords[dim:, order[1:]]) w = np.maximum(0.0, yy - xx + 1) inter = np.prod(w, axis=0) ovr = inter / (areas[i] + areas[order[1:]] - inter) inds = np.where(ovr <= threshold)[0] order = order[inds + 1] return keep