JAX作為一個高效能的數值計算函式庫,其自動微分和JIT編譯功能對於機器學習模型的訓練和最佳化至關重要。理解動量和速度的關係是物理學和機器學習的基礎,動量是物體品質和速度的乘積,速度是位移對時間的變化率,兩者共同影響物體的運動狀態。自動微分可以高效計算複雜函式的導數,包含前向模式和反向模式。前向模式逐步計算每個中間變數及其導數,而反向模式則從輸出開始反向計算,在機器學習中被廣泛應用於反向傳播演算法。VJP(向量-雅可比乘積)是計算梯度的一種有效方法,可以用於構建雅可比矩陣,並在機器學習模型訓練中發揮重要作用。

瞭解動量和速度的關係

在物理學中,動量和速度是兩個密切相關的概念。動量是指物體的品質和速度的乘積,而速度則是指物體在單位時間內移動的距離。瞭解動量和速度之間的關係對於分析和解決物理問題至關重要。

動量和速度的計算

動量(p)可以使用以下公式計算:

p = mv

其中,m是物體的品質,v是物體的速度。

速度(v)可以使用以下公式計算:

v = Δx / Δt

其中,Δx是物體在單位時間內移動的距離,Δt是時間間隔。

例子:計算動量和速度

假設有一個物體,其品質為2.0 kg,初始速度為7.0 m/s。在經過一段時間後,其速度變為1.137 m/s。要計算物體的動量和速度,我們可以使用以下步驟:

  1. 計算初始動量:

p0 = m × v0 = 2.0 kg × 7.0 m/s = 14.0 kg·m/s

  1. 計算最終動量:

p1 = m × v1 = 2.0 kg × 1.137 m/s = 2.274 kg·m/s

  1. 計算速度變化:

Δv = v1 - v0 = 1.137 m/s - 7.0 m/s = -5.863 m/s

  1. 計算動量變化:

Δp = p1 - p0 = 2.274 kg·m/s - 14.0 kg·m/s = -11.726 kg·m/s

內容解密:

在上述例子中,我們使用了動量和速度的公式來計算物體的動量和速度。動量的計算涉及到物體的品質和速度的乘積,而速度的計算涉及到物體在單位時間內移動的距離。透過這些計算,我們可以得到物體的動量和速度,並進一步分析其運動特性。

圖表翻譯:

下面是一個簡單的Mermaid圖表,展示了動量和速度之間的關係:

  flowchart TD
    A[品質] -->|乘以|> B[速度]
    B -->|等於|> C[動量]
    C -->|變化|> D[速度變化]
    D -->|等於|> E[動量變化]

這個圖表展示了品質、速度、動量之間的關係,以及速度變化和動量變化之間的關係。透過這個圖表,我們可以更清楚地瞭解動量和速度之間的關係。

瞭解自動微分的過程

在電腦科學中,自動微分是一種強大的工具,能夠高效地計算複雜函式的導數。這個過程對於許多應用,包括最佳化、機器學習和科學模擬,都是非常重要的。

自動微分的基本概念

自動微分的核心思想是將原始函式轉換為一系列的中間變數,每個中間變數都有一個對應的導數。這些導數代表了輸出相對於每個中間變數的敏感度。

前向模式自動微分

在前向模式自動微分中,我們從輸入開始,逐步計算每個中間變數及其導數。這個過程可以用以下步驟來描述:

  1. 初始化:首先,我們初始化輸入變數和其導數。
  2. 前向傳播:然後,我們計算每個中間變數及其導數,使用前一個中間變數及其導數。
  3. 輸出:最後,我們得到最終輸出及其導數。

例子

假設我們有一個簡單的函式 y = v3,其中 v3 是一個中間變數。假設 v3 的值為 14.99,並且其導數 v'3 為 1.0。

在前向模式自動微分中,我們可以計算每個中間變數及其導數,如下所示:

  • v'1 = v'3 * dv3/dv1 = 1.0
  • v'2 = v'3 * dv3/dv2 = 1.0

這些導數代表了輸出 y 相對於每個中間變數的敏感度。

圖表翻譯:
  graph LR
    A[輸入] --> B[中間變數 v1]
    B --> C[中間變數 v2]
    C --> D[中間變數 v3]
    D --> E[輸出 y]
    E --> F[導數 v'3]
    F --> G[導數 v'2]
    G --> H[導數 v'1]

這個圖表展示了前向模式自動微分的過程,從輸入開始,逐步計算每個中間變數及其導數,直到得到最終輸出及其導數。

瞭解反向自動微分

在自動微分中,反向模式是一種計算導數的方法,與前向模式相比,它具有不同的計算流程。反向模式從輸出開始,反向計算導數,直到輸入。這種方法在某些情況下比前向模式更有效,尤其是在函式具有多個輸入但只有一個輸出的情況下。

反向模式的工作原理

在反向模式中,我們首先計算輸出的導數,然後反向計算每個變數的導數。這個過程可以被視為一個圖的反向遍歷,從輸出開始,直到輸入。每個變數的導數都是透過將其影響輸出的所有途徑的導數加總而得。

與前向模式的比較

與前向模式相比,反向模式在某些情況下具有計算上的優勢,尤其是在函式具有多個輸入但只有一個輸出的情況下。然而,反向模式也需要更多的儲存空間來儲存中間結果。

實際應用

在機器學習中,反向模式(以及反向傳播)被廣泛用於計算模型引數的導數。由於機器學習模型通常具有多個引數和一個標量輸出,因此反向模式在這種情況下尤其有效。

vjp() 函式

vjp() 函式是一種計算向量-雅可比乘積(vector-Jacobian product, VJP)的方法,它可以用於構建雅可比矩陣的一行一行。這種方法在計算寬雅可比矩陣時尤其有效。

  flowchart TD
    A[開始] --> B[計算輸出的導數]
    B --> C[反向計算每個變數的導數]
    C --> D[將每個變數的導數加總]
    D --> E[得到最終的導數]

圖表翻譯:

上述流程圖展示了反向模式的工作原理。首先,我們計算輸出的導數,然後反向計算每個變數的導數。每個變數的導數都是透過將其影響輸出的所有途徑的導數加總而得。最終,我們得到每個變數的導數。

程式碼實作

import jax.numpy as jnp

def f(x1, x2):
    return x1 * x2 + jnp.sin(x1 * x2)

x = (7.0, 2.0)
jax.grad(f, argnums=(0, 1))(*x)

內容解密:

上述程式碼實作了反向模式的計算過程。首先,我們定義了一個函式 f(x1, x2),然後使用 jax.grad() 函式計算其導數。argnums=(0, 1) 引數指定了我們想要計算哪些變數的導數。在這種情況下,我們計算了 x1x2 的導數。最終,我們得到了一個包含兩個導數值的元組。

使用JAX的VJP計算梯度

在前面的章節中,我們討論瞭如何使用JAX計算梯度。現在,我們來看看如何使用VJP(向量-雅可比乘積)來計算梯度。

什麼是VJP?

VJP是一種計算梯度的方法,它可以計算一個函式對於輸入引數的梯度。它的輸出是一個元組,包含了函式的輸出值和一個函式,用於計算向後傳遞的梯度。

VJP的型別簽名

VJP的型別簽名可以寫成:

vjp :: (a -> b) -> a -> (b, CT b -> CT a)

這意味著VJP函式接受一個函式f和一個輸入值x,並傳回一個元組,包含了函式的輸出值f(x)和一個函式,用於計算向後傳遞的梯度。

使用VJP計算梯度

讓我們看看如何使用VJP計算梯度。假設我們有一個函式f(x1, x2) = x1*x2 + sin(x1*x2),我們想計算它的梯度。

import jax.numpy as jnp

def f(x1, x2):
    return x1*x2 + jnp.sin(x1*x2)

x = (7.0, 2.0)
p, vjp_func = jax.vjp(f, *x)
print(p)  # 輸出:14.990607
print(vjp_func(1.0))  # 輸出:(2.2734745, 7.9571605)

在這個例子中,我們使用VJP計算了函式f對於輸入引數x1x2的梯度。VJP函式傳回了一個元組,包含了函式的輸出值p和一個函式vjp_func,用於計算向後傳遞的梯度。當我們呼叫vjp_func(1.0)時,它傳回了梯度值。

使用VJP還原雅可比矩陣

現在,讓我們看看如何使用VJP還原雅可比矩陣。假設我們有一個函式f2(x) = [x[0]**2 + x[1]**2 - x[1]*x[2], x[0]**2 - x[1]**2 + 3*x[0]*x[2]],我們想還原它的雅可比矩陣。

def f2(x):
    return [
        x[0]**2 + x[1]**2 - x[1]*x[2],
        x[0]**2 - x[1]**2 + 3*x[0]*x[2]
    ]

我們可以使用VJP計算雅可比矩陣的每一行。

x = (1.0, 2.0, 3.0)
p, vjp_func = jax.vjp(f2, x)
print(p)  # 輸出:[10.0, 14.0]
print(vjp_func(1.0))  # 輸出:[2.0, 4.0, 6.0]

在這個例子中,我們使用VJP計算了函式f2對於輸入引數x的梯度。VJP函式傳回了一個元組,包含了函式的輸出值p和一個函式vjp_func,用於計算向後傳遞的梯度。當我們呼叫vjp_func(1.0)時,它傳回了梯度值。

圖表翻譯:

  graph LR
    A[輸入引數] -->|VJP|> B[函式輸出值]
    B -->|向後傳遞|> C[梯度值]
    C -->|還原雅可比矩陣|> D[雅可比矩陣]

在這個圖表中,我們展示瞭如何使用VJP計算梯度和還原雅可比矩陣。

自動微分的進階應用

在上一節中,我們已經瞭解瞭如何使用JAX進行自動微分。現在,我們將更深入地探討JAX的自動微分功能,包括如何定義自訂的微分規則和使用高階功能。

自訂微分規則

JAX提供了兩個函式,jax.custom_jvp()和`jax.custom_vjp()%,用於定義自訂的微分規則。這些函式允許您定義自訂的微分規則,以便JAX可以正確地計算您的函式的導數。

import jax
import jax.numpy as jnp

# 定義一個自訂的函式
def my_func(x):
    return x**2 + 2*x + 1

# 定義自訂的微分規則
@jax.custom_jvp
def my_func_jvp(x):
    def my_func_jvp_rule(x, t):
        return my_func(x) + t * (2*x + 2)
    return my_func_jvp_rule

# 測試自訂的微分規則
x = jnp.array([3.0, 4.0, 5.0])
p, vjp_func = jax.vjp(my_func_jvp, x)
print(p)
print(vjp_func([1.0, 0.0]))
print(vjp_func([0.0, 1.0]))

高階功能

JAX還提供了許多高階功能,包括計算Hessian-vector積、對複數函式進行微分等。這些功能可以透過JAX的Autodiff Cookbook檔案進行了解。

關於JAX的核心繫統

JAX的核心繫統是根據一個稱為Primitive的抽象概念。Primitive是一個可以被JAX轉換的基本單元,例如加法、乘法等。透過瞭解JAX的核心繫統,您可以定義新的Primitive例項,並實作自訂的轉換規則。

內容解密:

在上面的程式碼中,我們定義了一個自訂的函式my_func(),然後使用jax.custom_jvp()定義了自訂的微分規則。透過這個自訂的微分規則,我們可以計算出函式的導數。同時,我們還使用了jax.vjp()函式來計算導數。

圖表翻譯:

  graph LR
    A[原始函式] -->|定義自訂微分規則|> B[自訂微分規則]
    B -->|計算導數|> C[導數]
    C -->|使用jax.vjp()|> D[最終結果]

在這個圖表中,我們展示瞭如何定義自訂的微分規則,然後使用jax.vjp()函式來計算導數。最終結果是得到函式的導數。

編譯您的程式碼

在本章中,我們將探討Just-In-Time(JIT)編譯,該技術可為CPU、GPU或TPU產生高效的程式碼。同時,我們也會深入瞭解JIT的內部運作,包括中間表示和加速線性代數編譯器。此外,我們還會討論JIT的限制。

在第1章中,我們比較了JAX函式在CPU和GPU上的效能,同時也探討了JIT的作用。在第2章中,我們使用JIT編譯兩個函式,以加速簡單神經網路的訓練。因此,您基本上已經瞭解JIT的作用。它可以編譯您的函式以適應目標硬體平臺,並使其執行更快。

從第4章開始,我們學習了JAX轉換(請記住,JAX是關於可組合函式轉換!)。那一章教導我們自動微分和grad()轉換。本章將討論編譯和對應的jit()轉換。在接下來的章節中,我們將學習更多關於自動向量化和平行化的轉換。

在數值計算的領域中,JAX作為一個強大的框架,建立在Google的XLA編譯器的基礎上。XLA編譯器不僅是一個工具,它是專門為高效能運算任務設計的。可以把它想象成一位建築師,精心設計藍圖,以建造針對特定目的而最佳化的結構。

JAX依賴XLA的重要性是深遠的。XLA已經在機器學習框架中取得了成熟的應用,特別是TensorFlow。其多功能性由玄貓所強調:CPU是通用計算的基礎;GPU是專門用於平行計算的單元;而TPU是Google為機器學習工作負載而設計的專用加速器。這種相容性譜系表明JAX的適應性,指出它已經針對多樣化的硬體環境進行了最佳化。

使用編譯

本章致力於編譯,並解釋如何使用它來使您的程式碼執行更快。我們將使用簡單的教學範例來瞭解JAX中的編譯,並強調其功能。這對於大規模和長時間的計算任務(例如大型神經網路訓練、氣候模型或其他大規模計算任務)尤其有用。同時,它也可以為不是那麼龐大的任務帶來價值。例如,加速影像濾波器可能對您的應用程式使用者具有巨大的價值。

讓我們從一個著名的啟用函式開始,稱為縮放指數線性單元(SELU)。我們將使用這個函式來演示JIT。

import jax.numpy as jnp

def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
    """
    縮放指數線性單元啟用函式。
    """
    return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

這個函式包含兩個分支。當輸入x為正時,它傳回x的縮放值。否則,它是一個更複雜的公式,涉及指數和常數項。

編譯SELU函式

現在,我們可以使用jit()轉換來編譯SELU函式。

from jax import jit

@jit
def compiled_selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
    return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

透過使用@jit裝飾器,我們告訴JAX編譯SELU函式以獲得更好的效能。這個編譯過程只需執行一次,之後就可以重複使用編譯後的版本。

使用JIT編譯最佳化SELU函式的效能

SELU(Scaled Exponential Linear Units)是一種常用的神經網路啟用函式,其公式為scale * (alpha * e^x - alpha)。在本文中,我們將探討如何使用JIT(Just-In-Time)編譯來最佳化SELU函式的效能。

使用JIT編譯

假設我們需要對一百萬個啟用值應用SELU函式。雖然在單次前向執行中,計算量可能不大,但是在訓練過程中,會有多次前向執行,因此這個例子並不過於離譜。讓我們使用JIT編譯來比較SELU函式的效能,分別在有無JIT編譯的情況下進行測試。

import jax
import jax.numpy as jnp

# 定義SELU函式
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
    return scale * (alpha * jnp.exp(x) - alpha)

# 建立一個大型隨機陣列
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))

# 使用JIT編譯
selu_jit = jax.jit(selu)

# 測試SELU函式的效能
%timeit -n100 selu(x).block_until_ready()
%timeit -n100 selu_jit(x).block_until_ready()

結果表明,JIT編譯後的SELU函式比未編譯的版本快了五倍以上。

JIT和AOT編譯

JIT編譯是一種在程式執行時進行編譯的技術,當程式需要執行某段程式碼時,JIT編譯器會將其編譯為機器碼。另一種編譯技術是AOT(Ahead-Of-Time)編譯,即在程式執行前就將其編譯為機器碼。

JAX提供了JIT編譯的功能,可以使用jax.jit()函式或@jax.jit裝飾器來對函式進行JIT編譯。

# 使用jax.jit()函式
selu_jit = jax.jit(selu)

# 使用@jax.jit裝飾器
@jax.jit
def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
    return scale * (alpha * jnp.exp(x) - alpha)

JIT 編譯技術在神經網路中的應用

在深度學習中,神經網路的運算效率對於模型的訓練和推理速度有著重要影響。為了提高運算效率,JIT(Just-In-Time)編譯技術被廣泛應用於神經網路的最佳化中。JIT 編譯技術可以將 Python 程式碼轉換為機器碼,從而提高執行速度。

JIT 編譯的原理

JIT 編譯技術的基本原理是將 Python 程式碼在執行時編譯為機器碼。這個過程包括兩個步驟:編譯和執行。編譯步驟將 Python 程式碼轉換為中間程式碼,然後中間程式碼被轉換為機器碼。執行步驟則是將機器碼載入記憶體並執行。

使用 JAX 實作 JIT 編譯

JAX 是一個根據 Python 的神經網路框架,提供了 JIT 編譯功能。使用 JAX,可以輕鬆地將 Python 程式碼轉換為 JIT 編譯版本。以下是使用 JAX 實作 JIT 編譯的示例:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.6732632423543772848170429916717, scale=1.0507009873554804934193349852946):
    '''Scaled exponential linear unit activation function.'''
    return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)

在這個示例中,selu 函式被轉換為 JIT 編譯版本 selu_jit

控制後端和裝置

JAX 提供了兩個引數來控制 JIT 編譯的後端和裝置:backenddevicebackend 引數指定了 XLA 後端,可以是 'cpu''gpu''tpu'device 引數指定了裝置,可以是 'cpu''gpu''tpu'

以下是控制後端和裝置的示例:

selu_jit_cpu = jax.jit(selu, backend='cpu')
selu_jit_gpu = jax.jit(selu, backend='gpu')

在這個示例中,selu_jit_cpuselu_jit_gpu 分別是為 CPU 和 GPU 編譯的 JIT 版本。

效能比較

使用 JIT 編譯可以顯著提高神經網路的運算效率。以下是使用 timeit 函式庫比較 selu_jit_cpuselu_jit_gpu 的效能:

%timeit -n100 selu_jit_cpu(x).block_until_ready()
%timeit -n100 selu_jit_gpu(x).block_until_ready()

結果顯示,GPU 版本的運算速度約為 CPU 版本的 10 倍。

編譯程式碼的重要性

在深度學習和機器學習的應用中,編譯程式碼可以大大提高效能。JAX是一個強大的函式庫,可以幫助我們編譯程式碼。在本章中,我們將探討如何使用JAX編譯程式碼和控制後端和張量裝置放置。

測試編譯功能

首先,我們需要測試編譯功能的速度。以下是測試結果:

%timeit -n100 selu_jit_cpu(x).block_until_ready()
>>> 791 μs ± 66.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit -n100 selu_jit_gpu(x).block_until_ready()
>>> 1.81 ms ± 178 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

由結果可見,編譯後的CPU版本比GPU版本快。但是,需要注意的是,即使用了@jit標記,JAX仍可能使用GPU或TPU進行計算。

控制後端和張量裝置放置

為了正確地測試編譯功能的速度,我們需要控制後端和張量裝置放置。以下是控制後端和張量裝置放置的例子:

%timeit -n100 selu(x_cpu).block_until_ready()
>>> 2.74 ms ± 95.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit -n100 selu(x_gpu).block_until_ready()
>>> 872 μs ± 80.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()
>>> 437 μs ± 4.29 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()
>>> 27.1 μs ± 4.94 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

由結果可見,編譯後的CPU版本比非編譯版本快6倍,編譯後的GPU版本比非編譯版本快32倍。

內容解密:

上述程式碼使用了JAX的@jit標記來編譯程式碼。@jit標記可以幫助我們編譯程式碼,提高效能。但是,需要注意的是,即使用了@jit標記,JAX仍可能使用GPU或TPU進行計算。為了正確地測試編譯功能的速度,我們需要控制後端和張量裝置放置。

  flowchart TD
    A[開始] --> B[測試編譯功能]
    B --> C[控制後端和張量裝置放置]
    C --> D[測試編譯後的CPU版本]
    D --> E[測試編譯後的GPU版本]
    E --> F[比較結果]

圖表翻譯:

上述Mermaid圖表展示了測試編譯功能的流程。首先,我們需要測試編譯功能的速度。然後,我們需要控制後端和張量裝置放置。接下來,我們需要測試編譯後的CPU版本和GPU版本。最後,我們需要比較結果。

使用靜態引數解決JIT編譯錯誤

在JAX中,靜態引數是一種機制,可以用於不同的場合。當您想要編譯一個函式,其引數不是陣列,而是某個類別的例項或函式時,靜態引數就很有用。例如,您想要編譯一個神經網路層,並將啟用函式作為引數傳遞。在這種情況下,JIT編譯將失敗,因為函式的引數和傳回值應該是陣列、標量或標準Python容器(如元組、列表或字典)。

靜態引數可以解決這個問題,因為被標記為靜態的引數可以是任何東西,只要它們是可雜湊的並且具有相等操作。您可以使用static_argnumsstatic_argnames引數來標記某些引數為靜態。這兩個引數都是可選的。

static_argnums引數是一個整數或整數集合,指定哪些位置引數應該被視為靜態。static_argnames是一個字串或字串集合,指定哪些名稱引數應該被視為靜態。如果沒有提供static_argnumsstatic_argnames,則沒有引數被視為靜態。

技術上,這意味著在Python的追蹤過程中,只依賴靜態引數的操作將被常數折疊。常數折疊是一種最佳化技術,消除了可以在程式碼執行前確定值的表示式。

以下示例實作了這個概念:

import jax.numpy as jnp
from jax import jit

def dense_layer(x, w, b, activation_func):
    return activation_func(x * w + b)

x = jnp.array([1.0, 2.0, 3.0])
w = jnp.ones((3, 3))
b = jnp.ones(3)

# JIT編譯失敗,因為activation_func不是有效的JAX型別
dense_layer_jit = jit(dense_layer)
try:
    dense_layer_jit(x, w, b, selu)
except TypeError as e:
    print(e)

# 使用靜態引數解決JIT編譯錯誤
dense_layer_jit = jit(dense_layer, static_argnums=3)
result = dense_layer_jit(x, w, b, selu)
print(result)

在這個示例中,dense_layer函式的activation_func引數被標記為靜態,然後JIT編譯就成功了。結果是一個陣列,包含了啟用函式應用的結果。

內容解密:

  • static_argnumsstatic_argnames引數用於標記某些引數為靜態。
  • 靜態引數可以是任何東西,只要它們是可雜湊的並且具有相等操作。
  • 常數折疊是一種最佳化技術,消除了可以在程式碼執行前確定值的表示式。
  • JIT編譯失敗是因為activation_func不是有效的JAX型別。
  • 使用靜態引數可以解決JIT編譯錯誤。

圖表翻譯:

  flowchart TD
    A[開始] --> B[定義dense_layer函式]
    B --> C[建立JIT編譯例項]
    C --> D[嘗試JIT編譯]
    D --> E[捕捉TypeError]
    E --> F[使用靜態引數解決JIT編譯錯誤]
    F --> G[成功JIT編譯]
    G --> H[計算結果]

這個圖表展示了程式碼的執行流程,從定義dense_layer函式到成功JIT編譯和計算結果。

編譯您的程式碼

在本章中,我們將探討如何使用JAX編譯您的程式碼。JAX是一個強大的函式庫,允許您編譯Python程式碼以獲得更好的效能。

靜態引數

當您使用JAX編譯函式時,您可以指定某些引數為靜態引數。靜態引數是指那些不會在編譯過程中改變的引數。透過指定靜態引數,您可以告訴JAX只編譯一次函式,並為不同的靜態引數值建立多個版本的編譯函式。

例如,假設您有一個函式dense_layer(),它接受一個啟用函式作為引數。如果您不指定啟用函式為靜態引數,JAX將會在編譯過程中報錯。因為啟用函式不是一個允許的輸入或輸出型別。

import jax.numpy as jnp
from jax import jit

def dense_layer(activation, x, w, b):
    return activation(jnp.dot(x, w) + b)

# 報錯,因為啟用函式不是一個允許的輸入或輸出型別
dense_layer_jit = jit(dense_layer)

為瞭解決這個問題,您可以指定啟用函式為靜態引數。這樣,JAX就會為不同的啟用函式建立多個版本的編譯函式。

dense_layer_jit = jit(dense_layer, static_argnums=0)

Minkowski距離

另一個例子是計算兩個向量之間的Minkowski距離。Minkowski距離是一種度量兩個向量之間距離的方法。它的計算公式為:

$$d(x, y) = \left(\sum_{i=1}^n |x_i - y_i|^p\right)^{1/p}$$

其中,$x$和$y$是兩個向量,$p$是距離的階數。

以下是計算Minkowski距離的函式:

def dist(order, x, y):
    return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

如果您不指定order引數為靜態引數,JAX將會在編譯過程中報錯。因為order引數不是一個允許的輸入或輸出型別。

dist_jit = jit(dist)

為瞭解決這個問題,您可以指定order引數為靜態引數。這樣,JAX就會為不同的order值建立多個版本的編譯函式。

dist_jit = jit(dist, static_argnums=0)

結果

現在,您可以使用編譯後的函式計算Minkowski距離了。

print(dist_jit(1, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])))
print(dist_jit(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0])))
print(dist_jit(1, jnp.array([10.0, 10.0]), jnp.array([2.0, 2.0])))

結果如下:

4.0
2.828427
16.0

如您所見,編譯後的函式可以正確地計算Minkowski距離。

內容解密:

  • jax.jit()函式用於編譯Python程式碼。
  • static_argnums引數用於指定靜態引數。
  • 靜態引數是指那些不會在編譯過程中改變的引數。
  • 編譯後的函式可以為不同的靜態引數值建立多個版本。

圖表翻譯:

以下是Minkowski距離計算過程的Mermaid圖表:

  graph LR
    A[計算Minkowski距離] --> B[指定order引數]
    B --> C[計算距離]
    C --> D[傳回結果]

這個圖表展示了計算Minkowski距離的過程。首先,指定order引數,然後計算距離,最後傳回結果。

編譯與最佳化

編譯是將程式碼轉換成機器碼的過程,讓電腦能夠直接執行。JAX是一個根據Python的編譯框架,能夠將Python程式碼編譯成高效的機器碼。

靜態引數編譯

JAX提供了一種編譯方式,稱為靜態引數編譯(static_argnums)。這種編譯方式可以指定某些引數為靜態引數,不需要在執行時計算。這樣可以簡化編譯過程,提高編譯速度。

例如,以下程式碼使用了靜態引數編譯:

from functools import partial
from jax import jit

@partial(jit, static_argnums=0)
def dist(order, x, y):
    return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

在這個例子中,order引數被指定為靜態引數,編譯器會在編譯時計算這個引數的值,而不是在執行時計算。

最佳化相關引數

JAX提供了一些最佳化相關引數,可以用來控制編譯過程。例如,donate_argnums引數可以指定哪些引數可以被捐贈給計算,用來減少記憶體使用量。keep_unused引數可以控制是否保留未使用的引數。

純函式與編譯過程

JAX設計用來工作於純函式上,即沒有全域狀態和副作用的函式。雖然你仍然可以寫和執行不純的函式,但JAX不保證它們會正確工作。

JAX使用了一種稱為追蹤(tracing)的過程來轉換Python函式成一種稱為Jaxpr的中間語言。然後,轉換會在Jaxpr表現上工作。有了編譯,Jaxpr表現會進一步被編譯成XLA(Accelerated Linear Algebra)程式碼。

內容解密:

在上面的程式碼中,我們使用了@partial(jit, static_argnums=0)裝飾器來指定order引數為靜態引數。這樣可以簡化編譯過程,提高編譯速度。

  flowchart TD
    A[Python Code] --> B[JAX Compilation]
    B --> C[Jaxpr Representation]
    C --> D[XLA Compilation]
    D --> E[Machine Code]

圖表翻譯:

上面的流程圖描述了JAX編譯過程。首先,Python程式碼被轉換成Jaxpr表現。然後,Jaxpr表現被編譯成XLA程式碼。最後,XLA程式碼被轉換成機器碼。

這個過程可以提高程式碼的執行效率,因為XLA程式碼可以直接在GPU或CPU上執行,而不需要經過Python直譯器。同時,JAX也提供了一些最佳化相關引數,可以用來控制編譯過程,進一步提高執行效率。

從技術架構視角來看,JAX 的 JIT 編譯機制提供了一個將 Python 程式碼轉換為高效機器碼的有效途徑,尤其在深度學習模型訓練等計算密集型任務中,效能提升顯著。藉由 XLA 編譯器,JAX 能夠針對 CPU、GPU 和 TPU 等不同硬體後端進行最佳化,展現了其跨平臺的應用彈性。然而,JIT 編譯並非沒有限制,例如處理非純函式和靜態引數的挑戰。對於非純函式,JAX 的行為並非總是可預測的,而靜態引數的使用則需要開發者仔細考量函式的特性和引數的變化模式。對於追求極致效能的開發者,深入理解 Jaxpr 中間表示和 XLA 編譯器的運作原理至關重要,這有助於編寫更最佳化的 JAX 程式碼並充分發揮硬體的計算潛力。玄貓認為,JAX 的 JIT 編譯機制與 Python 的易用性相結合,為科學計算和機器學習領域提供了強大的工具,隨著 JAX 生態系統的持續發展,其應用前景值得期待。