De nombreuses personnes utilisent DataLoader lors du chargement des ensembles de données avec PyTorch. (Il existe de nombreux bons articles sur l'utilisation de DataLoader. Par exemple, cet article est facile à comprendre.)
collate_fn
est l'un des arguments donnés au constructeur lors de la création d'une instance DataLoader
, et a pour rôle de regrouper les données individuelles extraites de l'ensemble de données dans un mini-lot.
Plus précisément, collate_fn
provient de l'ensemble de données **, comme décrit dans la documentation officielle (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn). Entrez la liste des données récupérées **. Ensuite, la valeur de retour de collate_fn
sera sortie de DataLoader
.
Par conséquent, lorsque vous lisez des données de votre propre ensemble de données avec DataLoader
, vous pouvez les gérer en créant votre propre collate_fn
comme indiqué dans l'exemple ci-dessous.
def simple_collate_fn(list_of_data):
#Ici, on suppose que chaque donnée est un vecteur D-dimensionnel.
tensors = [torch.FloatTensor(data) for data in list_of_data]
#Combinez les dimensions nouvellement ajoutées dans un mini-lot dans une matrice N x D.(N est le nombre de données)
batched_tensor = tensor.stack(tensors, dim=0)
#Cette valeur de retour est
# for batched_tensor in dataloader:
#Est sorti du chargeur de données.
return batched_tensor
Afin de simplifier l'implémentation, je voudrais éviter d'implémenter mon propre collate_fn
si le comportement par défaut sans donner collate_fn
peut être utilisé.
Quand je l'ai recherché, collate_fn
est assez sophistiqué même par défaut, et il semble que ce ne soit pas seulement une combinaison de tenseurs comme torch.stack (*, dim = 0)
, donc cette fois comme mémorandum ce défaut Je voudrais résumer les fonctions.
En fait, le comportement par défaut de collate_fn
est bien documenté dans la documentation officielle (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn).
- It always prepends a new dimension as the batch dimension.
- It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
- It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.
En d'autres termes, il semble avoir les fonctions suivantes.
dict
, list
, tuple
, namedtuple
, etc.)J'ai été particulièrement surpris car je n'avais jamais entendu parler de l'existence de la troisième fonction. (Je suis gêné d'avoir implémenté un simple collate_fn
qui regroupe respectivement plusieurs vecteurs de données ...)
Cependant, comme le comportement détaillé ne peut être compris sans examiner réellement l'implémentation, Implémentation réelle Je voudrais jeter un œil à (/collate.py).
Je pense que c'est le plus rapide pour le lire, mais je vais le résumer grossièrement pour que vous n'ayez pas à relire l'implémentation lorsque vous la vérifierez à nouveau à l'avenir.
Informations à partir de la version 1.5.
La valeur par défaut collate_fn
, default_collate
, est un processus récursif, et le processus est classé en fonction du type du premier élément de l'argument batch
.
elem = batch[0]
elem_type = type(elem)
Ci-dessous, nous résumerons le traitement spécifique par type d''elem '.
torch.Tensor
Si batch
est torch.Tensor
, il ajoute simplement une dimension en premier et se joint.
return torch.stack(batch, 0)
Dans le cas de "ndarray" de numpy, il est tensorisé puis combiné comme dans le cas de "torch.Tensor".
return default_collate([torch.as_tensor(b) for b in batch])
Par contre, dans le cas du scalaire numpy, le "batch" courant est un vecteur, il est donc converti en un tenseur tel quel.
return torch.as_tensor(batch)
float
, int
, str
Dans ce cas également, «batch» est un vecteur, il est donc renvoyé sous forme de tensorisé ou de liste comme indiqué ci-dessous.
# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch
collections.abc.Mapping
telles que dict
Comme indiqué ci-dessous, chaque clé est groupée et renvoyée en tant que valeur de clé d'origine.
return {key: default_collate([d[key] for d in batch]) for key in elem}
namedtuple
Dans ce cas également, le traitement par lots est effectué pour chaque attribut tout en conservant le même nom d'attribut que le «namedtuple» d'origine.
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
collections.abc.Sequence
telles que list
Le traitement par lots est effectué pour chaque élément comme indiqué ci-dessous.
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
Par exemple, essayez de lire un ensemble de données avec une structure compliquée qui inclut des dictionnaires et des chaînes de caractères avec la valeur par défaut collate_fn
comme indiqué ci-dessous.
import numpy as np
from torch.utils.data import DataLoader
if __name__=="__main__":
complex_dataset = [
[0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
[1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
]
dataloader = DataLoader(complex_dataset, batch_size=2)
for batch in dataloader:
print(batch)
Ensuite, vous pouvez confirmer qu'il est correctement groupé comme suit.
[
tensor([0, 1]),
('Bob', 'Tom'),
{
'height': tensor([172.5000, 153.1000], dtype=torch.float64),
'feature': tensor([[1, 2, 3],[3, 2, 1]])
}
]
Au fait, le «float» de python est converti en «torch.float64» par défaut. Normalement, numpy.ndarray
exprime un vecteur ou un tenseur, donc je pense qu'il n'y a pas de problème, mais si vous ne le savez pas, vous tomberez dans un piège.
Recommended Posts