JAX

維基百科,自由的百科全書
跳至導覽 跳至搜尋
JAX
File:Google JAX logo.svg
開發者Google, Nvidia[1]
首次發布2019年10月31日,​6年前​(2019-10-31[2]
當前版本
    Module:EditAtWikidata第29行Lua錯誤:attempt to index field 'wikibase' (a nil value)
    原始碼庫github.com/jax-ml/jax
    程式語言Python, C++
    引擎
      Module:EditAtWikidata第29行Lua錯誤:attempt to index field 'wikibase' (a nil value)
      作業系統Linux, macOS, Windows
      平台Python, NumPy
      類型機器學習
      許可協議Apache 2.0
      網站docs.jax.dev/en/latest/

      JAX,由Google開發、並由Nvidia做出部分貢獻[3][4][5]Python機器學習框架,用於變換數值函數。JAX結合了修改版的Autograd自動微分系統[6],以及來自OpenXLA專案的編譯器XLA英語Accelerated Linear Algebra[7],可加速數值線性運算。其設計目標是在介面與程式設計風格上盡可能與NumPy保持相容,使使用者能夠以熟悉的方式撰寫高效能運算程式。此外,JAX亦可與TensorFlowPyTorch等機器學習框架整合使用。[8][9]

      主要功能[編輯]

      JAX的主要功能是[3]

      grad[編輯]

      下面的代碼演示grad函數的自動微分。

      # 导入库
      from jax import grad
      import jax.numpy as jnp
      
      # 定义logistic函数
      def logistic(x):  
          return jnp.exp(x) / (jnp.exp(x) + 1)
      
      # 获得logistic函数的梯度函数
      grad_logistic = grad(logistic)
      
      # 求值logistic函数在x = 1处的梯度 
      grad_log_out = grad_logistic(1.0)   
      print(grad_log_out)
      

      最終的輸出為:

      0.19661194
      

      jit[編輯]

      下面的代碼演示jit函數的優化。

      # 导入库
      from jax import jit
      import jax.numpy as jnp
      
      # 定义cube函数
      def cube(x):
          return x * x * x
      
      # 生成数据
      x = jnp.ones((10000, 10000))
      
      # 创建cube函数的jit版本
      jit_cube = jit(cube)
      
      # 应用cube函数和jit_cube函数于相同数据来比较其速度
      cube(x)
      jit_cube(x)
      

      可見jit_cube的運行時間顯著的短於cube

      vmap[編輯]

      下面的代碼展示vmap函數的通過SIMD的向量化。

      # 导入库
      from functools import partial
      from jax import vmap
      import jax.numpy as jnp
      
      # 定义函数
      def grads(self, inputs):
          in_grad_partial = partial(self._net_grads, self._net_params)
          grad_vmap = vmap(in_grad_partial)
          rich_grads = grad_vmap(inputs)
          flat_grads = np.asarray(self._flatten_batch(rich_grads))
          assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
          return flat_grads
      

      使用JAX的庫[編輯]

      一些Python庫使用JAX作為後端,這包括:

      參見[編輯]

      引用[編輯]

      1. ^ jax/AUTHORS at main · jax-ml/jax. GitHub. [December 21, 2024]. 
      2. ^ jax-v0.1.49. 
      3. ^ 3.0 3.1 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao, JAX: Autograd and XLA, Astrophysics Source Code Library (Google), 2022-06-18 [2022-06-18], Bibcode:2021ascl.soft11002B, (原始內容存檔於2022-06-18) 
      4. ^ Frostig, Roy; Johnson, Matthew James; Leary, Chris. Compiling machine learning programs via high-level tracing (PDF). MLsys. 2018-02-02: 1–3. (原始內容存檔 (PDF)於2022-06-21). 
      5. ^ Using JAX to accelerate our research. www.deepmind.com. [2022-06-18]. (原始內容存檔於2022-06-18) (English). 
      6. ^ autograd. [2023-09-23]. (原始內容存檔於2022-07-18). 
      7. ^ XLA. [2023-09-23]. (原始內容存檔於2022-09-01). 
      8. ^ Lynley, Matthew. Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta. Business Insider. [2022-06-21]. (原始內容存檔於2022-06-21) (en-US). 
      9. ^ Why is Google's JAX so popular?. Analytics India Magazine. 2022-04-25 [2022-06-18]. (原始內容存檔於2022-06-18) (en-US). 
      10. ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29 [2022-07-29], (原始內容存檔於2022-09-03) 
      11. ^ Kidger, Patrick, Equinox, 2022-07-29 [2022-07-29], (原始內容存檔於2023-09-19) 
      12. ^ Kidger, Patrick, Diffrax, 2023-08-05 [2023-08-08], (原始內容存檔於2023-08-10) 
      13. ^ Optax, DeepMind, 2022-07-28 [2022-07-29], (原始內容存檔於2023-06-07) 
      14. ^ Lineax, Google, 2023-08-08 [2023-08-08], (原始內容存檔於2023-08-10) 
      15. ^ RLax, DeepMind, 2022-07-29 [2022-07-29], (原始內容存檔於2023-04-26) 
      16. ^ Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08 [2023-08-08], (原始內容存檔於2022-11-23) 
      17. ^ jaxtyping, Google, 2023-08-08 [2023-08-08], (原始內容存檔於2023-08-10) 
      18. ^ NumPyro - Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. [2022-08-31]. (原始內容存檔於2022-08-31). 
      19. ^ Brax - Massively parallel rigidbody physics simulation on accelerator hardware. [2022-08-31]. (原始內容存檔於2022-08-31). 

      外部連結[編輯]