Source code for mlreco.utils.numba

import numpy as np
import numba as nb
import torch
import inspect
from functools import wraps

[docs]def numba_wrapper(cast_args=[], list_args=[], keep_torch=False, ref_arg=None): ''' Function which wraps a *numba* function with some checks on the input to make the relevant conversions to numpy where necessary. Args: cast_args ([str]): List of arguments to be cast to numpy list_args ([str]): List of arguments which need to be cast to a numba typed list keep_torch (bool): Make the output a torch object, if the reference argument is one ref_arg (str) : Reference argument used to assign a type and device to the torch output Returns: Function ''' def outer(fn): @wraps(fn) def inner(*args, **kwargs): # Convert the positional arguments in args into key:value pairs in kwargs keys = list(inspect.signature(fn).parameters.keys()) for i, val in enumerate(args): kwargs[keys[i]] = val # Extract the default values for the remaining parameters for key, val in inspect.signature(fn).parameters.items(): if key not in kwargs and val.default != inspect.Parameter.empty: kwargs[key] = val.default # If a torch output is request, register the input dtype and device location if keep_torch: assert ref_arg in kwargs dtype, device = None, None if isinstance(kwargs[ref_arg], torch.Tensor): dtype = kwargs[ref_arg].dtype device = kwargs[ref_arg].device # If the cast data is not a numpy array, cast it for arg in cast_args: assert arg in kwargs if not isinstance(kwargs[arg], np.ndarray): assert isinstance(kwargs[arg], torch.Tensor) kwargs[arg] = kwargs[arg].detach().cpu().numpy() # For now cast to CPU only # If there is a reflected list in the input, type it for arg in list_args: assert arg in kwargs kwargs[arg] = nb.typed.List(kwargs[arg]) # Get the output ret = fn(**kwargs) if keep_torch and dtype: if isinstance(ret, np.ndarray): ret = torch.tensor(ret, dtype=dtype, device=device) elif isinstance(ret, list): ret = [torch.tensor(r, dtype=dtype, device=device) for r in ret] else: raise TypeError('Return type not recognized, cannot cast to torch') return ret return inner return outer
@nb.njit(cache=True) def unique_nb(x: nb.int32[:]) -> (nb.int32[:], nb.int32[:]): b = np.sort(x.flatten()) unique = list(b[:1]) counts = [1 for _ in unique] for x in b[1:]: if x != unique[-1]: unique.append(x) counts.append(1) else: counts[-1] += 1 return unique, counts @nb.njit(cache=True) def submatrix_nb(x:nb.float32[:,:], index1: nb.int32[:], index2: nb.int32[:]) -> nb.float32[:,:]: """ Numba implementation of matrix subsampling """ subx = np.empty((len(index1), len(index2)), dtype=x.dtype) for i, i1 in enumerate(index1): for j, i2 in enumerate(index2): subx[i,j] = x[i1,i2] return subx @nb.njit(cache=True) def cdist_nb(x1: nb.float32[:,:], x2: nb.float32[:,:]) -> nb.float32[:,:]: """ Numba implementation of Eucleadian cdist in 3D. """ res = np.empty((x1.shape[0], x2.shape[0]), dtype=x1.dtype) for i1 in range(x1.shape[0]): for i2 in range(x2.shape[0]): res[i1,i2] = np.sqrt((x1[i1][0]-x2[i2][0])**2+(x1[i1][1]-x2[i2][1])**2+(x1[i1][2]-x2[i2][2])**2) return res @nb.njit(cache=True) def mean_nb(x: nb.float32[:,:], axis: nb.int32) -> nb.float32[:]: """ Numba implementation of np.mean(x, axis) """ assert axis == 0 or axis == 1 mean = np.empty(x.shape[1-axis], dtype=x.dtype) if axis == 0: for i in range(len(mean)): mean[i] = np.mean(x[:,i]) else: for i in range(len(mean)): mean[i] = np.mean(x[i]) return mean @nb.njit(cache=True) def argmin_nb(x: nb.float32[:,:], axis: nb.int32) -> nb.int32[:]: """ Numba implementation of np.argmin(x, axis) """ assert axis == 0 or axis == 1 argmin = np.empty(x.shape[1-axis], dtype=np.int32) if axis == 0: for i in range(len(argmin)): argmin[i] = np.argmin(x[:,i]) else: for i in range(len(argmin)): argmin[i] = np.argmin(x[i]) return argmin @nb.njit(cache=True) def argmax_nb(x: nb.float32[:,:], axis: nb.int32) -> nb.int32[:]: """ Numba implementation of np.argmax(x, axis) """ assert axis == 0 or axis == 1 argmax = np.empty(x.shape[1-axis], dtype=np.int32) if axis == 0: for i in range(len(argmax)): argmax[i] = np.argmax(x[:,i]) else: for i in range(len(argmax)): argmax[i] = np.argmax(x[i]) return argmax @nb.njit(cache=True) def min_nb(x: nb.float32[:,:], axis: nb.int32) -> nb.float32[:]: """ Numba implementation of np.max(x, axis) """ assert axis == 0 or axis == 1 xmin = np.empty(x.shape[1-axis], dtype=np.int32) if axis == 0: for i in range(len(xmin)): xmin[i] = np.min(x[:,i]) else: for i in range(len(xmax)): xmin[i] = np.min(x[i]) return xmin @nb.njit(cache=True) def max_nb(x: nb.float32[:,:], axis: nb.int32) -> nb.float32[:]: """ Numba implementation of np.max(x, axis) """ assert axis == 0 or axis == 1 xmax = np.empty(x.shape[1-axis], dtype=np.int32) if axis == 0: for i in range(len(xmax)): xmax[i] = np.max(x[:,i]) else: for i in range(len(xmax)): xmax[i] = np.max(x[i]) return xmax @nb.njit(cache=True) def all_nb(x: nb.float32[:,:], axis: nb.int32) -> nb.int32[:]: """ Numba implementation of np.all(x, axis) """ assert axis == 0 or axis == 1 all = np.empty(x.shape[1-axis], dtype=np.bool_) if axis == 0: for i in range(len(all)): all[i] = np.all(x[:,i]) else: for i in range(len(all)): all[i] = np.all(x[i]) return all @nb.njit(cache=True) def softmax_nb(x: nb.float32[:,:], axis: nb.int32) -> nb.float32[:,:]: assert axis == 0 or axis == 1 if axis == 0: xmax = max_nb(x, axis=0) logsumexp = np.log(np.sum(np.exp(x-xmax), axis=0)) + xmax return np.exp(x - logsumexp) else: xmax = max_nb(x, axis=1).reshape(-1,1) logsumexp = np.log(np.sum(np.exp(x-xmax), axis=1)).reshape(-1,1) + xmax return np.exp(x - logsumexp) @nb.njit(cache=True) def log_loss_nb(x1: nb.boolean[:], x2: nb.float32[:]) -> nb.float32: if len(x1) > 0: return -(np.sum(np.log(x2[x1])) + np.sum(np.log(1.-x2[~x1])))/len(x1) else: return 0.