It's been half a year since I started studying machine learning, and I managed to make a Dataset with PyTorch, so I'll post it as a reminder. When I was studying GAN, I was studying by dropping the code from GitHub, but since I was only reading MNIST and CIFAR, I wanted to execute it with my own dataset, so I made my own Dataset. (I don't know because the practice of posting articles is also an article for some time ...)
PyTorch Dataset inheritance To pass an object of this Dataset inheritance class to DataLoader when passing it to the training model
\ _ \ _ getitem \ _ \ _ and \ _ \ _ len \ _ \ _ methods \ _ \ _ Getitem \ _ \ _ is a method that returns data and labels as tuples \ _ \ _ Len \ _ \ _ means as it is, a method that returns the number of data
So, the basic configuration is like this.
class MyDataset(torch.utils.data.Dataset):
def __init__(self, imageSize, dir_path, transform=None):
pass
def __len__(self):
pass
def __getitem__(self, idx):
pass
In addition to the Path to the data, we passed the image input size and transform for preprocessing as arguments to the class.
The constructor, which is automatically called when the class is created, performs the following processing.
def __init__(self, imageSize, dir_path, transform=None):
self.transform = transforms.Compose([
transforms.Resize(imageSize), #Image resizing
transforms.ToTensor(), #Tensorization
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), #Standardization
])
#Enter the input data and label here
self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]
self.data_num = len(self.image_paths) #Here is__len__Becomes the return value of
self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}
I had multi-classified material data, so I used it.
Since \ _ \ _ getitem \ _ \ _ is a method for reading data and its correct label during learning, we will implement it using the information read by the constructor.
def __getitem__(self, idx):
p = self.image_paths[idx]
image = Image.open(p)
if self.transform:
out_data = self.transform(image)
out_label = p.split("\\")
out_label = self.class_to_idx[out_label[3]]
return out_data, out_label
I think that the image data can be read by the constructor, but I was worried about the memory when there was a lot of data, so I decided to read it each time. I also use a slightly annoying method of making a dictionary for class labels.
When you actually read it in the code, you can use it for learning by using it as follows. (DataLoader argument shuffle randomizes how data is referenced)
data_set = MyDataset(32, dir_path=root_data)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)
import torch.utils.data
import torchvision.transforms as transforms
from pathlib import Path
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
def __init__(self, imageSize, dir_path, transform=None):
self.transform = transforms.Compose([
transforms.Resize(imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]
self.data_num = len(self.image_paths)
self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}
def __len__(self):
return self.data_num
def __getitem__(self, idx):
p = self.image_paths[idx]
image = Image.open(p)
if self.transform:
out_data = self.transform(image)
out_label = p.split("\\")
out_label = self.class_to_idx[out_label[3]]
return out_data, out_label
if __name__ == "__main__":
root_data = 'Path to data'
data_set = MyDataset(32, dir_path=root_data)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)
I implemented it while looking at the following site. Thank you very much. Explanation of transforms, Datasets, Dataloader of pyTorch and creation and use of self-made Dataset PyTorch: Dataset and DataLoader (Image Processing Task)
Recommended Posts