Try Semantic Segmentation (Pytorch)

at first

Semantic segmentation is a type of image recognition technology that can be recognized pixel by pixel. seg.png

I will leave the detailed theory separately, but I would like to try semantic segmentation using Pytorch. This time, we will deal with a network that is shallower and simpler and can be learned sufficiently even with a notebook PC, rather than a network with a deep and complicated structure like Seg-Net, U-net, or PSP-net.

The environment is CPU: intel(R) core(TM)i5 7200U Memory: 8 GB OS: Windows10 python ver3.6.9 pytorch ver1.3.1 numpy ver1.17.4

Creating a dataset

This time, I will use the image that I composited myself. The upper line image is the input data, and the lower filled image is the teacher data. In other words, it automatically creates a network that fills like paint software. input_auto.png correct_auto.png

Create the data necessary for learning. imgs is 1000 input data imgs_ano outputs 1000 data (teacher data) The squares and squares are not always covered, and the length of the sides and the number of squares are also randomly determined.

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

def rectangle(img, img_ano, centers, max_side):
    """
img… 2D image with only rectangular lines
    img_ano… That tion image
centers… list of center coordinates
    max_side… 1 of the maximum length of the side/2 
    """
    if max_side < 3: #max_When the side is too small
        max_side = 4
    #Side length 1/Define 2
    side_x = np.random.randint(3, int(max_side))
    side_y = np.random.randint(3, int(max_side))    
    
    #Center coordinates,(x, y)Define
    x = np.random.randint(max_side + 1, img.shape[0] - (max_side + 1))
    y = np.random.randint(max_side + 1, img.shape[1] - (max_side + 1))
    
    #When a position close to the past center position is included,Return input data as it is
    for center in centers:
        if np.abs(center[0] - x) < (2 *max_side + 1):
            if np.abs(center[1] - y) < (2 * max_side + 1):
                return img, img_ano, centers
            
    img[x - side_x : x + side_x, y - side_y] = 1.0      #Top side
    img[x - side_x : x + side_x, y + side_y] = 1.0      #Bottom side
    img[x - side_x, y - side_y : y + side_y] = 1.0      #Left side
    img[x + side_x, y - side_y : y + side_y + 1] = 1.0  #right side
    img_ano[x - side_x : x + side_x + 1, y - side_y : y + side_y + 1] = 1.0
    centers.append([x, y])
    return img, img_ano, centers


num_images = 1000                                   #Number of images to generate
length = 64                                          #Image size
imgs = np.zeros([num_images, 1, length, length])     #Generate zero matrix,Input image
imgs_ano = np.zeros([num_images, 1, length, length]) #Output image

for i in range(num_images):
    centers = []
    img = np.zeros([length, length])
    img_ano = np.zeros([64, 64])
    for j in range(6):                       #Generate up to 6 rectangles
        img, img_ano, centers = rectangle(img, img_ano, centers, 12) 
    imgs[i, 0, :, :] = img
    imgs_ano[i, 0, :, :] = img_ano
   
imgs = torch.tensor(imgs, dtype = torch.float32)                 #ndarray - torch.tensor
imgs_ano = torch.tensor(imgs_ano, dtype = torch.float32)           #ndarray - torch.tensor
data_set = TensorDataset(imgs, imgs_ano)
data_loader = DataLoader(data_set, batch_size = 100, shuffle = True)

Network_1_1

Then create a network class in Pytorch. First of all, I used the network defined by the autoencoder of previous as it is. Both the autoencoder and segmentation generate an image the same size as the input image, so I could use it (in the case of Pytorch). What's going on with Tensorflow?

class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        #Encoder Layers
        self.conv1 = nn.Conv2d(in_channels = 1,
                               out_channels = 16,
                               kernel_size = 3,
                               padding = 1)
        self.conv2 = nn.Conv2d(in_channels = 16,
                               out_channels = 4,
                               kernel_size = 3,
                               padding = 1)
        #Decoder Layers
        self.t_conv1 = nn.ConvTranspose2d(in_channels = 4, out_channels = 16,
                                          kernel_size = 2, stride = 2)
        self.t_conv2 = nn.ConvTranspose2d(in_channels = 16, out_channels = 1,
                                          kernel_size = 2, stride = 2)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        #encode#                           
        x = self.relu(self.conv1(x))        
        x = self.pool(x)                  
        x = self.relu(self.conv2(x))      
        x = self.pool(x)                  
        #decode#
        x = self.relu(self.t_conv1(x))    
        x = self.sigmoid(self.t_conv2(x))
        return x

Let's learn on this network.

#******Select a network******
net = ConvAutoencoder()                               
loss_fn = nn.MSELoss()                                #Definition of loss function
optimizer = optim.Adam(net.parameters(), lr = 0.01)

losses = []                                     #Record loss for each epoch
epoch_time = 30
for epoch in range(epoch_time):
    running_loss = 0.0                          #Calculation of loss for each epoch
    net.train()
    for i, (XX, yy) in enumerate(data_loader):
        optimizer.zero_grad()       
        y_pred = net(XX)
        loss = loss_fn(y_pred, yy)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("epoch:",epoch, " loss:", running_loss/(i + 1))
    losses.append(running_loss/(i + 1))

#Visualization of loss
plt.plot(losses)
plt.ylabel("loss")
plt.xlabel("epoch time")
plt.savefig("loss_auto")
plt.show()

It is a visualization of the loss for each epoch. Is epoch converged to some extent after 30 times? loss_auto.png

Try using an image that is not used for learning. I can determine the rough position, but I have the impression that the area around the boundary is not well taken. output_auto.png

net.eval()            #Evaluation mode
#Generate one image that has not been learned so far
num_images = 1
img_test = np.zeros([num_images, 1, length, length])
imgs_test_ano = np.zeros([num_images, 1, length, length])
for i in range(num_images):
    centers = []
    img = np.zeros([length, length])
    img_ano = np.zeros([length, length])
    for j in range(6):
        img, img_ano, centers = rectangle(img, img_ano, centers, 7)
    img_test[i, 0, :, :] = img

img_test = img_test.reshape([1, 1, 64, 64])
img_test = torch.tensor(img_test, dtype = torch.float32)
img_test = net(img_test)             #Transfer the generated image to the trained network
img_test = img_test.detach().numpy() #torch.tensor - ndarray
img_test = img_test[0, 0, :, :]

plt.imshow(img, cmap = "gray")       #Visualization of input data
plt.savefig("input_auto")
plt.show()
plt.imshow(img_test, cmap = "gray")  #Visualization of output data
plt.savefig("output_auto")
plt.show()
plt.imshow(img_ano, cmap = "gray")   #Correct answer data
plt.savefig("correct_auto")
plt.plot()

Try to deepen the network.

I couldn't get enough performance with the previous model, so I would like to deepen the layer. Here, let's not just deepen it, but also add batch normalization to prevent overfitting and upsampling with a decoder. For a detailed explanation of upsampling, this article was easy to understand.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #encoder
        self.encoder_conv_1 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 1, 
                                                      out_channels = 6,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(6)
                                            ])
        
        self.encoder_conv_2 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 6,
                                                      out_channels = 16,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(16)
                                            ])
        self.encoder_conv_3 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 16,
                                                      out_channels = 32,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(32)
                                            ])
        
        #decoder
        self.decoder_convt_3 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 32,
                                                               out_channels = 16,
                                                               kernel_size = 3,
                                                               padding = 1),
                                            nn.BatchNorm2d(16)
                                            ])
        
        self.decoder_convt_2 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 16,
                                                               out_channels = 6,
                                                               kernel_size = 3,
                                                               padding = 1),
                                            nn.BatchNorm2d(6)
                                            ])
        
        self.decoder_convt_1 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 6,
                                                               out_channels = 1,
                                                               kernel_size = 3,
                                                               padding = 1)
                                            ])
    
    def forward(self, x):
        #encoder
        dim_0 = x.size()                    
        x = F.relu(self.encoder_conv_1(x))                            
        x, indices_1 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)  #Record the position of maxpool with indice
        dim_1 = x.size()
        x = F.relu(self.encoder_conv_2(x))                            
        x, indices_2 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)            
        
        dim_2 = x.size()
        x = F.relu(self.encoder_conv_3(x))
        x, indices_3 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)
        
        #decoder
        x = F.max_unpool2d(x, indices_3, kernel_size = 2,
                           stride = 2, output_size = dim_2)
        x = F.relu(self.decoder_convt_3(x))
        
        x = F.max_unpool2d(x, indices_2, kernel_size = 2,
                           stride = 2, output_size = dim_1)           
        x = F.relu(self.decoder_convt_2(x))                           
        
        x = F.max_unpool2d(x, indices_1, kernel_size = 2,
                           stride = 2, output_size = dim_0)           
        x = F.relu(self.decoder_convt_1(x))                           
        x = torch.sigmoid(x)                                       
        
        return x

It ’s easy to switch to this network

#******Select a network******
net = ConvAutoencoder()

Just change the location to the newly created class.

#******Select a network******
net = Net()

Graph the transition of loss.

loss_auto.png

Enter the data that is not used for training and compare it with the correct image. output.png

You can see that the segmentation is done.

At the end

I tried a simple segmentation this time. It was a simple model far from practical use, but I feel that I have grasped the atmosphere.

Recommended Posts

Try Semantic Segmentation (Pytorch)
Bear ... not semantic segmentation
Summary of problems when doing Semantic Segmentation with Pytorch
Try an autoencoder with Pytorch
Try implementing XOR with PyTorch
[PyTorch] Data Augmentation for segmentation
Computer Vision: Semantic Segmentation Part2 --Real-Time Semantic Segmentation
Try the new scheduler chaining in PyTorch 1.4