这几天各大科技媒体都在唱衰TensorFlow,鼓吹JAX。恰好前两个月我都在用JAX,算是从JAX新人进阶为小白,过来吹吹牛。
吃瓜群众都在疯狂吐槽TensorFlow的API多混乱,PyTorch多好用,但是好像,并没有多少人真正说到JAX。
JAX到底是啥?简单说,JAX是一种自动微分的NumPy。所以JAX并不是一个深度学习框架,而是一个科学计算框架。深度学习是JAX功能的一个子集。
既然是NumPy,那就可以用NumPy接口做各类科学计算。
而且还带自动微分,科学计算世界中,微分是最常用的一种计算。JAX的自动微分包含了前向微分、反向微分等各种接口。反正各类花式微分,几乎都可以用JAX实现。
除了"NumPy" + "自动微分",JAX还有几个其他的功能:
将NumPy接口写的计算转成高效的二进制代码,可以在CPU/GPU/TPU上获得极高加速比。JIT编译主要还是基于XLA(accelerated linear algebra)。XLA是一种编译器,可以将TF/JAX的代码在CPU/GPU/TPU上加速。
说到JAX速度快,主要就靠XLA!
比起简单的NumPy,JAX提供了大量接口做并行。无论是tf还是torch,一个简单的并行方法是:batch size。JAX用 vmap
做并行, 用户只用实现一条数据的处理,JAX帮我们将做拓展,可以拓展到batch size大小。vmap
的思想与 Spark 中的 map
一样。用户关注 map
里面的一条数据的处理方法,JAX 帮我们做并行化。
到这就不得不提JAX的函数式编程。函数式编程相对“面向对象”(Object Oriented)就难很多了。毕竟,绝大多数中国程序员都没有系统学习过函数式编程。
JAX是纯函数式的。
第一让人不适应的就是数据的不可变(Immutable)。不能原地改数据,只能创建新数据。
第二就是各类闭包。“闭包”这个名字就很抽象,更不用说真正写起来了。
然后就是partial
。
这些东西在torch用户那里可能一辈子都用不到。
来到JAX世界,你都会怀疑自己到底学没学过Python。
JAX并不是一个深度学习框架。想要做深度学习,还要再在JAX上套一层。
要想在JAX上实现一个全连接网络,要 np.dot(w, x) + b
。竟然没有现成的 nn.Dense
或者 nn.Linear
。
于是有了DeepMind的 haiku ,Google的 flax,和其他各种各样的库。
JAX是纯函数的,代码写起来和tf、torch也不太一样。没有了 .fit()
这样傻瓜式的接口,没有 MSELoss
这样的损失函数。而且要适应数据的不可变:模型参数先初始化init
,才能使用。
不过,flax 和 haiku 也有不少市场了。大名鼎鼎的AlphaFold就是用 haiku 写的。
JAX到底好不好我不敢说。但是大家都在学它。看看PyTorch刚发布的 torchfunc
,里面的vmap
就是学得JAX。还有各个框架都开始提供的前向微分 jvp,都是JAX的影子。