[PyTorch] Pourquoi vous pouvez traiter une instance de CrossEntropyLoss () comme une fonction

Instance = fonction? ?? ?? ??

[Apprenez en faisant! Développement Deep Learning par PyTorch](https://www.amazon.co.jp/%E3%81%A4%E3%81%8F%E3%82%8A%E3%81%AA%E3%81%8C% E3% 82% 89% E5% AD% A6% E3% 81% B6% EF% BC% 81PyTorch% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E7% 99% BA% E5% B1% 95% E3% 83% 87% E3% 82% A3% E3% 83% BC% E3% 83% 97% E3% 83% A9% E3% 83% BC% E3% 83% 8B% E3% 83% B3% E3% 82% B0-% E5% B0% 8F% E5% B7% 9D-% E9% 9B% 84% E5% A4% AA% E9% 83% 8E-ebook / dp / B07VPDVNKW) Il y avait une telle description dans 1-3 transfert d'apprentissage. (Vous pouvez voir tout le code sur Author GitHub)

1-3_transfer_learning.ipynb


#Importation de package
import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

(Omission)

#Paramètres de la fonction de perte
criterion = nn.CrossEntropyLoss()

(Omission)

def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

    #boucle d'époque
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        #Boucle d'apprentissage et de vérification pour chaque époque
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  #Mettre le modèle en mode entraînement
            else:
                net.eval()   #Mettre le modèle en mode validation

            epoch_loss = 0.0  #somme des pertes d'époque
            epoch_corrects = 0  #Nombre de bonnes réponses pour l'époque

            #Époque pour vérifier les performances de vérification lorsque non appris=0 formation omise
            if (epoch == 0) and (phase == 'train'):
                continue

            #Boucle pour récupérer un mini-lot à partir du chargeur de données
            for inputs, labels in tqdm(dataloaders_dict[phase]):

                #Initialiser l'optimiseur
                optimizer.zero_grad()

                #Calcul à terme
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)  #Calculer la perte
                    _, preds = torch.max(outputs, 1)  #Étiquette de prédiction
                    
  
                    #Propagation du dos pendant l'entraînement
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    #Calcul des résultats d'italation
                    #Mettre à jour la perte totale
                    epoch_loss += loss.item() * inputs.size(0)  
                    #Mise à jour du nombre total de bonnes réponses
                    epoch_corrects += torch.sum(preds == labels.data)

            #Afficher la perte et le taux de réponse correct pour chaque époque
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double(
            ) / len(dataloaders_dict[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

Je voudrais que vous prêtiez attention ici aux ** critères **. Il est défini comme une instance de nn.CrossEntropyLoss () comme suit:

1-3_transfer_learning.ipynb


criterion = nn.CrossEntropyLoss()

Et je traite les ** critères ** comme une fonction.

1-3_transfer_learning.ipynb


loss = criterion(outputs, labels)

Cependant, lorsque je vérifie le code source de torch.nn.CrossEntropyLoss, il n'y a pas de description de __call__ method **! Alors ** pourquoi pouvez-vous traiter une instance de CrossEntropyLoss () comme une fonction? ** ** Le but de cet article est de résoudre ce mystère. Voir ici pour savoir pourquoi la présence ou l'absence de la «méthode call» est importante.

À propos de l'héritage de classe

Le début du Code source de la classe CrossEntropyLoss est écrit comme suit.

Python:torch.nn.modules.loss


class CrossEntropyLoss(_WeightedLoss):

Tout d'abord, qu'est-ce que cela signifie de mettre quelque chose entre parenthèses lors de la définition de class en Python? Ceci est appelé ** héritage de classe **, et est utilisé lors de l'appel d'une fonction ou d'une méthode définie dans une autre classe telle quelle. (L'exemple spécifique suivant est cité ici)

#Héritage
class MyClass:
    def hello(self):
        print("Hello")

class MyClass2(MyClass):
    def world(self):
        print("World")

a = MyClass2()
a.hello() # Hello
a.world() # World

La mise en garde ici est que si une classe parent et une classe enfant ont une méthode avec le même nom défini, la méthode de la classe enfant sera écrasée. C'est ce qu'on appelle un remplacement.

#passer outre
class MyClass:
    def hello(self):
        print("Hello")

class MyClass2(MyClass):
    def hello(self):        #Classe des parents bonjour()Méthode d'écrasement
        print("HELLO")

a = MyClass2()
a.hello()                   # HELLO

Et je veux utiliser la méthode définie dans la classe parent pour la méthode de la classe enfant! Vous pouvez utiliser la fonction super () quand vous y pensez.

class MyClass1:
    def __init__(self):
       self.val1 = 123

class MyClass2(MyClass1):
    def __init__(self):
        super().__init__()
        self.val2 = 456

a = MyClass2()
print(a.val1) # 123
print(a.val2) # 456

Pour revenir à l'histoire, la classe CrossEntropyLoss hérite de la classe _WeightedLoss. Au fait, si vous vérifiez un peu plus le code de CrossEntropyLoss,

Python:torch.nn.modules.loss


class CrossEntropyLoss(_WeightedLoss):

__constants__ = ['ignore_index', 'reduction']
    ignore_index: int

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean') -> None:
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

La description est un peu différente de l'exemple ci-dessus car elle est super (CrossEntropyLoss, self), mais [Python official](https://docs.python.org/ja/3/library/functions.html?highlight Si vous faites référence à = super # super), vous pouvez voir que les significations des deux sont exactement les mêmes.

Du fonctionnaire


class C(B):
    def method(self, arg):
        super().method(arg)    # This does the same thing as:
                               # super(C, self).method(arg)

Jetons maintenant un œil à la description de la classe _WeitedLoss.

Python:torch.nn.modules.loss


class _WeightedLoss(_Loss):

De là, nous pouvons voir que _WeitedLoss hérite de _Loss. Jetons maintenant un œil à la description de la classe _WeitedLoss.

Python:torch.nn.modules.loss


class _Loss(Module):

De là, nous pouvons voir que _Loss hérite de Module. Jetons un œil à la description de la classe Module.

Python:torch.nn.modules.module


class Module:

Module n'hérite de rien! Vérifions donc à partir du contenu de Module.

torch.nn.Module

La _Loss class hérite de la __init__ method de la Module class, donc vérifiez ceci uniquement. Je vais essayer.

Python:torch.nn.modules.module


#Remarque:Tous les codes ne sont pas répertoriés
from collections import OrderedDict, namedtuple

class Module:
    _version: int = 1

    training: bool

    dump_patches: bool = False
    
    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

Vous pouvez voir que de nombreux ʻOrderedDict () sont définis ici. Pour plus d'informations sur ʻOrederedDict () , veuillez vous référer à here. En termes simples, comme son nom l'indique, **" Un dict vide qui est ordonné **. En d'autres termes, cette classe définit simplement un grand nombre de dictionnaires vides.

Et, en fait, la méthode __call__ en question est définie ici!

Python:torch.nn.modules.module


def _call_impl(self, *input, **kwargs):
        for hook in itertools.chain(
                _global_forward_pre_hooks.values(),
                self._forward_pre_hooks.values()):
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in itertools.chain(
                _global_forward_hooks.values(),
                self._forward_hooks.values()):
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in itertools.chain(
                        _global_backward_hooks.values(),
                        self._backward_hooks.values()):
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

    __call__ : Callable[..., Any] = _call_impl

La dernière ligne __call__: Callable [..., Any] = _call_impl définit le contenu de __call__ sur _call_impl, donc si vous appelez l'instance comme une fonction, la fonction ci-dessus sera exécutée. Si vous ne comprenez pas la signification de «Callable [..., Any]», vous pouvez vous référer à ici. En outre, ce deux-points est une annotation de fonction, voir ici pour plus de détails. En termes simples, il "écrit simplement une expression qui sert d'annotation dans l'argument ou la valeur de retour de la fonction".

Je vais suivre la signification de ce code dans cet article.

En plus de ce qui précède, certaines méthodes sont définies dans la Classe de module, donc vérifiez si nécessaire.

Ce qui suit peut être lu en scannant.

torch.nn._Loss

La _Loss class hérite de la __init__ method de la _Loss class. Je vais vérifier.

Python:torch.nn.modules.loss


class _Loss(Module):
reduction: str

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_Loss, self).__init__()
        if size_average is not None or reduce is not None:
            self.reduction = _Reduction.legacy_get_string(size_average, reduce)
        else:
            self.reduction = reduction

Ici, vous pouvez voir que nous introduisons une nouvelle self.reduction. Et cette valeur semble dépendre des valeurs de «size_average» et de «reduction».

torch.nn.__WeightedLoss

La méthode __init__ de la classe _WeightedLoss est héritée de la classe CrossEntropyLoss. Je vais vérifier.

Python:torch.nn.modules.loss


class _WeightedLoss(_Loss):
    def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)

Ici, ʻOptional [Tensor] est spécifié dans l'annotation de fonction de weight`. L'explication de ici est facile à comprendre. En termes simples, «poids» signifie que soit le «type Tensor» soit le «type Aucun» peuvent être inclus.

Revenons au sujet principal. Il y a une nouvelle fonction appelée self.register_buffer ici, qui est une fonction définie dans la classe Module. Voici le code source.

Python:torch.nn.modules.module


forward: Callable[..., Any] = _forward_unimplemented

    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
        r"""Adds a buffer to the module.

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the module's state. Buffers, by
        default, are persistent and will be saved alongside parameters. This
        behavior can be changed by setting :attr:`persistent` to ``False``. The
        only difference between a persistent buffer and a non-persistent buffer
        is that the latter will not be a part of this module's
        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.

        Args:
            name (string): name of the buffer. The buffer can be accessed
                from this module using the given name
            tensor (Tensor): buffer to be registered.
            persistent (bool): whether the buffer is part of this module's
                :attr:`state_dict`.

        Example::

            >>> self.register_buffer('running_mean', torch.zeros(num_features))

        """
        if persistent is False and isinstance(self, torch.jit.ScriptModule):
            raise RuntimeError("ScriptModule does not support non-persistent buffers")

        if '_buffers' not in self.__dict__:
            raise AttributeError(
                "cannot assign buffer before Module.__init__() call")
        elif not isinstance(name, torch._six.string_classes):
            raise TypeError("buffer name should be a string. "
                            "Got {}".format(torch.typename(name)))
        elif '.' in name:
            raise KeyError("buffer name can't contain \".\"")
        elif name == '':
            raise KeyError("buffer name can't be empty string \"\"")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists".format(name))
        elif tensor is not None and not isinstance(tensor, torch.Tensor):
            raise TypeError("cannot assign '{}' object to buffer '{}' "
                            "(torch Tensor or None required)"
                            .format(torch.typename(tensor), name))
        else:
            self._buffers[name] = tensor
            if persistent:
                self._non_persistent_buffers_set.discard(name)
            else:
                self._non_persistent_buffers_set.add(name)

C'est un code assez long, mais la moitié supérieure est l'explication du code, et la partie ci-dessus ʻelse de ʻif instruction ne définit l'erreur, donc l'explication est omise. Et dans ʻelse, vous mettez des éléments dans self._buffersdedict type. En d'autres termes, en définissant la classe WeightedLoss`, nous avons:

self._buffer = {'weight': weight} #Le poids à droite est de type Tensor ou None

torch.nn.CrossEntropyLoss Enfin, je suis revenu à la question. Voici le code source. Il y a un long commentaire, mais je vais tous les citer.

Python:torch.nn.modules.loss


class CrossEntropyLoss(_WeightedLoss):
    r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class.

    It is useful when training a classification problem with `C` classes.
    If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
    assigning weight to each of the classes.
    This is particularly useful when you have an unbalanced training set.

    The `input` is expected to contain raw, unnormalized scores for each class.

    `input` has to be a Tensor of size either :math:`(minibatch, C)` or
    :math:`(minibatch, C, d_1, d_2, ..., d_K)`
    with :math:`K \geq 1` for the `K`-dimensional case (described later).

    This criterion expects a class index in the range :math:`[0, C-1]` as the
    `target` for each value of a 1D tensor of size `minibatch`; if `ignore_index`
    is specified, this criterion also accepts this class index (this index may not
    necessarily be in the class range).

    The loss can be described as:

    .. math::
        \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
                       = -x[class] + \log\left(\sum_j \exp(x[j])\right)

    or in the case of the :attr:`weight` argument being specified:

    .. math::
        \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right)

    The losses are averaged across observations for each minibatch. If the
    :attr:`weight` argument is specified then this is a weighted average:

    .. math::
        \text{loss} = \frac{\sum^{N}_{i=1} loss(i, class[i])}{\sum^{N}_{i=1} weight[class[i]]}

    Can also be used for higher dimension inputs, such as 2D images, by providing
    an input of size :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`,
    where :math:`K` is the number of dimensions, and a target of appropriate shape
    (see below).


    Args:
        weight (Tensor, optional): a manual rescaling weight given to each class.
            If given, has to be a Tensor of size `C`
        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
            the losses are averaged over each loss element in the batch. Note that for
            some losses, there are multiple elements per sample. If the field :attr:`size_average`
            is set to ``False``, the losses are instead summed for each minibatch. Ignored
            when reduce is ``False``. Default: ``True``
        ignore_index (int, optional): Specifies a target value that is ignored
            and does not contribute to the input gradient. When :attr:`size_average` is
            ``True``, the loss is averaged over non-ignored targets.
        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
            losses are averaged or summed over observations for each minibatch depending
            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
            batch element instead and ignores :attr:`size_average`. Default: ``True``
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
            be applied, ``'mean'``: the weighted mean of the output is taken,
            ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in
            the meantime, specifying either of those two args will override
            :attr:`reduction`. Default: ``'mean'``

    Shape:
        - Input: :math:`(N, C)` where `C = number of classes`, or
          :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
          in the case of `K`-dimensional loss.
        - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
          K-dimensional loss.
        - Output: scalar.
          If :attr:`reduction` is ``'none'``, then the same size as the target:
          :math:`(N)`, or
          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
          of K-dimensional loss.

    Examples::

        >>> loss = nn.CrossEntropyLoss()
        >>> input = torch.randn(3, 5, requires_grad=True)
        >>> target = torch.empty(3, dtype=torch.long).random_(5)
        >>> output = loss(input, target)
        >>> output.backward()
    """
    __constants__ = ['ignore_index', 'reduction']
    ignore_index: int

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,
                 reduce=None, reduction: str = 'mean') -> None:
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

Tout d'abord, dans la méthode «init», une nouvelle variable appelée «self.ignore_index» a été ajoutée. Et une fonction appelée forward () est également définie. Cependant, la méthode __call__ n'a pas été définie depuis la classe Module. Par conséquent, la «méthode call» de la «classe Module» était l'identité que l'instance de la «classe CrossEntropyLoss» était utilisée comme une fonction.

Dans cet article, j'aimerais voir de plus près ce qui se passe lorsque vous traitez une instance de CrossEntropyLoss () comme une fonction!

Recommended Posts

[PyTorch] Pourquoi vous pouvez traiter une instance de CrossEntropyLoss () comme une fonction
[PyTorch] Un peu de compréhension de CrossEntropyLoss avec des formules mathématiques
Créer une instance d'une classe prédéfinie à partir d'une chaîne en Python
[Road to Intermediate Python] Appelez une instance de classe comme une fonction avec __call__
Si vous donnez une liste avec l'argument par défaut de la fonction ...
Utilisation de lambda (lors du passage d'une fonction comme argument d'une autre fonction)