Metric Learning - eine Einführung
Dr. William Clemens
Die wahrscheinlich häufigste Form von Problemen, die wir mit maschinellem Lernen angehen, ist die Klassifizierung, d. h. die Einordnung neuer Datenpunkte in eine von mehreren festgelegten Gruppen oder Klassen. Was aber, wenn wir nicht unbedingt alle Klassen kennen, wenn wir das Modell trainieren? Ein gutes Beispiel hierfür ist die Gesichtserkennung, bei der wir ein System benötigen, das Gesichter speichern und dann erkennen kann, ob neue Bilder, die es sieht, eines dieser Gesichter enthalten. Natürlich können wir das Modell nicht jedes Mal neu trainieren, wenn wir der Datenbank ein neues Gesicht hinzufügen, also brauchen wir eine bessere Lösung.
Eine Möglichkeit, dieses Problem zu lösen, ist das Metric Learning. Beim Metric Learning besteht unser Ziel darin, eine Metrik oder ein Abstandsmaß zwischen verschiedenen Datenpunkten zu lernen. Wenn wir unser Modell richtig trainieren, wird dieses Abstandsmaß Beispiele der gleichen Klasse nahe beieinander und verschiedene Klassen weiter auseinander liegen lassen.
Anwendungen
Wie bereits erwähnt, ist die naheliegendste Anwendung für Metric Learning die Gesichtserkennung, aber es gibt eine extrem breite Palette von Anwendungen (mit Open-Source-Datensätzen), darunter Vogelarten, Fahrzeuge und Produkte. Diese basieren alle auf Computer Vision, aber es gibt keinen Grund, warum Sie diese Technik nicht auch für jede andere Art von Daten verwenden können, die Sie haben. Ersetzen Sie einfach den Kodierer durch etwas, das für Ihre Daten sinnvoll ist!
Abstandsmaße
Wie sieht nun ein Modell aus, das den Abstand zwischen zwei Datenpunkten misst? Nun, das eigentliche Modell bildet einfach einen Datenpunkt auf einen Vektor ab. Dann müssen wir uns für ein Abstandsmaß zwischen diesen Punkten entscheiden.
Das naheliegendste davon ist der normale euklidische Abstand:
wobei $$x$$ und $$y$$ die beiden Vektoren sind. Dieser Abstand hat jedoch das Problem, dass er nach oben unbegrenzt ist und (je nach Wahl der Verlustfunktion) das Modell "schummeln" und schöne Verluste erzielen könnte, indem es einfache nicht übereinstimmende Paare weit auseinander schickt und sich bei schwierigen Fälle nicht verbessert. Stattdessen ist es üblicher, den Kosinusabstand zu verwenden:
wobei $$|x|$$ die L2-Norm von $$x$$ ist. Dieser hat den Vorteil, dass er begrenzt ist (wir erzwingen im Grunde, dass alle unsere Einbettungen auf einer hochdimensionalen Sphäre liegen). Ich sollte jedoch darauf hinweisen, dass es in der Literatur eine große Anzahl von Abstandsmaßen gibt. Ich werde sie nicht alle durchgehen, aber ich möchte besonders die hyperbolischen Einbettungen hervorheben, die bei einigen Datensätzen derzeit der Stand der Technik zu sein scheinen.
Verlustfunktionen
Nachdem wir unser Abstandsmaß gewählt haben, müssen wir eine Verlustfunktion auswählen. Ich werde hier nur über eine Option sprechen, den klassischen Triplett-Verlust, aber es gibt eine ebenso schwindelerregende Anzahl von Verlustfunktionen, von Ersatzklassifizierungsaufgaben bis hin zum kontrastiven Verlust im Stil von CLIP (einige Beispiele finden Sie hier).
Bei der Triplett-Verlustfunktion werden einfach drei Beispiele aus dem Datensatz ausgewählt, ein "Anker"-Beispiel als Vergleichsbasis, ein "positives" Beispiel derselben Klasse und ein "negatives" Beispiel einer anderen Klasse. Der Verlust ist dann einfach
wobei d das Abstandsmaß ist, $$A$$ der Ankervektor, $$N$$ der negative Vektor, $$P$$ der positive Vektor und $$margin$$ ein Hyperparameter ist. Diese Verlustfunktion ist gut, weil sie von unten begrenzt ist (sie kann nicht negativ sein) und der "margin" bedeutet, dass das Modell gezwungen ist, kleinere Werte für positive Paare vorherzusagen, anstatt einfach alle Paare klein zu machen. Der Verlust ist nur dann minimiert, wenn der Abstand für jedes negative Paar mindestens um die Marge größer ist als der Abstand zwischen jedem positiven Paar.
Sobald wir die Verlustfunktion haben, können Sie das Modell einfach mit all Ihren bevorzugten Deep-Learning-Trainingstricks trainieren.
Nützliche Pythonbibliotheken
Wenn Sie Metric Learning zu Hause ausprobieren wollen, habe ich ein paar Empfehlungen für Python-Bibliotheken (ich benutze PyTorch, also basieren diese alle auf diesem Framework, aber ich bin sicher, dass es ähnliche Ressourcen für TensorFlow gibt):
Pytorch Lightning: Es gibt eine große Anzahl von Wrappern für PyTorch, die die Notwendigkeit für Boilerplate-Code beseitigen. Die meisten von ihnen sind zu restriktiv, aber ich finde, dass PyTorch Lightning Ihnen genug Freiheit über Ihren Trainingsprozess gibt.
Pytorch Image Models (timm): Dies ist eine großartige Ressource für vortrainierte Bildkodierer, einschließlich einiger hochmoderner Architekturen; wenn Sie irgendeine Art von Computer Vision betreiben, ist dies ein großartiger Ausgangspunkt.
Pytorch Metric Learning: Dies ist eine riesige Bibliothek für Metric Learning. Ich habe nur die Datensätze und Verlustfunktionen verwendet, aber es gibt auch einige vorgefertigte Trainingsschleifen.
Falls Ihnen diese Einführung ins Metric Learning nicht ausreicht um Ihr Machine Learning Projekt zu verwirklichen, buchen sie doch einen kostenlosen Tech Lunch, indem wir Ihnen noch tiefgründigere Informationen zu Machine Learning Themen geben.