Automatic differentiation is convenient when you want to find the derivative of a complicated function. At that time, I used to use Pytorch's automatic differentiation. However, I wanted to use only automatic differentiation, but the pytorch package is quite heavy, so when I was looking for a lighter package, I arrived at JAX. ..
Official An updated version of Autograd (not currently maintained). You can use the GPU to calculate automatic differentiation at high speed (of course it also works on the CPU).
CPC only version
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
If you want to use GPU, please refer to the pip installation guidance in Official.
Let's try to find the second derivative of the log function.
import jax.numpy as jnp
from jax import grad
#log function definition
fn = lambda x0: jnp.log(x0)
# x =Differentiate around 1
x = 1
#Substitution
y0 = fn(x)
#One-time differentiation
y1 = grad(fn)(x)
#Second derivative
y2 = grad(grad(fn))(x)
Execution result
>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)
With JAX, you can easily and easily use automatic differentiation.
Recommended Posts