[PyTorch] Be careful of different types of operations between different versions.

TL;DR In torch = <1.2.0 and torch => 1.4.0, the result of float type and torch.int64 type operation (torch.sub) was different.

In torch = <1.2.0, the operation of (float type) and (torch.int64 type) = (torch.int64). In torch => 1.4.0, the operation of (float type) and (torch.int64 type) = (torch.float32).

When torch = <1.2.0, the information after the decimal point of (float type) disappears from the operation result. It is important to unify the execution environment or consider the cast when performing calculations.

Introduction

I will post it on Qiita for the first time. This is A (twitter).

I usually use PyTorch to enjoy building NN and studying. Under such circumstances, I ran into this problem when performing numerical operations between different versions of PyTorch, so I will write it down as a memo.

If there are any mistakes, I would appreciate it if you could point them out in the comments.

background

In the first place, why do I go back and forth between different versions of PyTorch?

Since the frequency of entering the laboratory has decreased due to the influence of corona, I thought that it would be easier to do research if I built an environment on my home gaming PC, so I built the environment as follows.

Then, when I ran the code I was running in the laboratory at home, it worked normally, so I thought, "Eh, eh, eh," and enjoyed my research life comfortably.

problem

After that, I checked the coding work and whether it could be executed at home, and when I was learning and verifying NN in the laboratory, I noticed that different results were output between the two environments. ..

Below is a Python command line to reproduce the problem.

Laboratory environment

>>> import torch
>>> torch.__version__
'1.1.0'
>>> float = 3.14
>>> tensor_int = torch.tensor(3, dtype=torch.int64)
>>>
>>> type(float)
<class 'float'>
>>> tensor_int.dtype
torch.int64
>>>
>>> ans = torch.sub(float, tensor_int)
>>> ans
tensor(0)
>>>
>>> ans.dtype
torch.int64
>>>  

Home environment

>>> import torch
>>> torch.__version__
'1.5.0'
>>> float = 3.14
>>> tensor_int = torch.tensor(3, dtype=torch.int64)
>>>
>>> type(float)
<class 'float'>
>>> tensor_int.dtype
torch.int64
>>>
>>> ans = torch.sub(float, tensor_int)
>>> ans
tensor(0.1400)
>>>
>>> ans.dtype
torch.float32
>>>  

As you can see, the data type of the operation result "ans" is torch.int64 in the laboratory environment and torch.float32 in the home environment. In other words, in torch == 1.1.0, the information after the decimal point of "float" disappeared in "ans".

This issue probably depends on issues between torch versions. (I think that torch == 1.5.0 solves the problem of torch.int64. Thank you PyTorch.)

Verification

Now that I expected it to depend on issues between torch versions, I tried to verify where this specification changed from torch = = 1.1.0 to 1.5.0.

The verified environment is as follows.

The verified version is as follows.

(1.3.0 wasn't in the PyTorch Archive, so I haven't verified it.)

torch == 1.1.0 (repost)

>>> import torch
>>> torch.__version__
'1.1.0'
>>> float = 3.14
>>> tensor_int = torch.tensor(3, dtype=torch.int64)
>>>
>>> ans = torch.sub(float, tensor_int)
>>> ans
tensor(0)
>>>
>>> ans.dtype
torch.int64
>>>  

torch==1.2.0

>>> import torch
>>> torch.__version__
'1.2.0'
>>> float = 3.14
>>> tensor_int = torch.tensor(3, dtype=torch.int64)
>>>
>>> ans = torch.sub(float, tensor_int)
>>> ans
tensor(0)
>>>
>>> ans.dtype
torch.int64
>>>  

torch==1.4.0

>>> import torch
>>> torch.__version__
'1.4.0'
>>> float = 3.14
>>> tensor_int = torch.tensor(3, dtype=torch.int64)
>>>
>>> ans = torch.sub(float, tensor_int)
>>> ans
tensor(0.1400)
>>>
>>> ans.dtype
torch.float32
>>>      

torch == 1.5.0 (repost)

>>> import torch
>>> torch.__version__
'1.5.0'
>>> float = 3.14
>>> tensor_int = torch.tensor(3, dtype=torch.int64)
>>>
>>> ans = torch.sub(float, tensor_int)
>>> ans
tensor(0.1400)
>>>
>>> ans.dtype
torch.float32
>>>  

From the result, it seems that the specifications have changed from torch = = 1.4.0. It also depended on the torch version.

Probably Official Documentation and pytorch 1.4 Release Information You can see it by reading /pytorch/pytorch/releases/tag/v1.4.0). (I couldn't find it ...)

in conclusion

Under the title of "Operations of different types between different versions", we focused on float type and torch.int64 type operations between different versions of torch and verified the difference in output results.

What I can say is

is.

Since I had CUDA 9.0 installed in my laboratory environment, I had the compromise of using torch = = 1.1.0, which led to this result. Taking this opportunity, we upgraded Python, CUDA, and torch to align the laboratory environment with the home environment.

Why don't you take another look at the development environment?

Recommended Posts

[PyTorch] Be careful of different types of operations between different versions.
Summary of examples that cannot be pyTorch backward
The answer of "1/2" is different between python2 and 3
Let's use different versions of SQLite3 from Python3!
[Java] How to switch between multiple versions of Java