Probably the most common form of problem we tackle with machine learning is classification, that is taking new data points and putting them into one of a number of fixed sets or classes. But what if we don’t necessarily know all the classes when we train the model? A good example of this is face recognition where we want a system that can store faces and then identify if any new images it sees contain that face. Obviously, we can’t retrain the model every time we add someone new to the database so we need a better solution.
One way to solve this problem is metric learning. In metric learning, our goal is to learn a metric or distance measure between different data points. If we train our model correctly then this distance measure will put examples of the same class close together and different classes further apart.
As we mentioned earlier, the most obvious use for metric learning is in face recognition but there is an extremely wide range of applications (with open source datasets) including bird species, vehicles, and products. These are all computer vision based but there’s no reason you can’t use this technique with any other kind of data you have, just replace the encoder with something that makes sense for your data!
So what does a model that measures the distance between two data points look like? Well, the actual model just maps a data point to a vector. We then have to decide on a measure of distance between those.
The most obvious of these is just the normal Euclidean distance:
where $$x$$ and $$y$$ are the two vectors. But this has the problem that it is unbounded from above and (depending on our choice of loss function) the model could “cheat” and get nice losses by sending easy unmatched pairs far apart and not improving on the difficult cases. Instead, it’s more common to use the cosine distance
where $$|x|$$ is the L2 norm of $$x$$. This has the advantage that it’s bounded (we are basically forcing all our embeddings to be on a high-dimensional sphere). Although I should point out that there are a great number of distance measures in the literature. I won’t go through them all but I will give a special shout-out to hyperbolic embeddings which seem to be the state of the art on a few datasets at the moment.
After choosing our distance measure we need to pick a loss function. I’ll only talk about one option here, the classic triplet loss, but there is an equally dizzying array of loss functions from surrogate classification tasks to CLIP-style contrastive loss (some examples can be found here).
The triplet loss function simply samples three examples from the dataset, an “anchor” example as a basis for comparison, a “positive” example of the same class, and a “negative” example of a different one. The loss is then simply
where d is the distance measure, $$A$$ is the anchor vector, $$N$$ is the negative vector, $$P$$ is the positive vector, and $$margin$$ is a hyperparameter. This loss function is nice because it is bounded from below (it cannot be negative) and the margin means that the model is pushed to predict smaller values for positive pairs, rather than simply making all pairs small. The loss is only minimized when the distance for every negative pair is larger by at least the margin than the distance between every positive pair.
Once we have the loss function you can simply train the model using all your favorite deep learning training tricks.
Useful python libraries
If you want to try out metric learning at home I have a few python library recommendations (I use PyTorch so these are all based on that framework, but I’m sure similar resources exist for TensorFlow):
Pytorch Lightning: There are a large number of wrappers for PyTorch that remove the need for boilerplate code. Most of these are too restrictive but I find that PyTorch lightning gives you enough freedom over your training process.
Pytorch Image Models (timm): This is a great resource for pretrained image encoders including some very state-of-the-art architectures, if you’re doing any kind of computer vision this is a great place to start.
Pytorch Metric Learning: This is a huge library for metric learning. I’ve only used the datasets and loss functions but it also has some ready-made training loops.