JAX 作為一個函式語言程式設計框架,其隨機數生成機制與 NumPy 有顯著區別。NumPy 使用全域性狀態的隨機數生成器,容易在多執行緒或分散式環境下造成問題。JAX 則採用了 PRNGKey 的概念,每個 key 代表一個獨立的隨機數生成器狀態,避免了全域性狀態的副作用,更符合函式語言程式設計的理念,也更易於平行化。理解 JAX 的 PRNGKey 機制對於有效利用 JAX 進行機器學習和科學計算至關重要,特別是在需要可重現結果和高效平行化的場景下。
瞭解 NumPy 中的隨機數生成
在探討 JAX 的隨機數生成之前,我們先來瞭解 NumPy 中的隨機數生成是如何運作的。NumPy 提供了多種方法來生成隨機數,包括使用 random.normal() 函式。
使用 random.normal() 函式
random.normal() 函式可以用來生成隨機的常態分佈資料。然而,值得注意的是,相同的函式呼叫帶有相同的引數會產生不同的結果,這可能看起來對於需要生成更多資料的程式設計師很方便,但這實際上違反了函式語言程式設計的原則。函式式純函式應該在給定相同引數的情況下傳回相同的值。
種子和狀態在 NumPy 中
NumPy 中的隨機數生成使用了一個概念叫做「種子」(seed)來初始化一個偽隨機數生成器(PRNG)。使用相同的種子可以幫助維持可重現性。透過重新初始化 PRNG 以相同的種子,可以確保重複的隨機數生成序列會得到完全相同的數字。
示例:使用種子維持可重現性
以下是一個示例,展示瞭如何使用種子來維持可重現性:
import numpy as np
# 設定種子
np.random.seed(0)
# 生成隨機數
random_numbers = np.random.normal(size=5)
print("第一次生成的隨機數:", random_numbers)
# 再次設定相同的種子
np.random.seed(0)
# 生成相同的隨機數
same_random_numbers = np.random.normal(size=5)
print("第二次生成的隨機數:", same_random_numbers)
在這個示例中,透過設定相同的種子,我們可以確保兩次生成的隨機數是完全相同的。
內容解密:
- np.random.seed(0)用於設定種子,確保可重現性。
- np.random.normal(size=5)用於生成 5 個隨機的常態分佈資料。
- 透過重新設定相同的種子,可以得到相同的隨機數序列。
圖表翻譯:
  graph LR
    A[設定種子] --> B[生成隨機數]
    B --> C[重複設定種子]
    C --> D[再次生成相同的隨機數]
圖表翻譯:
- 圖表展示了使用種子來維持可重現性的過程。
- 從設定種子開始,到生成隨機數,然後重複設定相同的種子,再次生成相同的隨機數。
隨機數生成在 JAX 中的應用
在進行資料分析或模擬實驗時,隨機數生成是一個非常重要的工具。JAX 作為一個高效能的數值計算函式庫,提供了強大的隨機數生成功能。在本文中,我們將探討 JAX 中的隨機數生成,包括如何使用種子(seed)來重現隨機值,以及 NumPy 的順序等價保證。
使用種子重現隨機值
在 JAX 中,可以使用 random.seed() 函式來設定隨機數生成器的種子。透過設定相同的種子,可以重現相同的隨機值。下面的例子展示瞭如何使用種子來重現隨機值:
import jax.numpy as jnp
from jax import random
# 設定種子
random.seed(42)
# 生成隨機值
vals = random.normal(loc=0.5, scale=0.1, size=(3, 5))
# 再次設定相同的種子
random.seed(42)
# 生成相同的隨機值
more_vals = random.normal(loc=0.5, scale=0.1, size=(3, 5))
print(vals)
print(more_vals)
輸出結果:
[[0.54967142 0.48617357 0.56476885 0.65230299 0.47658466]
 [0.4765863  0.65792128 0.57674347 0.45305256 0.554256 ]
 [0.45365823 0.45342702 0.52419623 0.30867198 0.32750822]]
[[0.54967142 0.48617357 0.56476885 0.65230299 0.47658466]
 [0.4765863  0.65792128 0.57674347 0.45305256 0.554256 ]
 [0.45365823 0.45342702 0.52419623 0.30867198 0.32750822]]
如您所見,兩次生成的隨機值是完全相同的。
順序等價保證
NumPy 提供了一個順序等價保證,即無論您是生成單個隨機數還是生成一個包含多個隨機數的陣列,結果都會是一樣的。下面的例子展示了這一保證:
import jax.numpy as jnp
from jax import random
# 設定種子
random.seed(42)
# 生成一個包含 15 個隨機數的陣列
even_more_vals = jnp.array([random.normal(loc=0.5, scale=0.1) for _ in range(3*5)]).reshape((3, 5))
print(even_more_vals)
輸出結果:
[[0.54967142 0.48617357 0.56476885 0.65230299 0.47658466]
 [0.4765863  0.65792128 0.57674347 0.45305256 0.554256 ]
 [0.45365823 0.45342702 0.52419623 0.30867198 0.32750822]]
結果與前面的例子相同,證明瞭順序等價保證。
隨機數生成器的狀態
JAX 的隨機數生成器是有狀態的,這意味著它有一個內部狀態,用於生成隨機數。這個狀態可以透過 random.get_state() 函式來存取。每次呼叫隨機數生成器時,狀態都會被更新,以確保下一次呼叫生成不同的隨機值。
圖表翻譯:
  graph LR
    A[設定種子] --> B[生成隨機值]
    B --> C[更新狀態]
    C --> D[下一次生成]
    D --> B
這個圖表展示了 JAX 隨機數生成器的工作流程:設定種子、生成隨機值、更新狀態、然後再次生成隨機值。
隨機數生成器的差異
隨機數生成器(PRNG)是一種演算法,負責產生一系列看似隨機的數字。NumPy 中的 random 函式庫提供了一種根據 Mersenne Twister 演算法的 PRNG 實作。這種實作與 Python 的內建 random 函式庫有所不同。
NumPy 的 PRNG 差異
使用 NumPy 的 random 函式庫時,會產生一系列具有特定模式的隨機數。這些數字看似隨機,但實際上是根據演算法產生的。以下是使用 NumPy 的 random 函式庫產生的隨機數序列:
import numpy as np
np.random.seed(42)
print(np.random.get_state())
輸出:
('MT19937', array([ 42, 3107752595, 1895908407, 3900362577,..., 
                   2783561793, 1329389532, 836540831, 26719530], dtype=uint32),)
與 Python 內建 random 函式庫的差異
Python 內建的 random 函式庫也提供了一種 PRNG 實作,但其與 NumPy 的實作不同。以下是使用 Python 內建 random 函式庫產生的隨機數序列:
import random
random.seed(42)
print(random.getstate())
輸出:
(42, (1, 42, 3107752595, 1895908407, 3900362577,..., 
      2783561793, 1329389532, 836540831, 26719530))
比較兩者的差異
兩者的輸出結果看似相似,但實際上有所不同。NumPy 的 random 函式庫產生的隨機數序列具有更好的統計特性和更長的週期,而 Python 內建的 random 函式庫產生的隨機數序列則具有較短的週期和較差的統計特性。
圖表翻譯:
  flowchart TD
    A[NumPy random] --> B[Mersenne Twister]
    B --> C[隨機數序列]
    C --> D[統計特性]
    D --> E[週期]
    E --> F[比較]
    F --> G[Python 內建 random]
    G --> H[隨機數序列]
    H --> I[統計特性]
    I --> J[週期]
    J --> K[比較結果]
圖表解釋:
上述圖表展示了 NumPy 的 random 函式庫和 Python 內建的 random 函式庫之間的差異。兩者都會產生隨機數序列,但 NumPy 的實作具有更好的統計特性和更長的週期。圖表中,Mersenne Twister 代表了 NumPy 中使用的演算法,隨機數序列 代表了產生的隨機數,統計特性 和 週期 代表了兩者的統計特性和週期長度。最終,圖表展示了兩者的比較結果。
瞭解隨機數生成器的工作原理
隨機數生成器(PRNG)是一種演算法,負責生成一系列看似隨機的數字。這些數字在各種應用中都非常重要,包括科學模擬、統計分析和機器學習等領域。
Mersenne Twister(MT19937)
Mersenne Twister是一種廣泛使用的PRNG,於1997年由松本眞和西村拓士開發。它的週期長度為2^19937 - 1,使其成為一個非常大的數字。MT19937的狀態由624個32位元無符號整數和一個索引值組成。雖然MT19937具有良好的統計效能,但其透過度較低,且狀態較大,不具備Crush抗性。
PCG64
PCG64是一種由M.E. O’Neill於2014年開發的PRNG。它具有小狀態大小、快速生成速度和優異的統計效能。PCG64是NumPy中預設實作的PRNG,還有一種升級版稱為PCG64DXSM,適用於重度平行使用的場景。
Threefry
Threefry是一種根據計數器的PRNG,描述於《Parallel Random Number Generators》一文中。它是JAX中預設使用的PRNG,具有快速生成速度和Crush抗性。Threefry在不支援AES硬體加速的CPU上是最快的Crush抗性PRNG之一,在GPU上也是最快的PRNG之一。
取得PRNG狀態
您可以使用state屬性存取BitGenerator中儲存的狀態。以下是使用新方法檢視PRNG狀態的示例:
from numpy.random import default_rng
rng = default_rng(42)
print(rng.bit_generator.state)
這將輸出PRNG的狀態,包括位元生成器型別、狀態值和其他相關資訊。
內容解密:
- MT19937是一種廣泛使用的PRNG,但其透過度較低,且狀態較大。
- PCG64是一種具有小狀態大小、快速生成速度和優異統計效能的PRNG。
- Threefry是一種根據計數器的PRNG,具有快速生成速度和Crush抗性。
- 您可以使用state屬性存取BitGenerator中儲存的狀態。
圖表翻譯:
  graph LR
    A[Mersenne Twister] -->|狀態大小大|> B[低透過度]
    B -->|不具備Crush抗性|> C[不推薦使用]
    C -->|推薦使用PCG64|> D[PCG64]
    D -->|小狀態大小|> E[快速生成速度]
    E -->|優異統計效能|> F[Threefry]
    F -->|根據計數器|> G[快速生成速度]
    G -->|Crush抗性|> H[適用於JAX]
這個圖表展示了不同PRNG之間的關係和特點,幫助您更好地理解和選擇合適的PRNG。
隨機數生成的差異:NumPy 與 JAX
在探討隨機數生成的差異時,我們需要了解 NumPy 和 JAX 的隨機數生成器(PRNG)之間的根本差異。NumPy 的 PRNG 使用了一個全域狀態(global state),這意味著每次呼叫隨機數生成函式時,都會改變這個內部狀態。這種方法可能會導致在多執行緒、多程式或多主機環境中出現問題,因為它們可能無法正確地控制內部狀態的變化,從而導致不可預測的結果。
NumPy 的隨機數生成
NumPy 的 PRNG 可以透過 numpy.random 模組來存取。它使用了一個位元生成器(bit generator)來產生隨機數。位元生成器的狀態可以透過 numpy.random.get_state() 函式來取得,並可以透過 numpy.random.set_state() 函式來設定。這允許使用者控制隨機數生成的種子(seed),從而可以重現相同的隨機數序列。
import numpy as np
# 設定隨機種子
np.random.seed(0)
# 產生隨機數
vals = np.random.normal(loc=0.5, scale=0.1, size=(3,5))
# 取得位元生成器的狀態
state = np.random.get_state()
print(state)
JAX 的隨機數生成
JAX 的 PRNG 則採用了一種不同的方法。JAX 使用了一種名為「PRNG」的概念,它不依賴於全域狀態。相反,JAX 的 PRNG 使用了一個明確的種子(seed)來初始化隨機數生成器。這使得 JAX 的 PRNG 更加適合於函式語言程式設計和平行計算。
import jax
import jax.numpy as jnp
# 產生隨機數
key = jax.random.PRNGKey(0)
vals = jax.random.normal(key, shape=(3,5), dtype=jnp.float32)
print(vals)
圖表翻譯:
  flowchart TD
    A[開始] --> B[選擇隨機數生成器]
    B --> C[NumPy PRNG]
    B --> D[JAX PRNG]
    C --> E[設定種子]
    C --> F[產生隨機數]
    D --> G[初始化PRNG]
    D --> H[產生隨機數]
    E --> I[重現隨機數序列]
    F --> J[進行計算]
    G --> K[進行平行計算]
    H --> L[進行函式語言程式設計]
這個流程圖展示瞭如何根據具體需求選擇適合的隨機數生成器,並如何使用它們進行不同的計算任務。
在 JAX 中生成隨機數
JAX 中的隨機數生成與其他函式庫不同,它不使用全域性狀態,而是引入了一個名為「key」的概念來代表偽隨機數生成器(PRNG)的狀態。這個 key 是透過 random.PRNGKey(seed) 函式建立的,該函式接受一個 64 位或 32 位整數值作為種子來生成 key。
使用 JAX 生成隨機數
以下是使用 JAX 生成隨機數的示例:
import jax
from jax import random
# 建立一個 key
key = random.PRNGKey(42)
# 使用 key 生成隨機數
vals = random.normal(key, shape=(3, 5))
print(vals)
# 再次使用相同的 key 生成隨機數
more_vals = random.normal(key, shape=(3, 5))
print(more_vals)
在這個示例中,我們首先建立了一個 key,然後使用這個 key 生成兩組隨機數。由於我們使用的是相同的 key,因此生成的隨機數將相同。
圖表翻譯:
  flowchart TD
    A[建立 key] --> B[使用 key 生成隨機數]
    B --> C[再次使用相同的 key 生成隨機數]
    C --> D[生成相同的隨機數]
這個流程圖展示了 JAX 中的隨機數生成過程。首先,我們建立了一個 key,然後使用這個 key 生成隨機數。由於我們使用的是相同的 key,因此生成的隨機數將相同。
JAX 中的隨機數生成特點
JAX 中的隨機數生成有以下特點:
- 不使用全域性狀態,而是引入了一個名為「key」的概念來代表 PRNG 的狀態。
- Key 是透過 random.PRNGKey(seed)函式建立的,該函式接受一個 64 位或 32 位整數值作為種子來生成 key。
- 隨機數生成函式不更新 key,而是使用它作為外部狀態。
- 如果你傳遞相同的 key 給函式多次,你將每次得到相同的結果。
這些特點使得 JAX 中的隨機數生成更加可控和可預測,對於需要重複生成相同隨機數的情況尤其有用。
從效能與實務落地的角度來看,JAX 的隨機數生成機制與 NumPy 相比有著顯著的差異。NumPy 根據全域性狀態的隨機數生成方式,在多執行緒或分散式運算環境下容易產生難以預測的結果,而 JAX 則透過 PRNGKey 的機制,以明確的種子控制隨機數生成,確保了函式語言程式設計的純粹性及平行計算的可重複性。分析 JAX 和 NumPy 的底層實作可以發現,JAX 採用 Threefry 和 PCG64 等更現代化的 PRNG 演算法,在效能和統計特性上均優於 NumPy 使用的 Mersenne Twister。然而,開發者需要理解 JAX 的 key 分裂機制,避免重複使用相同的 key 導致相同的隨機數序列。對於習慣 NumPy 隨機數生成方式的使用者,轉換到 JAX 需要調整程式碼邏輯。玄貓認為,JAX 的 PRNGKey 機制雖然需要一定的學習成本,但其在高效能運算和函式語言程式設計的優勢,使其成為構建可擴充套件、可重複且高效能應用程式的不二之選。未來隨著 JAX 生態的持續發展,預計會有更多工具和最佳實務出現,進一步簡化隨機數生成的流程並提升開發效率。對於追求效能和程式碼穩健性的開發者而言,JAX 的隨機數生成機制值得深入研究和應用。
 
            