DataLoader et DataSet de PyTorch que j'ai utilisés sans trop comprendre quand j'ai remarqué. Je vous serais reconnaissant si vous pouviez vous y référer si vous voulez faire quelque chose d'un peu élaboré.
La deuxième partie est ici.
Si vous utilisez PyTorch, vous avez probablement vu DataLoader. L'exemple PyTorch de MNIST, que tout le monde utilise pour l'apprentissage automatique, a également cette description.
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=256,
shuffle=True)
Ou si vous recherchez avec Qiita etc., vous verrez cette façon d'écrire.
train_dataset = datasets.MNIST(
'~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=256,
shuffle=True)
Après cela, vous pouvez voir comment les données sont acquises et apprises par lots en les tournant avec l'instruction For.
for epoch in epochs:
for img, label in train_loader:
#Décrivez le processus d'apprentissage dans ce
Parlant verbalement, DataLoader est un gars pratique qui suit une certaine règle et transporte les données comme décrit dans le DataSet. Par exemple, dans l'exemple ci-dessus, 256 données MNIST (mini-lot) seront incluses dans img et label avec des données normalisées. Jetons un coup d'œil au contenu pour voir comment cela est réalisé.
Jetons un coup d'œil à l'implémentation de DataLoader. Vous pouvez immédiatement voir qu'il est en classe.
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
"""
#réduction
Je vais omettre les détails, mais si vous regardez de plus près les informations en tant qu'itérateur, vous trouverez l'implémentation suivante.
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
Lorsque ce __next__
est appelé, les données sont renvoyées.
Et ces données semblent être créées en passant l'index à l'ensemble de données.
À ce stade, vous n'avez pas à être si nerveux à propos de la façon dont l'index est créé et de la façon dont le jeu de données est appelé, mais comme c'est un gros problème, allons plus loin.
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
L'index est passé à l'ensemble de données. Appeler une instance de la classe de cette manière devrait signifier que __getitem__
est appelé dans l'ensemble de données. (Ici est détaillé. Passons à l'ensemble de données basé sur cela.
Dès que vous accédez à la définition de MNIST, vous pouvez voir qu'il s'agit d'une classe.
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
"""
#réduction
Allons voir __getitem__
.
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
Il est écrit d'une manière facile à comprendre que l'index est passé et les données correspondantes sont renvoyées. Je vois, c'est ainsi que les données MNIST sont renvoyées.
En regardant le processus pendant un moment, Image.fromarray de PIL est également écrit. En d'autres termes, si vous concevez et écrivez ce __getitem__
, il est possible de renvoyer n'importe quelle donnée.
Mais il y a encore quelque chose que je ne comprends pas. Comment l'index est-il créé? L'indice est ici.
if sampler is None: # give default samplers
if self.dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
self.sampler = sampler
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
L'index semble être créé par échantillonneur. Par défaut, l'échantillonneur est commuté par l'argument True, False appelé shuffle. Par exemple, regardons l'implémentation lorsque shuffle = False.
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
Le data_source ici est un ensemble de données. Il semble que je sois arrivé jusqu'ici et que j'ai une idée générale. En d'autres termes, il se répète pour la longueur de l'ensemble de données. Inversement, il semble nécessaire de préparer une méthode spéciale appelée «len» dans l'ensemble de données.
Vérifions __len__
dans datasets.MNIST.
def __len__(self):
return len(self.data)
Vous renvoyez la longueur des données. Étant donné que les données dans MNIST ont une taille de 60000x28x28, 60000 seront renvoyées. Cela a été assez rafraîchissant.
L'article s'allonge, c'est donc tout pour la première partie. Dans Partie 2, vous allez créer votre propre ensemble de données.
Recommended Posts