Learn how to cross-validate when using a Dataset with Pytorch.
You can use torch.utils.data.dataset.Subset
to split a Dataset by specifying an index. Combine this with the scikit-learn sklearn.model_selection
.
train_test_split
Use sklearn.model_selection.train_test_split
to split the index into train_index
and valid_index
, and use Subset
to split the Dataset.
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import train_test_split
dataset = get_dataset()
train_index, valid_index = train_test_split(range(len(dataset)), test_size=0.3)
batch_size = 16
train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)
#Learning code here
Use sklearn.model_selection.KFold
to split the index into train_index
and valid_index
, and use Subset
to split the Dataset.
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold
dataset = get_dataset()
batch_size = 16
kf = KFold(n_splits=3)
for _fold, (train_index, test_index) in enumerate(kf.split(X)):
train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)
#Learning code here
If it is a class classification Dataset, you should be able to get the value of y
by usingdataset [:] [1]
, so you should be able to do Stratified KFold
as well.
Recommended Posts