torchvision.datasets
)Primarily for supervised or semi-supervised segmentation datasets
Before looking at the case in question, let's first consider the case where there is no problem.
When doing Data Augmentation with PyTorch, usually define the transformation as follows
transform = torchvision.transforms.Compose([
#Rotate by angle degrees
transforms.RandomRotation(degrees),
#Flip horizontally
transforms.RandomHorizontalFlip(),
#Invert vertically
transforms.RandomVerticalFlip()
])
I will put it in the argument of the dataset
dataset = HogeDataset.HogeDataset(
train=True, transform=transform
)
Probably this is not a problem for object class recognition etc. The reason is that the teacher data is not an image, so you only have to process the original image.
The next problematic case The difference from the previous case is that the teacher data is given as an image.
transform = torchvision.transforms.Compose([
#Rotate by angle degrees
transforms.RandomRotation(degrees),
#Flip horizontally
transforms.RandomHorizontalFlip(),
#Invert vertically
transforms.RandomVerticalFlip()
])
I will put it in the argument of the dataset
dataset = HogeDataset.HogeDataset(
train=True, transform=transform, target_transform=transform
)
However, in this case, when retrieving data from HogeDataset
, the conversions made to the original image and the mask image do not correspond.
Example) Original image: 90 degree rotation, mask image: 270 degree rotation
In this case, even if the data is inflated, it will not function as teacher data.
Argument target_transform
, why do you exist? However, the reason for this existence is probably that the mask image is also processed (without randomness) such as torchvision.transforms.Resize ()
and torchvision.transforms.ToTensor ()
. I think it's in
So, how can we apply the same processing to the mask image as the original image? As a solution, you can create your own Dataset class as shown below.
HogeDataset.py
import os
import glob
import torch
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random
from PIL import Image
DATA_PATH = '[Original image directory path]'
MASK_PATH = '[Mask image directory path]'
TRAIN_NUM = [Number of training data]
class HogeDataset(torch.utils.data.Dataset):
def __init__(self, transform = None, target_transform = None, train = True):
#transform and target_transform is a non-random transform such as tensorization
self.transform = transform
self.target_transform = target_transform
data_files = glob.glob(DATA_PATH + '/*.[File extension]')
mask_files = glob.glob(MASK_PATH + '/*.[File extension]')
self.dataset = []
self.maskset = []
#Import original image
for data_file in data_files:
self.dataset.append(Image.open(
DATA_PATH + os.path.basename(data_file)
))
#Mask image reading
for mask_file in mask_files:
self.maskset.append(Image.open(
MASK_PATH + os.path.basename(mask_file)
))
#Divided into training data and test data
if train:
self.dataset = self.dataset[:TRAIN_NUM]
self.maskset = self.maskset[:TRAIN_NUM]
else:
self.dataset = self.dataset[TRAIN_NUM+1:]
self.maskset = self.maskset[TRAIN_NUM+1:]
# Data Augmentation
#Random conversion is done here
self.augmented_dataset = []
self.augmented_maskset = []
for num in range(len(self.dataset)):
data = self.dataset[num]
mask = self.maskset[num]
#Random crop
for crop_num in range(16):
#Crop position is determined by random numbers
i, j, h, w = transforms.RandomCrop.get_params(data, output_size=(250,250))
cropped_data = tvf.crop(data, i, j, h, w)
cropped_mask = tvf.crop(mask, i, j, h, w)
#rotation(0, 90, 180,270 degrees)
for rotation_num in range(4):
rotated_data = tvf.rotate(cropped_data, angle=90*rotation_num)
rotated_mask = tvf.rotate(cropped_mask, angle=90*rotation_num)
#Either horizontal inversion or vertical inversion
#Invert(horizontal direction)
for h_flip_num in range(2):
h_flipped_data = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_data)
h_flipped_mask = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_mask)
"""
#Invert(Vertical direction)
for v_flip_num in range(2):
v_flipped_data = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_data)
v_flipped_mask = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_mask)
"""
#Add Data Augmented data
self.augmented_dataset.append(h_flipped_data)
self.augmented_maskset.append(h_flipped_mask)
self.datanum = len(self.augmented_dataset)
#Data size acquisition method
def __len__(self):
return self.datanum
#Data acquisition method
#Non-random conversion is done here
def __getitem__(self, idx):
out_data = self.augmented_dataset[idx]
out_mask = self.augmented_maskset[idx]
if self.transform:
out_data = self.transform(out_data)
if self.target_transform:
out_mask = self.target_transform(out_mask)
return out_data, out_mask
What we are doing is simple, we do Data Augmentation inside __init __ ()
At that time, about each image pair
Comprehensive processing in all cases
For the time being, you can apply the same processing as the original image to the mask image and perform Data Augmetation like this ** [Supplement] It is recommended to use only horizontal or vertical inversion because the combination of rotation and inversion can cause duplication! !! ** **
Try using your own Dataset class in 3
import torch
import torchvision
import HogeDataset
BATCH_SIZE = [Batch size]
#Preprocessing
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
target_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224), interpolation=0),
torchvision.transforms.ToTensor()
])
#Preparation of training data and test data
trainset = HogeDataset.HogeDataset(
train=True,
transform=transform,
target_transform=target_transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=BATCH_SIZE,
shuffle=True
)
testset = EpiDatasets.EpiDatasets(
train=False,
transform=transform,
target_transform=target_transform
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=BATCH_SIZE,
shuffle=True
)
Recommended Posts