À propos de la fonction Déplier

Arithmétique Conv2D

Considérant l'opération de convolution bidimensionnelle, entrée $ (Batch, H, W, C_ {in}) $, sortie $ (Batch, H, W, C_ {out}) $, taille du noyau $ (3,3) $, Poids de convolution $ W = (3,3, C_ {in}, C_ {out}) $

Fonctionnement efficace de Conv2D matmul(x,W)=matmul((Batch,HW,9C_{in}), (9C_{in}, C_{out}))=(Batch,HW,C_{out}) Cela revient à considérer l'opération matricielle de matmul. Ici, dans l'opération matricielle de $ c = matmul (a, b) $, si $ a = (i, j, k, m), b = (m, n) $, alors $ c = (i, j, k, n ) $.

D'autre part, pour l'entrée $ (Batch, H, W, C_ {in}) $

python


x[:,0]=input[:,0:H-2,0:W-2,:] \\
x[:,1]=input[:,0:H-2,1:W-1,:] \\
x[:,2]=input[:,0:H-2,2:W-0,:] \\ 
x[:,3]=input[:,1:H-1,0:W-2,:] \\
x[:,4]=input[:,1:H-1,1:W-1,:] \\
x[:,5]=input[:,1:H-1,2:W-0,:] \\ 
x[:,6]=input[:,2:H-0,0:W-2,:] \\
x[:,7]=input[:,2:H-0,1:W-1,:] \\
x[:,8]=input[:,2:H-0,2:W-0,:]

Extrayez $ (H-2, W-2) $ de $ (H, W) $ like et convertissez-le en une matrice comme $ (Batch, HW, 9C_ {in}) $ avant l'opération de la matrice. besoin de le faire. Une telle transformation matricielle est appelée $ im2col $. On peut considérer que ce processus double le nombre de canaux d'entrée par le nombre total de tailles de noyau. De plus, le processus $ im2col $ lui-même n'a aucun poids.

python


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):

    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col

Dans Pytorch, la fonction im2col est appelée fonction Déplier. Par conséquent, il devrait être ** Conv2D = (im2col + matmul) = (Unfold + matmul) **. J'ai essayé de voir si le sujet principal était vraiment comme ça.

Comparaison dans PyTorch

PyTorch est le premier canal avec l'entrée $ (Batch, C_ {in}, H, W) = (25,3,32,32) $, output $ (Batch, C_ {out}, H, W) = (25,16) , 30,30) $, taille du noyau $ (3,3) $, poids $ W = (C_ {out}, 3 × 3 × C_ {in}) = (16,27) $.

(Déplier + matmul) opération

python


import numpy as np
import torch

input = torch.tensor(np.random.rand(25,3,32,32)).float()
weight = torch.tensor(np.random.rand(16,3,3,3)).float()
weight2 = weight.reshape((16,27))

print('input.shape=  ', input.shape)
print('weight.shape= ', weight.shape)
print('weight2.shape=', weight2.shape)

x = torch.nn.Unfold(kernel_size=(3,3), stride=(1,1), padding=(0,0), dilation=(1,1))(input)
output1 = torch.matmul(weight2, x).reshape((25,16,30,30))

print('x.shape=      ', x.shape)
print('output1.shape=', output1.shape)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 32, 32])
weight.shape=  torch.Size([16, 3, 3, 3])
weight2.shape= torch.Size([16, 27])
x.shape=       torch.Size([25, 27, 900])
output1.shape= torch.Size([25, 16, 30, 30])

Si vous appliquez la fonction Déplier à l'entrée, $ x = (25, 3 × 3 × 3, 30 × 30) = (25,27,900) $, et quand $ W = (16,27) $, $ matmul (W) , x) = (25,16,30 × 30) $.

Arithmétique Conv2D

Par contre, si l'entrée $ (Batch, C_ {in}, H, W) = (25,3,32,32) $ et le poids de la fonction Conv2D est $ W = (16,3,3,3) $ Le code de la sortie est ci-dessous.

python


conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, bias=False)
conv1.weight.data = weight

output2 = conv1(input)

print('conv1.weight.shape=', conv1.weight.shape)
print('output2.shape= ', output2.shape)
-----------------------------------------------------------
conv1.weight.shape= torch.Size([16, 3, 3, 3])
output2.shape=  torch.Size([25, 16, 30, 30])

En comparant ** output1 ** obtenu par (Unfold + matmul) et ** output2 ** obtenu par Conv2D à partir de ceux-ci, les valeurs étaient complètement les mêmes. Par conséquent, il a été confirmé qu'il est équivalent en calcul à ** Conv2D = (Unfold + matmul) **.

python


output1:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4978, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0875,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......
output2:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4979, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0874,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......

Autres utilisations de la fonction Déplier

Lorsque kernel_size et stride sont égaux, cela correspond à la division des patchs de Vision Transformer. Eh bien, le fractionnement de patch peut être remplacé par un remodelage et une transposition sans utiliser Déplier ...

python


input = torch.tensor(np.random.rand(25,3,224,224)).float()
x = torch.nn.Unfold(kernel_size=(14,14), stride=(14,14), padding=(0,0), dilation=(1,1))(input)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 224, 224])
x.shape=       torch.Size([25, 588, 256]) #(25,3*14*14,16*16)

Dans l'histoire selon laquelle Vision Transformer n'utilise pas du tout Conv2D, puisque matmul est inclus dans le calcul du poids et de la valeur de l'attention, j'ai eu une illusion non fondée que Unfold + matmul est équivalent à Conv2D même dans ViT.

Sommaire

La fonction Déplier est la fonction im2col dans Pytorch, et ** Conv2D = (Déplier + matmul) **. Dans tensorflow, il s'agit de la fonction extract_image_patches.

Recommended Posts

À propos de la fonction Déplier
À propos de la fonction enumerate (python)
Pensez grossièrement à la fonction de perte
À propos du test
À propos de la file d'attente
A propos des arguments de la fonction setup de PyCaret
À propos des arguments de fonction (python)
La première «fonction» GOLD
Python: à propos des arguments de fonction
À propos de la commande de service
À propos de la matrice de confusion
À propos du modèle de visiteur
Concernant la fonction d'activation Gelu
Quelle est la fonction d'activation?
À propos du module Python venv
fonction de mémorandum python pour débutant
À propos de la fonction fork () et de la fonction execve ()
À propos du problème du voyageur de commerce
À propos de la compréhension du lecteur en 3 points [...]
À propos des composants de Luigi
À propos des fonctionnalités de Python
Qu'est-ce que la fonction de rappel?
Comment utiliser la fonction zip
Avertissement de tri dans la fonction pd.concat
Pensez au problème de changement minimum
[Python] Qu'est-ce que @? (À propos des décorateurs)
À propos de la valeur de retour de pthread_mutex_init ()
Précautions lors de l'utilisation de la fonction urllib.parse.quote
À propos de la valeur de retour de l'histogramme.
[Python] Faire de la fonction une fonction lambda
À propos du type de base de Go
À propos de la limite supérieure de threads-max
À propos de l'option moyenne de sklearn.metrics.f1_score
À propos du comportement de yield_per de SqlAlchemy
À propos de la taille des points dans matplotlib
À propos de la liste de base des bases de Python
[Python Kivy] À propos de la modification du thème de conception
A propos du comportement de enable_backprop de Chainer v2
À propos de l'environnement virtuel de Python version 3.7
Notes diverses sur le framework Django REST
Prenez la somme logique de List en Python (fonction zip)
[OpenCV] À propos du tableau retourné par imread
À propos de NumFOCUS, une organisation de support open source
[Python3] Réécrire l'objet code de la fonction
Pensez grossièrement à la méthode de descente de gradient
[Python] Résumez les éléments rudimentaires du multithreading
À propos de l'environnement de développement que vous utilisez
Présentation de la fonction addModuleCleanup / doModuleCleanups de unittest
Qu'en est-il de 2017 autour du langage Crystal? (Illusion)
À propos de la relation entre Git et GitHub
À propos de l'équation normale de la régression linéaire
Un mémo que j'ai essayé le tutoriel Pyramid