Try using Pytorch's collate_fn

Pytorch collate_fn is an argument to Dataloader.

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

This time I would like to confirm its behavior and usage.

What is collate_fn?

When the \ _ \ _ getitem \ _ \ _ defined in the dataset is in batch form, each element (image, target, etc.) is first consolidated in a list. Collate_fn manipulates it as described in Pytroch Official and eventually makes it torch.Tensor It is a function.

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

By default, it is only Tensor with torch.stack (), but you can make an advanced batch by using your own collate_fn.

Make your own collate_fn

The default behavior is almost the same as below. (Although the number of returns depends on \ _ \ _ getitem \ _ \ _) It takes a batch as an argument, stacks it, and returns it.

def collate_fn(batch):
    images, targets= list(zip(*batch))
    images = torch.stack(images)
    targets = torch.stack(targets)
    return images, targets

You can change the contents of your own collate_fn.

This time we will create a batch of object detection. Object detection basically inputs the rectangle of the object and its label, but since there may be multiple rectangles in one image, it is necessary to connect which image is which rectangle when batching, and the index Must be attached.

[[label, xc, yx, w, h],
 [                   ],
 [                   ],...]

#Change this down

[[0, label xc, yx, w, h],
 [0,                   ],
 [1,                   ],...]

The implementation itself is not that difficult.

def batch_idx_fn(batch):
    images, bboxes = list(zip(*batch))
    targets = []
    for idx, bbox in enumerate(bboxes):
        target = np.zeros((len(bbox), 6))
        target[:, 1:] = bbox
        target[:, 0] = idx
        targets.append(target)
    images = torch.stack(images)
    targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
    return images, targets

When you actually use it, it will be as follows.

test_data_loader = torch.utils.data.DataLoader(
                       test_dataset, 
                       batch_size=1, 
                       shuffle=False, 
                       collate_fn=batch_idx_fn
                       )
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]

in conclusion

Other than when indexing in this article When the target changes for each batch, When the target is not numerical data that cannot be stacked I think it can be used when you want to use the same Dataset with slightly different changes.

Recommended Posts

Try using Pytorch's collate_fn
Try using Tkinter
Try using docker-py
Try using cookiecutter
Try using geopandas
Try using Selenium
Try using scipy
Try using pandas.DataFrame
Try using django-swiftbrowser
Try using matplotlib
Try using tf.metrics
Try using virtualenv (virtualenvwrapper)
Try using virtualenv now
[Kaggle] Try using LGBM
Try using Python's feedparser.
Try using Python's Tkinter
Try using Tweepy [Python2.7]
Try using PythonTex with Texpad.
[Python] Try using Tkinter's canvas
Try using Jupyter's Docker image
Try using scikit-learn (1) --K-means clustering
Try using matplotlib with PyCharm
Try using Azure Logic Apps
Try using Kubernetes Client -Python-
Try using the Twitter API
Try using OpenCV on Windows
Try using Jupyter Notebook dynamically
Try using AWS SageMaker Studio
Try tweeting automatically using Selenium.
Try using SQLAlchemy + MySQL (Part 1)
Try using the Twitter API
Try using SQLAlchemy + MySQL (Part 2)
Try using the PeeringDB 2.0 API
Try using Pelican's draft feature
Try using pytest-Overview and Samples-
Try using folium with anaconda
Try using Janus gateway's Admin API
Try using Spyder included in Anaconda
Try using design patterns (exporter edition)
Try using Pillow on iPython (Part 1)
Try using Pillow on iPython (Part 2)
Try using Pleasant's API (python / FastAPI)
Try using LevelDB in Python (plyvel)
Try using pynag to configure Nagios
Try using PyCharm's remote debugging feature
Try using ArUco on Raspberry Pi
Try using cheap LiDAR (Camsense X1)
[Sakura rental server] Try using flask.
Try to get statistics using e-Stat
Try using the Python Cmd module
Try using Python's networkx with AtCoder
Try using Leap Motion in Python
Try using GCP Handwriting Recognition (OCR)
Try using Amazon DynamoDB from Python
code-server Local environment (3) Try using VSCode Plugin
Try using the Wunderlist API in Python
Try mathematical formulas using Σ with python
Try using the web application framework Flask
Try using Bash on Windows 10 2 (TensorFlow installation)
Try using the Kraken API in Python
Try using the $ 6 discount LiDAR (Camsense X1)