Objekterkennung in Videos mit PyTorch
Dr. William Clemens
Selbstfahrende Autos haben immer noch Schwierigkeiten, vor ihnen liegende Objekte mit ausreichender Zuverlässigkeit zu erkennen. Im Allgemeinen ist die Leistung modernster Objekterkennungsmodelle jedoch bereits sehr beeindruckend - und sie sind nicht allzu schwierig anzuwenden.
Hier werde ich Sie durch das Streaming eines YouTube-Videos in Python und die anschließende Anwendung eines vortrainierten PyTorch-Modells zur Erkennung von Objekten führen.
Wir werden ein auf dem Objekterkennungsdatensatz COCO vortrainiertes Modell. (In Wirklichkeit würden wir natürlich selbst die Feinabstimmung des Modells vornehmen!)
Von YouTube zu OpenCV
Zuerst die Importe. Die meisten davon sind ziemlich gängig im Bereich Bildverarbeitung und Computer Vision. Pafy ist eine Video-Streaming-Bibliothek, und wir werden später die Colourmaps aus der Matplotlib für die Bounding-Boxen benötigen.
COCO_CLASSES
ist nur ein Wörterbuch, das die COCO-Klassennamen enthält.
Wir werden die NVIDIA-Implementierung des SSD-Modells mit Hilfe von torch hub verwenden. Wenn Sie an den Details des Netzwerks interessiert sind, können Sie das Papier hier lesen.
import numpy as np
import cv2
import pafy
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
import torch
from torch import nn
from torchvision import transforms
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils')
url = "https://www.youtube.com/watch?v=wqctLW0Hb_0"
Lassen Sie uns eine Hilfsfunktion schreiben, um ein OpenCV-VideoCapture
-Objekt zu erhalten, das unser YouTube-Video enthält:
def get_youtube_cap(url):
play = pafy.new(url).streams[-1] # we will take the lowest quality stream
assert play is not None # makes sure we get an error if the video failed to load
return cv2.VideoCapture(play.url)
Jetzt können wir die Ausgabe dieser Funktion einfach als normales OpenCV-VideoCapture
-Objekt verwenden, wie von einer Webcam!
Wir öffnen den ersten Frame eines Videos, um einen Blick darauf zu werfen.
cap = get_youtube_cap("https://www.youtube.com/watch?v=usf5nltlu1E")
ret, frame = cap.read()
cap.release()
plt.imshow(frame[:,:,::-1]) # OpenCV uses BGR, whereas matplotlib uses RGB
plt.show()
Sie sollten diese Ausgabe erhalten:
Objekte erkennen
Ok, wir können bequem ein YouTube-Video laden, jetzt werden wir eine Objekterkennung durchführen.
Damit unser Code weiterhin gut aussieht, werden wir alle Details der Implementierung in einer aufrufbaren Klasse zusammenfassen.
class ObjectDetectionPipeline:
def __init__(self, threshold=0.5, device="cpu", cmap_name="tab10_r"):
# First we need a Transform object to turn numpy arrays to normalised tensors.
# We are using an SSD300 model that requires 300x300 images.
# The normalisation values are standard for pretrained pytorch models.
self.tfms = transforms.Compose([
transforms.Resize(300),
transforms.CenterCrop(300),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Next we need a model. We're setting it to evaluation mode and sending it to the correct device.
# We get some speedup from the gpu but not as much as we could.
# A more efficient way to do this would be to collect frames to a buffer,
# run them through the network as a batch, then output them one by one
self.model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd').eval().to(device)
# Stop the network from keeping gradients.
# It's not required but it gives some speedup / reduces memory use.
for param in self.model.parameters():
param.requires_grad = False
self.device = device
self.threshold = threshold # Confidence threshold for displaying boxes.
self.cmap = cm.get_cmap(cmap_name) # colour map
self.classes_to_labels = utils.get_coco_object_dictionary()
@staticmethod
def _crop_img(img):
"""Crop an image or batch of images to square"""
if len(img.shape) == 3:
y = img.shape[0]
x = img.shape[1]
elif len(img.shape) == 4:
y = img.shape[1]
x = img.shape[2]
else:
raise ValueError(f"Image shape: {img.shape} invalid")
out_size = min((y, x))
startx = x // 2 - out_size // 2
starty = y // 2 - out_size // 2
if len(img.shape) == 3:
return img[starty:starty+out_size, startx:startx+out_size]
elif len(img.shape) == 4:
return img[:, starty:starty+out_size, startx:startx+out_size]
def _plot_boxes(self, output_img, labels, boxes):
"""Plot boxes on an image"""
for label, (x1, y1, x2, y2) in zip(labels, boxes):
if (x2 - x1) * (y2 - y1) < 0.25:
# The model seems to output some large boxes that we know cannot be possible.
# This is a simple rule to remove them.
x1 = int(x1*output_img.shape[1])
y1 = int(y1*output_img.shape[0])
x2 = int(x2*output_img.shape[1])
y2 = int(y2*output_img.shape[0])
rgba = self.cmap(label)
bgr = rgba[2]*255, rgba[1]*255, rgba[0]*255
cv2.rectangle(output_img, (x1, y1), (x2, y2), bgr, 2)
cv2.putText(output_img, self.classes_to_labels[label - 1], (int(x1), int(y1)-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, bgr, 2)
return output_img
def __call__(self, img):
"""
Now the call method This takes a raw frame from opencv finds the boxes and draws on it.
"""
if type(img) == np.ndarray:
# single image case
# First convert the image to a tensor, reverse the channels, unsqueeze and send to the right device.
img_tens = self.tfms(Image.fromarray(img[:,:,::-1])).unsqueeze(0).to(self.device)
# Run the tensor through the network.
# We'll use NVIDIAs utils to decode.
results = utils.decode_results(self.model(img_tens))
boxes, labels, conf = utils.pick_best(results[0], self.threshold)
# Crop the image to match what we've been predicting on.
output_img = self._crop_img(img)
return self._plot_boxes(output_img, labels, boxes)
elif type(img) == list:
# batch case
if len(img) == 0:
# Catch empty batch case
return None
tens_batch = torch.cat([self.tfms(Image.fromarray(x[:,:,::-1])).unsqueeze(0) for x in img]).to(self.device)
results = utils.decode_results(self.model(tens_batch))
output_imgs = []
for im, result in zip(img, results):
boxes, labels, conf = utils.pick_best(result, self.threshold)
output_imgs.append(self._plot_boxes(self._crop_img(im), labels, boxes))
return output_imgs
else:
raise TypeError(f"Type {type(img)} not understood")
Probieren wir's aus
Jetzt haben wir im Grunde den gesamten Code geschrieben!
Probieren wir es am ersten Videobild aus.
obj_detect = ObjectDetectionPipeline(device="cpu", threshold=0.5)
plt.figure(figsize=(10,10))
plt.imshow(obj_detect(frame)[:,:,::-1])
plt.show()
Wir können dann einfach über das Video laufen und wie üblich in OpenCV in eine Videodatei schreiben.
batch_size = 16
cap = get_youtube_cap(url)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
size = min([width, height])
fourcc = cv2.VideoWriter_fourcc(*"MJPG")
out = cv2.VideoWriter("out.avi", fourcc, 20, (size, size))
obj_detect = ObjectDetectionPipeline(device="cuda", threshold=0.5)
exit_flag = True
while exit_flag:
batch_inputs = []
for _ in range(batch_size):
ret, frame = cap.read()
if ret:
batch_inputs.append(frame)
else:
exit_flag = False
break
outputs = obj_detect(batch_inputs)
if outputs is not None:
for output in outputs:
out.write(output)
else:
exit_flag = False
cap.release()
Fazit
Das war's. Oben habe ich alles vorgestellt und erklärt, was Sie brauchen, um Ihr eigenes Objekterkennungsmodell auf einem beliebigen YouTube-Video auszuführen. Falls Sie weitere Details zu einem Thema im Bereich Machine Learning brauchen um Ihr Projekt erfolgreich umzusetzen, können Sie einen kostenlosen Tech Lunch mit uns buchen.