Many people use DataLoader when loading datasets with PyTorch. (There are many good articles on how to use DataLoader. For example, this article is easy to understand.)
collate_fn
is one of the arguments given to the constructor when creating a DataLoader
instance, and has the role of grouping the individual data retrieved from the dataset into a mini-batch.
More specifically, collate_fn
is from the ** dataset, as described in the Official Documentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn). Enter the list of retrieved data **. Then, the return value of collate_fn
will be output from DataLoader
.
Therefore, when reading data from your own dataset with DataLoader
, you can handle it by creating collate_fn
as shown in the example below.
def simple_collate_fn(list_of_data):
#Here we assume that each piece of data is a D-dimensional vector.
tensors = [torch.FloatTensor(data) for data in list_of_data]
#Combine the newly added dimensions into a mini-batch into an N x D matrix.(N is the number of data)
batched_tensor = tensor.stack(tensors, dim=0)
#This return value is
# for batched_tensor in dataloader:
#It is output from DataLoader as follows.
return batched_tensor
In order to simplify the implementation, I would like to avoid implementing my own collate_fn
if the default behavior without giving collate_fn
can be used.
When I looked it up, collate_fn
is quite sophisticated even by default, and it seems that it is not just a combination of tensors liketorch.stack (*, dim = 0)
, so this time as a memorandum this default I would like to summarize the functions.
In fact, the default behavior of collate_fn
is well documented in the Official Documentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn).
- It always prepends a new dimension as the batch dimension.
- It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
- It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.
In other words, it seems to have the following functions.
dict
, list
, tuple
, namedtuple
, etc.)I was particularly surprised because I had never heard of the existence of the third function. (I'm embarrassed to implement a simple collate_fn
that batches multiple data vectors each ...)
However, since you can not understand the detailed behavior without actually looking at the implementation, [Actual implementation](https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/data/_utils I would like to take a look at (/collate.py).
I think it's the quickest to actually read it, but I'll summarize it roughly so that you don't have to read the implementation again when you check it again in the future.
Information as of version 1.5.
The default collate_fn
, default_collate
, is a recursive process, and the process is classified according to the type of the first element of the argument batch
.
elem = batch[0]
elem_type = type(elem)
Below, we will summarize the specific processing by the type of ʻelem`.
torch.Tensor
If batch
is torch.Tensor
, it simply adds one dimension first and joins.
return torch.stack(batch, 0)
numpy
In the case of ndarray
of numpy, it is tensorized and then combined as in the case of torch.Tensor
.
return default_collate([torch.as_tensor(b) for b in batch])
On the other hand, in the case of numpy scalar, the current batch
is a vector, so it is tensorized as it is.
return torch.as_tensor(batch)
float
, int
, str
In this case as well, batch
is a vector, so it is returned as a tensor or list as shown below.
# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch
collections.abc.Mapping
such as dict
As shown below, each key is batched and returned as the original key value.
return {key: default_collate([d[key] for d in batch]) for key in elem}
namedtuple
In this case as well, batch processing is performed for each attribute while retaining the same attribute name as the original namedtuple
.
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
collections.abc.Sequence
such as list
Batch processing is performed for each element as shown below.
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
For example, try reading a dataset with a complex structure that includes dictionaries and strings as shown below with the default collate_fn
.
import numpy as np
from torch.utils.data import DataLoader
if __name__=="__main__":
complex_dataset = [
[0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
[1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
]
dataloader = DataLoader(complex_dataset, batch_size=2)
for batch in dataloader:
print(batch)
Then, you can confirm that it is successfully batched as follows.
[
tensor([0, 1]),
('Bob', 'Tom'),
{
'height': tensor([172.5000, 153.1000], dtype=torch.float64),
'feature': tensor([[1, 2, 3],[3, 2, 1]])
}
]
By the way, python's float
is converted to torch.float64
by default. Normally, numpy.ndarray
is used to represent vectors and tensors, so I think there is no problem, but if you don't know it, you'll fall into a trap.
Recommended Posts