When N-folding an existing Dataset, 1,2, ... N, 1,2, ..., N, 1,2, ..., N, 1,2 ,. .. and split the dataset. It is used to prevent the data from being biased to a specific month or season when dividing time series data.
from torch.utils.data import Dataset
class LayeredFoldWrapper(Dataset):
def __init__(self, dataset, n_splits=5, fold=0, valid=False):
self.dataset = dataset
self.n_splits = n_splits
self.fold = fold
self.valid = valid
self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
self.train_index = list(set(range(len(dataset))) - set(self.valid_index))
def __len__(self):
return len(self._get_index_list(self.valid))
def __getitem__(self, i):
return self.dataset.__getitem__(self._get_index_list(self.valid)[i])
def _valid_index(self, N, n_splits, fold):
"""
N:Number of total data
n_splits:Number of fold splits
fold:Value to specify each fold 0<=fold<=n_splits-1
"""
assert(0<=fold<=n_splits-1)
return range(n_splits - fold - 1, N+1, n_splits)
def _get_index_list(self, valid):
if valid:
return self.valid_index
else:
return self.train_index
Recommended Posts