import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict
# KERNEL FUNCTIONS FOR CLUSTERING
[docs]def gauss(centroid, sigma):
'''
Constructor for a gaussian kernel functions.
INPUTS:
- centroid: (D, ) Tensor for the coordinates of the gaussian centroid.
- sigma: value for gaussian bandwidth.
RETURNS:
- f (function): kernel function defined by centroid and sigma.
'''
def f(x):
dists = torch.sum(torch.pow(x - centroid, 2), dim=1)
probs = torch.exp(-dists / (2.0 * sigma**2))
return probs
return f
[docs]def mvgauss(centroid, L, dim=3):
'''
Constructor for multivariate gaussian kernels.
L (torch.Tensor): D x D tensor representing Cholesky decomposition of
the covariance matrix. The covariance matrix is then calculated as:
\Sigma = LL^T.
'''
def f(x):
N = x.shape[0]
cov = torch.zeros(dim, dim)
tril_indices = torch.tril_indices(row=dim, col=dim, offset=0)
cov[tril_indices[0], tril_indices[1]] = L
cov = torch.matmul(cov, cov.T)
dist = torch.matmul((x - centroid), cov)
dist = torch.bmm(dist.view(N, 1, -1), (x-centroid).view(N, -1, 1)).squeeze()
probs = torch.exp(-dist)
return probs
return f
[docs]def laplace(centroid, sigma):
def f(x):
dists = torch.sum(torch.norm(x - centroid), dim=1)
probs = torch.exp(-dist / sigma)
return probs
return f
[docs]def student_t(centroid):
'''
Pairwise student t distribution as used in TSNE
'''
def f(x):
dists = torch.sum(torch.pow(x - centroid, 2), dim=1)
probs = 1 / (1 + dists)
return probs
return f