JAX 作為一個高效能運算框架,能夠在 CPU、GPU 和 TPU 等多種裝置上執行,其 Pallas 擴充套件允許開發者針對特定硬體編寫自定義核心,進一步提升效能。非同步排程機制讓 JAX 能夠立即傳回控制權給 Python 程式,避免阻塞,而 block_until_ready() 方法則允許在需要時等待計算完成。理解這些核心概念對於充分發揮 JAX 的效能至關重要,尤其是在基準測試和效能分析時,正確使用 block_until_ready() 可以避免得到過於樂觀的結果。在實際應用中,需要根據任務需求選擇合適的裝置,並考慮非同步排程的影響,才能最大化 JAX 的運算效率。

在不同裝置上進行計算

當在不同裝置(如GPU和CPU)上進行計算時,需要確保張量(tensors)正確地提交到目標裝置。否則,可能會出現錯誤,如「Received incompatible devices for jitted computation.」。

Pallas:JAX的自定義核心語言

JAX擁有一個名為Pallas的擴充套件,允許使用者為GPU和TPU編寫自定義核心。這使得使用者可以最佳化計算效能,並針對特定的硬體進行最佳化。

非同步排程

JAX使用非同步排程,這意味著當一個操作被執行時,JAX不會等待操作完成,而是立即傳回控制權給Python程式。這使得Python程式可以在不等待計算完成的情況下繼續執行。

等待計算完成

如果需要等待計算完成,可以使用block_until_ready()方法。這個方法可以阻塞Python程式,直到計算完成為止。

Benchmarking

在進行benchmarking時,需要注意非同步排程的影響。否則,可能會得到過於樂觀的結果。為了避免這個問題,可以使用block_until_ready()方法來等待計算完成。

示例

以下示例展示瞭如何使用JAX在GPU和CPU上進行計算,並展示了非同步排程的影響。

import jax.numpy as jnp

# 建立一個大型陣列
a = jnp.array(range(1000000)).reshape((1000, 1000))

# 在GPU上進行計算
%time x = jnp.dot(a, a)
%time x = jnp.dot(a, a).block_until_ready()

# 在CPU上進行計算
%time x = jnp.dot(a, a).block_until_ready()

結果顯示,GPU計算比CPU計算快30倍。

使用 JAX 在 TPU 上執行計算

在前面的章節中,我們已經瞭解瞭如何使用 JAX 在 CPU 和 GPU 上執行計算。現在,我們將探討如何在 TPU(Tensor Processing Unit)上執行計算。TPU 是由 Google 開發的專用晶片,旨在加速機器學習和深度學習計算。

什麼是 TPU?

TPU 是一種專用晶片,設計用於加速機器學習和深度學習計算。它可以提供比傳統 CPU 和 GPU 更快的計算速度和更低的延遲。TPU 通常用於大規模的機器學習和深度學習任務,例如影像和語音識別、自然語言處理等。

如何在 TPU 上執行 JAX 計算

要在 TPU 上執行 JAX 計算,我們需要先準備一個 TPU 虛擬機器(VM)。這可以透過 Google Cloud Console 完成。具體步驟如下:

  1. 建立一個 TPU 虛擬機器:我們可以使用 gcloud compute tpus tpu-vm create 命令建立一個 TPU 虛擬機器。
  2. 連線到 TPU 虛擬機器:我們可以使用 gcloud compute tpus tpu-vm ssh 命令連線到 TPU 虛擬機器。
  3. 安裝 JAX 和其他所需的軟體:我們可以使用 pip install jax 命令安裝 JAX 和其他所需的軟體。
  4. 啟動 Jupyter 伺服器:我們可以使用 jupyter notebook 命令啟動 Jupyter 伺服器。
  5. 連線 Colab 到 TPU 虛擬機器:我們可以使用 Colab 的「連線到本地執行時」功能連線到 TPU 虛擬機器。

在 TPU 上執行 JAX 計算的範例

以下是使用 JAX 在 TPU 上執行計算的範例:

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
>>> tpu

這個範例顯示瞭如何使用 JAX 在 TPU 上執行計算。首先,我們匯入 xla_bridge 模組,然後使用 get_backend() 函式取得目前的後端。最後,我們使用 print() 函式印出後端的平臺,結果為 tpu

使用JAX與TPU進行高效運算

首先,我們需要了解JAX(Java Advanced eXtensions)是一個什麼樣的框架。JAX是一個開源的神經網路框架,能夠在多種硬體平臺上執行,包括CPU、GPU和TPU(Tensor Processing Unit)。在這裡,我們將使用JAX與TPU進行高效運算。

安裝JAX與TPU支援

要使用JAX與TPU,首先需要安裝JAX並啟用TPU支援。這可以透過以下命令完成:

import jax
jax.local_devices()

這將會顯示可用的TPU裝置列表。在這個例子中,我們可以看到有八個TPU裝置可用,每個裝置都有一個唯一的ID。

建立陣列並進行運算

接下來,我們建立一個大型陣列並進行矩陣乘法運算:

import jax.numpy as jnp
a = jnp.array(range(1000000)).reshape((1000,1000))
x = jnp.dot(a, a)

這裡,我們使用jnp.array建立一個大型陣列,並使用jnp.dot進行矩陣乘法運算。

測試運算時間

現在,我們可以測試運算時間了。首先,我們使用%time指令測試運算時間:

%time x = jnp.dot(a, a)

這將會顯示運算時間。然後,我們可以使用block_until_ready方法等待運算完成:

%time x = jnp.dot(a, a).block_until_ready()

這將會顯示實際的運算時間。

結果分析

從結果中,我們可以看到使用TPU進行運算可以大大提高效率。使用單個TPU核心進行運算,可以在幾毫秒內完成大型矩陣乘法運算。

TPU裝置資訊

在這裡,我們可以看到每個TPU裝置都有一個唯一的ID,同時還有一個coords tuple包含了TPU晶片的二進位制坐標,以及core_on_chip屬性標記了TPU晶片內的核心編號。

內容解密:
  • jax.local_devices(): 顯示可用的TPU裝置列表。
  • jnp.array(): 建立一個大型陣列。
  • jnp.dot(): 進行矩陣乘法運算。
  • %time: 測試運算時間。
  • block_until_ready(): 等待運算完成。

圖表翻譯:

  flowchart TD
    A[開始] --> B[建立陣列]
    B --> C[進行矩陣乘法運算]
    C --> D[測試運算時間]
    D --> E[等待運算完成]
    E --> F[顯示結果]

這個流程圖描述了使用JAX與TPU進行高效運算的過程。首先,建立一個大型陣列,然後進行矩陣乘法運算。接下來,測試運算時間,並等待運算完成。最後,顯示結果。

第三章:使用陣列

在使用張量的過程中,我們可以看到裝置現在是一個TPU(Tensor Processing Unit),更有趣的是,它是一個特定的TPU核心,總共有八個核心(第一個裝置,是預設裝置)。點積計算也發生在這個特定的核心上。此外,我們還可以看到process_index屬性。所有的TPUs都連線到一個單一的JAX過程,索引為0。在多過程組態中,我們可能會看到更多的多樣性。每個JAX過程都可以使用jax.process_index()函式獲得其過程索引。

警告:請記得在不使用Cloud GPU或TPU機器時停止和刪除它們,以避免產生重大費用。

現在,我們已經完成了對TPU的探索。正如你所見,開始使用TPU相對容易,只是第一次準備可能需要一些時間。

接下來,我們將探討JAX和NumPy之間的差異。

3.3 JAX和NumPy之間的差異

如果你不需要JAX提供的任何優勢,特別是在執行小型一次性計算時,你可能仍然想使用純NumPy。但是,如果你想要使用JAX提供的優勢,你可能需要從NumPy切換到JAX,並對程式碼進行一些修改。

儘管JAX的NumPy-like API嘗試盡可能地遵循原始NumPy API,但仍然存在一些重要的區別。一個明顯的區別是我們已經知道的加速器支援。張量可以居住在不同的後端(CPU、GPU、TPU),你可以精確地管理張量裝置放置。此外,非同步排程也屬於這一類別,因為它被設計用於有效地使用加速計算。

另一個我們已經提到的區別是非陣列輸入的行為,如3.2.2節所述。記住,許多JAX函式不接受列表或元組作為輸入,以防止效能惡化。其他差異包括不可變性以及與支援的資料型別和型別提升相關的特殊主題。讓我們深入探討這些主題。

3.3.1 不可變性

JAX陣列是不可變的。試著改變任何張量,你將會看到一個錯誤。為什麼會出現錯誤?讓我們改變一個張量並看看會發生什麼。

import jax.numpy as jnp
import numpy as np

a_jnp = jnp.array(range(10))
a_np = np.array(range(10))

print(a_jnp[5], a_np[5])  # (5, 5)

a_np[5] = 100
print(a_np[5])  # 100

try:
    a_jnp[5] = 100
except TypeError as e:
    print(e)

JAX陣列是不可變的,這是因為JAX被設計為遵循函式語言程式設計正規化。這就是為什麼JAX轉換如此強大的原因。函式語言程式設計的基本原則之一是程式碼必須不具有副作用,也就是說,程式碼不應該修改原始引數。唯一能夠建立修改後張量的方法是根據原始張量建立一個新的張量。

這與一些NumPy程式設計實踐相矛盾。在NumPy中,一個典型的操作是索引更新,即透過更改陣列中的值來修改張量內的值。這在NumPy中是完全可以接受的,但在JAX中會引發一個錯誤。

幸運的是,JAX錯誤訊息非常具體,並建議了一個解決方案。讓我們看看如何解決清單3.21中的錯誤。

對於所有典型的就地表達式,用於更新張量元素的值,在JAX中都有一個對應的函式式純等價物。你可以在表3.1中找到JAX函式式操作與NumPy風格就地表達式的列表。

就地元素指定在NumPy中(允許) 就地元素指定在JAX中(不允許)
a[5] = 100 a = a.at[5].set(100)

透過使用at[]方法,我們可以建立一個新的張量,其元素與原始張量相同,但某些索引處的值已經修改。這種方法保證了JAX陣列的不可變性,並使程式碼更容易推理和最佳化。

陣列操作:JAX 的索引更新功能

在 JAX 中,陣列操作與 NumPy 相似,但有一些重要的差異。JAX 的陣列是不可變的(immutable),這意味著當你更新一個陣列時,會建立一個新的陣列,而不是修改原有的陣列。以下是 JAX 中的索引更新功能與 NumPy 的對應關係:

NumPy 風格的運算 JAX 對應語法
x[idx] = y x = x.at[idx].set(y)
x[idx] += y x = x.at[idx].add(y)
x[idx] *= y x = x.at[idx].multiply(y)
x[idx] /= y x = x.at[idx].divide(y)
x[idx] **= y x = x.at[idx].power(y)
x[idx] = minimum(x[idx], y) x = x.at[idx].min(y)
x[idx] = maximum(x[idx], y) x = x.at[idx].max(y)
ufunc.at(x, idx) x = x.at[idx].apply(ufunc)
x = x[idx] x = x.at[idx].get()

所有這些 x.at 運算都會傳回修改後的陣列副本,而不是修改原始陣列。雖然這可能會導致效率略低,但由於 JAX 的即時編譯(JIT),低階別的運算如 x = x.at[idx].set(y) 將保證在原始陣列不再使用時進行就地修改,從而保持計算效率。

更新陣列元素

以下是使用 JAX 更新陣列元素的範例:

a_jnp = a_jnp.at[5].set(100)
print(a_jnp[5])  # Output: Array(100, dtype=int32)

這段程式碼更新了 a_jnp 陣列中索引為 5 的元素為 100。

超出陣列界限的索引

當索引超出陣列界限時,JAX 會採取特殊的行為以避免錯誤。對於索引更新運算,JAX 會跳過超出界限的更新;對於索引檢索運算,JAX 會將索引限制在陣列界限內,以確保傳回有效值。這種行為類別似於浮點數計算中使用特殊值(如 NaN)來處理錯誤。

建立陣列副本

當你更新一個陣列時,JAX 會建立一個新的陣列副本,而不是修改原始陣列。這確保了 JAX 的陣列操作是安全且可預測的。

圖表翻譯

  graph LR
    A[原始陣列] -->|更新|> B[陣列副本]
    B -->|傳回|> C[修改後的陣列]
    C -->|索引檢索|> D[傳回元素]
    D -->|索引超出界限|> E[跳過更新或限制索引]

這個圖表展示了 JAX 的陣列更新和索引檢索過程,以及超出界限的索引行為。

JAX 中的索引超出範圍行為

JAX 與 NumPy 相比,有著不同的索引超出範圍行為。JAX 預設假設所有索引都在合法範圍內,但也提供了實驗性的支援,允許使用者透過 mode 引數指定索引更新函式的行為。這些選項包括:

  • "promise_in_bounds"(預設):使用者保證所有索引都在合法範圍內,因此不進行額外的檢查。在實踐中,這意味著所有超出範圍的索引在 get() 中被截斷,在 set()add() 和其他修改函式中被丟棄。
  • "clip":將超出範圍的索引限制在有效範圍內。
  • "drop":忽略超出範圍的索引。
  • "fill""drop" 的別名,但對於 get(),它會傳回在 fill_value 引數中指定的值。

以下示例展示了使用不同選項的示例:

import jax.numpy as jnp

# 建立一個 JAX 陣列
a_jnp = jnp.array(range(10))
print(a_jnp)

# 預設行為(promise_in_bounds)
print(a_jnp[42])  # 會傳回陣列的最後一個元素

# 使用 'drop' 模式
print(a_jnp.at[42].get(mode='drop'))  # 會傳回 -2147483648

# 使用 'fill' 模式並指定填充值
print(a_jnp.at[42].get(mode='fill', fill_value=-1))  # 會傳回 -1

# 將值設定為 100 並使用 'clip' 模式
a_jnp = a_jnp.at[42].set(100, mode='clip')
print(a_jnp)  # 會將陣列的最後一個元素設定為 100

如您所見,JAX 中的索引超出範圍不會產生錯誤;相反,它會傳回某個值,您可以控制這種情況下的行為。

內容解密:

在上述程式碼中,我們首先匯入 jax.numpy 模組並建立一個 JAX 陣列 a_jnp。然後,我們展示了使用不同 mode 選項的示例,包括預設行為、"drop" 模式、"fill" 模式以及使用 mode='clip' 將值設定為 100。

每個選項都對索引超出範圍的行為有著不同的影響,包括截斷、丟棄或填充。透過選擇合適的 mode,使用者可以控制 JAX 中索引超出範圍的情況下的行為。

圖表翻譯:

  flowchart TD
    A[建立 JAX 陣列] --> B[索引超出範圍]
    B --> C{模式選擇}
    C -->|promise_in_bounds| D[截斷索引]
    C -->|clip| E[限制索引]
    C -->|drop| F[忽略索引]
    C -->|fill| G[填充索引]
    D --> H[傳回最後一個元素]
    E --> I[傳回限制後的索引]
    F --> J[傳回 -2147483648]
    G --> K[傳回填充值]

在這個流程圖中,我們展示了 JAX 中索引超出範圍的不同模式選擇和對應的行為。根據使用者的選擇,JAX 會截斷索引、限制索引、忽略索引或填充索引。每個模式都對索引超出範圍的情況有著不同的影響。

3.3.2 資料型別

在 JAX 中,與 NumPy 相比,有幾個關於資料型別的不同之處。這包括低精確度和高精確度浮點數格式的支援,以及型別提升語義(type promotion semantics),它們規定了當操作元的型別為特定型別(可能不同)時,操作結果的型別將是什麼。

浮點數支援

雖然 NumPy 將操作元積極地提升到雙精確度(或 float64)型別,但 JAX 則強制使用單精確度(或 float32)數字。當您直接建立一個 float64 陣列時,您可能會驚訝於 JAX 默默地將其轉換為 float32。對於許多機器學習(尤其是深度學習)工作負載,這是完全可以接受的。但對於一些高精確度科學計算,可能不是理想的選擇。

浮點數型別:float64、float32、float16、bfloat16 在科學計算和深度學習中,使用了許多不同的浮點數型別。IEEE 標準浮點數運算(IEEE 754)定義了不同精確度的幾種格式,這些格式被廣泛使用。 科學計算的預設浮點數資料型別是一個雙精確度浮點數,或 float64,因為這個浮點數的大小是 64 位。IEEE 754 雙精確度二進位制浮點數格式有一個 1 位符號位、11 位指數位和 52 位小數部分。它的範圍是 ~2.23e-308 到 ~1.80e308,具有完整的 15-17 位小數精確度。 對於某些情況,還有更高精確度的型別,例如長雙精確度或擴充套件精確度浮點數,它通常是一個 80 位浮點數,在 x86 平臺上(但是,有很多注意事項)。NumPy 支援 np.longdouble 型別以獲得擴充套件精確度,而 JAX 對這種型別沒有支援。 深度學習應用程式往往對較低精確度具有強健性,因此單精確度浮點數或 float32 已經成為這種應用程式的預設資料型別,並且是 JAX 中的預設浮點數資料型別。32 位 IEEE 754 浮點數有一個 1 位符號位、8 位指數位和 23 位小數部分。它的範圍是 ~1.18e-38 到 ~3.40e38,具有 6-9 位有效小數精確度。 對於許多深度學習情況,甚至 32 位浮點數太多了,近年來,較低精確度的訓練和推理已經變得流行起來。通常,進行較低精確度推理比訓練更容易,而且有一些混合 float16/32 精確度訓練方案存在。

低精確度浮點數格式

在較低精確度浮點數格式中,有兩種 16 位浮點數:float16 和 bfloat16。IEEE 754 半精確度浮點數或 float16 有一個 1 位符號位、5 位指數位和 10 位小數部分。它的範圍是 ~5.96e−8 到 65504,具有四位有效小數。 另一個 16 位格式最初由 Google 開發,被稱為「Brain Floating Point Format」,或簡稱 bfloat16。原始的 IEEE float16 沒有考慮深度學習應用,因此其動態範圍太窄。bfloat16 型別解決了這個問題,提供了一個與 float32 相同的動態範圍。它有一個 1 位符號位、8 位指數位和 7 位小數部分。它的範圍是 ~1.18e-38 到 ~3.40e38,具有三位有效小數。

與 NumPy 的差異

bfloat16 格式作為一個截斷的 IEEE 754 float32,可以快速地轉換為 IEEE 754 float32。在轉換為 bfloat16 格式時,指數位被保留,而 significand 欄可以被減少。 還有一些其他特殊格式,您可以在我的文章中瞭解更多關於它們的資訊。

強制使用 float64 計算

要強制使用 float64 計算,您需要在啟動時設定 jax_enable_x64 組態變數。以下程式碼示範瞭如何做到這一點。

# 這只在啟動時有效!
config.update("jax_enable_x64", True)
import jax.numpy as jnp

# 這可能不適用於 TPU 後端。嘗試使用 CPU 或 GPU。
x = jnp.array(range(10), dtype=jnp.float64)
print(x.dtype)

使用 16 位浮點數型別

在深度學習中,往往使用較低精確度的格式,最常見的是半精確度或 float16,或者是一種特殊的 bfloat16,它不被 NumPy 支援。在 JAX 中,您可以輕鬆地切換到使用這些較低精確度的 16 位型別。

xb16 = jnp.array(range(10), dtype=jnp.bfloat16)
print(xb16.dtype)
print(xb16.nbytes)

x16 = jnp.array(range(10), dtype=jnp.float16)
print(x16.dtype)

圖表翻譯:

  graph LR
    A[開始] --> B[選擇資料型別]
    B --> C[使用 float64]
    B --> D[使用 float32]
    B --> E[使用 float16/bfloat16]
    C --> F[進行高精確度計算]
    D --> G[進行單精確度計算]
    E --> H[進行低精確度計算]

內容解密:

以上程式碼示範瞭如何在 JAX 中使用不同的浮點數型別,包括 float64、float32、float16 和 bfloat16。透過設定 jax_enable_x64 組態變數,可以強制使用 float64 計算。同時,也展示瞭如何使用 16 位浮點數型別,如 float16 和 bfloat16。在進行深度學習工作負載時,選擇合適的資料型別非常重要,以確保計算的準確性和效率。

使用JAX進行陣列運算

JAX是一個強大的陣列運算函式庫,提供了許多高效的運算功能。在本文中,我們將探討JAX的陣列運算功能,包括型別提升、稀疏陣列和控制流程。

型別提升

JAX的型別提升規則與NumPy不同。當進行二元運算時,JAX會根據運算元的型別自動提升結果的型別。例如,當我們將兩個16位元浮點數相加時,結果將是一個32位元浮點數。

import jax.numpy as jnp

x = jnp.array([1, 2, 3], dtype=jnp.float16)
y = jnp.array([4, 5, 6], dtype=jnp.bfloat16)

result = x + y
print(result.dtype)  # float32

稀疏陣列

JAX提供了稀疏陣列的支援,允許我們只儲存非零元素。這對於大型稀疏矩陣尤其有用,可以節省大量的記憶體空間。

import jax.experimental.sparse as jsparse

# 建立一個稀疏矩陣
matrix = jsparse.BCOO((10, 10), [1, 2, 3], [4, 5, 6])

# 進行矩陣運算
result = matrix @ matrix.T
print(result)

控制流程

JAX提供了控制流程的功能,包括lax.switchlax.while_looplax.fori_loop等。這些功能允許我們在JAX中實作複雜的控制流程。

import jax
from jax import lax

# 定義一個隨機增強函式
def random_augmentation(image, augmentations, rng_key):
    augmentation_index = jax.random.randint(
        key=rng_key, minval=0, maxval=len(augmentations))
    return lax.switch(augmentation_index, augmentations, image)

# 定義一些影像增強函式
augmentations = [
    lambda x: x + 1,
    lambda x: x * 2,
    lambda x: x - 1
]

# 進行影像增強
image = jnp.array([1, 2, 3])
rng_key = jax.random.PRNGKey(0)
result = random_augmentation(image, augmentations, rng_key)
print(result)

圖表翻譯:

  graph LR
    A[影像] -->|增強|> B[隨機增強]
    B -->|選擇增強函式|> C[增強函式]
    C -->|執行增強|> D[增強後影像]

在這個例子中,我們定義了一個隨機增強函式random_augmentation,它使用lax.switch選擇一個隨機的增強函式,並將其應用於輸入影像。然後,我們定義了一些影像增強函式,並使用random_augmentation進行影像增強。最終,我們得到了一個增強後的影像。

3.4.2 資料型別提升

jax.lax API 比 jax.numpy 嚴格。它不會隱式地提升具有混合資料型別的操作引數。使用 jax.lax 時,您必須手動進行資料型別提升。

資料型別提升範例

import jax.numpy as jnp

# jax.numpy 隱式提升資料型別
result = jnp.add(42, 42.0)
print(result)  # Array(84., dtype=float32, weak_type=True)

from jax import lax

try:
    # jax.lax 需要手動提升資料型別
    lax.add(42, 42.0)
except TypeError as e:
    print(e)  # lax.add requires arguments to have the same dtype

在這個範例中,jax.numpy 隱式地將整數 42 提升為浮點數 42.0,而 jax.lax 則需要手動進行資料型別提升。

手動資料型別提升

要使用 jax.lax 進行資料型別提升,您可以使用 jax.numpyastype 方法將引數轉換為相同的資料型別。

import jax.numpy as jnp
from jax import lax

# 手動提升資料型別
result = lax.add(jnp.array(42, dtype=jnp.float32), 42.0)
print(result)  # Array(84., dtype=float32, weak_type=True)

在這個範例中,我們使用 jax.numpyastype 方法將整數 42 轉換為浮點數 42.0,然後再使用 jax.laxadd 函式進行加法運算。

3.5 JAX 中的資料型別和運算

JAX是一個強大的深度學習和科學計算框架,它提供了多種資料型別和運算方式。在本文中,我們將介紹JAX中的資料型別和運算,包括JAX的資料型別、jax.numpy API、jax.lax API等。

3.5.1 JAX 中的資料型別

JAX 中的基本資料型別是 jax.Array,它是一種多維陣列,可以用於表示張量。JAX 也提供了一種類別似 NumPy 的 API,稱為 jax.numpy,它提供了許多與 NumPy 相同的函式和操作。

3.5.2 jax.numpy API

jax.numpy API 是 JAX 中的一種高階 API,它提供了許多與 NumPy 相同的函式和操作。jax.numpy API 試圖盡可能地遵循原始的 NumPy API,但也有一些差異。

例如,jax.numpy API 中的 add 函式可以用於將兩個陣列相加:

import jax.numpy as jnp

a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

result = jnp.add(a, b)
print(result)  # [5, 7, 9]

3.5.3 jax.lax API

jax.lax API 是 JAX 中的一種低階 API,它提供了更多的控制權和靈活性。jax.lax API 中的函式可以用於進行更複雜的運算,例如控制流程和非同步運算。

例如,jax.lax API 中的 switch 函式可以用於根據條件選擇不同的運算:

import jax.lax as lax

x = 1
y = 2

result = lax.switch(x > y, lambda: x + y, lambda: x - y)
print(result)  # 3

3.5.4 資料型別轉換

JAX 中的資料型別轉換可以使用 jax.numpy API 中的函式進行。例如,可以使用 jnp.float32() 函式將一個整數轉換為浮點數:

import jax.numpy as jnp

x = 1

result = jnp.float32(x)
print(result)  # 1.0
內容解密:
  • JAX 中的基本資料型別是 jax.Array
  • jax.numpy API 提供了許多與 NumPy 相同的函式和操作。
  • jax.lax API 提供了更多的控制權和靈活性。
  • 資料型別轉換可以使用 jax.numpy API 中的函式進行。

圖表翻譯:

  graph LR
    A[JAX] -->|使用|> B[jax.numpy API]
    B -->|提供|> C[NumPy 函式]
    C -->|實作|> D[陣列運算]
    D -->|結果|> E[浮點數]
    E -->|轉換|> F[整數]
    F -->|結果|> G[最終結果]

在這個圖表中,我們展示了 JAX 中的資料型別和運算過程。JAX 使用 jax.numpy API 提供 NumPy 函式,實作陣列運算,得到浮點數結果。然後,浮點數可以轉換為整數,得到最終結果。

計算梯度

在深度學習中,計算梯度是訓練神經網路的關鍵步驟。梯度代表了函式輸出的變化率,以便我們能夠更新神經網路的權重,從而最小化損失函式。在本章中,我們將探討如何使用自動微分(autodiff)來計算梯度,並深入瞭解其在 JAX 框架中的實作。

首先,我們需要了解為什麼計算梯度如此重要。神經網路是一種複雜的數學函式,其輸出取決於輸入和權重。給定輸入和權重,我們可以計算神經網路的輸出,並根據損失函式評估其與理想輸出的差異。為了更新神經網路的權重,我們需要計算損失函式對於權重的導數,也就是梯度。梯度下降法是一種常用的最佳化演算法,透過計算梯度並沿著其反方向更新權重,以最小化損失函式。

自動微分

自動微分(autodiff)是一種計算導數的方法,透過將函式表示為一系列的基本運算,並自動計算每個運算的導數。JAX 框架提供了一種高效的自動微分實作,允許我們輕鬆地計算梯度。

在 JAX 中,我們可以使用 grad 函式來計算梯度。例如,給定一個函式 f(x) = x^4 + 12x + 1/x,我們可以使用以下程式碼計算其梯度:

import jax.numpy as jnp
from jax import grad

def f(x):
    return x**4 + 12*x + 1/x

x = 2.0
gradient = grad(f)(x)
print(gradient)

這將輸出函式 f 在點 x=2.0 處的梯度值。

手動微分

除了自動微分外,我們還可以使用手動微分來計算梯度。手動微分需要我們手動計算函式的導數,這可能是一個繁瑣且容易出錯的過程。

例如,給定函式 f(x) = x^4 + 12x + 1/x,我們可以手動計算其導數:

def f_derivative(x):
    return 4*x**3 + 12 - 1/x**2

然後,我們可以使用這個導數函式來計算梯度:

x = 2.0
gradient = f_derivative(x)
print(gradient)

這將輸出函式 f 在點 x=2.0 處的梯度值。

梯度下降法

梯度下降法是一種常用的最佳化演算法,透過計算梯度並沿著其反方向更新權重,以最小化損失函式。給定一個損失函式 L(w),我們可以使用以下公式更新權重:

w = w - alpha * gradient

其中 alpha 是學習率,gradient 是損失函式對於權重的導數。

在 JAX 中,我們可以使用 optax 函式來實作梯度下降法。例如:

import optax

def loss_function(w):
    # 定義損失函式
    return w**2

w = 2.0
alpha = 0.1

# 計算梯度
gradient = grad(loss_function)(w)

# 更新權重
w = w - alpha * gradient

這將更新權重 w 以最小化損失函式。

4.1 梯度計算

梯度計算是機器學習和深度學習中的一個基本概念,理解如何計算梯度對於最佳化模型至關重要。梯度計算的目的是找到函式的最小值或最大值。

4.1.1 手動微分

手動微分是一種傳統的計算梯度的方法,需要人工計算函式的導數。例如,給定一個函式 f(x) = x^4 + 12x - 1/x,我們可以使用微分規則計算其導數:

f'(x) = 4x^3 + 12 - 1/x^2

這個導數可以用 Python 程式碼實作如下:

def df(x):
    return 4*x**3 + 12 - 1/x**2

x = 11.0
print(df(x))

手動微分的優點是可以得到一個閉式表示式,但其缺點是需要人工計算導數,可能會出現計算錯誤。

4.1.2 自動微分

自動微分(Autodiff)是一種自動計算梯度的方法,不需要人工計算導數。Autodiff 可以透過反向傳播演算法或前向傳播演算法實作。反向傳播演算法是目前最常用的自動微分方法,其基本思想是從輸出端開始,反向計算梯度。

自動微分的優點是可以自動計算梯度,減少人工計算導數的錯誤,但其缺點是需要額外的計算資源。

4.1.3 比較不同方法

不同方法的比較如下:

方法 優點 缺點
手動微分 可得到閉式表示式 需要人工計算導數,可能會出現計算錯誤
自動微分 可自動計算梯度,減少人工計算導數的錯誤 需要額外的計算資源
圖表翻譯:
  graph LR
    A[函式] --> B[手動微分]
    B --> C[自動微分]
    C --> D[梯度計算]
    D --> E[模型最佳化]

這個圖表展示了從函式到模型最佳化的過程,手動微分和自動微分是兩種不同的計算梯度的方法。

從效能最佳化視角來看,JAX 在 TPU、GPU 和 CPU 等不同裝置上的計算效能差異顯著。雖然 JAX 的非同步排程機制能提升 Python 程式碼的執行效率,但在效能評測時務必使用 block_until_ready() 方法等待計算完成,才能獲得真實的效能資料。非同步排程雖然能掩蓋裝置間的效能差距,但透過精確測量,TPU 在處理大型矩陣運算上展現出相較於 CPU 和 GPU 的顯著優勢,尤其在深度學習的訓練和推論任務中,TPU 的高效能更為關鍵。考量到成本效益,技術團隊應根據實際需求選擇合適的硬體組態,並針對不同裝置調整程式碼以最佳化效能。玄貓認為,JAX 結合 TPU 的解決方案,為追求極致效能的大規模機器學習應用提供了可行的途徑,未來隨著硬體和軟體的持續發展,預期 JAX 在分散式運算和高效能運算領域將扮演更重要的角色。