PyTorch DataLoader est lent

Dans PyTorch, DataLoader (torch.utils.data.DataLoader) est souvent utilisé pour récupérer un mini-lot à partir d'un ensemble de données, mais lors de l'expérimentation de données de grande taille, le DataLoader de PyTorch peut être utilisé. Cela s'est avéré très long. À titre de comparaison, j'ai créé mon propre italator qui récupère le mini-lot de l'ensemble de données et l'ai essayé, mais j'ai constaté que le DataLoader de Pytorch était considérablement plus lent que cela. Cela peut être un goulot d'étranglement, en particulier lors de l'utilisation de grandes tailles de données.

[Ajout: 23/03/2020] J'ai reçu un commentaire indiquant que la cause du retard est BatchSampler, qui est utilisé par défaut dans DataLoader. Voir les commentaires pour plus de détails.

Réglage

Dans ce qui suit, on suppose qu'un mini-lot d'une taille de lot de 10 000 est extrait de manière répétée de «étiquette» et «cible» avec 1 million de données. Google Colaboratory a été utilisé comme environnement de calcul.

import torch

label  = torch.randn(1000000,10)
target = torch.randn(1000000,10)
batch_size = 10000

Créez un loader pour récupérer le mini-batch et mesurez le temps d'exécution en utilisant la fonction suivante qui répète simplement la récupération du mini-batch.

def run_loader(loader):
    for label,target in loader:
        pass

Chargeur de données Pytorch

Quand j'ai créé un chargeur en utilisant torch.utils.data.DataLoader (sans shuffle) et mesuré le temps d'exécution, il était de 6,8 secondes. On a l'impression que la récupération des données prend beaucoup de temps.

dataset = torch.utils.data.TensorDataset(label,target)
loader1 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader1)

 1 loop, best of 1: 6.83 s per loop

Lorsque la lecture aléatoire a été effectuée, cela a pris 7,0 secondes.

dataset = torch.utils.data.TensorDataset(label,target)
loader2 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader2)

 1 loop, best of 1: 6.97 s per loop

Chargeur de données Homebrew

À titre de comparaison, j'ai créé un italateur qui extrait un mini-lot de l'ensemble de données et conduit la même expérience.

class DataLoader:

    def __init__(self,dataset,batch_size=1,shuffle=False):
        self.dataset = dataset 
        self.batch_size = batch_size
        self.shuffle = shuffle
        assert all([ dataset[i].size(0) == dataset[0].size(0) for i in range(len(dataset)) ]), 'all the elemtnes must have the same length'
        self.data_size = dataset[0].size(0)

    def __iter__(self):
        self._i = 0
        
        if self.shuffle:
            index_shuffle = torch.randperm(self.data_size)
            self.dataset = [ v[index_shuffle] for v in self.dataset ]

        return self

    def __next__(self):

        i1 = self.batch_size * self._i
        i2 = min( self.batch_size * ( self._i + 1 ), self.data_size )
        
        if i1 >= self.data_size:
            raise StopIteration()

        value = [ v[i1:i2] for v in self.dataset ]

        self._i += 1

        return value

Si vous utilisez votre propre DataLoader (sans shuffle), vous pouvez voir que le temps d'exécution est de 500 microsecondes et que la récupération ne prend presque pas de temps.

loader3 = DataLoader([label,target],batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader3)

 1 loop, best of 1: 468 µs per loop

Le temps d'exécution de la lecture aléatoire est de 300 millisecondes, ce qui est plus long que sans lui, mais il est toujours négligeable par rapport à l'utilisation du chargeur de données de Pytorch.

loader4 = DataLoader([label,target],batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader4)

 1 loop, best of 1: 296 ms per loop

Résumé

Il s'avère que la récupération d'un mini-lot prend beaucoup de temps à l'aide du DataLoader de PyTorch. Cet effet est très important, en particulier lorsqu'il s'agit de données de grande taille.

Recommended Posts

PyTorch DataLoader est lent
[Tutoriel PyTorch ①] Qu'est-ce que PyTorch?
pandas idxmax est lent
le type de booléen pypy est lent
[Pytorch] Mémo sur Dataset / DataLoader
Le module PyTorch indique que libcusparse.so.10 est manquant
[Tutoriel PyTorch ⑥] Qu'est-ce que torch.nn?
L'opérateur de Python est-il lent? (À partir de ABC167D)