Notez le Dataset / DataLoader utilisé lors de la création d'un dataset avec Pytorch
référence:
Pour le prétraitement des données, il existe une bibliothèque par torchvision.transforms
ou ʻalbumentations. L'opération de base est la même pour les deux. Créez une instance en compressant l'instance de la classe de prétraitement dans la liste et en l'utilisant comme argument de
Compose ().
Compose a une méthode
call (self, img)`, donc si vous mettez une image dans l'argument de l'instance créée, elle sera prétraitée.
import alubmentations as alb
def get_augmentation(phase):
transform_list = []
if phase == 'train':
transform_list.extend([albu.HorizonFlip(p=0.5),
albu.VerticalFlip(p=0.5)])
transform_list.extend([albu.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
p=1),
albu.ToTensor()
])
return albu.Compose(transform_list)
Dataset ** Un module qui récupère les données d'entrée et les étiquettes correspondantes une par une **. Lors du prétraitement des données, ** les transformations doivent être utilisées pour renvoyer les données prétraitées **.
**
--Héritage de Dataset
__getitem__
, __len__
Fondamentalement OK si ce qui précède est satisfait! Une instance de la classe d'héritage Dataset est le premier argument de DataLoder. (Plus tard pour Data Lodaer)
Par exemple, supposons que l'ensemble de données ait la structure de répertoires suivante.
datasets/ ____ train_images/
|__ test_images/
|__ train.csv
Cette fois, nous supposons que le fichier .csv contient des informations de chemin de données et d'étiquette pour l'ensemble de données.
import os.path as osp
import cv2
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class MyDataManager(Dataset):
"""My Dataset
Args:
root(str): root path of dataset directory
df(DataFrame): DataFrame object from csv file
phase(str): train or test
"""
def __init__(self, root, df, phase):
super(MyDataManager, self).__init__()
self.root = root
self.df = df
self.phase = phase
self.transfoms = get_augmentation()
def __getitem__(self, idx):
img_path = osp.join(self.root, self.df.iloc[idx].name)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = self.transform(image=img)
label = osp.join(eslf.root, self.df.iloc[idx].value)
ret = {'image': img, 'label': label}
return ret
def __len__(self):
return len(self.df)
Cette fois, la valeur de retour est de type dict, mais il n'y a pas de problème avec return image, label
.
Lors de l'exécution d'une segmentation, etc., il est nécessaire de donner l'étiquette en tant qu'image de masque, donc dans ce cas, transférez également l'image de masque.
DataLoader
Les données récupérées par Datset peuvent être utilisées comme argument de DataLoader. La structure des arguments de DataLoader est la suivante
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
Par conséquent, créez la fonction suivante.
def dataloader(dir_path,phase,batch_size, num_workers, shuffle=False):
df_path = osp.join(dir_path, 'train.csv')
df = pd.read_csv(df_path)
dataset = MyDataManager(dir_path, df, phase)
dl = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=shuffle)
return dl
Recommended Posts