Hello Jax

JAX is quickly becoming a popular framework for training Deep Learning models. I spent some time playing around with it last year, so decided to share some basic stuff that I learnt in case someone finds it useful. This tutorial implements logistic regression with Autograd in JAX. At its core, JAX is essentially a numerical computation library, like Numpy but with support for accelerators like GPU and TPU, and with a robust support for differentiation (Autograd). In this tutorial, we’ll be implementing a logistic regression from scratch in JAX to give a very basic introduction to JAX. The GIF below shows what we’ll be implementing today. ...

October 30, 2022 · 5 min · Gagan Madan