I implemented Shake-Shake Regularization (ShakeNet) with PyTorch

What is Shake-Shake Regularization?

It is one of regularization. By increasing the training data in a pseudo manner, there is an advantage that you can learn slowly for a long time.

Is it effective as data augmentation when the number of data is small? For the time being, I would like to try it with CIFAR10 this time.

shake-shake-regularization.jpg

A brief description of Shake-Shake Regularization is shown above. Create two redisual blocks in parallel in Resnet and add the following operations to the output of the residual blocks.

-** Multiply a random number α from 0 to 1 for forward propagation during learning ** -** When backpropagating the error, multiply the random number β from 0 to 1 (generated separately from α) ** -** When inferring, multiply by a constant 0.5 instead of a random number **

For details, please refer to other people's articles. I read and understood this article. https://qiita.com/masataka46/items/fc7f31073c89b02f8a04

Other details that were written in the paper

The dissertation had various ideas.

--The flow of Plain architecture is ReLU → Conv → BN → ReLU → Conv → BN → Mul (with random number α) ――Divided into 3 stages, each with 4 Residual Blocks --32,64,128 channels for each stage --Apply a 3x3 Conv before Stage 1 --Apply 8x8 average pooling after stage 3 --Last is the fc layer --The learning rate is 0.2, and it varies with the cosine curve. --1800epoch Learn --Special shortcut for downsampling --Images are standardized with random flips --Mini batch size is 128

Creating a residual block

There are two types of resnet, the plain architecture and the bottleneck architecture, but this time I would like to use the Plain architecture following the paper.

test.py


class ResidualPlainBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride, padding=0):
        super(ResidualPlainBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.conv1 = nn.Conv2d(in_channels,  out_channels, kernel_size=3,  stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels,  kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv1_2 = nn.Conv2d(in_channels,  out_channels, kernel_size=3,  stride=stride, padding=1)
        self.bn1_2 = nn.BatchNorm2d(out_channels)

        self.conv2_2 = nn.Conv2d(out_channels, out_channels,  kernel_size=3, stride=1, padding=1)
        self.bn2_2 = nn.BatchNorm2d(out_channels)

        self.identity = nn.Identity()

        if in_channels != out_channels:
          self.down_avg1 = nn.AvgPool2d(kernel_size=1, stride=1)
          self.down_conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)
          self.down_pad1 = nn.ZeroPad2d((1,0,1,0))
          self.down_avg2 = nn.AvgPool2d(kernel_size=1, stride=1)
          self.down_conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)

    #Special processing during down sampling
    def shortcut(self,x):
      x = F.relu(x)
      h1 = self.down_avg1(x)
      h1 = self.down_conv1(h1)
      h2 = self.down_pad1(x[:,:,1:,1:])
      h2 = self.down_avg1(h2)
      h2 = self.down_conv2(h2)
      return torch.cat((h1,h2),axis=1)


    def forward(self, x):
      if self.training:
        #1st Residual Block
          out = self.bn1(self.conv1(F.relu(x)))
          out = self.bn2(self.conv2(F.relu(out)))
          
        #Second Residual Block
          out2 = self.bn1_2(self.conv1_2(F.relu(x)))
          out2 = self.bn2_2(self.conv2_2(F.relu(out2)))

          if self.in_channels != self.out_channels:
            output = self.shortcut(x) + ShakeShake.apply(out,out2)
          else:
            output = self.identity(x) + ShakeShake.apply(out,out2)
          
          return output
      else:
          out = self.bn1(self.conv1(F.relu(x)))
          out = self.bn2(self.conv2(F.relu(out)))
          
          out2 = self.bn1_2(self.conv1_2(F.relu(x)))
          out2 = self.bn2_2(self.conv2_2(F.relu(out2)))

          if self.in_channels != self.out_channels:
            output = self.shortcut(x) + (out+out2)*0.5
          else:
            output = self.identity(x) + (out+out2)*0.5
          
          return output

The constructor is messy, but I think you can understand it by looking at the forward function.

The contents of the forward function are 1: Give the received x to two blocks 2: Output out and out2 3: Have out and out2 processed by ** ShakeShake.apply () ** 4: Add the shortcut and 3 and output them together

Capture 1: ShakeShake.apply ()

You can define a class called ShakeShake class to define forward and backward processing.

test.py


class ShakeShake(torch.autograd.Function):
  @staticmethod
  def forward(ctx, i1, i2):
    alpha = random.random()
    result = i1 * alpha + i2 * (1-alpha)

    return result
  @staticmethod
  def backward(ctx, grad_output):
    beta  = random.random()

    return grad_output * beta, grad_output * (1-beta)

In forward, a random number alpha is generated and multiplied by out and out2.

In backward, a new random number beta is generated and applied to grad_output (value transmitted by error back propagation).

Capture 2: Change the learning rate with a cosine curve

PyTorch allows you to schedule learning rates.

Implement as follows.

test.py


learning_rate = 0.02
optimizer = optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.0001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0.001)

for i in range(200):
  #for 1epoch
    #Learn for 1 epoch here...
  scheduler.step()

By doing this, the learning rate will fluctuate along the cosine curve for each epoch.

After defining the optimizer, call something called ** CosineAnnealingLR **.

The first argument is optimizer, the second argument (T_max) is the number of steps (number of epochs) up to the half cycle of the cosine, and the third argument is the minimum learning rate.

In the above case, the learning rate drops from 0.02 to 0.001 at 50 epochs, then returns at 50 epochs, then drops at 50 epochs, and so on.

Execution result

** Accuracy was 89.43% **.

It's subtle because it's over 95% in the dissertation.

However, the accuracy of normal ResNet, which is not specially devised, is about 80%, so it seems that the accuracy is better than that.

The blue one is train_acc and the orange one is test_acc. 17-89.43.png

Actually, there are some places that are not implemented according to the paper.

--Only 200 epoch is trained instead of 1800 epoch --In the dissertation, the maximum learning rate is set to 0.2, but in that case the error became nan, so I changed it to 0.02. ――Is there a mistake in reading the paper?

Finally

Shake-Shake Regularization is attracting attention as a powerful regularization method.

Recently, it seems that a new method called Shake Drop has been devised, so I will implement that as well.

Recommended Posts

I implemented Shake-Shake Regularization (ShakeNet) with PyTorch
I implemented Attention Seq2Seq with PyTorch
I implemented VQE with Blueqat
I made Word2Vec with Pytorch
I tried implementing DeepPose with PyTorch
[Introduction to Pytorch] I played with sinGAN ♬
I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch PartⅡ
I tried to implement CVAE with PyTorch
I tried to detect Mario with pytorch + yolov3
I tried to implement reading Dataset with PyTorch
I rewrote Chainer's MNIST code with PyTorch + Ignite
Play with PyTorch
I implemented CycleGAN (1)
Cross-validation with PyTorch
Beginning with PyTorch
I implemented ResNet!
I tried to move Faster R-CNN quickly with pytorch
I tried to implement and learn DCGAN with PyTorch
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I implemented collaborative filtering (recommendation) with redis and python
I tried to implement SSD with PyTorch now (Dataset)
I got an error when using Tensorboard with Pytorch
Use RTX 3090 with PyTorch
I played with wordcloud!
Qiskit: I implemented VQE
I implemented Python Logging
Install torch-scatter with PyTorch 1.7
I implemented the FloodFill algorithm with TRON BATTLE of CodinGame.
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement SSD with PyTorch now (model edition)