JAX 作為一個高效能數值計算 Python 函式庫,其核心優勢在於 XLA 編譯器和 JIT 機制。XLA 能將多個運算融合成單一核函式,減少記憶體存取和運算時間,大幅提升執行效率。JAX 則利用 XLA 將 Python 程式碼轉換為 Jaxpr 中間表示,再轉換為 MHLO 和 HLO,最終生成高度最佳化的機器碼。OpenXLA 作為開源 ML 編譯器生態系統,提供標準化模型表示和硬體特定最佳化,進一步增強了 JAX 的跨平臺能力。然而,JIT 編譯也存在一些限制,例如對副作用和全域狀態的處理、浮點數運算精確度、控制流和編譯時間等。為瞭解決這些問題,JAX 提供了 jax.lax.scan 等工具,可以有效處理迴圈和遞迴,避免產生過大的中間表示,並縮短編譯時間。

XLA 的優點

XLA 可以將多個運算合併為一個單一的核函式,從而減少記憶體存取和運算時間。例如,以下程式碼:

def f(x, y, z):
    return jnp.sum(x + y * z)

在沒有 XLA 的情況下,可能需要三個不同的運算:乘法、加法和總和。XLA 可以將這些運算合併為一個單一的核函式,從而提高執行效率。

XLA 的工作原理

XLA 的工作原理如下:

  1. 目標無關最佳化:XLA 對輸入的高階運算進行最佳化,包括共同子表示式消除、目標無關操作融合和緩衝區分析。
  2. 目標相關最佳化:XLA 的後端可以對高階運算進行進一步的最佳化,包括決定如何更好地分割計算到 GPU 流中和執行其他融合。
  3. 目標特定程式碼生成:XLA 的後端生成低階中間表示、最佳化和程式碼生成。

OpenXLA

OpenXLA 是一個開源的 ML 編譯器生態系統, 由玄貓和其他業界長官者共同開發。OpenXLA 提供了一個模組化的工具鏈,支援標準化的模型表示,提供了強大的目標無關和硬體特定最佳化。

XLA HLO

XLA HLO(High-Level Operations)是一種特殊的輸入語言,用於描述計算。XLA 將 HLO 轉換為目標硬體架構的特殊機器指令。

編譯過程

編譯過程如下:

  1. 目標無關最佳化:XLA 對輸入的 HLO 進行最佳化。
  2. 目標相關最佳化:XLA 的後端對 HLO 進行進一步的最佳化。
  3. 目標特定程式碼生成:XLA 的後端生成低階中間表示、最佳化和程式碼生成。

JAX 和 XLA

JAX 使用 XLA 生成高效的程式碼 для特定的後端。JAX 將 Python 函式轉換為 Jaxpr,然後轉換為 MHLO(Machine Learning Intermediate Representation),最後轉換為 HLO。XLA 將 HLO 轉換為最佳化的 HLO。

內容解密:

  • XLA 是一種編譯器,專門用於加速線性代數運算。
  • XLA 可以將多個運算合併為一個單一的核函式,從而提高執行效率。
  • OpenXLA 是一個開源的 ML 編譯器生態系統。
  • XLA HLO 是一種特殊的輸入語言,用於描述計算。
  • 編譯過程包括目標無關最佳化、目標相關最佳化和目標特定程式碼生成。

圖表翻譯:

  graph LR
    A[Python 函式] --> B[Jaxpr]
    B --> C[MHLO]
    C --> D[HLO]
    D --> E[最佳化的 HLO]
    E --> F[機器碼]

圖表描述了 JAX 和 XLA 的編譯過程。

編譯技術之旅:從 Python 到 HLO

在深度學習應用中,編譯器技術扮演著至關重要的角色。它們能夠將高階語言(如 Python)轉換為低階別、可執行的程式碼,從而提高執行效率。在本文中,我們將探討兩種重要的編譯器技術:XLA 和 MLIR。

XLA:跨平臺編譯器

XLA(Accelerated Linear Algebra)是一種高效能的線性代數編譯器,能夠將高階語言轉換為低階別、可執行的程式碼。XLA 的一個重要特點是它可以跨平臺執行,支援 CPU、GPU 和 TPU 等多種硬體平臺。XLA 使用 LLVM 作為其後端,從而可以生成高效的機器碼。

近期,JAX 框架已經從 MHLO(Machine Learning HLO)轉換為 StableHLO(Stable High-Level Optimizer)。StableHLO 是根據 MHLO 的,並且增加了序列化和版本控制等功能。這使得 ML 框架和編譯器之間的相容性得到了提高。

MLIR:多級別中間表示

MLIR(Machine Learning Intermediate Representation)是一種多級別中間表示,旨在為異構硬體提供可重用和可擴充套件的編譯器基礎結構。MLIR 支援多種不同需求的統一基礎結構,包括:

  • 資料流圖的表示(如 TensorFlow)
  • 最佳化和轉換(如迴圈最佳化)
  • 程式碼生成和下降轉換(如 DMA 插入、快取管理)
  • 量化和圖轉換

MLIR 的目標是提供一個強大的表示形式,但它也有一些非目標。例如,它不嘗試支援低階別機器碼生成演算法,而是專注於高階別的最佳化和轉換。

編譯 Python 程式碼

在以下範例中,我們將展示如何將 Python 程式碼編譯為 StableHLO 和 HLO。首先,我們定義了一個簡單的 Python 函式,包含乘法、加法和求和操作。然後,我們使用 JAX 框架將這個函式轉換為 StableHLO 程式碼。接下來,我們將這個 StableHLO 程式碼編譯為 HLO,從而可以在目標後端上執行。

import jax.numpy as jnp

def f(x, y, z):
    return jnp.sum(x + y * z)

x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3, 3)) * 2.0
z = jnp.array([2.0, 1.0, 0.0]).T

JIT 編譯器的內部運作

在瞭解 JIT 編譯器的基本概念後,我們來深入探討其內部運作機制。首先,我們需要了解 JIT 編譯器如何將 Python 程式碼轉換為機器碼。

使用 JAX 的 JIT 編譯

JAX 是一個流行的 Python 函式庫,提供了 JIT 編譯功能。以下是使用 JAX 的 JIT 編譯的範例:

import jax

def f(x, y, z):
    return x + y * z

f_jitted = jax.jit(f)

在這個範例中,我們定義了一個 Python 函式 f,然後使用 jax.jit 將其編譯為 JIT 版本 f_jitted

編譯過程

當我們呼叫 f_jitted 函式時,JAX 會啟動編譯過程。以下是編譯過程的步驟:

  1. Lowering:JAX 會將 Python 程式碼轉換為 StableHLO(High-Level Optimizer)程式碼。
  2. Compiling:StableHLO 程式碼會被編譯為 HLO(High-Level Optimizer)程式碼。
  3. Optimization:HLO 程式碼會被最佳化以提高效能。
  4. Code Generation:最佳化後的 HLO 程式碼會被轉換為機器碼。

使用 AOT 編譯

AOT(Ahead-Of-Time)編譯是一種編譯技術,允許我們在執行時間之前編譯程式碼。JAX 提供了 AOT 編譯 API,允許我們控制編譯過程。

以下是使用 JAX 的 AOT 編譯的範例:

import jax

def f(x):
    return x + 1

f_compiled = jax.jit(f).lower(x).compile()

在這個範例中,我們定義了一個 Python 函式 f,然後使用 jax.jit 將其編譯為 JIT 版本。接著,我們使用 lower 方法將 JIT 版本轉換為 StableHLO 程式碼,然後使用 compile 方法將其編譯為機器碼。

內容解密:
  • jax.jit:用於編譯 Python 函式為 JIT 版本。
  • lower:用於將 JIT 版本轉換為 StableHLO 程式碼。
  • compile:用於將 StableHLO 程式碼編譯為機器碼。
  • AOT 編譯:是一種編譯技術,允許我們在執行時間之前編譯程式碼。

圖表翻譯:

以下是 JIT 編譯過程的 Mermaid 圖表:

  graph LR
    A[Python 程式碼] -->|JIT 編譯|> B[StableHLO 程式碼]
    B -->|Compiling|> C[HLO 程式碼]
    C -->|Optimization|> D[最佳化後的 HLO 程式碼]
    D -->|Code Generation|> E[機器碼]

這個圖表展示了 JIT 編譯過程的各個步驟,包括 Lowering、Compiling、Optimization 和 Code Generation。

編譯過程中的HLO代表

在深度學習框架中,編譯過程是一個至關重要的步驟,能夠將模型轉換為高效的機器碼。其中,HLO(High-Level Optimizer)是一種高階最佳化器,負責將模型編譯為特定後端(backend)的機器碼。

HLO編譯流程

當我們呼叫jit()函式對模型進行編譯時,框架會將模型轉換為StableHLO表示法。這種表示法是一種中間表示法,能夠被不同後端所理解。然後,框架會根據特定的後端生成HLO程式碼。

內部表示法

在編譯過程中,模型會經過多個階段的轉換。首先,模型會被轉換為內部表示法,這是一種框架特有的表示法。然後,內部表示法會被轉換為Lower表示法,這是一種更低階的表示法。

Lower表示法

Lower表示法是一種更接近機器碼的表示法。它包含了模型的計算圖和最佳化器的組態。在這個階段,框架會對模型進行最佳化,例如合併運算和刪除無用計算。

Stage out

Stage out是編譯過程中的最後一個階段。在這個階段,框架會將最佳化後的模型輸出為HLO程式碼。HLO程式碼可以被不同後端所執行,例如GPU或CPU。

HLO代表的優點

HLO代表有一些優點,包括:

  • 高效: HLO代表可以被最佳化為高效的機器碼,能夠提高模型的執行速度。
  • 跨平臺: HLO代表可以被不同後端所理解,能夠在不同的硬體平臺上執行。
  • 靈活: HLO代表可以被用於不同的深度學習框架和應用場景。

內容解密:

上述內容介紹了HLO代表的基本概念和編譯流程。下面是一個簡單的例子,示範如何使用HLO代表編譯一個模型:

import tensorflow as tf

# 定義一個簡單的模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 編譯模型為HLO程式碼
hlo_code = tf.compiler.jit(model)

# 執行HLO程式碼
output = hlo_code(input_data)

在這個例子中,我們定義了一個簡單的神經網路模型,然後使用tf.compiler.jit()函式編譯模型為HLO程式碼。最後,我們執行HLO程式碼並取得輸出結果。

圖表翻譯:

下面是一個簡單的Mermaid圖表,示範了HLO代表的編譯流程:

  graph LR
    A[模型定義] -->|編譯|> B[內部表示法]
    B -->|最佳化|> C[Lower表示法]
    C -->|Stage out|> D[HLO程式碼]
    D -->|執行|> E[輸出結果]

這個圖表示範了HLO代表的編譯流程,從模型定義到輸出結果。

編譯過程與即時編譯(JIT)和提前編譯(AOT)的比較

在深度學習框架中,編譯過程是一個至關重要的步驟,可以將模型轉換為高效的機器碼,以便在各種硬體平臺上執行。即時編譯(JIT)和提前編譯(AOT)是兩種不同的編譯策略,各有其優缺點。

即時編譯(JIT)

即時編譯是一種動態編譯技術,指的是在程式執行期間,將程式碼即時編譯為機器碼。這種方法可以提供更好的效能和更低的延遲,因為編譯過程是在執行時進行的。然而,JIT編譯也有一些缺點,例如需要額外的記憶體和計算資源,以及可能導致程式碼膨脹。

提前編譯(AOT)

提前編譯是一種靜態編譯技術,指的是在程式編譯之前,將程式碼提前編譯為機器碼。這種方法可以提供更好的效能和更低的延遲,因為編譯過程是在編譯時進行的。然而,AOT編譯也有一些缺點,例如需要額外的編譯時間和可能導致程式碼膨脹。

編譯過程

編譯過程包括以下幾個階段:

  1. 原始碼: 程式碼被寫入並儲存為原始碼檔案。
  2. 分析: 編譯器分析原始碼,並將其轉換為中間表示法(IR)。
  3. 最佳化: 編譯器對IR進行最佳化,以提高程式碼的效能和效率。
  4. 程式碼生成: 編譯器根據IR生成機器碼。
  5. 連結: 編譯器將生成的機器碼連結起來,形成可執行檔案。

JIT和AOT的比較

JIT AOT
編譯時間 執行時 編譯時
效能 更好 更好
延遲 更低 更低
記憶體使用 更高 更低
計算資源 更高 更低
程式碼膨脹 可能 可能
內容解密:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
    print('Function run')
    return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
selu_aot = jax.jit(selu).lower(1.0).compile()

print(selu_jit(17.8))
print(selu_aot(17.8))

圖表翻譯:

  graph LR
    A[原始碼] --> B[分析]
    B --> C[最佳化]
    C --> D[程式碼生成]
    D --> E[連結]
    E --> F[可執行檔案]

在這個例子中,我們使用JAX框架實作了一個簡單的SELU啟用函式,並使用JIT和AOT編譯策略進行比較。結果表明,JIT和AOT編譯策略都可以提供更好的效能和更低的延遲,但JIT編譯需要額外的記憶體和計算資源,而AOT編譯需要額外的編譯時間和可能導致程式碼膨脹。

編譯程式碼的差異

在編譯程式碼的過程中,存在著兩種不同的編譯方式:即時編譯(JIT)和提前編譯(AOT)。這兩種編譯方式在行為和應用上有著明顯的差異。

即時編譯(JIT)

即時編譯是一種在程式碼執行時進行編譯的過程。這意味著當程式碼被呼叫時,JIT編譯器會將其編譯成機器碼。這種編譯方式可以提供更好的效能,因為它可以根據具體的輸入資料進行最佳化。

提前編譯(AOT)

提前編譯則是在程式碼被編譯成可執行檔之前進行編譯。這意味著程式碼在執行前就已經被編譯成機器碼。AOT編譯可以提供更快的啟動時間和更低的記憶體使用量,但它可能無法提供與JIT編譯相同級別的最佳化。

行為差異

在上述程式碼中,我們可以看到JIT和AOT編譯之間的行為差異。當我們呼叫selu_jit函式時,它會正常執行並傳回結果。然而,當我們呼叫selu_aot函式時,它會丟擲一個型別錯誤,因為AOT編譯需要明確的型別資訊。

import jax
import jax.numpy as jnp

# 定義selu_jit函式
def selu_jit(x):
    return jax.numpy.where(x > 0, x, jax.numpy.exp(x) - 1)

# 定義selu_aot函式
@jax.jit
def selu_aot(x):
    return jax.numpy.where(x > 0, x, jax.numpy.exp(x) - 1)

# 呼叫selu_jit函式
print(selu_jit(17))

# 呼叫selu_aot函式
print(selu_aot(17))

批次處理

我們也可以使用jax.vmap函式來對這些函式進行批次處理。然而,當我們對AOT編譯的函式進行批次處理時,它會丟擲一個錯誤,因為AOT編譯不支援批次處理。

# 定義selu_jit_batched函式
selu_jit_batched = jax.vmap(selu_jit)

# 定義selu_aot_batched函式
selu_aot_batched = jax.vmap(selu_aot)

# 呼叫selu_jit_batched函式
print(selu_jit_batched(jnp.array([42.0, 78.0, -12.3])))

# 呼叫selu_aot_batched函式
print(selu_aot_batched(jnp.array([42.0, 78.0, -12.3])))

圖表翻譯:

  graph LR
    A[JIT編譯] -->|執行時編譯|> B[機器碼]
    C[AOT編譯] -->|提前編譯|> D[機器碼]
    E[批次處理] -->|JIT編譯|> F[批次機器碼]
    E -->|AOT編譯|> G[錯誤]

內容解密:

在上述程式碼中,我們可以看到JIT和AOT編譯之間的差異。JIT編譯是在執行時進行編譯,而AOT編譯則是在提前進行編譯。這兩種編譯方式都有其優缺點,JIT編譯可以提供更好的效能,但AOT編譯可以提供更快的啟動時間和更低的記憶體使用量。批次處理也可以使用jax.vmap函式來實作,但AOT編譯不支援批次處理。

JAX 和 jaxlib

JAX 是一個 Python 套件,分為兩個獨立的套件:jaxjaxlibjax 是一個純 Python 套件,而 jaxlib 是一個主要由 C++ 實作的套件,包含了 XLA、LLVM、MLIR 基礎設施等。

JAX 分佈結構

JAX 的分佈結構是這樣設計的,因為大多數對 JAX 的修改只涉及 Python 程式碼,因此可以獨立於 C++ 程式碼進行開發和更新。這樣可以提高開發速度。

JIT 編譯

JAX 支援即時編譯(JIT)和提前編譯(AOT)。JIT 編譯可以對函式進行最佳化和轉換,而 AOT 編譯則需要事先編譯好函式,不能在執行時進行修改。

JIT 限制

JIT 編譯有一些限制,例如:

  • 只能夠正確地處理純函式,如果函式含有副作用或全域狀態,JIT 編譯可能會改變函式的行為。
  • 可能會改變函式輸出的精確度,例如由於浮點數運算的最佳化和重排。
  • 對於控制流和條件陳述式,有一些限制,例如迴圈不能被追蹤,但可以使用 jax.lax API 來替代。
  • 編譯時間可能很慢,尤其是當程式碼生成的大型 internal representation 時。

解決 JIT 限制

可以透過以下方法來解決 JIT 限制:

  • 移除迴圈或使用向量化計算。
  • 使用 jax.lax API 來替代迴圈和控制流。
  • 改變演算法以避免 JIT 限制。
  • 避免將迴圈包裹在 jit 函式中,但可以在迴圈內使用 jit 函式。

示例

以下是一個計算累積和的示例:

import jax
import jax.numpy as jnp

def cumulative_sum(x):
    acc = 0.0
    y = []
    for i in range(x.shape[0]):
        acc += x[i]
        y.append(acc)
    return jnp.array(y)

# 使用 jit 函式進行最佳化
cumulative_sum_jit = jax.jit(cumulative_sum)

# 測試
x = jnp.array([1, 2, 3, 4, 5])
print(cumulative_sum_jit(x))

這個示例展示瞭如何使用 jax.jit 函式對累積和函式進行最佳化。

瞭解JAX的JIT編譯限制

在使用JAX(Java Advanced eXtensions)進行編譯時,瞭解其JIT(Just-In-Time)編譯限制是非常重要的。JAX是一個強大的工具,允許使用者使用Python編寫高效的數值計算程式。但是,當遇到複雜的迴圈或遞迴時,JAX的JIT編譯可能會遇到瓶頸。

什麼是JAXPR?

JAXPR(JAX Program Representation)是一種中間表示形式,用於表示JAX程式的計算圖。當您使用JAX編寫程式時,JAX會將您的程式轉換為JAXPR,以便進行最佳化和編譯。JAXPR包含了一系列的方程式,每個方程式代表了一個計算步驟。

JIT編譯限制

當您使用JAX的jax.jit函式對程式進行JIT編譯時,JAX會將您的程式轉換為機器碼。但是,如果您的程式包含複雜的迴圈或遞迴,JAX可能會產生大量的中間方程式,這會導致編譯時間過長。

例如,以下程式使用了一個簡單的迴圈來計算累積和:

def cumulative_sum(j):
    y = 0
    for i in range(len(j)):
        y += j[i]
    return y

當我們使用jax.jit對這個程式進行JIT編譯時,JAX會產生一個包含30,000個方程式的JAXPR。這個過程需要超過2分鐘的時間。

解決方案:使用lax.scan

為瞭解決這個問題,我們可以使用lax.scan原始函式來替代迴圈。lax.scan允許您定義一個函式,該函式對陣列中的每個元素進行計算,並累積狀態。

def cumulative_sum(j):
    def func(carry, x):
        carry += x
        return carry, carry

    init = 0
    carry, ys = lax.scan(func, init, j)
    return ys

這個版本的程式使用lax.scan來計算累積和,避免了迴圈的使用。結果,JAXPR包含的方程式數量大大減少,編譯時間也大大縮短。

5.3.4 使用 lax.scan 來實作累積和

在前面的例子中,我們使用了 jax.lax.scan 來實作累積和的功能。這個方法可以讓我們更有效地處理序列資料,並且可以減少中間表示的大小。

def cumulative_sum_fast(x):
    result, array = jax.lax.scan(
        lambda carry, elem: (carry+elem, carry+elem), 1.0, x)
    return array

在這個例子中,我們使用 jax.lax.scan 來掃描序列 x,並且使用 lambda 函式來計算累積和。這個 lambda 函式接收兩個引數:carryelem,其中 carry 是累積和的結果,elem 是序列中的每個元素。

使用 jax.lax.scan 來實作累積和的優點是,可以減少中間表示的大小,從而加快編譯時間。事實上,使用 jax.lax.scan 來實作累積和的編譯時間比使用原始的 for 迴圈要快得多。

j = jax.make_jaxpr(cumulative_sum_fast)(jnp.ones(10000))
len(j.jaxpr.eqns)
>>> 1
%time cs = jax.jit(cumulative_sum_fast)(jnp.ones(10000))
>>> CPU times: user 145 ms, sys: 6.02 ms, total: 151 ms
>>> Wall time: 213 ms

5.3.5 使用 @jit 對類別方法進行編譯

在 JAX 中,使用 @jit 對類別方法進行編譯需要一些額外的注意。簡單地在類別方法上新增 @jit 註解可能會導致錯誤。

class ScaleClass:
    def __init__(self, scale: jnp.array):
        self.scale = scale
    @jax.jit
    def apply(self, x: jnp.array):
        return self.scale * x

scale_double = ScaleClass(2)
scale_double.apply(10)
>>> TypeError: Cannot interpret value of type <class '__main__.ScaleClass'>
as an abstract array; it does not have a dtype attribute

這是因為 JAX 的 @jit 編譯器需要知道類別方法的輸入和輸出型別,但是類別方法的輸入和輸出型別可能不明確。為瞭解決這個問題,我們需要使用 jax.jit 來編譯類別方法,並且需要指定輸入和輸出型別。

內容解密:

在上面的例子中,我們定義了一個 ScaleClass 類別,該類別有一個 apply 方法。這個方法接收一個 x 引數,並且傳回 self.scale * x。我們想要使用 @jit 對這個方法進行編譯,但是簡單地在方法上新增 @jit 註解可能會導致錯誤。

為瞭解決這個問題,我們需要使用 jax.jit 來編譯類別方法,並且需要指定輸入和輸出型別。具體地說,我們需要使用 jax.jit 來編譯 apply 方法,並且需要指定輸入型別為 jnp.array,輸出型別為 jnp.array

圖表翻譯:

以下是使用 Mermaid 圖表來描述類別方法編譯的過程:

  flowchart TD
    A[定義類別] --> B[定義方法]
    B --> C[使用 @jit 編譯]
    C --> D[指定輸入和輸出型別]
    D --> E[編譯完成]

在這個圖表中,我們首先定義一個類別,然後定義一個方法。接下來,我們使用 @jit 對方法進行編譯,並且需要指定輸入和輸出型別。最後,編譯完成。

第五章:編譯您的程式碼

在這一章中,我們將探討如何使用編譯來加速您的程式碼,並瞭解相應的 jit() 轉換。

5.3.6 簡單函式

在某些情況下,函式已經很小,使用 JIT 編譯可能不會帶來顯著的效能提升。此外,編譯還需要額外的時間,同時還有使用 MLIR/XLA 和將資料複製到硬體加速器的額外開銷。在這種情況下,測量 JIT 編譯的效果並嘗試編譯最大的計算塊是很重要的,這樣可以給編譯器更多的最佳化空間。

使用輔助函式

當我們遇到類別方法無法直接使用 JIT 編譯時,可以使用輔助函式來解決這個問題。以下是使用輔助函式的範例:

from functools import partial
import jax
import jax.numpy as jnp

class ScaleClass:
    def __init__(self, scale: jnp.array):
        self.scale = scale

    def apply(self, x: jnp.array):
        return _apply_helper(self.scale, x)

@partial(jax.jit, static_argnums=0)
def _apply_helper(scale, x):
    return scale * x

scale_double = ScaleClass(2)
result = scale_double.apply(10)
print(result)  # 輸出:20

在這個範例中,我們建立了一個 ScaleClass 類別,該類別有一個 apply 方法。由於 apply 方法無法直接使用 JIT 編譯,因此我們建立了一個輔助函式 _apply_helper 並使用 @partial(jax.jit, static_argnums=0) 進行編譯。這樣可以確保 _apply_helper 函式被編譯,而 apply 方法則呼叫編譯後的 _apply_helper 函式。

圖表翻譯:
  flowchart TD
    A[開始] --> B[定義 ScaleClass 類別]
    B --> C[定義 apply 方法]
    C --> D[定義 _apply_helper 輔助函式]
    D --> E[使用 @partial(jax.jit, static_argnums=0) 進行編譯]
    E --> F[建立 ScaleClass 例項]
    F --> G[呼叫 apply 方法]
    G --> H[取得結果]

內容解密:

上述程式碼示範瞭如何使用輔助函式來解決類別方法無法直接使用 JIT 編譯的問題。首先,我們定義了一個 ScaleClass 類別,該類別有一個 apply 方法。由於 apply 方法無法直接使用 JIT 編譯,因此我們建立了一個輔助函式 _apply_helper 並使用 @partial(jax.jit, static_argnums=0) 進行編譯。這樣可以確保 _apply_helper 函式被編譯,而 apply 方法則呼叫編譯後的 _apply_helper 函式。最終,我們建立了一個 ScaleClass 例項,呼叫 apply 方法,並取得結果。

自動向量化:加速計算的強大工具

在前幾章中,我們探討瞭如何使用Just-In-Time(JIT)編譯和編譯後的最佳化來加速計算。在本章中,我們將介紹另一種加速計算的方法:自動向量化。自動向量化是一種簡單而有效的方法,可以簡化程式設計過程並加速計算。

自動向量化的優點

自動向量化提供了多種優點。首先,它簡化了程式設計過程,允許您一次處理多個元素或陣列。其次,它可以加速計算,如果您的硬體資源和程式邏輯允許您同時執行多個專案的計算。這通常比逐一處理陣列中的每個專案要快得多。

向量化函式的不同方法

在高效能運算和深度學習中,批次處理是常見的做法。mini-batch梯度下降就是根據這個想法。神經網路中的矩陣乘法都是為了能夠同時處理多個元素而設計的。否則,處理將會非常低效。因此,使用向量化是非常重要的。

從效能最佳化視角來看,XLA 作為一個線性代數加速編譯器,其核心價值在於將多個運算融合成單一核函式,減少記憶體存取和運算時間,從而顯著提升計算效率。透過目標無關和目標相關的最佳化,XLA 能夠針對不同硬體後端生成高度最佳化的機器碼。然而,XLA 的 JIT 編譯也存在一些限制,例如對副作用和全域狀態的處理、浮點數運算精確度以及控制流的處理。對於包含複雜迴圈或遞迴的程式,直接使用 JIT 編譯可能導致編譯時間過長和 JAXPR 過於龐大。為瞭解決這個問題,可以採用 lax.scan 原始函式替代迴圈,或使用輔助函式來間接編譯類別方法。此外,自動向量化也是一種有效提升計算效率的策略,它允許同時處理多個元素或陣列,尤其適用於批次處理和矩陣運算等場景。玄貓認為,XLA 及其相關技術,例如 OpenXLA 和 StableHLO,代表了機器學習編譯器的重要發展方向,未來將在更廣泛的硬體平臺和深度學習框架中發揮關鍵作用。對於追求極致效能的開發者而言,深入理解 XLA 的工作原理和最佳實務至關重要。