Introduction to how to use Pytorch Lightning ~ Until you format your own model and output it to tensorboard ~

What to do in this article

For people who "I made my own DNN model but the code is dirty" and "I'm tired of clerical work (save, log, DNN common code)"

--With AI development explosive library Pytorch Lightning ――Clean code management & learning & visualization of tensorboard

What is Pytorch Lightning?

A python library that does. It is the top github star number & popular deep learning framework.

How to use

1. First install

console


$ pip install pytorch-lightning

2. Write a deep learning model according to pytorch_lightning

pytorch_lightning.Inheriting LightningModule,



 * Network
 * 3 methods: forward (self, x), training_step (self, batch, batch_idx), configure_optimizers (self)

 If you define the two, you can use it immediately. However, note that ** the function name and argument pair cannot be changed **!
 (E.g. batch_idx If you define it like `` `training_step (self, batch) ``` even if you don't need it, it will cause a bug)


#### **`MyModel.py`**
```python

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMyModel(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)

    x = F.log_softmax(x, dim=1)
    return x

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)
    return loss

Each of the three functions "Return network output" "Work in 1 loop & return loss function" "return optimizer" Any processing is OK

** For those who are long but want to see VAE example (Click) **
#FC example learning MNIST
import pytorch_lightning as pl

class LitMyModel(pl.LightningModule):
    def __init__(self):
        # layers
        self.fc1 = nn.Linear(self.out_size, 400)
        self.fc4 = nn.Linear(400, self.out_size)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.out_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def training_step(self, batch, batch_idx):
        recon_batch, mu, logvar = self.forward(batch)
        loss = self.loss_function(
            recon_batch, batch, mu, logvar, out_size=self.out_size)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        return optimizer

Of course, if you already have a model, just move the code. After that, if you put the data loader and model in `fit ()` of `pl.Trainer ()`, learning starts already !!

runtime


dataloader = #Your own dataloader or datamodule

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)

lightning Easy and awesome.


3. Add other work to the methods of this class

Now that you can learn up to the above, add the ** test, validation, and other options ** methods to the class.

test Add ``` test_step (self, batch, batch_idx)` `` to the class method. Only. Execution

test run time


trainer.test()

validation This is also completed by adding the `val_step ()` method and the ``` val_dataloader () `` `method ~

dataloader This can also be grouped into class methods, but ** Dataset & Data Loader is recommended to inherit `pytorch_lightning.LightningDataModule``` from another class and define `MyDataModule``` class ** ..

** For those who are long but want to see MNIST example (Click) **

class MyDataModule(LightningDataModule): def init(self): super().init() self.train_dims = None self.vocab_size = 0

def prepare_data(self):
    # called only on 1 GPU
    download_dataset()
    tokenize()
    build_vocab()

def setup(self):
    # called on every GPU
    vocab = load_vocab()
    self.vocab_size = len(vocab)

    self.train, self.val, self.test = load_datasets()
    self.train_dims = self.train.next_batch.size()

def train_dataloader(self):
    transforms = ...
    return DataLoader(self.train, batch_size=64)

def val_dataloader(self):
    transforms = ...
    return DataLoader(self.val, batch_size=64)

def test_dataloader(self):
    transforms = ...
    return DataLoader(self.test, batch_size=64)
If you bite this into `` `.fit ()` `` at the time of learning & testing, it will be interpreted without passing data_loader.

runtime


datamodule = MyDataModule()

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)

callback Something like "process to do only at the beginning of train" and "process to do at the end of epoch" https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks There is a lot of information around. It's OK if you define a function for the timing you want to process

from pytorch_lightning.callbacks import Callback

class MyPrintingCallback(Callback):
    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('Trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

trainer = Trainer(callbacks=[MyPrintingCallback()])

If you define it in another class, you can write it concisely ~

4. Link with tensorboard & add recording settings

Now, from here on, the main record preservation relationship. To display numerical values (loss, accuracy, etc.), images, sounds, etc. on the tensorboard,

tensorflow example


with tf.name_scope('summary'):
  tf.summary.scalar('loss', loss)
  merged = tf.summary.merge_all()
  writer = tf.summary.FileWriter('./logs', sess.graph)

I tended to make dirty code by sticking the code I wanted to see in the middle, but pytorch_lightning can be written concisely,

MyModel.py


def training_step(self, batch, batch_idx):
  # ...
  loss = ...
  self.logger.summary.scalar('loss', loss, step=self.global_step)

  # equivalent
  result = TrainResult()
  result.log('loss', loss)

  return result

Add to `logger.summary``` in the method when recording like, or add the return loss``` part to the `` pytorch_lightning.LightningModule.TrainResult () `class once. Just bite it and it will automatically save to the save directory!

It's OK to add logger to the constructor of the `` Trainer () `class, and the storage directory is also decided here.

from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)

You can also save data such as text and images using the `.add_hogehoge ()` of the `` `logger.experiment``` object!

MyModel.py


def training_step(...):
  ...
  # the logger you used (in this case tensorboard)
  tensorboard = self.logger.experiment
  tensorboard.add_histogram(...)
  tensorboard.add_figure(...)

The official says that the timing of Callback is also recommended.

It's awesome ... (It's important, so I'll say it twice s (ry)

At the end

As a feeling of using Pytorch Lightning ~ ~ (compared to ignite's poor readability due to the processing being inserted) ~ ~ The rules are easy to understand, and the class design and document maintenance were proper, so I will use it first I felt that it was a recommended deep learning framework for

Recommended Posts

Introduction to how to use Pytorch Lightning ~ Until you format your own model and output it to tensorboard ~
Create your own exception
Introduction to how to use Pytorch Lightning ~ Until you format your own model and output it to tensorboard ~
Until you self-host your own interpreter
Until you get a snapshot of Amazon Elasticsearch service and restore it
[Introduction to pytorch-lightning] How to use torchvision.transforms and how to freely create your own dataset ♬
How to use pyenv and pyenv-virtualenv in your own way
How to install Cascade detector and how to use it
Introduction to Lightning pytorch
How to use Decorator in Django and how to make it
What is pip and how do you use it?
[Python] When you want to import and use your own package in the upper directory
[Introduction to Udemy Python 3 + Application] 36. How to use In and Not
Introduction of DataLiner ver.1.3 and how to use Union Append
[Introduction] How to use open3d
How to return the data contained in django model in json format and map it on leaflet
How to use Google Colaboratory and usage example (PyTorch x DCGAN)
[Introduction to Python] How to use the Boolean operator (and ・ or ・ not)
How to install and use Tesseract-OCR
How to create your own Transform
Until you self-host your own interpreter
How to use .bash_profile and .bashrc
How to install and use Graphviz
From the introduction of GoogleCloudPlatform Natural Language API to how to use it
Introduction of cyber security framework "MITRE CALDERA": How to use and training
It is convenient to use stac_info and exc_info when you want to display traceback in log output by logging.
[Introduction to Python] How to use class in Python?
How to install and use pandas_datareader [Python]
[TF] How to use Tensorboard from Keras
Until you install your own Python library
How to install your own (root) CA
python: How to use locals () and globals ()
Basics of PyTorch (1) -How to use Tensor-
How to use Python zip and enumerate
How to use is and == in Python
How to use pandas Timestamp and date_range
[Python] How to name table data and output it in csv (to_csv method)