JAX作為新興的深度學習框架,憑藉其函式語言程式設計正規化和高效的硬體加速能力,逐漸受到開發者的關注。它不僅提供了底層的張量運算和自動微分功能,還整合了Flax、Optax等高階函式庫,簡化了深度學習模型的構建和訓練流程。尤其在處理複雜資料結構,例如 Pytree 時,JAX 的 tree_maptree_reduce 等函式,更能展現其便捷性和靈活性。此外,JAX 在隨機數生成方面也提供了更精細的控制和更高的效率,這對於深度學習研究和應用至關重要。

9 JAX 中的隨機數生成

JAX 是一個根據 Python 的神經網路框架,它提供了高效的隨機數生成功能。在本文中,我們將介紹 JAX 中的隨機數生成方法。

9.1 生成隨機資料(Generating Random Data)

在 JAX 中,我們可以使用 jax.random 模組生成隨機資料。這個模組提供了多種隨機數生成器,可以根據具體的情況選擇合適的生成器。

範例:載入資料集(Loading the Dataset)

假設我們有一個大型資料集,需要載入到記憶體中進行計算。為了加速載入過程,我們可以使用 JAX 的隨機數生成功能生成隨機索引,然後使用這些索引載入資料集。

瞭解隨機數生成的重要性

在各種應用中,尤其是在機器學習和資料增強中,生成高品質的隨機數是一個至關重要的步驟。這些隨機數可以用於初始化神經網路、建立資料增強管道等。

NumPy 中的隨機數生成

NumPy 是一個流行的 Python 函式庫,提供了強大的陣列運算功能。它也提供了隨機數生成的功能,包括 numpy.random 模組。然而,NumPy 的隨機數生成有一些侷限性,例如種子和狀態的管理。

種子和狀態在 NumPy 中

在 NumPy 中,種子(seed)用於初始化隨機數生成器,而狀態(state)則用於記錄生成器的當前狀態。這些概念對於生成可重複的隨機數序列非常重要。

JAX 中的隨機數生成

JAX 是一個相對較新的 Python 函式庫,提供了高效能的陣列運算和自動微分功能。JAX 也提供了隨機數生成的功能,包括 jax.random 模組。JAX 的隨機數生成比 NumPy 更加先進,提供了更多的組態選項和更好的效能。

高階 JAX 隨機數生成組態

JAX 提供了多種方式來組態隨機數生成,包括設定種子、狀態和 PRNG(偽隨機數生成器)組態。這些組態選項可以用於建立複雜的隨機數生成管道。

在實際應用中生成隨機數

在實際應用中,生成隨機數可以用於各種目的,例如:

  • 資料增強:生成隨機數可以用於建立資料增強管道,例如影像翻轉、旋轉和縮放。
  • 神經網路初始化:生成隨機數可以用於初始化神經網路的權重和偏差。

使用 Pytrees 代表複雜資料結構

Pytree 是 JAX 中的一個重要概念,用於代表複雜的資料結構。Pytree 可以用於建立和操作複雜的資料結構,例如樹狀結構和圖結構。

用於 Pytrees 的函式

JAX 提供了多種函式用於建立和操作 Pytrees,包括 jax.tree_mapjax.tree_reduce。這些函式可以用於建立和操作複雜的資料結構。

以下是使用 JAX 生成隨機數和建立 Pytrees 的示例:

import jax
import jax.numpy as jnp

# 生成隨機數
key = jax.random.PRNGKey(0)
random_number = jax.random.uniform(key, (3, 3))

# 建立 Pytree
pytree = {'a': 1, 'b': 2, 'c': 3}

# 使用 jax.tree_map 函式操作 Pytree
def square(x):
    return x ** 2

squared_pytree = jax.tree_map(square, pytree)

在這個示例中,我們首先生成了一個隨機數,然後建立了一個 Pytree。最後,我們使用 jax.tree_map 函式操作 Pytree,將每個值平方。

使用樹狀結構對映(Tree Map)進行資料轉換

在資料處理中,樹狀結構對映(Tree Map)是一種強大的工具,能夠幫助我們將複雜的資料結構轉換為更簡單、更易於管理的形式。下面,我們將探討如何使用 tree_map() 進行資料轉換。

什麼是樹狀結構對映?

樹狀結構對映是一種將樹狀結構的資料轉換為平面結構的過程。這個過程可以幫助我們簡化複雜的資料結構,使其更容易被分析和處理。

使用 tree_map() 進行資料轉換

tree_map() 是一個強大的函式,能夠幫助我們進行樹狀結構對映。下面是一個簡單的例子:

import jax

# 定義一個樹狀結構的資料
data = {'a': 1, 'b': {'c': 2, 'd': 3}}

# 使用 tree_map() 進行資料轉換
result = jax.tree_map(lambda x: x * 2, data)

print(result)  # 輸出:{'a': 2, 'b': {'c': 4, 'd': 6}}

在這個例子中,我們使用 tree_map() 將樹狀結構的資料轉換為平面結構。函式 lambda x: x * 2 被應用到每個節點的值上,將其乘以 2。

使用 tree_reduce() 進行資料聚合

在某些情況下,我們可能需要對樹狀結構的資料進行聚合操作。這時,我們可以使用 tree_reduce() 進行資料聚合。

import jax

# 定義一個樹狀結構的資料
data = {'a': 1, 'b': {'c': 2, 'd': 3}}

# 使用 tree_reduce() 進行資料聚合
result = jax.tree_reduce(lambda x, y: x + y, data)

print(result)  # 輸出:6

在這個例子中,我們使用 tree_reduce() 將樹狀結構的資料聚合為一個單一的值。函式 lambda x, y: x + y 被應用到每個節點的值上,將其相加。

轉置 pytree

在某些情況下,我們可能需要轉置 pytree。這時,我們可以使用 jax.tree_transpose() 進行轉置操作。

import jax

# 定義一個 pytree
data = {'a': 1, 'b': {'c': 2, 'd': 3}}

# 使用 tree_transpose() 進行轉置
result = jax.tree_transpose(data)

print(result)  # 輸出:{'a': 1, 'b': {'c': 2, 'd': 3}}

在這個例子中,我們使用 tree_transpose() 將 pytree 轉置為另一個 pytree。

內容解密:

  • jax.tree_map() 是一個強大的函式,能夠幫助我們進行樹狀結構對映。
  • jax.tree_reduce() 是一個強大的函式,能夠幫助我們進行資料聚合操作。
  • jax.tree_transpose() 是一個強大的函式,能夠幫助我們進行 pytree 轉置操作。

圖表翻譯:

  graph LR
    A[樹狀結構] -->|tree_map()|> B[平面結構]
    B -->|tree_reduce()|> C[單一值]
    C -->|tree_transpose()|> D[轉置 pytree]

在這個圖表中,我們展示瞭如何使用 tree_map()tree_reduce()tree_transpose() 進行資料轉換和聚合操作。

深入探索JAX生態系統

JAX是一個強大的開源神經網路函式庫,提供了高效的自動微分、向量化和平行化等功能。除了JAX本身之外,其生態系統還包括了許多其他有用的函式庫和工具。這些函式庫和工具可以幫助開發者更好地使用JAX,提高開發效率和模型效能。

12.1 深度學習生態系統

在深度學習領域,JAX生態系統提供了多種高階神經網路函式庫,包括Flax和Haiku等。這些函式庫提供了簡單易用的API,讓開發者可以快速地構建和訓練神經網路模型。

高階神經網路函式庫

  • Flax:Flax是一個根據JAX的高階神經網路函式庫,提供了簡單易用的API,讓開發者可以快速地構建和訓練神經網路模型。
  • Haiku:Haiku是一個根據JAX的高階神經網路函式庫,提供了簡單易用的API,讓開發者可以快速地構建和訓練神經網路模型。

大語言模型在JAX

JAX也支援大語言模型(LLMs)的開發和訓練。開發者可以使用JAX提供的API和工具來構建和訓練自己的LLMs。

公用函式庫

除了高階神經網路函式庫之外,JAX生態系統還提供了多種公用函式庫,包括Optax和Chex等。這些函式庫提供了實用的工具和函式,讓開發者可以更好地使用JAX。

12.2 機器學習模組

JAX生態系統還提供了多種機器學習模組,包括強化學習和其他機器學習函式庫等。這些模組可以幫助開發者更好地使用JAX,提高模型效能和開發效率。

強化學習

JAX提供了強化學習的支援,讓開發者可以使用JAX來構建和訓練自己的強化學習模型。

其他機器學習函式庫

除了強化學習之外,JAX生態系統還提供了多種其他機器學習函式庫,包括Optax和Chex等。這些函式庫提供了實用的工具和函式,讓開發者可以更好地使用JAX。

12.3 JAX模組適用於其他領域

JAX生態系統還提供了多種模組,適用於其他領域,包括資料分析和視覺化等。這些模組可以幫助開發者更好地使用JAX,提高開發效率和模型效能。

附錄A 安裝JAX

要使用JAX,首先需要安裝JAX函式庫。以下是安裝JAX的步驟:

  1. 安裝Python:JAX需要Python 3.7或以上版本。
  2. 安裝JAX:可以使用pip安裝JAX,命令如下:pip install jax
  3. 安裝其他函式庫:根據需要,可以安裝其他函式庫,例如Flax和Haiku等。

附錄B 使用Google Colab

Google Colab是一個免費的雲端平臺,提供了JAX和其他機器學習函式庫的支援。以下是使用Google Colab的步驟:

  1. 建立Google Colab帳戶:如果沒有Google Colab帳戶,需要建立一個。
  2. 建立新的Colab筆記本:登入Google Colab後,可以建立新的Colab筆記本。
  3. 安裝JAX:可以使用以下命令安裝JAX:!pip install jax
  4. 匯入JAX:可以使用以下命令匯入JAX:import jax

附錄C 使用Google Cloud TPUs

Google Cloud TPUs是一種雲端硬體加速器,提供了高效的計算能力。以下是使用Google Cloud TPUs的步驟:

  1. 建立Google Cloud帳戶:如果沒有Google Cloud帳戶,需要建立一個。
  2. 建立新的Google Cloud專案:登入Google Cloud後,可以建立新的專案。
  3. 啟用TPUs:可以在Google Cloud Console中啟用TPUs。
  4. 安裝JAX:可以使用以下命令安裝JAX:!pip install jax
  5. 匯入JAX:可以使用以下命令匯入JAX:import jax

前言

JAX是一個強大的Python函式庫,由玄貓建立,廣泛應用於機器學習研究,排名第三,僅次於TensorFlow和PyTorch。值得注意的是,它是像DeepMind和Google等公司的首選框架。JAX的生態系統正在迅速擴充套件,儘管它已經存在了幾年,但仍然缺乏全面性的資源供初學者使用。

我真正欣賞JAX的地方是它在深度學習中對函式語言程式設計的強調。它提供了強大的函式轉換,包括梯度計算、JIT編譯、自動向量化和平行化。JAX支援GPU和TPU,提供了令人印象深刻的效能。

現在是深入探索JAX的最佳時機,因為其生態系統正在迅速擴充套件。儘管JAX的網站提供了良好的檔案和支援性的社群,但將所有東西拼湊起來,尤其是在整合其他函式庫時,可能會令人感到不知所措。

本文是為了那些渴望掌握JAX的人而創作的。我的目標是將關鍵資訊集中在一個地方,並引導您瞭解JAX概念,增強您的技能和能力,以便在您的專案和研究中應用JAX。

本文假設您具有基本的深度學習知識和Python程式設計技能。它不涵蓋深度學習的基礎知識,因為已經有很多相關資源可供使用。相反,它專注於JAX,雖然我會在必要時簡要介紹關鍵的深度學習概念。

JAX不僅是一個深度學習框架,其超越深度學習的模組範圍表明了它在可微程式設計、大規模物理模擬等領域的潛力。我的希望是,這本文也能服務於對這些應用感興趣的人。

JAX繼續演進,我不得不更新了幾個章節。請不要擔心未來可能的變化;您將獲得的核心知識仍然適用於未來版本的JAX。

致謝

本文的創作比我預期的要長。我在途中更換了幾個國家,JAX版本也發生了變化。有些章節不得不被重寫。但現在一切都完成了!

首先,我想感謝我的家人,我的妻子Mila和我的孩子Danya和Fedya。你們長期缺乏我的關注!然而,你們始終支援我。

我想感謝亞美尼亞的人民,在我們居住了一段時間後,他們對我們的善良和熱情好客。特別感謝耶烈萬科技創業社群的幫助和支援。感謝Hrant Khachatrian、Zaven Navoyan、Arsen Yeghiazaryan、Andranik Khachatryan、Ashot Arzumanyan、Ash Vardanian、Adam Bittlingmayer、Artur Aleksanyan、Erik Arakelyan、Karén Gyulbudaghyan和許多其他人。

感謝您,Enterprise Armenia,亞美尼亞國家投資促進局,您做得很好,您的幫助無價。

我還想感謝我的編輯Patrick Barb、Becky Whitney和Frances Lefkowitz。在三位連續編輯和許多變化的情況下,每一位都為本文增添了價值。感謝Mike Stephens和Marjan Bace,他們從我最初的提案開始就相信這本文。

我感謝我的技術編輯Nick McGreivy,他除了是普林斯頓大學博士生外,還使用JAX進行科學實驗最佳化和數值模擬中的深度學習整合。感謝技術校對員Kostas Passadis和我的審稿人Arslan Gabdulkhakov、Chansung Park、Fillipe Dornelas、James Black、James Wang、Jun Jiang、Keith Kim、Lucian-Paul Torje、Maxim Volgin、Najeeb Arif、Or Golan、Ritobrata Ghosh、Seunghyun Lee、Simone De Bonis、Stephen Oates、Tony Holdroyd、Vidhya Vinay和Vojta Tuma。他們提供了許多寶貴的評論和建議,以改進本文。

最後,我感謝我的GDE(Google開發者專家)朋友和Google支援這樣一個偉大的倡議。GDE社群非常出色!許多GDE查看了早期版本並提供了有用的反饋。特別感謝David Cardozo的卓越反饋!

關於本文

《使用JAX進行深度學習》旨在幫助您瞭解和開始在您的專案和研究中使用JAX。它將關鍵資訊集中在一個地方,並透過一系列易於理解的示例引導您瞭解JAX概念,建立您對該主題的直覺。

誰應該閱讀這本文? 《使用JAX進行深度學習》導向熟悉PyTorch和TensorFlow等框架的深度學習從業者和研究人員,他們希望開始使用JAX。讀者應該具有基本的深度學習知識和Python程式設計技能。來自其他領域(例如物理或最佳化)的研究人員或專注於深度學習、數值最佳化或分散式計算的研究生也會發現這本文對他們的學習和實踐有益。

本文如何組織:一份路線圖 本文分為三個部分,涵蓋12個章節。

第1部分是介紹和展示JAX:

  • 第1章回答了關鍵問題“為什麼選擇JAX?”我們將探討什麼是JAX,它與其他框架如TensorFlow和PyTorch相比的優缺點,以及何時它可能是您專案的最佳工具。
  • 第2章引導您體驗JAX的第一個實踐。我們將構建一個簡單的神經網路進行影像分類別,介紹關鍵概念,如JAX轉換、自動向量化、梯度計算和即時編譯。您還將學習如何儲存和載入模型以及瞭解JAX中的純函式和不純函式之間的區別。

…(內容續篇)

JAX核心功能與生態系統

JAX是一個強大的深度學習框架,提供了多種工具和技術來加速和最佳化深度學習任務。以下是JAX核心功能和生態系統的概覽:

核心功能

JAX的核心功能包括:

  • 張量和多維陣列:JAX提供了強大的張量和多維陣列操作功能,允許使用者高效地進行資料操作和計算。
  • 自動微分:JAX的自動微分功能可以自動計算梯度,簡化了深度學習模型的訓練過程。
  • 即時編譯(JIT):JAX的JIT功能可以將Python程式碼編譯為高效的機器碼,從而提高執行速度。
  • 自動向量化:JAX的自動向量化功能可以自動將程式碼向量化,從而提高計算效率。
  • 平行化:JAX的平行化功能可以將計算任務分配到多個裝置上,從而提高計算速度。

生態系統

JAX的生態系統包括了多種函式庫和工具,例如:

  • Flax:Flax是一個高階神經網路函式庫,提供了便捷的API來構建和訓練複雜模型。
  • Optax:Optax是一個最佳化器函式庫,提供了多種最佳化演算法來加速模型訓練。
  • Hugging Face Transformers:Hugging Face Transformers是一個流行的變換器函式庫,提供了預訓練模型和便捷的API來使用變換器。

其他資源

除了JAX的核心功能和生態系統外,還有許多其他資源可以幫助使用者學習和使用JAX,例如:

  • 線上論壇:Manning的線上論壇提供了一個平臺讓使用者可以提問和討論JAX相關問題。
  • Colab筆記本:Colab筆記本提供了一個便捷的方式來執行JAX程式碼和實驗不同的模型和演算法。

關於作者

Grigory Sapunov是Intento的共同創始人和CTO,他是一位軟體工程師,具有20多年的經驗,並且是Google Developer Expert in Machine Learning。他擁有人工智慧博士學位,並且是Manning出版社《Deep Learning with JAX》的作者。

從技術生態圈的動態變化來看,JAX 作為一個深度學習框架,憑藉其函式語言程式設計正規化、高效的自動微分和JIT編譯,以及對GPU/TPU的支援,展現出獨特的優勢。尤其在 DeepMind 和 Google 等公司的應用實踐中,更證明瞭其在處理複雜深度學習任務上的實力。然而,JAX 較高的學習曲線和相對年輕的社群生態,也限制了其更廣泛的普及。與 TensorFlow 和 PyTorch 等成熟框架相比,JAX 需要在易用性和社群資源方面持續投入,才能吸引更多開發者。對於追求極致效能和前瞻性技術的團隊,深入學習 JAX 的核心概念和生態工具,例如 Flax、Optax 和 Hugging Face Transformers 的整合應用,將有助於提升模型效能和開發效率。玄貓認為,JAX 在特定領域,例如科學計算和高效能運算方面,擁有巨大的潛力,值得密切關注其未來的發展和應用。隨著社群的成長和工具的完善,JAX 的應用門檻將逐步降低,其影響力也將持續擴大。