About the Unfold function

Conv2D operation

Considering the operation of 2D convolution, input $ (Batch, H, W, C_ {in}) $, output $ (Batch, H, W, C_ {out}) $, kernel size $ (3,3) $, Convolution weight $ W = (3,3, C_ {in}, C_ {out}) $

Effectively Conv2D operation matmul(x,W)=matmul((Batch,HW,9C_{in}), (9C_{in}, C_{out}))=(Batch,HW,C_{out}) It is equivalent to considering the matrix operation of matmul. Here, in the matrix operation $ c = matmul (a, b) $, if $ a = (i, j, k, m), b = (m, n) $, then $ c = (i, j, k, n ) $.

On the other hand, for input $ (Batch, H, W, C_ {in}) $

python


x[:,0]=input[:,0:H-2,0:W-2,:] \\
x[:,1]=input[:,0:H-2,1:W-1,:] \\
x[:,2]=input[:,0:H-2,2:W-0,:] \\ 
x[:,3]=input[:,1:H-1,0:W-2,:] \\
x[:,4]=input[:,1:H-1,1:W-1,:] \\
x[:,5]=input[:,1:H-1,2:W-0,:] \\ 
x[:,6]=input[:,2:H-0,0:W-2,:] \\
x[:,7]=input[:,2:H-0,1:W-1,:] \\
x[:,8]=input[:,2:H-0,2:W-0,:]

Extract $ (H-2, W-2) $ from $ (H, W) $ like, and convert it to a matrix like $ (Batch, HW, 9C_ {in}) $ before matrix operation. need to do it. Such a matrix transformation is called $ im2col $. It can be considered that this process doubles the number of input channels by the total number of kernel sizes. Also, the $ im2col $ process itself has no weight.

python


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):

    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col

In Pytorch, the im2col function is called the Unfold function. Therefore, it should be ** Conv2D = (im2col + matmul) = (Unfold + matmul) **. I tried to see if the main subject was really that way.

Comparison in PyTorch

PyTorch is channel first with input $ (Batch, C_ {in}, H, W) = (25,3,32,32) $, output $ (Batch, C_ {out}, H, W) = (25,16) , 30,30) $, kernel size $ (3,3) $, weight $ W = (C_ {out}, 3 × 3 × C_ {in}) = (16,27) $.

(Unfold + matmul) operation

python


import numpy as np
import torch

input = torch.tensor(np.random.rand(25,3,32,32)).float()
weight = torch.tensor(np.random.rand(16,3,3,3)).float()
weight2 = weight.reshape((16,27))

print('input.shape=  ', input.shape)
print('weight.shape= ', weight.shape)
print('weight2.shape=', weight2.shape)

x = torch.nn.Unfold(kernel_size=(3,3), stride=(1,1), padding=(0,0), dilation=(1,1))(input)
output1 = torch.matmul(weight2, x).reshape((25,16,30,30))

print('x.shape=      ', x.shape)
print('output1.shape=', output1.shape)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 32, 32])
weight.shape=  torch.Size([16, 3, 3, 3])
weight2.shape= torch.Size([16, 27])
x.shape=       torch.Size([25, 27, 900])
output1.shape= torch.Size([25, 16, 30, 30])

Here, when the Unfold function is input, $ x = (25, 3 × 3 × 3, 30 × 30) = (25,27,900) $, and when $ W = (16,27) $, $ matmul (W , x) = (25,16,30 × 30) $.

Conv2D operation

On the other hand, if the input $ (Batch, C_ {in}, H, W) = (25,3,32,32) $ and the weight of the Conv2D function is $ W = (16,3,3,3) $ The code for the output is below.

python


conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, bias=False)
conv1.weight.data = weight

output2 = conv1(input)

print('conv1.weight.shape=', conv1.weight.shape)
print('output2.shape= ', output2.shape)
-----------------------------------------------------------
conv1.weight.shape= torch.Size([16, 3, 3, 3])
output2.shape=  torch.Size([25, 16, 30, 30])

Comparing ** output1 ** obtained by (Unfold + matmul) and ** output2 ** obtained by Conv2D from these, the values were completely the same. Therefore, it was confirmed that it is computationally equivalent to ** Conv2D = (Unfold + matmul) **.

python


output1:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4978, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0875,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......
output2:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4979, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0874,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......

Other uses of the Unfold function

When kernel_size and stride are equal, it corresponds to the patch division of Vision Transformer. Well, patch splitting can be replaced by reshape and transpose without using Unfold ...

python


input = torch.tensor(np.random.rand(25,3,224,224)).float()
x = torch.nn.Unfold(kernel_size=(14,14), stride=(14,14), padding=(0,0), dilation=(1,1))(input)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 224, 224])
x.shape=       torch.Size([25, 588, 256]) #(25,3*14*14,16*16)

In the story that Vision Transformer does not use Conv2D at all, since matmul is included in the calculation of Attention weight and Value, I had an unfounded delusion that Unfold + matmul is equivalent to Conv2D even in ViT.

Summary

The Unfold function is the im2col function in Pytorch, ** Conv2D = (Unfold + matmul) **. In tensorflow, it is the extract_image_patches function.

Recommended Posts

About the Unfold function
About the enumerate function (python)
Roughly think about the loss function
About the test
About the queue
About the arguments of the setup function of PyCaret
About function arguments (python)
The first GOLD "Function"
Python: About function arguments
About the service command
About the confusion matrix
About the Visitor pattern
Regarding the activation function Gelu
What is the activation function?
About the Python module venv
About python beginner's memorandum function
About fork () function and execve () function
About the traveling salesman problem
About understanding the 3-point reader [...]
About the components of Luigi
About the features of Python
What is the Callback function?
How to use the zip function
sort warning in the pd.concat function
Think about the minimum change problem
[Python] What is @? (About the decorator)
About the return value of pthread_mutex_init ()
Precautions when using the urllib.parse.quote function
About the return value of the histogram.
[Python] Make the function a lambda function
About the basic type of Go
About the upper limit of threads-max
About the average option in sklearn.metrics.f1_score
About the behavior of yield_per of SqlAlchemy
About the size of matplotlib points
About the basics list of Python basics
[Python Kivy] About changing the design theme
Compute the partition function with the sum-product algorithm
About the behavior of enable_backprop of Chainer v2
About the virtual environment of python version 3.7
Miscellaneous notes about the Django REST framework
OR the List in Python (zip function)
[OpenCV] About the array returned by imread
About the open source support group NumFOCUS
[Python3] Rewrite the code object of the function
Roughly think about the gradient descent method
[Python] Summarize the rudimentary things about multithreading
About the development environment you are using
Introducing the addModuleCleanup / doModuleCleanups function for unittest
What about 2017 around the Crystal language? (Delusion)
About the relationship between Git and GitHub
About the Normal Equation of Linear Regression
A note about doing the Pyramid tutorial