linuxsir首页 LinuxSir.Org | Linux、BSD、Solaris、Unix | 开源传万世,因有我参与欢迎您!
网站首页 | 设为首页 | 加入收藏
您所在的位置:主页 > 小企鹅新闻图书馆 >

谷歌开源TensorFlow的简化库JAX

时间:2018-12-17  来源:未知  作者:admin666

谷歌开源了一个 TensorFlow 的简化库 JAX。


JAX 结合了 Autograd 和 XLA,专门用于高性能机器学习研究。

凭借 Autograd,JAX 可以求导循环、分支、递归和闭包函数,并且它可以进行三阶求导。通过 grad,它支持自动模式反向求导(反向传播)和正向求导,且二者可以任何顺序任意组合。

得力于 XLA,可以在 GPU 和 TPU 上编译和运行 NumPy 程序。默认情况下,编译发生在底层,库调用实时编译和执行。但是 JAX 还允许使用单一函数 API jit 将 Python 函数及时编译为 XLA 优化的内核。编译和自动求导可以任意组合,因此可以在 Python 环境下实现复杂的算法并获得最大的性能。

demo:

import jax.numpy as np
from jax import grad, jit, vmap
from functools import partial
def predict(params, inputs):
 for W, b in params:
 outputs = np.dot(inputs, W) + b
 inputs = np.tanh(outputs)
 return outputs
def logprob_fun(params, inputs, targets):
 preds = predict(params, inputs)
 return np.sum((preds - targets)**2)
grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grads

更深入地看,JAX 实际上是一个可扩展的可组合函数转换系统,grad 和 jit 都是这种转换的实例。

项目地址:https://github.com/google/JAX

友情链接