[PyTorch] Make sure the model and dataset are in cuda mode


PyTorch can easily switch between cpu and gpu modes for tensors with `` hoge.to (device) `, etc., but I often don't know if this dataset or model is cpu or gpu, so check it. I will write down how to do it.

Confirmation method

As a prerequisite, data set and model preparation

TRAIN = 'train'
DATA_DIR = 'dataset/predata/'
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transforms = {
    TRAIN: transforms.Compose([

#Data preprocessing
image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x), data_transforms[x]) for x in [TRAIN]}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=BATCH_SIZE, shuffle=True, num_workers=4) for x in [TRAIN]}

is_If you use cuda, it will be returned as a bool type.[here](https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180/2?u=kazuki_hamaguchi)(1)With reference to

data, label = iter(dataloaders[TRAIN]).next()
data = data.to(DEVICE)
label = label.to(DEVICE)

model = model.to(DEVICE)

You can check it with this.

at the end

I wonder what happens to the mode before and after exiting the local variable. It might be interesting to check

Reference material

(1) How to check if Model is on cuda

Recommended Posts

[PyTorch] Make sure the model and dataset are in cuda mode
Make any key the primary key in Django's model
Make sure to align the pre-processing at the time of forecast model creation and forecast
Generalized linear model (GLM) and neural network are the same (1)
Save the pystan model and results in a pickle file
Generalized linear model (GLM) and neural network are the same (2)
Fold Pytorch Dataset in layers
To make sure that the specified key is in the specified bucket in Boto 3
Throw something to Kinesis with python and make sure it's in
Deploy and use the prediction model created in Python on SQL Server
What are the "pipeline" and "{...}" in the Jenkins Pipeline pipeline {...} (for Groovy beginners, for experienced languages)
I implemented the VGG16 model in Keras and tried to identify CIFAR10
I tried to make PyTorch model API in Azure environment using TorchServe