How to use PyTorch-based image processing library "Kornia"

Introduction

I often use PyTorch as a framework for deep learning, but I recently learned about an image processing library called "Kornia" that is based on PyTorch. So I investigated the basic functions and usage, so I will leave it as a memorandum.

What is Kornia

Kornia is an open source computer vision library implemented with PyTorch as the backend. (Kornia GitHub)

It consists of a set of routines and differentiable modules to solve generic computer vision problems. At its core, the package uses PyTorch as its main backend both for efficiency and to take advantage of the reverse-mode auto-differentiation to define and compute the gradient of complex functions.

It implements low-level image processing operations similar to OpenCV, such as filtering, color conversion, and geometric conversion. And because PyTorch is the back end, it has the advantage of being more susceptible to the benefits of GPU support and automatic differentiation.

Installation and basic usage

Installation can be done with pip etc. as described in the README. (In this case, PyTorch will be entered automatically) pip install kornia

You will also need OpenCV, matplotlib, and torchvision to run the tutorials (https://kornia.readthedocs.io/en/latest/tutorials/index.html).

As an example of basic usage, the process of applying Gaussian Blur to an image is as follows.

import kornia
import cv2

#Image reading with OpenCV
img_src = cv2.imread('./data/lena.jpg')
img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB)

# torch.Convert to Tensor
tensor_src = kornia.image_to_tensor(img_src, keepdim=False).float() # 1xCxHxW

# Gaussian Blur
gauss = kornia.filters.GaussianBlur2d((11, 11), (10.5, 10.5))
tensor_blur = gauss(tensor_src)

# OpenCV(numpy.ndarray)Return to image
img_blur = kornia.tensor_to_image(tensor_blur.byte())

# --> show [img_src | img_blur]
sphx_glr_gaussian_blur_001.png In this way, the target processing is performed for torch.Tensor. (By the way, kornia.filters.GaussianBlur2d inherits torch.nn.Module)

Other image processing examples

An example of blurring processing and color change other than the above is shown below.

# Box Blur
tensor_blur = kornia.box_blur(tensor_src, (9, 9))

# Median Blur
tensor_blur = kornia.median_blur(tensor_src, (5, 5))

# Adjust Brightness
tensor_brightness = kornia.adjust_brightness(tensor_src, 0.6)

# Adjust Contrast
tensor_contrast = kornia.adjust_contrast(tensor_src, 0.2)

# Adjust Gamma
tensor_gamma = kornia.adjust_gamma(tensor_src, gamma=3., gain=1.5)

# Adjust Saturation
tensor_saturated = kornia.adjust_saturation(tensor_src, 0.2)

# Adjust Hue
tensor_hue = kornia.adjust_hue(tensor_src, 0.5)

Combination with torch.nn.Sequential

By summarizing the processing described above with nn.Sequential, you can write the preprocessing of the image neatly. The following is an example. Also, here, processing is performed on the GPU on the premise of an environment where the GPU can be used.

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import kornia

class DummyDataset(Dataset):
    def __init__(self):
        self.data_index = range(100)

    def __len__(self):
        return len(self.data_index)

    def __getitem__(self, idx):
        # generate dummy image and label
        image = torch.rand(3, 240, 320)
        label = torch.randint(5, (1,))
        return image, label

device = torch.device('cuda')

dataset = DummyDataset()
loader = DataLoader(dataset, batch_size=16, shuffle=True)

transform = nn.Sequential(
    kornia.color.AdjustSaturation(0.2),
    kornia.color.AdjustBrightness(0.5),
    kornia.color.AdjustContrast(0.7),
)

for i, (images, labels) in enumerate(loader):
    print(f'iter: {i}, images: {images.shape}, labels: {labels.shape}')

    images = images.to(device) # -->To GPU
    images_tr = transform(images) #Apply transform to image

    # training etc ...

Example of using automatic differentiation

As an example of using PyTorch's automatic differentiation, I will quote a part from the tutorial total_variation_denoising.py (total variation noise removal).

# read the image with OpenCV
img: np.ndarray = cv2.imread('./data/doraemon.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
img = img + np.random.normal(loc=0.0, scale=0.1, size=img.shape)
img = np.clip(img, 0.0, 1.0)

# convert to torch tensor
noisy_image: torch.tensor = kornia.image_to_tensor(img).squeeze()  # CxHxW

# define the total variation denoising network
class TVDenoise(torch.nn.Module):
   def __init__(self, noisy_image):
       super(TVDenoise, self).__init__()
       self.l2_term = torch.nn.MSELoss(reduction='mean')
       self.regularization_term = kornia.losses.TotalVariation()
       # create the variable which will be optimized to produce the noise free image
       self.clean_image = torch.nn.Parameter(data=noisy_image.clone(), requires_grad=True)
       self.noisy_image = noisy_image

   def forward(self):
       return self.l2_term(self.clean_image, self.noisy_image) + 0.0001 * self.regularization_term(self.clean_image)

   def get_clean_image(self):
       return self.clean_image

tv_denoiser = TVDenoise(noisy_image)

# define the optimizer to optimize the 1 parameter of tv_denoiser
optimizer = torch.optim.SGD(tv_denoiser.parameters(), lr=0.1, momentum=0.9)

# run the optimization loop
num_iters = 500
for i in range(num_iters):
   optimizer.zero_grad()
   loss = tv_denoiser()
   if i % 25 == 0:
       print("Loss in iteration {} of {}: {:.3f}".format(i, num_iters, loss.item()))
   loss.backward()
   optimizer.step()

# convert back to numpy
img_clean: np.ndarray = kornia.tensor_to_image(tv_denoiser.get_clean_image())
sphx_glr_total_variation_denoising_001.png

Here, noisy_image is passed to torch.nn.Parameter () to make clean_image the initial state. (This is the update target by optimizer) In addition, kornia's TotalVariation () is used as a regularization term.

Summary

We investigated how to use the Pytorch-based image processing library Kornia, focusing on tutorials. There are likely to be various useful functions other than those I wrote this time. It can be used not only in image preparation but also in the forward of neural networks, so I thought it might be useful when you want to add processing that cannot be handled by ordinary torch / torchvision alone.

Recommended Posts

How to use PyTorch-based image processing library "Kornia"
How to use Python Image Library in python3 series
How to use Requests (Python Library)
How to use xml.etree.ElementTree
How to use Python-shell
How to use the C library in Python
How to use tf.data
How to use virtualenv
How to use Seaboan
How to use image-match
How to use shogun
How to use Virtualenv
How to use numpy.vectorize
How to use pytest_report_header
How to use partial
How to use Bio.Phylo
How to use SymPy
How to use x-means
How to use WikiExtractor.py
How to use virtualenv
How to use Matplotlib
How to use iptables
How to use numpy
How to use TokyoTechFes2015
How to use venv
How to use Pyenv
How to use list []
How to use python-kabusapi
How to use OptParse
How to use return
How to use dotenv
How to use pyenv-virtualenv
How to use Go.mod
How to use imutils
How to use import
[Introduction to Python] How to use while statements (repetitive processing)
[Python] How to use the graph creation library Altair
How to use the Rubik's Cube solver library "kociemba"
opencv-python Introduction to image processing
How to use Qt Designer
How to use search sorted
python3: How to use bottle (2)
Understand how to use django-filter
How to use the generator
How to use FastAPI ③ OpenAPI
How to use Python argparse
How to use IPython Notebook
How to use Pandas Rolling
[Note] How to use virtualenv
How to use redis-py Dictionaries
[Python] How to use checkio
[Go] How to use "... (3 periods)"
[Python] How to use input ()
How to use the decorator
[Introduction] How to use open3d
How to use Python lambda
How to use Jupyter Notebook
[Python] How to use virtualenv
python3: How to use bottle (3)
python3: How to use bottle
How to use Google Colaboratory