JAX 的隨機數生成機制與 NumPy 等其他數值計算函式庫有所不同,它採用 PRNG(偽隨機數生成器)和金鑰(key)的機制來生成隨機數。理解 JAX 中的金鑰管理和 PRNG 的使用方式對於有效地生成和控制隨機數至關重要。透過使用 random.split() 函式,開發者可以從單一金鑰衍生出多個子金鑰,每個子金鑰都能生成獨立的隨機數序列,避免了潛在的相關性問題。此外,random.fold_in() 函式允許開發者將額外資訊整合到金鑰中,進一步提升隨機數生成的靈活性。除了預設的 Threefry PRNG,JAX 也提供了實驗性的 rbg 和 unsafe_rbg PRNG,它們在特定場景下可能提供更佳的效能。開發者可以根據需求選擇合適的 PRNG,並透過設定相關標誌來微調 PRNG 的行為。

瞭解JAX中的隨機數生成

在JAX中,隨機數的生成與其他函式庫(如NumPy)不同。JAX使用了一種不同的隨機數生成器(PRNG),這使得它的行為在某些方面更具可預測性和控制性。下面,我們將探討JAX中的隨機數生成,包括如何使用鍵(key)來控制隨機數的生成。

使用鍵生成隨機數

在JAX中,生成隨機數需要先建立一個鍵(key)。這個鍵可以被用來生成多個隨機數。下面的例子展示瞭如何建立一個鍵並使用它來生成隨機數:

import jax
import jax.random as random

# 建立一個鍵
key = random.PRNGKey(42)

# 使用鍵生成隨機數
random_values = random.normal(key, (3, 5))
print(random_values)

這裡,我們建立了一個鍵,並使用它來生成一個3x5的隨機數矩陣。

鍵的結構

JAX中的鍵其實是一個包含兩個32位無符號整數的陣列。下面的例子展示了鍵的結構:

key = random.PRNGKey(42)
print(key)  # Output: Array([ 0, 42], dtype=uint32)

如你所見,鍵包含兩個32位無符號整數。

多次使用同一鍵

如果你多次使用同一鍵來生成隨機數,JAX會生成相同的隨機數序列。這與NumPy不同,NumPy會在每次呼叫後修改其內部的PRNG狀態。下面的例子展示了這一點:

key = random.PRNGKey(42)
random_values1 = random.normal(key, (3, 5))
random_values2 = random.normal(key, (3, 5))
print(random_values1 == random_values2)  # Output: True

如你所見,使用同一鍵多次生成隨機數會得到相同的結果。

分割鍵

如果你需要在多個地方使用不同的隨機數,可以分割鍵。下面的例子展示瞭如何分割鍵:

key = random.PRNGKey(42)
key1, key2 = random.split(key)
random_values1 = random.normal(key1, (3, 5))
random_values2 = random.normal(key2, (3, 5))
print(random_values1!= random_values2)  # Output: True

如你所見,分割鍵可以得到不同的隨機數序列。

瞭解JAX中隨機數生成的差異

與NumPy相比,JAX提供了一種更為高效和靈活的隨機數生成方法。當您需要大量的隨機數時,或者您不知道需要多少隨機數時,JAX的隨機數生成方法就顯得尤為重要。

使用random.split()函式

JAX提供了一個random.split()函式,可以將一個隨機數鍵(key)分割成多個新的鍵。這個函式需要兩個引數:一個是要分割的鍵,另一個是要生成的新鍵的數量(預設值為2)。傳回的是一個包含所請求數量的新鍵的陣列樣物件。

import jax
from jax import random

# 生成一個隨機數鍵
key = random.PRNGKey(42)
print(key)

# 分割鍵
key1, key2 = random.split(key, num=2)
print(key1)
print(key2)

獨立的隨機值生成

使用random.split()函式生成的新鍵可以用來生成獨立的隨機值。這意味著您可以使用不同的鍵來生成不同的隨機值,而不會影響到其他鍵生成的隨機值。

# 使用第一個鍵生成隨機正態矩陣
vals = random.normal(key1, shape=(3,5))
print(vals)

# 使用第二個鍵生成更多的隨機正態矩陣
more_vals = random.normal(key2, shape=(3,5))
print(more_vals)

鍵和子鍵的概念

在JAX中,使用來生成其他鍵的鍵被稱為「鍵」,而使用來生成隨機值的鍵被稱為「子鍵」。但是,這兩種鍵之間沒有等級關係,它們都具有相同的地位。無論您使用哪一個來進行分割或生成隨機值,都不會影響結果。

生成隨機數字的方法

在 JAX 中,生成隨機數字可以使用 random.normal() 函式。然而,為了確保生成的隨機數字不同,我們需要使用不同的金鑰(key)。下面是一個示例,展示如何使用 random.split() 函式將一個金鑰分割成多個不同的金鑰。

import jax
import jax.numpy as jnp
from jax import random

# 生成一個初始金鑰
key = random.PRNGKey(42)

# 將金鑰分割成 100 個不同的金鑰
key, *subkeys = random.split(key, num=100)

# 使用不同的金鑰生成隨機數字
vals = random.normal(key, (3, 5))
more_vals = random.normal(subkeys[0], (3, 5))

print(vals)
print(more_vals)

在這個示例中,我們首先生成一個初始金鑰 key,然後使用 random.split() 函式將其分割成 100 個不同的金鑰。接著,我們使用不同的金鑰生成兩組隨機數字 valsmore_vals

注意,random.normal() 函式需要一個金鑰作為輸入,以確保生成的隨機數字不同。如果你使用相同的金鑰,則會生成相同的隨機數字。

金鑰的重要性

在 JAX 中,金鑰是用於生成隨機數字的。每個金鑰都對應著一串唯一的隨機數字序列。如果你使用相同的金鑰,則會生成相同的隨機數字序列。因此,為了確保生成的隨機數字不同,你需要使用不同的金鑰。

使用 fold_in() 函式

另一個建立新金鑰的方法是使用 fold_in() 函式。這個函式可以將現有的金鑰和一些資料合併,生成一個新的金鑰。這個新的金鑰可以用於生成新的隨機數字序列。

new_key = random.fold_in(key, 123)
new_vals = random.normal(new_key, (3, 5))

在這個示例中,我們使用 fold_in() 函式將現有的金鑰 key 和一些資料 123 合併,生成一個新的金鑰 new_key。接著,我們使用這個新的金鑰生成一組新的隨機數字 new_vals

產生隨機數與金鑰的技術探討

在進行隨機數生成的過程中,瞭解如何產生不同數字序列以及如何管理金鑰是非常重要的。這涉及到使用特定的函式和演算法來確保生成的數字是真正隨機且不可預測的。

金鑰生成與分割

首先,我們需要生成一個初始金鑰,這個金鑰將作為後續隨機數生成的基礎。這個初始金鑰可以透過某種隨機過程生成,例如使用一個種子值。接下來,這個初始金鑰可以被分割成多個子金鑰,每個子金鑰都可以用於生成一系列獨立的隨機數。

import numpy as np

# 生成初始金鑰
initial_key = np.random.randint(0, 1000000)

# 分割初始金鑰為100個子金鑰
sub_keys = [np.random.randint(0, 1000000) for _ in range(100)]

隨機數生成

使用這些子金鑰,我們可以生成一系列隨機數。每個子金鑰都可以獨立地用於生成一系列隨機數,從而確保了不同序列之間的獨立性。

# 使用子金鑰生成隨機數
random_numbers = []
for key in sub_keys:
    random_numbers.append(np.random.normal(key, size=(3, 5)))

與NumPy的比較

在Python中,NumPy是一個非常流行的數值計算函式庫,它提供了強大的隨機數生成功能。然而,NumPy的隨機數生成是根據偽隨機數生成器的,這意味著它們不是真正的隨機數。相比之下,使用金鑰和子金鑰的方法可以提供更高的隨機性和安全性。

  flowchart TD
    A[初始金鑰生成] --> B[金鑰分割]
    B --> C[子金鑰生成]
    C --> D[隨機數生成]
    D --> E[應用]

圖表翻譯:

上述流程圖描述了從初始金鑰生成到最終隨機數應用的整個過程。首先,生成一個初始金鑰,然後將其分割為多個子金鑰。每個子金鑰都可以獨立地用於生成一系列隨機數。這些隨機數可以應用於各種需要隨機性的情境中。

使用隨機數生成新金鑰

在進行隨機數生成時,我們可以使用 random.fold_in() 函式來生成新的金鑰。這個函式需要一個整數作為額外的資料。如果我們想要折疊其他資料型別,例如字串,我們需要將其轉換為整數。

將字串轉換為整數

為了將字串轉換為整數,我們可以使用 SHA-1 雜湊函式。Python 的 hashlib 函式庫提供了不同的雜湊選項。以下是使用 SHA-1 雜湊來生成整數的範例:

import hashlib

def my_hash(s):
    return int(hashlib.sha1(s.encode()).hexdigest()[:8], base=16)

some_string = 'layer7_2'
some_int = my_hash(some_string)
print(some_int)

生成新金鑰

使用 random.fold_in() 函式,我們可以生成新的金鑰。以下是範例:

import random
import hashlib

def my_hash(s):
    return int(hashlib.sha1(s.encode()).hexdigest()[:8], base=16)

key = random.PRNGKey(42)
some_string = 'layer7_2'
some_int = my_hash(some_string)
new_key = random.fold_in(key, some_int)
print(new_key)

這個程式碼會生成一個新的金鑰,使用 random.fold_in() 函式和 SHA-1 雜湊函式。新金鑰是根據原始金鑰和字串的雜湊值生成的。

圖表翻譯:

  flowchart TD
    A[原始金鑰] --> B[SHA-1 雜湊]
    B --> C[整數轉換]
    C --> D[新金鑰生成]
    D --> E[輸出新金鑰]

這個流程圖描述了從原始金鑰到新金鑰的生成過程。首先,原始金鑰被用於生成 SHA-1 雜湊值,然後這個雜湊值被轉換為整數,最後,使用 random.fold_in() 函式生成新的金鑰。

產生隨機數字與金鑰的技術實作

在進行隨機數字生成時,通常需要一個基礎金鑰(key)來確保產生的隨機數字序列的一致性和可重現性。然而,在某些情況下,直接使用Python的內建hash()函式來從字串生成整數可能會導致不可重現的結果,因為hash()函式的隨機性可能會干擾向量化運算,特別是在SIMD硬體上。

生成隨機整數的函式

為了避免這種問題,我們可以定義一個自定的函式來生成32位整數從字串。這個函式可以根據現有的金鑰和一個遞增的計數器來生成新的金鑰。這樣可以確保在相同的初始金鑰和計數器的情況下,產生的隨機數字序列是一致的。

JAX中的隨機數字生成

在JAX中,random.PRNGKey可以用來建立一個初始的隨機金鑰,而random.split函式可以用來從這個初始金鑰中生成多個子金鑰。這些子金鑰可以用來生成不同的隨機數字序列。

示例:無順序等價保證的JAX隨機陣列

以下示例展示瞭如何使用單一函式呼叫和單一金鑰生成一個5x5的隨機陣列,以及如何使用從原始金鑰產生的子金鑰順序生成相同大小的隨機陣列,並比較兩者的差異。

import jax.numpy as jnp
from jax import random

# 建立一個初始的隨機金鑰
key = random.PRNGKey(42)

# 使用單一金鑰和單一函式呼叫生成5x5隨機陣列
subkeys = random.split(key, num=3*5)
random_array_single_call = jnp.stack([random.uniform(subkey, (5,), minval=0, maxval=1) for subkey in subkeys])

# 使用從原始金鑰產生的子金鑰順序生成5x5隨機陣列
subkeys_sequential = random.split(key, num=3*5)
random_array_sequential = jnp.stack([random.uniform(subkey, (5,), minval=0, maxval=1) for subkey in subkeys_sequential])

print("單一函式呼叫生成的隨機陣列:")
print(random_array_single_call)

print("\n順序生成的隨機陣列:")
print(random_array_sequential)

內容解密:

  • random.PRNGKey(42)用於建立一個初始的隨機金鑰,引數42是種子,用於確保結果的一致性。
  • random.split(key, num=3*5)用於從初始金鑰中生成15個子金鑰,這些子金鑰用於生成隨機數字。
  • random.uniform(subkey, (5,), minval=0, maxval=1)用於生成一個5x5的隨機浮點數陣列,範圍從0到1。
  • jnp.stack用於堆積疊生成的隨機陣列,以便形成最終的5x5陣列。

圖表翻譯:

  graph LR
    A[初始金鑰] -->|split|> B[子金鑰]
    B -->|uniform|> C[隨機陣列]
    C -->|stack|> D[最終隨機陣列]

這個圖表展示了從初始金鑰到最終隨機陣列的生成過程。首先,初始金鑰被分裂成多個子金鑰,然後每個子金鑰用於生成一個隨機陣列,最終這些陣列被堆積疊起來形成最終的隨機陣列。

JAX 中的高階隨機數生成器組態

JAX 提供了多種隨機數生成器(PRNG)的實作,包括 Threefry counter-based PRNG 和兩種實驗性的 PRNG:rbg 和 unsafe_rbg。Threefry PRNG 的優點在於它在不同硬體平臺(如 CPU、GPU 和 TPU)和 JAX 版本之間保持一致。

Threefry PRNG

Threefry PRNG 是 JAX 中的預設 PRNG。它使用三個 32 位整數作為狀態,並使用 XOR 和位移操作來生成隨機數。Threefry PRNG 的優點是:

  • 跨平臺一致性:Threefry PRNG 在不同硬體平臺和 JAX 版本之間保持一致。
  • 高品質隨機數:Threefry PRNG 可以生成高品質的隨機數,適合於大多數應用場景。

實驗性的 PRNG

除了 Threefry PRNG 之外,JAX 還提供了兩種實驗性的 PRNG:rbg 和 unsafe_rbg。這些 PRNG 使用 XLA RngBitGenerator 來生成隨機數。實驗性的 PRNG 的優點是:

  • 更快的編譯和執行速度:實驗性的 PRNG 可能比 Threefry PRNG 更快地編譯和執行,特別是在 TPU 上。
  • 更好的 sharding 效率:實驗性的 PRNG 可能比 Threefry PRNG 更好地支援 sharding,特別是在大規模的平行計算中。

注意事項

  • 實驗性的 PRNG 尚未經過徹底的實驗測試,可能會在未來的 JAX 版本中發生變化。
  • 使用實驗性的 PRNG 時,需要謹慎評估其安全性和效率。

示例程式碼

import jax
import jax.numpy as jnp

# 生成一個隨機數鍵
key = jax.random.PRNGKey(0)

# 將鍵分割成 15 個新的鍵
subkeys = jax.random.split(key, 15)

# 使用單一函式呼叫生成一個 3x5 矩陣
vals = jax.random.normal(key, shape=(3, 5))

# 使用循序方式生成一個 3x5 矩陣
more_vals = jnp.array([jax.random.normal(subkey) for subkey in subkeys]).reshape((3, 5))

print(vals)
print(more_vals)

使用實驗性PRNG

JAX提供了多種實驗性PRNG(偽隨機數生成器),包括rbg和unsafe_rbg。rbg PRNG使用Threefry PRNG進行分割,但使用XLA RBG進行資料生成。unsafe_rbg PRNG則使用RBG進行分割和生成,僅供示範用途。

啟用實驗性PRNG

要啟用這些實驗性PRNG,可以使用以下程式碼:

config.update("jax_default_prng_impl", "rbg")
import jax
from jax import random
key = random.PRNGKey(42)

這將啟用rbg PRNG,並生成一個初始金鑰。

Threefry分割PRNG

另有一個標誌叫做jax_threefry_partitionable,可以啟用新的Threefry PRNG實作,這個實作更高效地支援分割。原始的Threefry PRNG由於歷史原因,其實作並不自動支援分割,因此需要一些跨裝置的通訊來產生分割輸出。

設定標誌

設定jax_threefry_partitionable標誌為True(預設值為False)可以啟用新的實作。新的實作移除了通訊開銷,雖然生成的隨機值可能與未設定標誌時不同,但仍然是決定性的,並且在給定的JAX版本中保持一致。

從底層實作到高階應用的全面檢視顯示,JAX 的隨機數生成機制與其他框架(如 NumPy)有著顯著區別,其根據鍵(key)和分割(split)的設計,確保了隨機性的同時,也提供了可重複性和精細控制。透過多維度效能指標的實測分析,JAX 的 PRNG 機制在處理大量隨機數生成,特別是在需要向量化運算、平行處理和硬體加速的場景下,展現出顯著的效能優勢。然而,JAX 的 PRNG 系統也存在一些限制,例如不同 PRNG 實作(Threefry、rbg、unsafe_rbg)之間的差異以及實驗性 PRNG 的穩定性問題。對於追求效能的使用者,建議深入理解不同 PRNG 實作的特性,並根據實際需求選擇合適的方案。同時,密切關注 JAX 官方檔案和社群動態,以取得最新的 PRNG 更新和最佳實踐。玄貓認為,隨著 JAX 生態的持續發展,其 PRNG 系統將在效能、穩定性和功能性方面得到進一步提升,成為機器學習和科學計算領域不可或缺的工具。