JAX 作為一個高效能數值計算函式庫,其生態系統日漸豐富,包含 Flax、Equinox 和 Keras 等深度學習框架,提供開發者更便捷的模型建構與訓練工具。本文除了介紹這些函式庫之外,也提供 JAX 在 CPU、GPU 和 TPU 等不同硬體平臺的安裝教學,讓使用者能根據自身環境選擇合適的安裝方式。此外,本文也詳細說明如何利用 Google Colab 和 Cloud TPU 資源,加速 JAX 程式碼的執行,並探討平行化技術 xmap() 和命名軸程式設計的概念,幫助讀者更深入理解 JAX 的運作機制。
JAX 生態系統的其他成員
JAX 生態系統中還有許多其他重要的函式庫和工具,包括:
- Flax:一個高階神經網路函式庫,提供了強大的神經網路建模和訓練功能。
- Equinox:一個根據 JAX 的神經網路函式庫,提供了簡單易用的 API 和高效能的計算能力。
- Keras 3.0:一個高階神經網路函式庫,提供了簡單易用的 API 和強大的神經網路建模和訓練功能。
- 其他函式庫和工具:JAX 生態系統中還有許多其他函式庫和工具,包括用於強化學習、電腦視覺、聯邦學習、機率程式設計和進化計算等領域的函式庫和工具。
JAX 的應用領域
JAX 不僅僅適用於深度學習,還可以用於物理、化學、天體物理學、量子計算等領域。JAX 的高效能運算能力和靈活的 API,使其成為了一個非常有用的工具。
安裝 JAX
JAX 可以透過 pip 安裝,安裝過程會根據目標架構(CPU、GPU 或 TPU)而有所不同。
在 CPU 上安裝 JAX
在 CPU 上安裝 JAX 的最簡單方法是使用 pip 安裝 CPU 版本的 JAX。可以使用以下命令:
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
在 GPU 上安裝 JAX
在 GPU 上安裝 JAX 的方法有幾種,包括使用 CUDA 和 CuDNN 安裝、使用自行安裝的 CUDA/CuDNN 安裝、使用 Docker 容器安裝等。以下是使用 pip 安裝 CUDA 和 JAX 的命令:
pip install --upgrade pip
# CUDA 12 安裝
pip install --upgrade "jax[cuda]"
注意:JAX 支援 NVIDIA GPU 的 Maxwell 架構或更新版本(計算能力 5.2 或更高)。
Mermaid 圖表:JAX 安裝流程
flowchart TD
A[開始] --> B[選擇目標架構]
B --> C[CPU 安裝]
B --> D[GPU 安裝]
C --> E[pip 安裝 CPU 版本的 JAX]
D --> F[pip 安裝 CUDA 和 JAX]
F --> G[安裝 NVIDIA 驅動程式]
G --> H[安裝 CUDA 和 CuDNN]
H --> I[安裝 JAX]
圖表翻譯:
此圖表示 JAX 安裝流程。首先,需要選擇目標架構(CPU 或 GPU)。如果選擇 CPU,則可以直接使用 pip 安裝 CPU 版本的 JAX。如果選擇 GPU,則需要先安裝 NVIDIA 驅動程式,然後安裝 CUDA 和 CuDNN,最後安裝 JAX。
安裝 JAX 的方法
安裝 JAX 可以透過 pip 進行,以下是幾種不同的安裝方法:
1. 使用 pip 安裝 JAX
您可以使用以下命令安裝 JAX:
pip install --upgrade "jax[cuda12]"
這是最簡單的安裝方法。
2. 使用 pip 安裝 JAX,並指定 CUDA 和 cuDNN 版本
如果您已經安裝了 CUDA 和 cuDNN,您可以使用以下命令安裝 JAX:
pip install --upgrade pip
pip install --upgrade "jax[cuda12_local]"
這個方法需要您已經安裝了 CUDA 和 cuDNN,並且版本相容。
3. 使用 Docker 容器安裝 JAX
您也可以使用 Docker 容器安裝 JAX,以下是相關的命令:
pip install jax[tpu] -f \
這個方法需要您已經安裝了 Docker,並且有相關的容器。
4. 在 Google Cloud TPU 上安裝 JAX
如果您想要在 Google Cloud TPU 上安裝 JAX,您可以使用以下命令:
pip install jax[tpu] -f \
這個方法需要您已經安裝了 Google Cloud TPU,並且有相關的設定。
使用 Google Colab
Google Colab 是一個免費的雲端計算平臺,提供了 GPU 和 TPU 的計算資源。您可以使用 Colab 來執行 JAX 的程式碼。
1. 啟動 Colab
您可以透過以下網址啟動 Colab:https://colab.research.google.com/
2. 安裝 JAX
在 Colab 中,您可以使用以下命令安裝 JAX:
!pip install jax
3. 執行 JAX 程式碼
在 Colab 中,您可以執行 JAX 程式碼,例如:
import jax
print(jax.__version__)
這個程式碼會印出 JAX 的版本號碼。
使用Google Colab進行JAX開發
Google Colab是一個免費的雲端平臺,提供了GPU和TPU加速的環境,非常適合用於深度學習和機器學習的開發。以下是使用Google Colab進行JAX開發的步驟。
建立Colab Notebook
首先,需要建立一個新的Colab Notebook。可以透過Google Drive或Google Colab網站建立一個新的Notebook。
選擇Runtime
Colab提供了多種Runtime選項,包括CPU、GPU和TPU。可以在Runtime選單中選擇所需的Runtime型別。
安裝JAX
安裝JAX可以透過以下命令實作:
!pip install jax
實驗JAX
安裝完成後,可以開始實驗JAX。以下是一個簡單的例子:
import jax
import jax.numpy as jnp
# 定義一個簡單的函式
def add(x, y):
return x + y
# 編譯函式
add_jit = jax.jit(add)
# 測試函式
print(add_jit(2, 3))
使用Cloud TPU
如果需要使用Cloud TPU,可以按照以下步驟進行:
- 建立一個新的Google Cloud專案。
- 啟用Cloud TPU API。
- 建立一個新的Cloud TPU節點。
- 連線到Cloud TPU節點。
- 安裝JAX和其他所需的函式庫。
以下是建立和連線到Cloud TPU節點的命令:
$ gcloud compute tpus tpu-vm create node-jax --zone us-central1-b --accelerator-type v2-8 --version tpu-vm-base
$ gcloud compute tpus tpu-vm ssh --zone us-central1-b node-jax -- -L 8888:localhost:8888
內容解密:
!pip install jax:安裝JAX函式庫。import jax:匯入JAX函式庫。import jax.numpy as jnp:匯入JAX NumPy函式庫。def add(x, y)::定義一個簡單的函式。return x + y:傳回函式結果。add_jit = jax.jit(add):編譯函式。print(add_jit(2, 3)):測試函式。
圖表翻譯:
graph LR
A[建立Colab Notebook] --> B[選擇Runtime]
B --> C[安裝JAX]
C --> D[實驗JAX]
D --> E[使用Cloud TPU]
E --> F[建立Cloud TPU節點]
F --> G[連線到Cloud TPU節點]
G --> H[安裝JAX和其他所需的函式庫]
圖表說明:
- 建立Colab Notebook。
- 選擇Runtime。
- 安裝JAX。
- 實驗JAX。
- 使用Cloud TPU。
- 建立Cloud TPU節點。
- 連線到Cloud TPU節點。
- 安裝JAX和其他所需的函式庫。
使用 Google Cloud TPUs
要使用 Google Cloud TPUs,您需要先安裝必要的模組。您可以使用 pip 安裝 JAX 和 TPU 支援:
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html
如果您想要使用 Google Colab 或 Jupyter 筆記本,您需要進行更多設定。首先,安裝必要的模組:
pip install -U jinja2
pip install notebook
請注意,Jupyter 會在本地安裝,並不會修改 PATH 變數。您可能需要更新 PATH 變數以包含 Jupyter 安裝目錄(在安裝過程中會顯示警告訊息)。您需要將以下路徑修改為您自己的路徑:
export PATH=$PATH:/home/your_username/.local/bin
接下來,安裝 Jupyter 擴充套件:
pip install jupyter_http_over_ws
啟動 Jupyter 伺服器並允許從 Google Colab 連線:
jupyter notebook --port=8888 --NotebookApp.port_retries=0
在啟動 Jupyter 伺服器後,會顯示存取筆記本的連結。在我的案例中,連結是 http://localhost:8888/tree?token=c2b887bc0c35841b1c7。請注意,Token 值在不同的執行中會有所不同,您需要複製這個連結。
連線到 Cloud TPU 節點
現在,您可以將 Colab 連線到這個新的執行環境。為此,在 Google Colab 中,前往「重新連線 -> 連線到本地執行環境」(如圖 C.1 所示)。
圖 C.1 連線 Colab 筆記本到本地執行環境
在 Backend URL 欄位中貼上連結,並點選「連線」(如圖 C.2 所示)。如果筆記本因為逾時而未能連線,請點選「重新連線」。
圖 C.2 輸入本地執行環境的 Backend URL
資源
- 更多關於管理 TPUs 的資訊
- 快速入門:如何在 Cloud TPU VM 上執行計算
- Cloud TPU VM 架構
- 使用 TPU Pod slices
- 檢查不同區域的 TPU 可用性
- 建立深度學習 VM 例項
- 深度學習 VM 映像
- Cloud TPU 定價
內容解密:
以上步驟說明瞭如何使用 Google Cloud TPUs。在開始之前,需要安裝必要的模組,包括 JAX 和 TPU 支援。然後,需要設定 Jupyter 伺服器以允許從 Google Colab 連線。最後,需要連線到 Cloud TPU 節點以使用其計算資源。
圖表翻譯:
以下是使用 Mermaid 語法繪製的流程圖,說明瞭如何使用 Google Cloud TPUs:
flowchart TD
A[安裝必要模組] --> B[設定 Jupyter 伺服器]
B --> C[啟動 Jupyter 伺服器]
C --> D[連線到 Cloud TPU 節點]
D --> E[使用 Cloud TPU]
這個流程圖顯示了使用 Google Cloud TPUs 的步驟,從安裝必要模組到啟動 Jupyter 伺服器、連線到 Cloud TPU 節點,最終到使用 Cloud TPU。
平行化技術:xmap() 和命名軸程式設計
在這個附錄中,我們將探討兩種實驗性的平行化技術:xmap() 和 pjit()。xmap() 是一個舊的技術,於 JAX 0.4.31 版本中被刪除,但它仍然對於理解 legacy 程式碼或 JAX 平行化演進有趣。
xmap() 轉換可以更容易地平行化函式,減少程式碼,取代巢狀的 pmap() 和 vmap() 呼叫,並且不需要手動 tensor 重新塑形。它還引入了命名軸程式設計模型,幫助您寫出更不容易出錯的程式碼。
xmap() 的一個可能的替代品來自 JAX 生態系統中的 Haliax 函式庫,它提供了一種使用命名 tensor 建立神經網路的方法。另一個替代品是 JAX 核心中的 shmap(),它現在具有 JAX Enhancement Proposals (JEPs) 狀態,是 xmap() 的替代品。
有時您的函式(或神經網路)可能太大,無法適應單個 GPU/TPU,因此需要在叢集上執行計算。這在大語言模型(LLMs)訓練和推理中很常見。JAX 允許您不僅可以將資料處理分佈在不同的機器上(資料平行性),還可以將大型計算分割成在不同的機器上執行的部分(所謂的模型平行性)。
使用 xmap() 和命名軸程式設計
xmap() 是一個有趣的實驗功能,可以在幾個方面簡化您的程式。首先,它引入了命名軸程式設計模型,從 tensor 軸索引切換到軸名稱,使您的程式碼更不容易出錯且更容易理解和修改。
其次,xmap() 允許您用單個函式取代 pmap() 和 vmap(),並消除一些技術問題,如巢狀 pmap()/vmap() 呼叫和重新塑形 tensor。使用 xmap(),您不需要關心可用裝置的數量,也不需要重新塑形您的資料,只需新增一個特殊維度來對映過程然後移除它。
命名軸程式設計
命名軸程式設計的想法是給 tensor 維度賦予明確的名稱,例如批次、特徵、高度、寬度或通道維度。當大多數 tensor 運算接受維度引數時,您可以避免跟蹤維度的需要,並提供額外的安全性。
JAX 有一個稍微不同的實驗實作此模型,稱為 named-axis programming。xmap() 是一個包裝器或介面卡,它接受標準陣列(tensor)具有位置軸,將某些軸轉換為命名軸(如指定),呼叫它包裝的函式,並將命名軸轉換回位置軸(使用 out_axes 引數)。
深入剖析 JAX 生態系統及其核心功能後,我們可以看到 JAX 不僅僅是一個高效能的數值計算函式庫,更是一個快速發展的機器學習生態系統。Flax、Equinox 和 Keras 3.0 的整合,為開發者提供了構建複雜模型的靈活性和易用性。然而,JAX 的應用遠超深度學習領域,其在科學計算領域的潛力同樣值得關注。雖然 xmap() 已被棄用,但其命名軸程式設計的概念,以及 shmap() 和 Haliax 等替代方案的出現,持續推動 JAX 在平行化計算方面的發展。目前 JAX 支援多種硬體平臺,從 CPU 到 GPU 和 TPU,滿足不同規模的計算需求。但對於模型平行化訓練等進階應用,仍需深入理解其硬體資源管理和排程機制。玄貓認為,JAX 作為新一代機器學習框架,展現出強大的效能優勢和靈活性,值得深度學習研究者和工程師投入時間學習和應用。隨著社群的持續貢獻和工具鏈的完善,JAX 的應用場景將會更加廣闊,並在推動科學和工程領域的創新方面發揮越來越重要的作用。