Trouvez le différentiel du second ordre avec la différenciation automatique JAX

introduction

La différenciation automatique est pratique lorsque vous souhaitez rechercher la différenciation d'une fonction compliquée. A cette époque, j'utilisais la différenciation automatique de Pytorch. Cependant, je voulais n'utiliser que la différenciation automatique, mais le paquet pytorch est assez lourd, donc quand je cherchais un paquet plus léger, je suis arrivé à JAX. ..

Qu'est-ce que JAX?

Officiel image.png Une version mise à jour de Autograd (non actuellement maintenue). Vous pouvez utiliser le GPU pour calculer la différenciation automatique à grande vitesse (bien sûr, cela fonctionne également sur le CPU).

Comment installer

CPC only version

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

Si vous souhaitez utiliser GPU, veuillez vous reporter aux instructions d'installation de pip dans Officiel.

Différenciation de second ordre dans JAX

Maintenant, essayons de trouver la différenciation de second ordre de la fonction de journal.

import jax.numpy as jnp
from jax import grad

#définition de la fonction log
fn = lambda x0: jnp.log(x0)

# x =Différencier autour de 1
x = 1

#Substitution
y0 = fn(x)
#Différenciation unique
y1 = grad(fn)(x)
#Différentiel de second ordre
y2 = grad(grad(fn))(x)

Résultat d'exécution


>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)

finalement

Avec JAX, vous pouvez facilement et facilement utiliser la différenciation automatique.

Recommended Posts

Trouvez le différentiel du second ordre avec la différenciation automatique JAX
Trouvez la distance d'édition (distance de Levenshtein) avec python
Trouvez la valeur SHA256 avec R (avec bonus)
La deuxième nuit de la boucle avec pour
Trouvez la valeur de l'humeur avec python (Rike Koi)
Trouvez la position au-dessus du seuil avec NumPy
Prédisez le deuxième tour de l'été 2016 avec scikit-learn
Découvrez le jour par date / heure
Trouvez l'itinéraire le plus court avec l'algorithme de Python Dijkstra
Calculer la somme des valeurs uniques par tabulation croisée des pandas
Découvrez l'emplacement des packages installés avec pip