I didn't understand batch normalization well, so I tried it with PyTorch. As a result, I understood that the input data should be aligned with an average of 0 and a variance of 1 for each column. Also, there are some notes that I noticed when I moved it, so make a note of it.
First import
import torch
import torch.nn.functional as F
from torch import nn
Determine the input data size and generate the value appropriately
input_samples = 100
input_features = 10
x = torch.rand((input_samples,input_features)) * 10
Although not very disjointed, it produces data with different column means and variances.
average
torch.mean(x, 0)
tensor([5.0644, 5.0873, 5.0446, 5.3872, 5.2406, 5.3518, 5.3203, 4.9909, 5.0590,
5.2169])
Distributed
torch.var(x, 0)
tensor([ 9.4876, 8.6519, 8.4050, 9.8280, 10.0146, 8.6054, 7.0800, 8.6111,
7.7851, 8.5604])
Let's apply batch normalization
batch_norm=nn.BatchNorm1d(input_features)
y = batch_norm(x)
average
torch.mean(y, 0)
tensor([ 1.9073e-08, 5.2452e-08, -4.7684e-09, 3.8743e-08, -3.8147e-08,
4.1723e-08, -7.8678e-08, -5.9605e-08, 5.7220e-08, 4.2915e-08],
grad_fn=<MeanBackward1>)
⇒ Almost 0.
Distributed
torch.var(y, 0)
tensor([1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101,
1.0101], grad_fn=<VarBackward1>)
⇒ Almost 1.
List the notes that I noticed by changing the input in various ways.
In case of one, the following error is output.
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 10])
In the case of one case, the mean is the data itself and the variance is 0, so there is no point in calculating it.
To make the mean 0, all the values in the column are 0, and of course the variance is 0 instead of 1.
If there are 3 data items, it will be `1.5```, if there are 10 items, it will be
`1.111```, etc. It will be 1 as the data size increases. I haven't delved into it in detail, but it seems to be due to the formula, so please take a look at the documentation.
If batch normalization is done immediately after input, can it be used for normalization of input data? When I searched, I found the following Q & A
https://www.366service.com/jp/qa/9a05f9f614c8ca449ef8693928b7921c
It is easier and more efficient to calculate the mean and variance of the entire sample once, but it is certainly true. That's true, but it's a hassle!
Recommended Posts