Understanding VQ-VAE

Introduction

VQ-VAE is a VAE that uses a technique called Vector Quantized. In the conventional VAE, learning is performed so that the latent variable z becomes a vector of a normal distribution (Gaussian distribution), but in VQ-VAE, it is a VAE that learns so that the latent variable becomes a discretized numerical value. The model consists of (Encoder)-(quantization part)-(Decoder), but Encoder and Decoder are not much different from VAE that performs convolution. When I glanced at the paper and implementation of VQ-VAE, my understanding of how to make the quantization charge changed, so I will summarize my understanding as a memorandum.

What is Embedding

Embedding is probably inevitable when talking about VQ-VAE. If you don't understand it like yourself, it's difficult to understand what this is like.

It was easiest for me to see an example. For example, consider the case where the input matrix $ (2,4) $, the numerical value is the index value, and the embedding matrix is $ (10,3) $ as shown below. In this case, if you convert the input matrix to onehot to make it $ (2,4,10) $ and multiply it by the embedding matrix $ (10,3) $, the $ (2,4,3) $ matrix will be created after embedding. In short, ** Embedding is just one hot input and multiplied by embedding matrix. ** **

python


>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

First understanding of VQ-VAE (wrong)

image.png At first, I understood as shown in the above figure. It is a mistake to say first. The input $ z_e $ is $ (10,10,32) $, the embedding matrix is $ (32,128) $, and we considered vector quantization (discretization of the latent space) of $ 128 $ in the latent space. Multiply the input $ z_e $ by the embedding matrix to get the index of the location closest to 1. (The index of the closest onehot vector of any onehot vector). It becomes $ q (z | x) $, which is a matrix of $ (10,10) $ and the value is the index value. This is converted to onehot and multiplied by the inverse matrix of the embedding matrix to be $ z_q $. Here, the process of converting $ z_q $ to onehot and multiplying it by the embedding matrix is nothing but the embedding process itself explained at the beginning. The loss function is $ (z_e-z_q) ^ 2 $ because the output $ z_q $ should approach the input $ z_e $ if the discretization of the latent space is successful.

If you write the change of the figure with numpy, it will be as follows.

python


import numpy as np

input = np.random.rand(10,10,32)
embed = np.random.rand(32,128)
embed_inv = np.linalg.pinv(embed)
dist = (np.dot(input, embed) - np.ones((10,10,128)))**2
embed_ind = np.argmin(dist, axis=2)
embed_onehot = np.identity(128)[embed_ind]
output = np.dot(embed_onehot, embed_inv)

print("input.shape=", input.shape)
print("embed.shape=", embed.shape)
print("embed_inv.shape=", embed_inv.shape)
print("dist.shape=", dist.shape)
print("embed_ind.shape=", embed_ind.shape)
print("embed_onehot.shape=", embed_onehot.shape)
print("output.shape=", output.shape)
----------------------------------------------
input.shape= (10, 10, 32)
embed.shape= (32, 128)
embed_inv.shape= (128, 32)
dist.shape= (10, 10, 128)
embed_ind.shape= (10, 10)
embed_onehot.shape= (10, 10, 128)
output.shape= (10, 10, 32)

What's wrong

The above interpretation is incorrect when compared to the Actual Implementation (https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py). One of the reasons is that the process of finding the inverse of the embedding matrix is probably not really possible. Therefore, we need to find $ q (z | x) $ and $ z_q $ in a way that does not use the inverse of the embedding matrix.

The other is that it differs from the definition of the paper of $ q (z | x) $. The paper is as follows, image.png The following formula in the above interpretation is incorrect. q(z|x)=argmin((z_e \cdot e_{mbed} - I)^2)

Considering multiplying argmin by the inverse matrix $ e_ {mbed \ inv} $ of the embedding matrix, it can be organized as follows. ((z_e \cdot e_{mbed} - I)^2 \cdot e_{mbed\ inv}^2)=(z_e \cdot e_{mbed} \cdot e_{mbed\ inv}- I \cdot e_{mbed\ inv})^2=(z_e - e_{mbed\ inv})^2 This is equivalent to the formula in the paper.

Then, it is more convenient to replace the names of the embedding matrix and its inverse matrix. In other words, $ e_ {mbed \ inv} $ will be called $ e_ {mbed} $ and $ e_ {mbed} $ will be called $ e_ {mbed \ inv} $.

Second understanding of VQ-VAE

With the above corrections, I have the following understanding. At this time, note that $ e_ {mbed \ inv} $ is not included in both the expressions for $ q (z | x) $ and $ z_q $. Both $ q (z | x) $ and $ z_q $ can be calculated with the inputs $ z_e $ and $ e_ {mbed} , eliminating the need to calculate their inverse matrix. In particularq(z|x)Isz_eWhene_{mbed}Fromz_qIsq(z|x)Whene_{mbed}$It is calculated from.

image.png

Gradient propagation of loss function

Now, if you think that the loss function related to vector quantization is $ (z_e-z_q) ^ 2 $, which is the difference before and after quantization, it is actually different. It is expressed as $ (sg (z_e) -z_q) ^ 2 + (z_e-sg (z_q)) ^ 2 $ using the gradient stop function $ sg () $. This is different from $ (z_e-z_q) ^ 2 $, probably because it is difficult to calculate the error backpropagation of $ z_e $ and $ e_ {mbed} $ from $ q (z | x) $. , It is thought that the gradient transmission of that part is cut off.

Also, the content to be updated differs between the second and third items of the loss function. The second item updates the embedding matrix, but the gradient does not propagate to the input (Encoder). The third item propagates the gradient to the input (Encoder) but does not update the embedding matrix. Regarding the first item of the loss function, it seems that it starts from Decorder, skips the quantization part from $ z_q to z_e $, and is transmitted to the Encoder. However, this is no different from the loss of a normal AutoEncoder. image.png image.png

argmin function

Let's write an argmin function that takes the index value of the smallest value in the array using the Heaviside step function.

argmin(a,b) = H(b-a) \cdot 0 + H(a-b) \cdot 1 \\
argmin(a,b,c) = H(b-a) \cdot H(c-a) \cdot 0 + H(a-b) \cdot H(c-b) \cdot 1 + H(a-c) \cdot H(b-c) \cdot 2\\
H(x) =\left\{
\begin{array}{ll}
1 & (x \geq 0) \\
0 & (x \lt 0)
\end{array}
\right.

Here, the term for subtracting the minimum value remains as the value of all products becomes 1, and when subtracting a value other than the minimum value, one of them becomes zero and does not remain. Therefore, $ argmin (a_ {1}, \ cdots, a_ {128}) $ is the product of higher-order step functions, so even if you replace the step function with a continuously differentiable function during gradient calculation (for example, sigmoid). Function) This seems to be difficult to differentiate. However, this can be avoided by using the gradient stopping function $ sg () $ as explained earlier.

Summary

I thought I had a glimpse of the implementation and understood it, but I realized that the embedding matrix I was thinking of at the beginning was the inverse of the actual embedding matrix. It may be easy to misunderstand that the embedding matrix is not the matrix to multiply when vector-quantizing the input. The embedding matrix is the one that is multiplied when converting a quantized latent variable into a latent space. Also, I felt that vector quantization was similar to Semantic Segmentation, which is pixel-based object recognition. Semantic Segmentation uses softmax to generate a one-hot vector for each pixel, while VQ uses distance square and argmin to generate a quantization vector.

Reference: pytorch VQ-VAE

From Actual implementation example

class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        embed = torch.randn(dim, n_embed)
        ...

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)
        ...
        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

Recommended Posts

Understanding VQ-VAE
Understanding Concatenate
Understanding Python Coroutine
Im2col thorough understanding
[Discord.py] Understanding Cog
Understanding Tensor (1): Dimension
col2im Thorough understanding
Understanding python self
Understanding TensorFlow Arithmetic