Introduction à l'utilisation de Pytorch Lightning ~ Jusqu'à ce que vous formatez votre propre modèle et le produisez sur tensorboard ~

Que faire dans cet article

Pour les personnes qui "J'ai fait mon propre modèle DNN mais le code est sale" et "J'en ai marre du travail de bureau (sauvegarde, journalisation, code commun DNN)"

--Avec la bibliothèque d'explosifs de développement IA Pytorch Lightning

Qu'est-ce que Pytorch Lightning?

Une bibliothèque python qui le fait. C'est le numéro d'étoile Github le plus populaire et le cadre d'apprentissage en profondeur populaire.

Comment utiliser

1. Première installation

console


$ pip install pytorch-lightning

2. Rédigez un modèle d'apprentissage en profondeur selon pytorch_lightning

pytorch_lightning.Hériter de LightningModule,



 * Réseau
 * 3 méthodes: forward (self, x), training_step (self, batch, batch_idx), configure_optimizers (self)

 Si vous définissez les deux, vous pouvez l'utiliser immédiatement. Cependant, notez que le ** nom de la fonction et la paire d'arguments ne peuvent pas être modifiés **!
 (Par exemple batch_idx Si vous le définissez comme `` `` training_step (self, batch) '' `` même si vous n'en avez pas besoin, ce sera bogué)


#### **`MyModel.py`**
```python

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMyModel(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)

    x = F.log_softmax(x, dim=1)
    return x

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)
    return loss

Chacune des trois fonctions "Sortie réseau de retour" "Fonction de travail en 1 boucle et perte de retour" "Optimiseur de retour" Tout traitement est OK

** Pour ceux qui sont longs mais veulent voir l'exemple VAE (Cliquez) **
#Exemple de FC apprentissage MNIST
import pytorch_lightning as pl

class LitMyModel(pl.LightningModule):
    def __init__(self):
        # layers
        self.fc1 = nn.Linear(self.out_size, 400)
        self.fc4 = nn.Linear(400, self.out_size)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.out_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def training_step(self, batch, batch_idx):
        recon_batch, mu, logvar = self.forward(batch)
        loss = self.loss_function(
            recon_batch, batch, mu, logvar, out_size=self.out_size)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        return optimizer

Bien sûr, si vous avez déjà un modèle, déplacez simplement le code Après cela, mettez le chargeur de données et le modèle dans pl.Trainer () fit () '' et commencez à apprendre!

Durée


dataloader = #Your own dataloader or datamodule

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)

éclair Facile et génial.


3. Ajoutez d'autres travaux aux méthodes de cette classe

Maintenant que vous pouvez apprendre tout ce qui précède, ajoutez les méthodes ** test, validation et autres options ** à la classe.

test Ajoutez `` test_step (self, batch, batch_idx) '' à la méthode de classe. Seulement. Exécution

Lors de l'exécution du test


trainer.test()

validation Ceci est également complété par l'ajout de la méthode val_step () '' et de la méthode val_dataloader () '' ~

dataloader Cela peut également être regroupé en méthodes de classe, mais ** Dataset & Data Loader est recommandé d'hériter pytorch_lightning.LightningDataModule d'une autre classe et de définir la classe `` MyDataModule ** ..

** Pour ceux qui sont longs mais veulent voir l'exemple MNIST (Cliquez) **

class MyDataModule(LightningDataModule): def init(self): super().init() self.train_dims = None self.vocab_size = 0

def prepare_data(self):
    # called only on 1 GPU
    download_dataset()
    tokenize()
    build_vocab()

def setup(self):
    # called on every GPU
    vocab = load_vocab()
    self.vocab_size = len(vocab)

    self.train, self.val, self.test = load_datasets()
    self.train_dims = self.train.next_batch.size()

def train_dataloader(self):
    transforms = ...
    return DataLoader(self.train, batch_size=64)

def val_dataloader(self):
    transforms = ...
    return DataLoader(self.val, batch_size=64)

def test_dataloader(self):
    transforms = ...
    return DataLoader(self.test, batch_size=64)
Si vous mordez ceci dans `` .fit () '' au moment de l'apprentissage et des tests, cela sera interprété sans passer data_loader.

Durée


datamodule = MyDataModule()

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)

callback Quelque chose comme "processus à effectuer uniquement au début du train" et "processus à effectuer à la fin de l'époque" https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks Il y a beaucoup d'informations autour. C'est OK si vous définissez une fonction pour le timing que vous souhaitez traiter

from pytorch_lightning.callbacks import Callback

class MyPrintingCallback(Callback):
    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('Trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

trainer = Trainer(callbacks=[MyPrintingCallback()])

Si vous le définissez dans une autre classe, vous pouvez l'écrire de manière concise ~

4. Lien avec tensorboard et ajouter des paramètres d'enregistrement

Maintenant, à partir de maintenant, la principale relation de préservation des enregistrements. Pour afficher des valeurs numériques (perte, précision, etc.), des images, des sons, etc. sur le tensorboard

exemple de tensorflow


with tf.name_scope('summary'):
  tf.summary.scalar('loss', loss)
  merged = tf.summary.merge_all()
  writer = tf.summary.FileWriter('./logs', sess.graph)

J'avais tendance à faire du code sale en collant le code que je voulais voir au milieu, mais pytorch_lightning peut être écrit de manière concise,

MyModel.py


def training_step(self, batch, batch_idx):
  # ...
  loss = ...
  self.logger.summary.scalar('loss', loss, step=self.global_step)

  # equivalent
  result = TrainResult()
  result.log('loss', loss)

  return result

Ajoutez à logger.summary dans la méthode lors de l'enregistrement comme, ou ajoutez une fois la partie `` `` return loss à la classe pytorch_lightning.LightningModule.TrainResult () ``. Il suffit de le mordre et il sera automatiquement enregistré dans le répertoire de sauvegarde!

logger est OK si vous l'ajoutez au constructeur de la classe Trainer () '', et le répertoire de stockage est également décidé ici.

from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)

Vous pouvez également enregistrer des données telles que du texte et des images en utilisant le .add_hogehoge () '' de l'objet `` logger.experiment```!

MyModel.py


def training_step(...):
  ...
  # the logger you used (in this case tensorboard)
  tensorboard = self.logger.experiment
  tensorboard.add_histogram(...)
  tensorboard.add_figure(...)

Le fonctionnaire dit que le moment du rappel est également recommandé.

C'est génial ... (C'est important, alors je vais le dire deux fois

À la fin

Comme impression d'utiliser Pytorch Lightning ~ ~ (par rapport à la mauvaise lisibilité de ignite en raison du traitement inséré) ~ ~ Les règles sont faciles à comprendre, et la conception de la classe et la maintenance des documents étaient correctes, je vais donc les utiliser en premier J'ai senti que c'était un cadre d'apprentissage en profondeur recommandé pour

Recommended Posts

Introduction à l'utilisation de Pytorch Lightning ~ Jusqu'à ce que vous formatez votre propre modèle et le produisez sur tensorboard ~
Créez votre propre exception
Introduction à l'utilisation de Pytorch Lightning ~ Jusqu'à ce que vous formatez votre propre modèle et le produisez sur tensorboard ~
Jusqu'à ce que vous hébergiez vous-même votre propre interprète
Jusqu'à ce que vous obteniez un instantané du service Amazon Elasticsearch et que vous le restauriez
Comment utiliser pyenv et pyenv-virtualenv à votre manière
Comment installer le détecteur Cascade et comment l'utiliser
Introduction à Lightning Pytorch
Comment utiliser Decorator dans Django et comment le créer
Qu'est-ce que pip et comment l'utilisez vous?
[Python] Lorsque vous souhaitez importer et utiliser votre propre package dans le répertoire supérieur
[Introduction à l'application Udemy Python3 +] 36. Utilisation de In et Not
Introduction de DataLiner ver.1.3 et comment utiliser Union Append
[Introduction] Comment utiliser open3d
Comment retourner les données contenues dans le modèle django au format json et les mapper sur le dépliant
Comment utiliser Google Colaboratory et exemple d'utilisation (PyTorch × DCGAN)
[Introduction à Python] Comment utiliser l'opérateur booléen (et ・ ou ・ non)
Comment installer et utiliser Tesseract-OCR
Jusqu'à ce que vous hébergiez vous-même votre propre interprète
Comment utiliser .bash_profile et .bashrc
Comment installer et utiliser Graphviz
De l'introduction de l'API GoogleCloudPlatform Natural Language à son utilisation
Introduction du cadre de cybersécurité "MITRE CALDERA": utilisation et formation
Il est pratique d'utiliser stac_info et exc_info lorsque vous souhaitez afficher la traceback dans la sortie du journal par journalisation.
[Introduction à Python] Comment utiliser la classe en Python?
Comment installer et utiliser pandas_datareader [Python]
[TF] Comment utiliser Tensorboard de Keras
Jusqu'à ce que vous installiez votre propre bibliothèque Python
Comment installer votre propre autorité de certification (racine)
python: Comment utiliser les locals () et globals ()
Bases de PyTorch (1) -Comment utiliser Tensor-
Comment utiliser le zip Python et énumérer
Comment utiliser is et == en Python
Comment utiliser les pandas Timestamp et date_range
[Python] Comment nommer les données de table et les sortir avec csv (méthode to_csv)