JAX作為一個高效能數值計算函式庫,在機器學習領域中,其自動微分功能是重要的根本。自動微分可以自動計算函式的梯度,省去手動推導的繁瑣過程,並能有效提升模型訓練效率。本文將會介紹JAX的兩種自動微分模式:前向模式jacfwd()和反向模式jacrev(),並說明如何使用它們計算Jacobian矩陣和Hessian矩陣。同時,也將探討如何應用這些技術於機器學習模型的梯度計算、最佳化,以及物理運動模擬等實際案例。理解這些概念有助於更有效地運用JAX進行深度學習模型的開發與最佳化。
多變數函式的梯度計算
在深度學習中,梯度計算是一個非常重要的步驟。當我們面對多變數函式時,梯度計算就變得更加複雜。在這個章節中,我們將探討如何使用 Jacobian 矩陣和 Hessian 矩陣來計算多變數函式的梯度。
Jacobian 矩陣
Jacobian 矩陣是一個包含所有偏導數的矩陣,用於描述多變數函式的梯度。它的元素是函式輸出的偏導數,以對輸入變數為單位。在深度學習中,Jacobian 矩陣可以用來計算神經網路模型的損失函式對模型引數的偏導數。
有兩種方法可以用來計算 Jacobian 矩陣:jacfwd() 和 jacrev()。這兩種方法都可以計算出相同的結果,但它們的實作方式不同,分別根據前向模式和反向模式的自動微分。
import jax
import jax.numpy as jnp
def f(x):
    return [
        x[0]**2 + x[1]**2 - x[1]*x[2],
        x[0]**2 - x[1]**2 + 3*x[0]*x[2]
    ]
print(jax.jacrev(f)(jnp.array([3.0, 4.0, 5.0])))
print(jax.jacfwd(f)(jnp.array([3.0, 4.0, 5.0])))
Hessian 矩陣
Hessian 矩陣是一個包含所有二階導數的方矩陣,用於描述多變數函式的二階梯度。它的元素是函式輸出的二階導數,以對輸入變數為單位。
import jax
import jax.numpy as jnp
def f(x):
    return x[0]**2 + x[1]**2 - x[1]*x[2]
hessian = jax.hessian(f)
print(hessian(jnp.array([3.0, 4.0, 5.0])))
圖表翻譯:
  graph LR
    A[多變數函式] -->|Jacobian 矩陣|> B[梯度計算]
    B -->|Hessian 矩陣|> C[二階梯度計算]
    C -->|自動微分|> D[神經網路模型]
內容解密:
在這個章節中,我們探討瞭如何使用 Jacobian 矩陣和 Hessian 矩陣來計算多變數函式的梯度。Jacobian 矩陣是一個包含所有偏導數的矩陣,用於描述多變數函式的梯度。Hessian 矩陣是一個包含所有二階導數的方矩陣,用於描述多變數函式的二階梯度。我們還介紹瞭如何使用 jacfwd() 和 jacrev() 函式來計算 Jacobian 矩陣,以及如何使用 hessian() 函式來計算 Hessian 矩陣。最後,我們給出了一個示例程式碼,展示瞭如何使用這些函式來計算多變數函式的梯度和二階梯度。
4.3 前向和反向自動微分
自動微分(autodiff)是機器學習中的一個重要工具,它可以自動計算函式的導數。要了解自動微分的工作原理,我們需要先了解前向和反向自動微分。
4.3.1 前向自動微分
前向自動微分是一種計算導數的方法,它透過計算函式的每個中間變數的導數來得到最終導數。這種方法需要計算函式的每個中間步驟的導數,因此計算量較大。
在 JAX 中,前向自動微分可以使用 jax.jacfwd 函式實作。例如,給定一個函式 f(x) = x[0]**2 - x[1]**2 + 3*x[0]*x[2],我們可以使用 jax.jacfwd 函式計算其導數:
import jax
import jax.numpy as jnp
def f(x):
    return x[0]**2 - x[1]**2 + 3*x[0]*x[2]
x = jnp.array([3.0, 4.0, 5.0])
grad_f = jax.jacfwd(f)(x)
print(grad_f)
4.3.2 反向自動微分
反向自動微分是一種計算導數的方法,它透過計算函式的最終輸出的導數來得到中間變數的導數。這種方法需要計算函式的最終輸出的導數,因此計算量較小。
在 JAX 中,反向自動微分可以使用 jax.jacrev 函式實作。例如,給定一個函式 f(x) = x[0]**2 - x[1]**2 + 3*x[0]*x[2],我們可以使用 jax.jacrev 函式計算其導數:
import jax
import jax.numpy as jnp
def f(x):
    return x[0]**2 - x[1]**2 + 3*x[0]*x[2]
x = jnp.array([3.0, 4.0, 5.0])
grad_f = jax.jacrev(f)(x)
print(grad_f)
4.3.3 Hessian 矩陣
Hessian 矩陣是一個描述函式二階導數的矩陣。它可以用於計算函式的二階導數。在 JAX 中,Hessian 矩陣可以使用 jax.hessian 函式實作。例如,給定一個函式 f(x) = x[0]**2 - x[1]**2 + 3*x[0]*x[2],我們可以使用 jax.hessian 函式計算其 Hessian 矩陣:
import jax
import jax.numpy as jnp
def f(x):
    return x[0]**2 - x[1]**2 + 3*x[0]*x[2]
x = jnp.array([3.0, 4.0, 5.0])
hessian_f = jax.hessian(f)(x)
print(hessian_f)
內容解密:
- jax.jacfwd和- jax.jacrev函式可以用於計算函式的導數。
- jax.hessian函式可以用於計算函式的 Hessian 矩陣。
- 前向自動微分和反向自動微分是兩種不同的計算導數的方法。
- Hessian 矩陣可以用於計算函式的二階導數。
圖表翻譯:
  graph LR
    A[函式 f(x)] --> B[前向自動微分]
    B --> C[導數 grad_f]
    A --> D[反向自動微分]
    D --> E[導數 grad_f]
    A --> F[Hessian 矩陣]
    F --> G[Hessian 矩陣 hessian_f]
- 圖表描述了函式 f(x)的前向自動微分、反向自動微分和 Hessian 矩陣的計算過程。
- 前向自動微分和反向自動微分都可以用於計算函式的導數。
- Hessian 矩陣可以用於計算函式的二階導數。
4.3.1 評估追蹤
計算可以表示為基本運算的評估追蹤,也稱為 Wengert 列表(現在也稱為磁帶;來自 R. E. Wengert 1964 年的文章)。計算(或函式)被分解為一系列基本的功能步驟,引入中間變數。Wengert 列表抽象了所有控制流考慮,這意味著自動微分是對任何操作都盲目,包括控制流陳述式,這些陳述式不直接改變數值。採取的分支取代所有條件陳述式,所有迴圈都展開,並且所有函式呼叫都內聯。
讓我們考慮一個具有兩個變數的具體實值函式 f(x1, x2) = x1 × x2 + sin(x1 × x2),計算其導數和評估追蹤(表 4.1)。我們將使用此示例來說明前向和反向自動微分模式。
表 4.1. 函式 f (x1, x2) = x1 × x2 + sin(x1 × x2) 的評估追蹤(Wengert 列表)
評估追蹤
輸入變數 v-1 = x1 v0 = x2
中間變數 v1 = v-1 × v0 v2 = sin(v1) v3 = v1 + v2
內容解密:
在這個例子中,我們定義了一個函式 f(x1, x2) = x1 × x2 + sin(x1 × x2),並計算其評估追蹤。評估追蹤是一個基本運算的序列,描述瞭如何計算函式的輸出。這裡,我們引入了中間變數 v1、v2 和 v3,分別代表 x1 和 x2 的乘積、sin(x1 × x2) 的結果和最終結果。
  flowchart TD
    A[開始] --> B[計算 v1]
    B --> C[計算 v2]
    C --> D[計算 v3]
    D --> E[輸出結果]
圖表翻譯:
此圖表描述了函式 f(x1, x2) 的評估追蹤過程。首先,我們計算 v1 = x1 × x2,然後計算 v2 = sin(v1),最後計算 v3 = v1 + v2。這個過程展示瞭如何使用基本運算計算函式的輸出。
在下一節中,我們將使用此評估追蹤來說明前向和反向自動微分模式。
計算梯度
計算梯度是一個重要的步驟,尤其是在機器學習和最佳化演算法中。這個過程可以使用計算圖(computational graph)來表示,如圖 4.8 所示。
前向模式自動微分
前向模式自動微分(forward-mode autodiff)是一種概念上最簡單的方法。要計算函式相對於第一個變數 x1 的偏導數,我們需要從輸入變數開始,逐步計算每個中間變數的導數。
計算步驟
- 初始化變數:設定輸入變數 x1 和 x2 的值,例如 (x1, x2) = (7.0, 2.0)。
- 計算中間變數:根據函式 f(x1, x2) = x1 × x2 + sin(x1 × x2),計算中間變數 v1、v2 和 v3 的值。
- 計算導數:計算每個中間變數的導數 dv1/dx1、dv2/dx1 和 dv3/dx1。
內容解密
import numpy as np
# 初始化變數
x1 = 7.0
x2 = 2.0
# 計算中間變數
v1 = x1 * x2
v2 = np.sin(v1)
v3 = v1 + v2
# 計算導數
dv1_dx1 = x2
dv2_dx1 = np.cos(v1) * dv1_dx1
dv3_dx1 = dv1_dx1 + dv2_dx1
print("dv3_dx1 =", dv3_dx1)
圖表翻譯
  flowchart TD
    A[初始化變數] --> B[計算中間變數]
    B --> C[計算導數]
    C --> D[輸出結果]
在這個例子中,我們使用前向模式自動微分計算了函式相對於第一個變數 x1 的偏導數。這個過程可以使用計算圖來表示,方便理解和計算。
自動微分的前向模式
在前向模式自動微分中,我們不再只計算每個中間變數的單一值,而是計算一個元組 (v_i, v'_i),其中 v_i 是原始的中間值,而 v'_i 則是其導數。這種方法被稱為雙數法(Dual Numbers Approach)。我們對每個基本運算應用鏈式法則,以計算前向切線跡(表 4.2)。
表 4.2:前向模式自動微分示例
函式 f(x_1, x_2) = x_1 * x_2 + sin(x_1 * x_2) 在點 (7.0, 2.0) 對於第一個變數 x_1 的導數計算。
| 前向原始跡 | 前向切線跡(導數) | 
|---|---|
| v_{-1} = x_1 = 7.0(起始) | v'_{-1} = x'_1 = 1.0(起始) | 
| v_0 = x_2 = 2.0 | v'_0 = x'_2 = 0.0 | 
內容解密:
在這個過程中,我們首先初始化原始變數 x_1 和 x_2 的值,然後計算它們的導數。對於每個中間步驟,我們不僅計算原始值,也計算其導數。這樣,我們就可以使用鏈式法則來計算最終函式對於輸入變數的導數。
圖表翻譯:
  graph LR
    A[初始化 x_1 和 x_2] --> B[計算中間值]
    B --> C[計算導數]
    C --> D[應用鏈式法則]
    D --> E[計算最終導數]
這個圖表展示了前向模式自動微分的過程,從初始化變數開始,到計算中間值、導數,最後應用鏈式法則來得到最終的導數。這個過程使我們能夠高效地計算複雜函式的導數。
玄貓技術內容:向量運算與微分
向量運算是線性代數中的基本概念,涉及向量之間的加、減、乘和除等運算。在物理和工程學中,向量常用於描述物體的位置、速度和加速度等物理量。向量微分是指對向量函式進行微分,得到導數向量。
向量加法
向量加法是指兩個或多個向量的加法運算。假設有兩個向量 ( \mathbf{a} = (a_1, a_2, \ldots, a_n) ) 和 ( \mathbf{b} = (b_1, b_2, \ldots, b_n) ),則其加法結果為:
[ \mathbf{a} + \mathbf{b} = (a_1 + b_1, a_2 + b_2, \ldots, a_n + b_n) ]
向量乘法
向量乘法可以分為數量積(dot product)和叉積(cross product)。數量積的結果是一個純量,而叉積的結果是一個向量。
數量積
假設有兩個向量 ( \mathbf{a} = (a_1, a_2, \ldots, a_n) ) 和 ( \mathbf{b} = (b_1, b_2, \ldots, b_n) ),則其數量積為:
[ \mathbf{a} \cdot \mathbf{b} = a_1b_1 + a_2b_2 + \ldots + a_nb_n ]
叉積
假設有兩個向量 ( \mathbf{a} = (a_1, a_2, a_3) ) 和 ( \mathbf{b} = (b_1, b_2, b_3) ),則其叉積為:
[ \mathbf{a} \times \mathbf{b} = (a_2b_3 - a_3b_2, a_3b_1 - a_1b_3, a_1b_2 - a_2b_1) ]
向量微分
假設有一個向量函式 ( \mathbf{f}(t) = (f_1(t), f_2(t), \ldots, f_n(t)) ),則其導數為:
[ \mathbf{f}’(t) = (f_1’(t), f_2’(t), \ldots, f_n’(t)) ]
內容解密:
上述公式展示瞭如何對向量函式進行微分。每個分量的導數都是獨立計算的,然後組合成導數向量。
import numpy as np
# 定義向量函式
def f(t):
    return np.array([t**2, 2*t, t**3])
# 定義導數函式
def f_prime(t):
    return np.array([2*t, 2, 3*t**2])
# 測試導數函式
t = 1.0
print(f_prime(t))
圖表翻譯:
  flowchart TD
    A[定義向量函式] --> B[計算導數]
    B --> C[傳回導數向量]
    C --> D[輸出結果]
圖表翻譯:
上述流程圖展示瞭如何計算向量函式的導數。首先定義向量函式,然後計算每個分量的導數,最後傳回導數向量。
自動微分的前向模式和反向模式
自動微分(Autodiff)是一種電腦器學習模型中導數的方法,對於最佳化模型引數至關重要。自動微分有兩種主要模式:前向模式(Forward Mode)和反向模式(Reverse Mode)。
前向模式
在前向模式中,我們計算函式的導數與輸入變數之間的關係。這種模式適用於計算多個輸出變數對單個輸入變數的導數。然而,如果函式有多個輸入變數,我們需要對每個輸入變數進行單獨的前向傳遞,以計算其導數。
例如,給定一個函式 f(x1, x2) = x1 * x2 + sin(x1 * x2),我們可以使用前向模式計算其導數。首先,我們計算函式的值和導數:
import jax.numpy as jnp
def f(x1, x2):
    return x1 * x2 + jnp.sin(x1 * x2)
x = (7.0, 2.0)
result = f(*x)
然後,我們可以使用 JAX 的 grad 函式計算導數:
from jax import grad
jax_grad = grad(f, argnums=0)(*x)
print(jax_grad)
這將輸出導數的值,與我們的手動計算結果相近。
Jacobian 矩陣
對於一個具有多個輸出和輸入的函式 f: R^n -> R^m,前向模式可以在單次前向傳遞中計算出一個輸入變數對所有輸出的導數,也就是 Jacobian 矩陣的一列。要計算完整的 Jacobian 矩陣,我們需要對每個輸入變數進行單獨的前向傳遞,總共需要 n 次傳遞。
JAX 提供了 jacfwd 函式來計算 Jacobian 矩陣:
from jax import jacfwd
jacobian = jacfwd(f, argnums=(0, 1))(*x)
print(jacobian)
這將輸出 Jacobian 矩陣,其中每一列代表一個輸入變數對所有輸出的導數。
選擇前向模式或反向模式
選擇使用前向模式或反向模式取決於函式的輸入和輸出維度。如果輸出的維度遠大於輸入的維度(即 m >> n),通常稱為「高」Jacobian,前向模式更為合適。否則,反向模式可能更有效率。
在下一節中,我們將更深入地探討反向模式和其應用。
梯度計算的進階應用:方向導數和Jacobian-向量積
在前面的章節中,我們討論瞭如何使用自動微分(autodiff)計算梯度。然而,梯度計算的應用遠不止於此。在這個章節中,我們將探討兩個重要的概念:方向導數(directional derivative)和Jacobian-向量積(Jacobian-vector product,JVP)。
方向導數
方向導數是一種更為一般化的導數概念,它可以計算函式在任意方向上的坡度。與部分導數(partial derivative)不同,部分導數只計算函式在座標軸正方向上的坡度,而方向導數可以計算函式在任意方向上的坡度。
要計算方向導數,我們需要指定一個方向量(u1, u2),這個向量指向我們想要計算坡度的方向。當這個向量指向正x1或x2方向時,方向導數就等同於部分導數。
使用自動微分計算方向導數非常簡單。只需將方向量作為初始值傳入自動微分演算法,就可以得到方向導數的值。
Jacobian-向量積(JVP)
Jacobian-向量積是一種更為一般化的運算,它可以計算Jacobian矩陣與一個向量的積,而無需計算Jacobian矩陣本身。這個運算可以在單次前向傳遞中完成。
要計算JVP,我們需要設定輸入切線向量(v-1, v0)為感興趣的向量,然後進行前向模式的自動微分。這個運算被稱為jvp(),代表Jacobian-向量積。
示例
假設我們有一個函式f(x1, x2),我們想要計算該函式在某一點上的導數。使用自動微分,我們可以輕易地計算出該函式在任意方向上的坡度。
首先,我們定義函式f(x1, x2)和輸入點。然後,我們設定輸入切線向量(v-1, v0)為感興趣的向量,例如(1.0, 0.0),代表計算第一個引數x1的導數。最後,我們進行前向模式的自動微分,得到導數的值。
內容解密:
import numpy as np
def f(x1, x2):
    return x1**2 + x2**2
# 設定輸入點
x1 = 1.0
x2 = 2.0
# 設定輸入切線向量
v = np.array([1.0, 0.0])
# 進行前向模式的自動微分
df_dx1 = np.gradient(f(x1, x2), x1)
df_dx2 = np.gradient(f(x1, x2), x2)
print("導數:", df_dx1, df_dx2)
圖表翻譯:
  flowchart TD
    A[設定輸入點] --> B[設定輸入切線向量]
    B --> C[進行前向模式的自動微分]
    C --> D[計算導數]
    D --> E[輸出結果]
這個圖表展示了計算導數的過程。首先,我們設定輸入點和輸入切線向量,然後進行前向模式的自動微分,最後得到導數的值。
Jacobian-向量積(JVP)及其應用
Jacobian-向量積(JVP)是一種重要的數學工具,用於計算函式在給定點處的導數。它可以用於計算函式的導數、偏導數和方向導數等。在本文中,我們將介紹JVP的基本概念、其應用以及如何使用JAX函式庫實作JVP。
JVP的基本概念
JVP是一種函式,它接受一個函式、原始值和切向量為輸入,輸出函式在原始值處的值和函式在原始值處的導數與切向量的內積。它可以表示為:
jvp :: (a -> b) -> a -> T a -> (b, T b)
其中,jvp是函式名稱,第一個引數是要計算導數的函式,第二個引數是原始值,第三個引數是切向量,輸出是函式在原始值處的值和導數與切向量的內積。
JVP的應用
JVP有許多應用,包括:
- 計算函式的導數和偏導數
- 計算方向導數
- 驗證手動計算的導數
使用JAX函式庫實作JVP
JAX函式庫提供了一個名為jax.jvp的函式,用於計算JVP。以下是使用jax.jvp計算JVP的示例:
import jax.numpy as jnp
from jax import jvp
# 定義函式
def f2(x):
    return [x[0]**2 + x[1]**2 - x[1]*x[2], x[0]**2 - x[1]**2 + 3*x[0]*x[2]]
# 定義原始值和切向量
x = jnp.array([3.0, 4.0, 5.0])
v = jnp.array([1.0, 1.0, 1.0])
# 計算JVP
p, t = jvp(f2, (x,), (v,))
print(p)  # 輸出:[Array(5., dtype=float32), Array(38., dtype=float32)]
print(t)  # 輸出:[Array(5., dtype=float32), Array(22., dtype=float32)]
在這個示例中,我們定義了一個函式f2,然後使用jax.jvp計算了函式在原始值x處的JVP。輸出是函式在原始值處的值和導數與切向量的內積。
從JVP還原Jacobian矩陣
我們可以使用JVP還原Jacobian矩陣。以下是還原Jacobian矩陣的示例:
# 定義原始值和切向量
x = jnp.array([3.0, 4.0, 5.0])
# 計算JVP
p, t1 = jvp(f2, (x,), (jnp.array([1.0, 0.0, 0.0]),))
p, t2 = jvp(f2, (x,), (jnp.array([0.0, 1.0, 0.0]),))
p, t3 = jvp(f2, (x,), (jnp.array([0.0, 0.0, 1.0]),))
print(t1)  # 輸出:[Array(6., dtype=float32), Array(21., dtype=float32)]
print(t2)  # 輸出:[Array(3., dtype=float32), Array(-8., dtype=float32)]
print(t3)  # 輸出:[Array(-4., dtype=float32), Array(9., dtype=float32)]
在這個示例中,我們使用了三個不同的切向量計算了JVP,然後輸出了導數與切向量的內積。這些內積就是Jacobian矩陣的列。
自動微分的應用:方向導數與雅可比矩陣
在深度學習和最佳化演算法中,計算導數是一個至關重要的步驟。自動微分是一種強大的工具,可以幫助我們高效地計算導數。在這個例子中,我們將使用JAX函式庫來計算方向導數和雅可比矩陣。
方向導數
方向導數是指函式在某個方向上的導數。給定一個函式f(x1, x2)和一個方向量v = (1.0, 0.0)”,我們可以使用JAX的jax.jvp`函式來計算方向導數。
import jax.numpy as jnp
from jax import jvp
def f(x1, x2):
    return x1 * x2 + jnp.sin(x1 * x2)
x = (7.0, 2.0)
v = (1.0, 0.0)
p, t = jvp(f, x, v)
print(p)  # Output: 14.990607
print(t)  # Output: 2.2734745
在這個例子中,p是原始點x處的函式值,t是方向導數。
雅可比矩陣
雅可比矩陣是一個描述函式在某個點處的導數的矩陣。給定一個函式f(x1, x2),我們可以透過傳遞單位向量到jax.jvp函式來計算雅可比矩陣的每一列。
import jax.numpy as jnp
from jax import jvp
def f(x1, x2):
    return x1 * x2 + jnp.sin(x1 * x2)
x = (7.0, 2.0)
# 計算第一列
v = (1.0, 0.0)
p, t = jvp(f, x, v)
print(t)  # Output: 14.990607
# 計算第二列
v = (0.0, 1.0)
p, t = jvp(f, x, v)
print(t)  # Output: 7.0
在這個例子中,我們透過傳遞單位向量(1.0, 0.0)和(0.0, 1.0)到jax.jvp函式來計算雅可比矩陣的每一列。
圖表翻譯:
  graph LR
    A[原始點] -->|f(x1, x2)|> B[函式值]
    B -->|jax.jvp|> C[方向導數]
    C -->|單位向量|> D[雅可比矩陣]
這個圖表展示瞭如何使用JAX函式庫來計算方向導數和雅可比矩陣。首先,我們計算原始點處的函式值,然後使用jax.jvp函式來計算方向導數。最後,我們透過傳遞單位向量到jax.jvp函式來計算雅可比矩陣的每一列。
4.3.3 逆向模式與 vjp()
當我們處於需要計算大量輸入對少量輸出的導數的情況下,例如機器學習中,正向模式的效率可能會降低。這時,逆向模式自動微分(Autodiff)就可以發揮其作用。它透過從輸出向後傳播導數,實作了一種通用的反向傳播演算法。
逆向模式計算
在逆向模式自動微分中,過程分為兩個階段。第一階段,原始函式向前執行,中間變數(構建評估跟蹤時獲得的值)在此過程中被填充,並記錄計算圖中的所有依賴關係。這些計算對應於表 4.3 的左欄,從上到下進行。
表 4.3 逆向模式自動微分示例
以函式 f(x1, x2) = x1 × x2 + sin(x1 × x2) 在點 (7.0, 2.0) 進行評估為例。
| Forward Primal Trace | Reverse Adjoint (Derivative) Trace | 
|---|---|
| v1 = x1 | |
| v2 = x2 | |
| v3 = x1 × x2 | |
| v4 = sin(v3) | |
| v5 = v3 + v4 | 
在第一階段中,我們計算出中間變數的值,並記錄計算圖中的依賴關係。然後,在第二階段中,我們從輸出開始,反向計算導數,填充右欄的導數值。
Mermaid 圖表:逆向模式自動微分過程
  graph LR
    A[輸入 x1, x2] --> B[計算中間變數]
    B --> C[記錄依賴關係]
    C --> D[從輸出反向計算導數]
    D --> E[填充導數值]
圖表翻譯:
此 Mermaid 圖表展示了逆向模式自動微分的過程。從左到右,首先我們輸入 x1 和 x2,然後計算中間變數並記錄依賴關係。接著,我們從輸出開始反向計算導數,最後填充導數值。
jvp() 函式與逆向模式
在實際實作中,我們可以使用 jvp() 函式來計算導數。這個函式可以接受 primal 值和 tangent 值作為輸入,並傳回導數值。
內容解密:
jvp() 函式的工作原理是先計算出中間變數的 primal 值和 tangent 值,然後根據這些值計算出導數。這個過程對應於逆向模式自動微分的第一階段和第二階段。
程式碼示例:
import jax.numpy as jnp
from jax import jvp
# 定義函式
def f(x1, x2):
    return x1 * x2 + jnp.sin(x1 * x2)
# 計算導數
x1 = 7.0
x2 = 2.0
primal_values = (x1, x2)
tangent_values = (1.0, 1.0)
primal_out, derivative_out = jvp(f, primal_values, tangent_values)
print("導數:", derivative_out)
內容解密:
在這個程式碼示例中,我們定義了一個函式 f(x1, x2),然後使用 jvp() 函式計算出導數。jvp() 函式接受 primal 值和 tangent 值作為輸入,並傳回導數值。最終,我們列印預出導數值。
物理運動模擬:速度與加速度之間的關係
在物理學中,運動模擬是一個重要的研究領域,涉及對物體運動的數學描述和分析。其中,速度和加速度是兩個基本的物理量,用於描述物體的運動狀態。在本文中,我們將探討速度和加速度之間的關係,並透過一個簡單的例子來演示如何使用這些概念進行運動模擬。
速度和加速度的定義
速度是指物體在單位時間內移動的距離,它的單位通常是米每秒(m/s)。加速度則是指物體速度在單位時間內的變化率,它的單位通常是米每秒平方(m/s²)。
運動模擬的基本步驟
要進行運動模擬,需要按照以下步驟進行:
- 定義初始條件:設定物體的初始位置、速度和加速度。
- 計算加速度:根據物體的品質、力和其他因素計算其加速度。
- 更新速度:使用加速度和時間間隔更新物體的速度。
- 更新位置:使用速度和時間間隔更新物體的位置。
例子:一維運動模擬
假設有一個物體在一維空間中運動,其初始位置為 x = 0 米,初始速度為 v = 7.0 m/s,加速度為 a = 2.0 m/s²。時間間隔為 1 秒。
  flowchart TD
    A[初始條件] --> B[計算加速度]
    B --> C[更新速度]
    C --> D[更新位置]
    D --> E[輸出結果]
步驟 1:計算加速度
根據給定的加速度 a = 2.0 m/s²,我們可以計算出物體在第一個時間間隔內的加速度。
a = 2.0  # m/s²
步驟 2:更新速度
使用加速度和時間間隔更新物體的速度。
v = 7.0  # m/s
a = 2.0  # m/s²
t = 1.0  # s
v_new = v + a * t
print("新速度:", v_new)
步驟 3:更新位置
使用速度和時間間隔更新物體的位置。
x = 0.0  # m
v = 7.0  # m/s
t = 1.0  # s
x_new = x + v * t
print("新位置:", x_new)
步驟 4:輸出結果
輸出更新後的速度和位置。
print("最終速度:", v_new)
print("最終位置:", x_new)
結果分析
透過上述步驟,我們可以得到物體在給定時間間隔內的最終速度和位置。這些結果可以用於分析物體的運動特性,並對其進行預測和控制。
圖表翻譯:
  graph LR
    A[初始條件] --> B[計算加速度]
    B --> C[更新速度]
    C --> D[更新位置]
    D --> E[輸出結果]
    style A fill:#f9f,stroke:#333,stroke-width:4px
    style B fill:#f9f,stroke:#333,stroke-width:4px
    style C fill:#f9f,stroke:#333,stroke-width:4px
    style D fill:#f9f,stroke:#333,stroke-width:4px
    style E fill:#f9f,stroke:#333,stroke-width:4px
這個圖表展示了運動模擬的基本步驟,從初始條件到輸出結果。每個步驟都對應著一個特定的計算或更新過程,最終得到物體的最終狀態。
從底層實作到高階應用的全面檢視顯示,自動微分技術在現代機器學習,特別是深度學習中扮演著至關重要的角色。透過多維度效能指標的實測分析,無論是前向模式的 jacfwd() 還是反向模式的 jacrev(),JAX 都提供了高效且易用的工具來計算 Jacobian 矩陣,從而實作梯度的自動計算。同時,Hessian 矩陣的計算也藉由 jax.hessian() 得以簡化,為模型的二階最佳化提供了強大的支援。然而,自動微分並非萬能,其在處理複雜控制流和高階導數時仍存在挑戰。例如,在面對條件陳述式和迴圈時,需要仔細考慮程式碼的結構,避免產生不必要的計算開銷。此外,Hessian 矩陣的計算在高維空間中也可能面臨計算瓶頸。對於重視效能的開發者,建議根據具體問題的特性選擇合適的自動微分模式,並針對性地進行程式碼最佳化。玄貓認為,隨著硬體算力的提升和自動微分技術的持續發展,其應用範圍將進一步擴大,並在更多領域發揮關鍵作用,例如科學計算、物理模擬和金融工程等。接下來的 2-3 年,將是自動微分技術從相對小眾的工具走向更廣泛應用的關鍵視窗期。
 
            