1.1.1 JAX是什么
JAX官方文档的解释是:“JAX是CPU、GPU和TPU上的NumPy,具有出色的自动差异化功能,可用于高性能机器学习研究。”
就像JAX官方文档解释的那样,最简单的JAX是加速器支持的NumPy,它具有一些便利的功能,可用于常见的机器学习操作。
更具体地说,JAX的前身是Autograd,也就是Autograd的升级版本。JAX可以对Python和NumPy程序进行自动微分,可通过Python的大量特征子集进行区分,包括循环、分支、递归和闭包语句进行自动求导,也可以求三阶导数(三阶导数是由原函数导数的导数的导数,即将原函数进行三次求导)。通过grad,JAX完全支持反向模式和正向模式的求导,而且这两种模式可以任意组合成任何顺序,具有一定灵活性。
开发JAX的出发点是什么?说到这,就不得不提NumPy。NumPy是Python中的一个基础数值运算库,被广泛使用,但是NumPy不支持GPU或其他硬件加速器,也没有对反向传播的内置支持。此外,Python本身的速度限制了NumPy使用,所以少有研究者在生产环境下直接用NumPy训练或部署深度学习模型。
在此情况下,出现了众多的深度学习框架,如PyTorch、TensorFlow等。但是NumPy具有灵活、调试方便、API稳定等独特的优势,而JAX的主要出发点就是将NumPy的优势与硬件加速相结合。
目前,基于JAX已有很多优秀的开源项目,如谷歌的神经网络库团队开发了Haiku,这是一个面向JAX的深度学习代码库,通过Haiku,用户可以在JAX上进行面向对象开发;又比如RLax,这是一个基于JAX的强化学习库,用户使用RLax就能进行Q-learning模型的搭建和训练;此外还包括基于JAX的深度学习库JAXnet,该库一行代码就能定义计算图,可进行GPU加速。可以说,在过去几年中,JAX掀起了深度学习研究的风暴,推动了其相关科学研究的迅速发展。