JAX

维基百科,自由的百科全书
跳转到导航 跳转到搜索
JAX
File:Google JAX logo.svg
开发者Google, Nvidia[1]
首次发布2019年10月31日,​6年前​(2019-10-31[2]
当前版本
    Module:EditAtWikidata第29行Lua错误:attempt to index field 'wikibase' (a nil value)
    源代码库github.com/jax-ml/jax
    编程语言Python, C++
    引擎
      Module:EditAtWikidata第29行Lua错误:attempt to index field 'wikibase' (a nil value)
      操作系统Linux, macOS, Windows
      平台Python, NumPy
      类型机器学习
      许可协议Apache 2.0
      网站docs.jax.dev/en/latest/

      JAX,由Google开发、并由Nvidia做出部分贡献[3][4][5]Python机器学习框架,用于变换数值函数。JAX结合了修改版的Autograd自动微分系统[6],以及来自OpenXLA专案的编译器XLA英语Accelerated Linear Algebra[7],可加速数值线性运算。其设计目标是在界面与程式设计风格上尽可能与NumPy保持相容,使使用者能够以熟悉的方式撰写高效能运算程式。此外,JAX亦可与TensorFlowPyTorch等机器学习框架整合使用。[8][9]

      主要功能[编辑]

      JAX的主要功能是[3]

      grad[编辑]

      下面的代码演示grad函数的自动微分。

      # 导入库
      from jax import grad
      import jax.numpy as jnp
      
      # 定义logistic函数
      def logistic(x):  
          return jnp.exp(x) / (jnp.exp(x) + 1)
      
      # 获得logistic函数的梯度函数
      grad_logistic = grad(logistic)
      
      # 求值logistic函数在x = 1处的梯度 
      grad_log_out = grad_logistic(1.0)   
      print(grad_log_out)
      

      最终的输出为:

      0.19661194
      

      jit[编辑]

      下面的代码演示jit函数的优化。

      # 导入库
      from jax import jit
      import jax.numpy as jnp
      
      # 定义cube函数
      def cube(x):
          return x * x * x
      
      # 生成数据
      x = jnp.ones((10000, 10000))
      
      # 创建cube函数的jit版本
      jit_cube = jit(cube)
      
      # 应用cube函数和jit_cube函数于相同数据来比较其速度
      cube(x)
      jit_cube(x)
      

      可见jit_cube的运行时间显著的短于cube

      vmap[编辑]

      下面的代码展示vmap函数的通过SIMD的向量化。

      # 导入库
      from functools import partial
      from jax import vmap
      import jax.numpy as jnp
      
      # 定义函数
      def grads(self, inputs):
          in_grad_partial = partial(self._net_grads, self._net_params)
          grad_vmap = vmap(in_grad_partial)
          rich_grads = grad_vmap(inputs)
          flat_grads = np.asarray(self._flatten_batch(rich_grads))
          assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
          return flat_grads
      

      使用JAX的库[编辑]

      一些Python库使用JAX作为后端,这包括:

      参见[编辑]

      引用[编辑]

      1. ^ jax/AUTHORS at main · jax-ml/jax. GitHub. [December 21, 2024]. 
      2. ^ jax-v0.1.49. 
      3. ^ 3.0 3.1 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao, JAX: Autograd and XLA, Astrophysics Source Code Library (Google), 2022-06-18 [2022-06-18], Bibcode:2021ascl.soft11002B, (原始内容存档于2022-06-18) 
      4. ^ Frostig, Roy; Johnson, Matthew James; Leary, Chris. Compiling machine learning programs via high-level tracing (PDF). MLsys. 2018-02-02: 1–3. (原始内容存档 (PDF)于2022-06-21). 
      5. ^ Using JAX to accelerate our research. www.deepmind.com. [2022-06-18]. (原始内容存档于2022-06-18) (English). 
      6. ^ autograd. [2023-09-23]. (原始内容存档于2022-07-18). 
      7. ^ XLA. [2023-09-23]. (原始内容存档于2022-09-01). 
      8. ^ Lynley, Matthew. Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta. Business Insider. [2022-06-21]. (原始内容存档于2022-06-21) (en-US). 
      9. ^ Why is Google's JAX so popular?. Analytics India Magazine. 2022-04-25 [2022-06-18]. (原始内容存档于2022-06-18) (en-US). 
      10. ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29 [2022-07-29], (原始内容存档于2022-09-03) 
      11. ^ Kidger, Patrick, Equinox, 2022-07-29 [2022-07-29], (原始内容存档于2023-09-19) 
      12. ^ Kidger, Patrick, Diffrax, 2023-08-05 [2023-08-08], (原始内容存档于2023-08-10) 
      13. ^ Optax, DeepMind, 2022-07-28 [2022-07-29], (原始内容存档于2023-06-07) 
      14. ^ Lineax, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10) 
      15. ^ RLax, DeepMind, 2022-07-29 [2022-07-29], (原始内容存档于2023-04-26) 
      16. ^ Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08 [2023-08-08], (原始内容存档于2022-11-23) 
      17. ^ jaxtyping, Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10) 
      18. ^ NumPyro - Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. [2022-08-31]. (原始内容存档于2022-08-31). 
      19. ^ Brax - Massively parallel rigidbody physics simulation on accelerator hardware. [2022-08-31]. (原始内容存档于2022-08-31). 

      外部链接[编辑]