The class neighborhood of a dataset can be learned using soft nearest neighbor loss. In this article, we discuss how to implement the soft nearest neighbor loss which we also talked about here.
Representation learning is the task of learning the most salient features in a given dataset by a deep neural network. It is usually an implicit task done in a supervised learning paradigm, and it is a crucial factor in the success of deep learning (Krizhevsky et al., 2012; He et al., 2016; Simonyan et al., 2014). In other words, representation learning automates the process of feature extraction. With this, we can use the learned representations for downstream tasks such as classification, regression, and synthesis.
We can also influence how the learned representations are formed to cater specific use cases. In the case of classification, the representations are primed to have data points from the same class to flock together, while for generation (e.g. in GANs), the representations are primed to have points of real data flock with the synthesized ones.
In the same sense, we have enjoyed the use of principal components analysis (PCA) to encode features for downstream tasks. However, we do not have any class or label information in PCA-encoded representations, hence the performance on downstream tasks may be further improved. We can improve the encoded representations by approximating the class or label information in it by learning the neighborhood structure of the dataset, i.e. which features are clustered together, and such clusters would imply that the features belong to the same class as per the clustering assumption in the semi-supervised learning literature (Chapelle et al., 2009).
To integrate the neighborhood structure in the representations, manifold learning techniques have been introduced such as locally linear embeddings or LLE (Roweis & Saul, 2000), neighborhood components analysis or NCA (Hinton et al., 2004), and t-stochastic neighbor embedding or t-SNE (Maaten & Hinton, 2008).
However, the aforementioned manifold learning techniques have their own drawbacks. For instance, both LLE and NCA encode linear embeddings instead of nonlinear embeddings. Meanwhile, t-SNE embeddings result to different structures depending on the hyperparameters used.
To avoid such drawbacks, we can use an improved NCA algorithm which is the soft nearest neighbor loss or SNNL (Salakhutdinov & Hinton, 2007; Frosst et al., 2019). The SNNL improves the NCA algorithm by introducing nonlinearity, and it is computed for each hidden layer of a neural network instead of solely on the last encoding layer. This loss function is used to optimize the entanglement of points in a dataset.
In this context, entanglement is defined as how close class-similar data points to each other are compared to class-different data points. A low entanglement means that class-similar data points are much closer to each other than class-different data points (see Figure 1). Having such a set of data points will render downstream tasks much easier to accomplish with an even better performance. Frosst et al. (2019) expanded the SNNL objective by introducing a temperature factor T. Thus giving us the following as the final loss function,
where d is a distance metric on either raw input features or hidden layer representations of a neural network, and T is the temperature factor that is directly proportional to the distances among data points in a hidden layer. For this implementation, we use the cosine distance as our distance metric for more stable computations.
The purpose of this article is to help readers understand and implement the soft nearest neighbor loss, and so we shall dissect the loss function in order to understand it better.
Distance Metric
The first thing we should compute are the distances among data points, that are either the raw input features or hidden layer representations of the network.
For our implementation, we use the cosine distance metric (Figure 3) for more stable computations. At the time being, let us ignore the denoted subsets ij and ik in the figure above, and let us just focus on computing the cosine distance among our input data points. We accomplish this through the following PyTorch code:
“`html
normalized_a = torch.nn.functional.normalize(features, dim=1, p=2)
normalized_b = torch.nn.functional.normalize(features, dim=1, p=2)
normalized_b = torch.conj(normalized_b).T
product = torch.matmul(normalized_a, normalized_b)
distance_matrix = torch.sub(torch.tensor(1.0), product)
“`
In the code snippet above, we first normalize the input features in lines 1 and 2 using Euclidean norm. Then in line 3, we get the conjugate transpose of the second set of the normalized input features. We compute the conjugate transpose to account for complex vectors. In lines 4 and 5, we compute the cosine similarity and distance of the input features.
Concretely, consider the following set of features,
“`html
tensor([[ 1.0999, -0.9438, 0.7996, -0.4247],
[ 1.2150, -0.2953, 0.0417, -1.2913],
[ 1.3218, 0.4214, -0.1541, 0.0961],
[-0.7253, 1.1685, -0.1070, 1.3683]])
“`
Using the distance metric we defined above, we gain the following distance matrix,
“`html
tensor([[ 0.0000e+00, 2.8502e-01, 6.2687e-01, 1.7732e+00],
[ 2.8502e-01, 0.0000e+00, 4.6293e-01, 1.8581e+00],
[ 6.2687e-01, 4.6293e-01, -1.1921e-07, 1.1171e+00],
[ 1.7732e+00, 1.8581e+00, 1.1171e+00, -1.1921e-07]])
“`
Sampling Probability
We can now compute the matrix that represents the probability of picking each feature given its pairwise distances to all other features. This is simply the probability of picking i points based on the distances between i and j or k points.
We can compute this through the following code:
“`html
pairwise_distance_matrix = torch.exp(-(distance_matrix / temperature)) - torch.eye(features.shape[0]).to(model.device)