[PyTorch] Un peu de compréhension de CrossEntropyLoss avec des formules mathématiques

introduction

Parce que j'utilise souvent `` critère = torch.nn.CrossEntropyLoss () '' comme base de la fonction de perte de Pytorch. Il est produit pour comprendre les détails. Si vous faites une erreur, faites-le moi savoir.

CrossEntropyLoss Exemple Pytorch (1)

torch.manual_seed(42) #Graine fixe pour maintenir la reproductibilité
loss = nn.CrossEntropyLoss()
input_num = torch.randn(1, 5, requires_grad=True)
target = torch.empty(1, dtype=torch.long).random_(5)
print('input_num:',input_num)
print('target:',target)
output = loss(input_num, target)
print('output:',output)
input_num: tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229]], requires_grad=True)
target: tensor([0])
output: tensor(1.3472, grad_fn=<NllLossBackward>)

En supposant que la classe de réponse correcte est $ class $ et que le nombre de classes est $ n $, l'erreur $ loss $ de CrossEntropyLoss peut être exprimée par la formule suivante.

loss=-\log(\frac{\exp(x[class])}{\sum_{j=0}^{n} \exp(x[j])}) \\
=-(\log(\exp(x[class])- \log(\sum_{j=0}^{n} \exp(x[j])) \\
=-\log(\exp(x[class]))+\log(\sum_{j=0}^{n} \exp(x[j])) \\
=-x[class]+\log(\sum_{j=0}^{n} \exp(x[j])) \\

À partir de l'exemple de code source, la classe de réponse correcte est $ class = 0 $ et le nombre de classes est $ n = 5 $, donc si vous la cochez

loss=-x[0]+\log(\sum_{j=0}^{5} \exp(x[j]))\\
=-x[0]+\log(\exp(x[0])+\exp(x[1])+\exp(x[2])+\exp(x[3])+\exp(x[4])) \\
= -0.3367 + \log(\exp(0.3367)+\exp(0.1288)+\exp(0.2345)+\exp(0.2303)+\exp(-1.1229)) \\
= 1.34717 \cdots \\
\fallingdotseq 1.34712

Il correspondait au résultat du programme en toute sécurité! Au fait, le calcul se fait avec le code suivant (le calcul manuel est impossible ...)

from math import exp, log
x_sum = exp(0.3367)+exp( 0.1288)+exp(0.2345)+exp(0.2303)+exp(-1.1229)
x = 0.3367
ans = -x + log(x_sum)
print(ans) # 1.3471717976017477

C'est une poussée.

à la fin

L'erreur d'arrondi (génération de fractions circulaires due à l'affichage binaire des points décimaux) semble désormais inutile. Normalement, c'est random.seed (42) '', mais avec Pytorch c'est torch.manual_seed (42) '', donc c'est comme ça.

Les références

(1)TORCH.NN

Recommended Posts

[PyTorch] Un peu de compréhension de CrossEntropyLoss avec des formules mathématiques
LiNGAM (version ICA) à comprendre avec des formules mathématiques et Python
Prédiction de la moyenne Nikkei avec Pytorch 2
Prédiction de la moyenne Nikkei avec Pytorch
Un peu coincé dans le chainer
Prédiction de la moyenne Nikkei avec Pytorch ~ Makuma ~
Une collection de conseils pour accélérer l'apprentissage et le raisonnement avec PyTorch
[PyTorch] Pourquoi vous pouvez traiter une instance de CrossEntropyLoss () comme une fonction
Classification multi-étiquette d'images multi-classes avec pytorch
Une petite introduction de fonction de niche de faiss
Essayez une formule utilisant Σ avec python
Compréhension complète de la programmation asynchrone Python
Une compréhension approximative de python-fire et un mémo
Un petit examen minutieux de Pandas 1.0 et Dask
Somme des variables dans un modèle mathématique
Mémorandum sur le QueryDict de Django
Créez un quiz de dessin avec kivy + PyTorch
Compréhension complète de la programmation orientée objet de Python
Mémorandum de migration avec GORM
[AtCoder] Résoudre un problème de ABC101 ~ 169 avec Python
Résolvez A ~ D du codeur yuki 247 avec python
Histoire d'essayer d'utiliser Tensorboard avec Pytorch
Obtenir une liste d'utilisateurs IAM avec Boto3
Planification des tâches un peu avancée avec AP Schuler
[Python] Une compréhension approximative du module de journalisation
Flux de création d'un environnement virtuel avec Anaconda
[PyTorch] J'étais un peu perdu dans torch.max ()
[Python] Une compréhension approximative des itérables, des itérateurs et des générateurs
Créer une table avec le notebook IPython
C'était un peu difficile de faire flacon avec la version docker de nginx-unit
Article qui vous aidera à comprendre un peu l'algorithme de collision de sphères rigides