[PyTorch] Runtime Error: Expected object of scalar type Float but got scalar type Double for argument # 4'mat1'

Error in PyTorch

I will leave it for those who are stuck in the same place. When using PyTorch, I got the following error.

RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'

As a result of various investigations, it seems that the problem is that the number in the tensor becomes torch.double type when converting to Tensor type. (There are many methods in PyTorch class that are premised on ** torch.float type **)

So

Before correction

X_train = torch.from_numpy(X_train)
y_train = torch.from_numpy(y_train)
X_test = torch.from_numpy(X_test)
y_test = torch.from_numpy(y_test)

Revised

X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).long()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).long() 

It seems that you should convert it with .float () or .long () like this. (.long () is a conversion to a label)

References: 2nd PyTorch Tensor & Data Type Cheat Sheet

Recommended Posts

[PyTorch] Runtime Error: Expected object of scalar type Float but got scalar type Double for argument # 4'mat1'
Python error support note: "... does not support argument 0 of type float ..."