JAX 提供了高效的自動微分和向量化功能,簡化了神經網路的訓練流程。首先,我們需要對資料進行預處理,例如將資料轉換為浮點數並歸一化到特定範圍,以提高模型的穩定性和準確性。接著,我們可以定義神經網路的架構,包括輸入層、隱藏層和輸出層,並使用 JAX 的 stax 模組初始化模型引數。為了提高訓練效率,我們可以使用 vmap 函式自動批次處理資料。在訓練過程中,我們使用梯度下降法來最佳化模型引數,並使用學習率衰減策略來調整學習率,以避免模型陷入區域性最優解。

在使用 JAX 訓練神經網路時,我們通常需要定義兩個函式:一個用於初始化模型引數,另一個用於執行前向傳播。前向傳播函式接受模型引數和輸入資料作為輸入,並傳回模型的輸出。為了計算梯度,我們可以使用 JAX 的 grad 函式。在更新模型引數時,我們可以使用指數衰減的學習率策略,以根據 epoch 數調整學習率。最後,我們可以評估模型在訓練集和測試集上的準確率,並監控訓練時間以最佳化訓練流程。

神經網路基礎:資料預處理與神經網路結構

在進行深度學習任務時,資料預處理是一個至關重要的步驟。為了將整數值轉換為浮點數,並使其落在[0, 1]範圍內,我們會將整數值除以整數值的最大可能值。這樣做的好處是可以使所有特徵值維持在相同的尺度上,從而提高模型的穩定性和準確性。

首先,我們需要將整數值轉換為float32浮點數,以便進行浮點數運算。然後,我們將這個浮點數除以最大整數值,以得到一個介於0和1之間的值。這個過程可以確保所有的資料都處於相同的尺度上,避免了不同尺度之間的差異對模型的影響。

import jax.numpy as jnp

def preprocess_data(data, max_value):
    # 將整數值轉換為float32浮點數
    data = jnp.array(data, dtype=jnp.float32)
    
    # 將浮點數除以最大整數值
    data = data / max_value
    
    return data

接下來,我們需要將預處理好的資料分割為訓練集和測試集,並使用批次生成器(batch generator)來產生批次資料。每個批次包含32張影像,並預先載入下一個批次,以加速訓練過程。

import jax
from jax.experimental import jax2tf

# 定義批次大小
batch_size = 32

# 將資料分割為訓練集和測試集
train_data, test_data =...

# 建立批次生成器
train_batches = jax2tf.convert_batch(train_data, batch_size)
test_batches = jax2tf.convert_batch(test_data, batch_size)

最後,我們來設計一個簡單的神經網路模型。這個模型包含兩層:一層輸入層,接收784個畫素點的影像;一層隱藏層,包含512個神經元,用於提取影像特徵。

import flax
from flax import linen as nn

class SimpleNeuralNetwork(nn.Module):
    @nn.compact
    def __call__(self, x):
        # 輸入層
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        
        # 輸出層
        x = nn.Dense(10)(x)
        x = nn.softmax(x)
        
        return x

這個簡單的神經網路模型可以用於影像分類別任務,例如MNIST資料集。透過調整模型的引數和結構,可以進一步提高模型的準確性和泛化能力。

深度學習神經網路架構

在深度學習中,神經網路是一種非常重要的模型結構。下面,我們將探討一個簡單的神經網路架構,並使用JAX進行實作。

神經網路結構

首先,我們需要了解神經網路的基本結構。一個典型的神經網路由多層神經元組成,每層神經元都會接收前一層的輸出作為輸入。最常見的神經網路結構包括輸入層、隱藏層和輸出層。

輸入層

輸入層負責接收外部輸入資料。在影像分類別任務中,輸入層通常會接收一張影像的畫素資料。假設我們有一張28x28的影像,則輸入層需要有784個神經元(28x28=784)。

隱藏層

隱藏層是神經網路中的一個或多個全連線(或稀疏連線)的層,負責提取輸入資料中的特徵。隱藏層的神經元數目可以根據任務的複雜度進行調整。在本例中,我們使用了一個具有512個神經元的全連線隱藏層。

輸出層

輸出層負責產生最終的預測結果。在影像分類別任務中,輸出層通常會有與類別數量相同的神經元數目。假設我們有10個類別,則輸出層需要有10個神經元。

JAX實作

在JAX中,建立一個神經網路模型需要注意以下幾點:

  1. 隨機數生成器:JAX需要外部提供隨機數生成器的狀態,以保證模型引數初始化的可重現性。
  2. 前向傳遞函式:JAX要求前向傳遞函式必須是無狀態和純函式,這意味著模型引數需要作為輸入資料傳遞給前向傳遞函式。

以下是使用JAX實作簡單神經網路模型的步驟:

  1. 初始化模型引數:首先,需要初始化模型引數,包括權重和偏差。
  2. 定義前向傳遞函式:定義一個前向傳遞函式,該函式接受模型引數和輸入資料作為輸入,並傳回模型的輸出。
  3. 應用模型:使用前向傳遞函式對輸入資料進行預測。

Mermaid圖表

  graph LR
    A[輸入層] --> B[隱藏層]
    B --> C[輸出層]
    C --> D[預測結果]

圖表翻譯:

上述Mermaid圖表描述了神經網路的基本結構。輸入層接收外部輸入資料,隱藏層提取輸入資料中的特徵,輸出層產生最終的預測結果。

內容解密:

在本文中,我們探討了神經網路的基本結構和使用JAX進行實作的步驟。透過瞭解神經網路的工作原理和JAX的特點,可以更好地應用深度學習技術解決實際問題。

程式碼實作示例:

import jax
import jax.numpy as jnp

# 初始化模型引數
params = {
    'weights': jnp.random.normal(0, 1, (784, 512)),
    'biases': jnp.zeros((512,)),
    'output_weights': jnp.random.normal(0, 1, (512, 10)),
    'output_biases': jnp.zeros((10,))
}

# 定義前向傳遞函式
def predict(params, image):
    # 輸入層
    input_layer = image
    
    # 隱藏層
    hidden_layer = jnp.dot(input_layer, params['weights']) + params['biases']
    hidden_layer = jax.nn.relu(hidden_layer)
    
    # 輸出層
    output_layer = jnp.dot(hidden_layer, params['output_weights']) + params['output_biases']
    output_layer = jax.nn.softmax(output_layer)
    
    return output_layer

# 應用模型
image = jnp.random.normal(0, 1, (784,))
output = predict(params, image)
print(output)

內容解密:

在上述程式碼中,我們定義了一個簡單的神經網路模型,包括輸入層、隱藏層和輸出層。使用JAX的jax.nn.relujax.nn.softmax函式實作了隱藏層和輸出層的啟用函式。最終,應用模型對隨機生成的輸入資料進行預測,並列印預出預測結果。

神經網路初始化

在訓練神經網路之前,我們需要初始化所有的引數,包括權重(w)和偏差(b)。這些引數將使用隨機數字進行初始化,以確保神經網路的多樣性和避免過度適應。

初始化神經網路引數

以下是使用JAX初始化神經網路引數的過程:

  1. 定義層大小:首先,我們需要定義神經網路的層大小。這通常包括輸入層、隱藏層和輸出層的大小。

  2. 初始化引數:然後,我們需要初始化每一層的權重和偏差。這可以使用隨機數字進行,通常使用高斯分佈或均勻分佈。

  3. 設定隨機種子:為了確保結果的可重複性,我們需要設定一個隨機種子。這可以使用JAX的random.PRNGKey函式進行。

實作初始化功能

以下是使用JAX實作神經網路初始化功能的範例:

import jax
import jax.numpy as jnp
from jax import random

# 定義層大小
LAYER_SIZES = [28*28, 512, 10]

# 定義引數縮放因子
PARAM_SCALE = 0.01

def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
    """初始化所有層的引數"""
    
    def random_layer_params(m, n, key, scale=1e-2):
        """隨機初始化權重和偏差"""
        
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
    
    # 初始化引數
    params = []
    for i in range(len(sizes) - 1):
        m, n = sizes[i], sizes[i + 1]
        w, b = random_layer_params(m, n, key, scale)
        params.append((w, b))
        key = random.fold_in(key, i)
    
    return params

# 初始化引數
params = init_network_params(LAYER_SIZES)

# 列印初始化引數
for i, (w, b) in enumerate(params):
    print(f"Layer {i+1} weights shape: {w.shape}, biases shape: {b.shape}")

圖示化神經網路流程

以下是神經網路初始化和應用過程的Mermaid圖表:

  flowchart TD
    A[初始化引數] --> B[定義層大小]
    B --> C[設定隨機種子]
    C --> D[初始化權重和偏差]
    D --> E[設定引數縮放因子]
    E --> F[初始化神經網路]
    F --> G[訓練神經網路]

圖表翻譯:

此圖表展示了神經網路初始化和應用的過程。首先,我們需要定義層大小和設定隨機種子。然後,我們可以初始化權重和偏差,並設定引數縮放因子。最後,我們可以初始化神經網路並進行訓練。

神經網路引數初始化與前向傳遞

在深度學習框架中,神經網路的引數初始化和前向傳遞是兩個非常重要的步驟。以下,我們將探討如何使用JAX進行這些操作。

引數初始化

首先,我們需要初始化神經網路的引數。這通常涉及為每個層的權重(w)和偏差(b)分配隨機值。在JAX中,這可以透過使用random模組來完成。然而,JAX的隨機數生成器與NumPy不同,因為它需要純函式,而NumPy的隨機數生成器不是純函式。

import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
from jax.experimental import stax

# 定義層大小
LAYER_SIZES = [784, 256, 10]

# 定義引數縮放因子
PARAM_SCALE = 0.1

# 生成隨機鍵
key = jax.random.PRNGKey(0)

# 將鍵分割為多個層所需的隨機鍵
keys = jax.random.split(key, len(LAYER_SIZES) - 1)

# 初始化神經網路引數
params = []
for m, n, k in zip(LAYER_SIZES[:-1], LAYER_SIZES[1:], keys):
    # 生成隨機權重和偏差
    w = jax.random.normal(k, (m, n)) * PARAM_SCALE
    b = jax.random.normal(k, (n,)) * PARAM_SCALE
    params.append((w, b))

前向傳遞

接下來,我們需要定義神經網路的前向傳遞函式。這個函式將接受輸入x和引數params,並傳回輸出。

import jax.nn as nn

def predict(params, x):
    # 逐層進行前向傳遞
    for w, b in params:
        x = nn.swish(jnp.dot(x, w) + b)
    return x

在這個例子中,我們使用了Swish啟用函式,它是一種流行的啟用函式,能夠替代ReLU和Sigmoid等傳統啟用函式。

啟動神經網路

現在,我們可以啟動神經網路了。首先,我們需要生成一個隨機輸入。

# 生成隨機輸入
x = jax.random.normal(key, (1, LAYER_SIZES[0]))

接下來,我們可以呼叫predict函式來進行前向傳遞。

# 進行前向傳遞
output = predict(params, x)

這樣,我們就完成了神經網路的引數初始化和前向傳遞。

圖表翻譯:

  graph LR
    A[引數初始化] --> B[前向傳遞]
    B --> C[輸出]
    C --> D[結果]
    style A fill:#f9f,stroke:#333,stroke-width:2px
    style B fill:#f9f,stroke:#333,stroke-width:2px
    style C fill:#f9f,stroke:#333,stroke-width:2px
    style D fill:#f9f,stroke:#333,stroke-width:2px

內容解密:

以上程式碼展示瞭如何使用JAX進行神經網路的引數初始化和前向傳遞。首先,我們定義了層大小和引數縮放因子。然後,我們生成了隨機鍵和分割了鍵以獲得多個層所需的隨機鍵。接下來,我們初始化了神經網路引數,包括權重和偏差。然後,我們定義了前向傳遞函式,該函式接受輸入和引數,並傳回輸出。最後,我們啟動了神經網路,生成了隨機輸入,並進行了前向傳遞。

第二章:使用JAX建立您的第一個程式

啟動函式

啟動函式是深度學習世界中的基本元件。它們在神經網路計算中提供非線性。沒有非線性,多層前饋神經網路就會等同於單層網路。由於簡單的數學運算,線性組合的線性組合仍然是線性組合,這就是單個神經元所做的工作。我們知道單個神經元解決複雜分類別問題的能力僅限於線性可分割任務(您可能聽說過著名的XOR問題,無法使用線性分類別器解決)。因此,啟動函式確保神經網路的表達能力並防止簡化為更簡單的模型。

許多不同的啟動函式已被發現。該領域始於簡單且易於理解的函式,如sigmoid或雙曲正切函式。它們平滑且具有數學家喜愛的屬性,例如在每個點都可微分。然後,出現了一種新的函式,即修正線性單元(ReLU)。ReLU不平滑,其導數在點x = 0處不存在。然而,實踐者發現神經網路使用ReLU學習速度更快。

然後,發現了許多其他啟動函式,其中一些是透過實驗發現的,另一些是透過玄貓設計的。其中流行的設計函式包括高斯誤差線性單元(GELU)。

深度學習中最新的趨勢之一是自動發現。它通常被稱為神經架構搜尋(NAS)。該方法的想法是設計一個豐富但可管理的搜尋空間,以描述感興趣的元件。它可以是啟動函式、層型別、最佳化器更新方程等。然後,我們執行一個自動程式來智慧地搜尋此空間。不同的方法也可以使用強化學習、進化計算甚至梯度下降。

JAX中的簡單神經網路

NAS是一個令人興奮的故事,我相信JAX豐富的表達能力可以對該領域做出重大貢獻。也許我們的一些讀者會在深度學習中取得令人興奮的進展!

以下,我們開發了一個前向傳遞函式,通常被稱為預測函式。它接受一張影像進行分類別,並執行所有前向傳遞計算以在輸出層神經元上產生啟用。啟用最高的神經元確定輸入影像的類別(因此,如果最高啟用是在神經元5上,那麼根據最直接的方法,神經網路檢測到輸入影像包含手寫數字5)。

清單2.5:神經網路的前向傳遞

import jax.numpy as jnp
from jax.nn import swish

def predict(params, image):
    """每個例子的預測函式"""
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = swish(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits

注意我們如何在這裡傳遞引數列表。它與PyTorch或TensorFlow中的典型程式不同,在那裡這些引數通常隱藏在類別內,並且函式使用類別變數來存取它們。

請關注神經網路計算如何結構化。在JAX中,您通常為神經網路有兩個函式:一個用於初始化引數,另一個用於將神經網路應用於某些輸入資料。第一個函式傳回引數作為某些資料結構(在這裡是一個陣列列表;稍後,它將是一個特殊的資料結構,稱為pytree)。第二個函式接受引數和資料並傳回將神經網路應用於資料的結果。這種模式將在未來出現很多次,即使在高階神經網路框架中也是如此。

就是這樣。我們可以使用新的函式進行每個例子的預測。在清單2.6中,我們生成了一張與我們的資料集相同大小且具有隨機畫素值的影像。

啟動函式匯入

from jax.nn import swish

注意我們匯入了swish啟動函式。

注意我們傳遞了兩個引數:權重和偏差項。

影像預測與神經網路引數

在神經網路中,影像預測是一個重要的應用。為了進行預測,我們需要將影像和網路引數傳遞給函式。首先,我們初始化啟用值(activations)以輸入影像的畫素值。

接下來,我們循序遍歷神經網路的每一層,從第一層到倒數第二層。對於每一層,我們都會更新啟用值,使用該層的輸出作為下一層的輸入。然而,對於最後一層,我們不會套用啟用函式。

預測過程

在進行預測時,我們可以使用隨機生成的影像或實際資料集中的影像。由於我們的神經網路尚未經過訓練,因此預測結果可能不佳。這裡,我們關注的是預測的輸出形狀,發現預測傳回一個包含10個啟用值的元組,對應於10個類別。

程式碼實作

import jax.numpy as jnp

# 初始化隨機影像
random_flattened_image = jnp.random.rand(784)

# 匯入神經網路模型
model =...

# 進行預測
prediction = model.predict(random_flattened_image)

print(prediction.shape)  # 輸出:(10,)

內容解密:

在上述程式碼中,我們首先匯入必要的函式庫,包括jax.numpy,然後初始化一個隨機的扁平化影像。接下來,我們匯入神經網路模型,並使用它來進行預測。最後,我們印出預測結果的形狀,以確認它是否符合我們的期望。

圖表翻譯:

  graph LR
    A[影像初始化] --> B[神經網路模型]
    B --> C[預測]
    C --> D[輸出形狀]

圖表翻譯:

上述流程圖描述了從影像初始化到預測輸出形狀的整個過程。首先,我們初始化影像,然後將其傳遞給神經網路模型,以進行預測。最後,預測結果的形狀被輸出,以確認是否符合預期。

使用vmap實作自動批次處理

在之前的步驟中,我們設計了一個預測函式(predict),它只能處理單一影像。然而,在實際應用中,我們通常需要同時處理多張影像,也就是批次處理。為了實作這一點,我們可以使用JAX中的vmap函式來自動向量化計算,使其能夠支援批次處理。

步驟1:生成隨機批次影像

首先,我們生成一個批次的隨機影像。這個批次包含32張影像,每張影像的大小為28x28畫素,且只有單一顏色通道。

import jax
import jax.numpy as jnp
from jax import random

# 生成隨機批次影像
random_flattened_images = random.normal(random.PRNGKey(1), (32, 28*28*1))

步驟2:嘗試使用預測函式處理批次影像

接下來,我們嘗試直接將這個批次影像傳遞給預測函式,看看會發生什麼。

try:
    preds = predict(params, random_flattened_images)
except TypeError as e:
    print(e)

執行這段程式碼後,會出現一個TypeError,提示我們dot_general要求收縮維度具有相同的形狀,但得到的形狀卻是(784,)(32,)。這表明預測函式不支援批次處理。

步驟3:使用vmap實作自動批次處理

為瞭解決這個問題,我們可以使用vmap來自動向量化預測函式,使其能夠支援批次處理。

# 使用vmap實作自動批次處理
batched_predict = jax.vmap(predict, in_axes=(None, 0))

# 測試批次預測
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

在這裡,jax.vmap函式被用來向量化predict函式。in_axes引數指定了哪些維度應該被向量化。在這個例子中,(None, 0)表示只有第二個引數(即批次影像)應該被向量化。

執行這段程式碼後,會輸出批次預測的結果形狀,從而證明我們成功實作了自動批次處理。

圖表翻譯:

  flowchart TD
    A[生成隨機批次影像] --> B[嘗試使用預測函式]
    B --> C[出現TypeError]
    C --> D[使用vmap實作自動批次處理]
    D --> E[測試批次預測]
    E --> F[輸出批次預測結果]

內容解密:

在上述程式碼中,我們首先生成了一個隨機批次影像。然後,我們嘗試直接將這個批次影像傳遞給預測函式,但由於預測函式不支援批次處理,因此出現了TypeError。為瞭解決這個問題,我們使用vmap來自動向量化預測函式,使其能夠支援批次處理。最後,我們測試了批次預測,並輸出了批次預測的結果形狀。

使用JAX進行自動向量化

在深度學習中,多維陣列(Tensors)是用於溝通神經網路和其層之間的主要資料結構。它們也被稱為張量。在數學或物理學中,張量具有更嚴格和更複雜的含義,但在深度學習中,它們只是多維陣列的同義詞。如果您曾經使用過NumPy,您幾乎知道了所有您需要知道的東西。

張量或多維陣列有特定的形式。矩陣是一個具有兩個維度的張量(或rank-2張量),向量是一個具有一個維度的張量(rank-1張量),而標量(或只是一個數字)是一個具有零個維度的張量(rank-0張量)。因此,張量是對標量、向量和矩陣的推廣,以任意數量的維度(rank)表示。

例如,您的損失函式值是一個標量(只有一個數字)。一個分類別神經網路的輸出為單個輸入的類別機率陣列是一個具有k個維度(k是類別數量)和一個維度(不要混淆大小和rank)的向量。一個批次的預測陣列是具有k × m大小(k是類別數量,m是批次大小)的矩陣。一個RGB影像是具有三個維度(寬度、高度和顏色通道)的rank-3張量。一個批次的RGB影像是具有四個維度(新的維度是批次維度)的rank-4張量。一個影片幀流也可以被視為具有四個維度(時間是新的維度)的rank-4張量。一個批次的影片是具有五個維度的rank-5張量,依此類別推。在深度學習中,您通常使用具有不超過四或五個維度的張量。

現在,我們來看看如何解決批次處理問題。首先,有一個簡單的解決方案:我們可以寫一個迴圈來分解批次為單個影像,並順序處理它們。這將有效,但它將非常低效,因為大多數硬體可以在單位時間內執行更多計算。如果這樣,硬體將被嚴重低估。如果您已經使用過MATLAB、NumPy或類別似的工具,您就知道向量化的好處了。它將是一個高效的解決方案。

因此,第二個選擇是重寫和手動向量化predict()函式,以便它可以接受批次資料作為輸入。這通常意味著您的輸入張量將具有額外的批次維度,您需要重寫計算以使用它。對於簡單的計算,這很直接,但對於複雜的函式,這可能很複雜。您通常在使用NumPy或TensorFlow/PyTorch低階別原始碼編寫神經網路時這樣做。

第三個選擇是自動向量化。JAX提供了vmap()轉換,可以將單元素函式轉換為批次函式。這是您在JAX中最常用的方法,因為它是最方便的方法,並且產生了優異的效能。我相信您會喜歡它。

以下是使用vmap()自動向量化predict()函式的示例:

from jax import vmap
batched_predict = vmap(predict, in_axes=(None, 0))

in_axes引數控制哪些輸入陣列軸要對映(或向量化)。其長度必須等於函式的位置引數數量。None表示我們不需要對映任何軸,在我們的示例中,它對應於predict()函式的第一個引數params。這個引數對於任何前向傳遞都保持不變,因此我們不需要對其進行批次處理(至少目前尚未)。

現在,我們可以將修改過的函式應用於批次並產生正確的輸出:

batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

輸出:

(32, 10)

這表明batched_predict()函式已經正確地對批次進行了向量化,並產生了32個影像的預測結果,每個影像有10個類別機率。

使用 vmap() 進行批次化處理

vmap() 是 JAX 中的一個強大工具,允許我們將單個專案的函式轉換為批次處理函式。這意味著我們可以輕鬆地將原本只適用於單個專案的函式擴充套件到批次處理。

建立批次化函式

首先,我們需要定義一個單個專案的函式。假設我們有一個名為 predict() 的函式,它接受一個單個影像作為輸入並傳回其類別啟用值。現在,我們想要建立一個批次化版本的 predict() 函式,名為 batched_predict(),它可以接受一個批次影像作為輸入並傳回每個影像的類別啟用值。

import jax
import jax.numpy as jnp

# 定義單個專案的函式
def predict(image):
    # 對影像進行預測
    return jnp.array([0.1, 0.2, 0.3, 0.4])  # 例如

# 使用 vmap() 建立批次化函式
batched_predict = jax.vmap(predict)

# 測試批次化函式
batch_images = jnp.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
output = batched_predict(batch_images)
print(output)

自動微分:無需瞭解導數即可計算梯度

自動微分(Autodiff)是 JAX 中的一個強大功能,允許我們在不瞭解導數的情況下計算梯度。這對於神經網路訓練尤其有用,因為我們通常需要計算損失函式對於模型引數的梯度。

神經網路訓練

要訓練一個神經網路,我們需要定義一個損失函式、選擇一個最佳化器並執行梯度下降迭代。JAX 提供了自動微分功能,使得計算梯度變得容易。

# 定義損失函式
def loss(params, images, labels):
    # 對影像進行預測
    predictions = batched_predict(images)
    # 計算損失
    return jnp.mean((predictions - labels) ** 2)

# 初始化引數
params = jnp.array([0.1, 0.2, 0.3, 0.4])

# 定義最佳化器
def optimizer(params, gradient):
    # 更新引數
    return params - 0.01 * gradient

# 執行梯度下降迭代
for i in range(100):
    # 計算梯度
    gradient = jax.grad(loss)(params, batch_images, jnp.array([0, 1]))
    # 更新引數
    params = optimizer(params, gradient)
    # 印出損失
    print(loss(params, batch_images, jnp.array([0, 1])))

這樣,我們就完成了使用 JAX 進行神經網路訓練的基本過程。JAX 的自動微分功能使得計算梯度變得容易,而 vmap() 函式則允許我們輕鬆地將單個專案的函式擴充套件到批次處理。

第二章:您的第一個JAX程式

在上一節中,我們簡要介紹了JAX的基本概念和優點。現在,讓我們一起建立您的第一個JAX程式。

Loss曲線和梯度下降

在機器學習中,Loss曲線(也稱為損失函式)是用於衡量模型預測值與實際值之間差異的指標。梯度下降是一種常用的最佳化演算法,透過遞迴地更新模型引數以最小化Loss值。

Loss曲線的視覺化

想象一下,一個高維度空間中的Loss曲線,其形狀就像是一個複雜的山谷和山峰的景觀。每一個點代表著模型引數的一種組合,而Loss值則代表著該點的高度。我們的目標是找到這個空間中Loss值最小的點,也就是所謂的全域最小值。

import numpy as np
import matplotlib.pyplot as plt

# 定義Loss函式
def loss(x):
    return x**2 + 10*np.sin(x)

# 生成x值
x = np.linspace(-10, 10, 400)

# 計算Loss值
y = loss(x)

# 繪製Loss曲線
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.show()

梯度下降的工作原理

梯度下降演算法從一個隨機的起始點開始,然後遞迴地更新模型引數以最小化Loss值。在每一步中,演算法計算當前點的梯度(即Loss函式對於模型引數的偏導數),然後沿著梯度的反方向更新模型引數。

import numpy as np

# 定義Loss函式和其梯度
def loss(x):
    return x**2 + 10*np.sin(x)

def gradient(x):
    return 2*x + 10*np.cos(x)

# 起始點
x = 5.0

# 學習率
lr = 0.1

# 更新步數
steps = 100

# 更新模型引數
for i in range(steps):
    # 計算梯度
    grad = gradient(x)
    
    # 更新模型引數
    x -= lr * grad
    
    # 輸出當前Loss值
    print(f'Step {i+1}, Loss: {loss(x)}')
內容解密:
  • Loss曲線是用於衡量模型預測值與實際值之間差異的指標。
  • 梯度下降是一種常用的最佳化演算法,透過遞迴地更新模型引數以最小化Loss值。
  • JAX提供了一種高效的方式來計算梯度和更新模型引數。

圖表翻譯:

  • Loss曲線圖顯示了Loss值隨著模型引數變化的趨勢。
  • 梯度下降演算法的工作原理可以透過圖表來視覺化,展示出模型引數如何被更新以最小化Loss值。
  flowchart TD
    A[起始點] --> B[計算梯度]
    B --> C[更新模型引數]
    C --> D[計算Loss值]
    D --> E[輸出結果]

圖表翻譯:

  • 圖表顯示了梯度下降演算法的工作流程,從起始點開始,計算梯度,更新模型引數,計算Loss值,最終輸出結果。

梯度下降法的實作

梯度下降法是一種常用的最佳化演算法,用於找到模型引數的最佳值。其基本思想是透過迭代更新模型引數,沿著損失函式的梯度方向下降,直到找到區域性最小值。

損失函式

損失函式是用於評估模型預測結果與真實標籤之間的差異。常用的損失函式包括均方差、交叉熵等。在本例中,我們使用的是多類別分類別的交叉熵損失函式。

import jax.numpy as jnp
from jax.nn import logsumexp

def loss(params, images, targets):
    """交叉熵損失函式"""
    logits = batched_predict(params, images)
    log_preds = logits - logsumexp(logits)
    return -jnp.mean(targets * log_preds)

取得梯度

要實作梯度下降法,需要計算損失函式對模型引數的梯度。這裡使用了JAX的grad轉換來計算梯度。

from jax import grad

# 定義損失函式
def loss(params, images, targets):
   ...

# 取得梯度
grad_loss = grad(loss, argnums=0)

更新模型引數

有了梯度後,可以更新模型引數了。梯度下降法的更新規則是:

params -= learning_rate * grad_loss(params, images, targets)

其中,learning_rate是學習率,控制更新步長。

實作梯度下降法

現在,可以實作梯度下降法了。以下是完整的程式碼:

import jax.numpy as jnp
from jax.nn import logsumexp
from jax import grad

def loss(params, images, targets):
    """交叉熵損失函式"""
    logits = batched_predict(params, images)
    log_preds = logits - logsumexp(logits)
    return -jnp.mean(targets * log_preds)

def gradient_descent(params, images, targets, learning_rate):
    """梯度下降法"""
    grad_loss = grad(loss, argnums=0)
    params -= learning_rate * grad_loss(params, images, targets)
    return params

# 初始化模型引數
params =...

# 載入資料
images =...
targets =...

# 設定學習率
learning_rate = 0.01

# 執行梯度下降法
for i in range(100):
    params = gradient_descent(params, images, targets, learning_rate)
    print(f"第 {i+1} 次迭代,損失值:{loss(params, images, targets)}")

這樣,就完成了梯度下降法的實作。注意,這裡的程式碼是簡化版本,實際上可能需要考慮更多因素,如正則化、批次歸一化等。

2.6.3 梯度更新步驟

在 JAX 中,梯度更新與其他框架如 TensorFlow 和 PyTorch 有所不同。後兩者通常在前向傳遞後計算梯度,並追蹤所有對感興趣的張量進行的操作。JAX 則採用了一種不同的方法:它轉換您的函式並生成另一個計算梯度的函式。然後,您可以透過將權重、資料和其他引數傳入這個新函式來計算梯度。

在這裡,我們計算梯度並更新所有引數,使其朝著梯度相反的方向移動(這就是為什麼權重更新公式中有負號)。所有梯度都會根據學習率引數進行縮放,而學習率則取決於epoch數(一個epoch是一次完整的訓練集遍歷)。我們實作了一個指數衰減的學習率,這意味著在後面的epoch中,學習率將低於早期的學習率。

實作梯度更新步驟

from jax import grad

INIT_LR = 1.0  # 初始學習率
DECAY_RATE = 0.95  # 衰減率
DECAY_STEPS = 5  # 衰減步驟

def update(params, x, y, epoch_number):
    """
    更新模型引數。
    
    引數:
    - params: 模型引數(權重和偏差)
    - x: 輸入資料
    - y: 真實標籤
    - epoch_number: 當前epoch數
    
    傳回:
    - 更新後的模型引數
    """
    # 計算梯度
    grads = grad(loss)(params, x, y)
    
    # 計算當前epoch的學習率
    lr = INIT_LR * (DECAY_RATE ** (epoch_number / DECAY_STEPS))
    
    # 更新模型引數
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)]

在這個例子中,您不直接計算損失函式,而是隻計算梯度。在許多情況下,您也想要追蹤損失值,JAX 提供了另一個函式 value_and_grad(),它既計算函式的值又計算其梯度。您可以修改 update() 函式以使用這個新功能。

修改更新函式以計算損失值和梯度

from jax import value_and_grad

def update(params, x, y, epoch_number):
    """
    更新模型引數,並計算損失值和梯度。
    
    引數:
    - params: 模型引數(權重和偏差)
    - x: 輸入資料
    - y: 真實標籤
    - epoch_number: 當前epoch數
    
    傳回:
    - 更新後的模型引數、損失值和梯度
    """
    # 計算損失值和梯度
    loss_value, grads = value_and_grad(loss)(params, x, y)
    
    # 計算當前epoch的學習率
    lr = INIT_LR * (DECAY_RATE ** (epoch_number / DECAY_STEPS))
    
    # 更新模型引數
    updated_params = [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)]
    
    return updated_params, loss_value

圖表翻譯:

  flowchart TD
    A[開始] --> B[計算梯度]
    B --> C[計算學習率]
    C --> D[更新模型引數]
    D --> E[傳回更新後的模型引數]

這個流程圖展示了模型引數更新的過程,從計算梯度到更新模型引數,並傳回更新後的模型引數。

學習率衰減引數

學習率衰減引數決定了在多少個 epoch 之後,學習率會再次衰減。這個引數對於模型的訓練過程有著重要的影響,因為它控制了學習率如何隨著時間的推移而調整。

動態學習率調整

在實際的模型訓練中,我們經常需要根據模型的表現動態調整學習率。為了實作這一點,我們可以定義一個函式,用於計算梯度,並將其應用於當前的引數和資料,以獲得梯度值。

梯度計算

梯度計算是機器學習中的一個關鍵步驟,它涉及到計算損失函式相對於模型引數的導數。這個過程可以透過自動微分(autodiff)技術來實作,自動微分允許我們在不需要手動計算導數的情況下,計算出梯度。

學習率更新

在每個訓練步驟中,學習率都需要根據當前的epoch數量進行更新。這個過程可以透過以下公式實作:

[ \text{lr} = \text{INIT_LR} \times \text{DECAY_RATE}^{(\text{epoch_number} / \text{DECAY_STEPS})} ]

其中,INIT_LR是初始學習率,DECAY_RATE是衰減率,epoch_number是當前的epoch數量,DECAY_STEPS是衰減步長。

引數更新

根據梯度和學習率,模型引數可以按照以下公式進行更新:

[ (w, b) = (w - \text{lr} \times dw, b - \text{lr} \times db) ]

其中,( w )和( b )分別是模型的權重和偏置,( dw )和( db )分別是權重和偏置的梯度。

實作細節

在實際實作中,我們可以使用如下的程式碼來計算梯度和更新引數:

import numpy as np

def value_and_grad(loss):
    # 自動微分函式,計算梯度
    def grad_func(params, x, y):
        # 計算梯度
        loss_value = loss(params, x, y)
        grads = np.gradient(loss_value)
        return loss_value, grads
    return grad_func

# 初始化引數和學習率
params = (np.random.rand(), np.random.rand())
x = np.random.rand()
y = np.random.rand()
INIT_LR = 0.1
DECAY_RATE = 0.9
DECAY_STEPS = 10

# 計算梯度和更新引數
loss_value, grads = value_and_grad(lambda params, x, y: np.mean((params[0]*x + params[1] - y)**2))(params, x, y)
lr = INIT_LR * DECAY_RATE ** (1 / DECAY_STEPS)
params_updated = (params[0] - lr * grads[0], params[1] - lr * grads[1])

print("更新後的引數:", params_updated)

內容解密:訓練迴圈實作

在這個步驟中,我們將實作訓練迴圈,以便對模型進行多次迭代的訓練。首先,我們需要定義一些工具函式來計算準確率和跟蹤相關資訊。

import jax.numpy as jnp
from jax.nn import one_hot

# 定義訓練迴圈的引數
num_epochs = 25  # 訓練迴圈的次數

# 定義批次準確率計算函式
def batch_accuracy(params, images, targets):
    """
    計算批次的準確率。
    
    Args:
    params: 模型引數
    images: 輸入影像
    targets: 目標標籤
    
    Returns:
    準確率
    """
    images = jnp.reshape(images, (len(images), NUM_PIXELS))  # 重塑影像為 (batch_size, NUM_PIXELS)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)  # 預測類別
    return jnp.mean(predicted_class == targets)  # 準確率

# 定義整體準確率計算函式
def accuracy(params, data):
    """
    計算整體準確率。
    
    Args:
    params: 模型引數
    data: 資料集
    
    Returns:
    準確率
    """
    accs = []  # 準確率列表
    for images, targets in data:
        accs.append(batch_accuracy(params, images, targets))  # 計算批次準確率
    return jnp.mean(jnp.array(accs))  # 整體準確率

import time  # 引入時間模組

# 訓練迴圈
for epoch in range(num_epochs):
    start_time = time.time()  # 記錄開始時間
    
    losses = []  # 損失列表
    
    for x, y in train_data:
        x = jnp.reshape(x, (len(x), NUM_PIXELS))  # 重塑輸入影像
        
        y = one_hot(y, NUM_LABELS)  # 對目標標籤進行 one-hot 編碼
        
        params, loss_value = update(params, x, y, epoch)  # 更新模型引數
        
        losses.append(loss_value)  # 記錄損失值
    
    epoch_time = time.time() - start_time  # 記錄本次迭代的時間
    
    # 可以在這裡新增列印或儲存相關資訊的程式碼

圖表翻譯:訓練迴圈流程

  flowchart TD
    A[開始] --> B[初始化模型引數]
    B --> C[開始訓練迴圈]
    C --> D[計算批次準確率]
    D --> E[更新模型引數]
    E --> F[記錄損失值]
    F --> G[結束訓練迴圈]
    G --> H[輸出最終結果]

在這個流程中,我們首先初始化模型引數,然後開始訓練迴圈。在每次迭代中,我們計算批次準確率,更新模型引數,記錄損失值。最後,我們結束訓練迴圈並輸出最終結果。

訓練模型的評估與最佳化

在訓練模型的過程中,評估模型的效能是一個非常重要的步驟。這涉及計算模型在訓練集和測試集上的準確率,以及評估模型訓練和評估的時間消耗。

計算準確率

準確率是評估模型效能的一個重要指標,它代表了模型正確預測的樣本數佔總樣本數的比例。以下是計算準確率的步驟:

# 計算訓練集上的準確率
train_acc = accuracy(params, train_data)

# 計算測試集上的準確率
test_acc = accuracy(params, test_data)

計算時間消耗

計算模型訓練和評估的時間消耗對於最佳化模型的效能也是非常重要的。以下是計算時間消耗的步驟:

# 記錄開始時間
start_time = time.time()

# 執行評估
eval_time = time.time() - start_time

# 列印評估時間
print("Eval in {:0.2f} sec".format(eval_time))

最佳化模型

最佳化模型涉及計算損失函式和梯度,以更新模型引數。以下是最佳化模型的步驟:

# 定義損失函式和梯度計算
def calculate_loss_and_gradient(params, data):
    # 計算損失函式
    loss = calculate_loss(params, data)
    
    # 計算梯度
    gradient = calculate_gradient(params, data)
    
    return loss, gradient

# 更新模型引數
def update_params(params, gradient):
    # 更新引數
    updated_params = params - learning_rate * gradient
    
    return updated_params

生成one-hot編碼

one-hot編碼是一種將類別標籤轉換為數值向量的方法。以下是生成one-hot編碼的步驟:

# 定義one-hot編碼函式
def one_hot_encoding(label):
    # 建立one-hot編碼向量
    one_hot_vector = np.zeros((num_classes,))
    one_hot_vector[label] = 1
    
    return one_hot_vector

訓練模型

訓練模型涉及多個epoch的迭代,每個epoch都涉及前向傳播、計算損失函式和梯度、更新模型引數等步驟。以下是訓練模型的步驟:

# 定義訓練模型函式
def train_model(params, train_data, num_epochs):
    for epoch in range(num_epochs):
        # 前向傳播
        outputs = forward_propagate(params, train_data)
        
        # 計算損失函式和梯度
        loss, gradient = calculate_loss_and_gradient(params, train_data)
        
        # 更新模型引數
        params = update_params(params, gradient)
        
        # 列印訓練過程資訊
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))

透過以上步驟,可以完成模型的評估和最佳化,從而得到一個高效能的模型。

使用JAX進行神經網路訓練和最佳化

在本文中,我們將探討如何使用JAX(Java Advanced eXtensions)進行神經網路訓練和最佳化。JAX是一種高效能的機器學習框架,提供了強大的工具和功能來加速和最佳化神經網路的訓練。

從模型訓練、效能評估到引數調整的完整流程來看,本文展示瞭如何利用 JAX 建構並訓練一個簡單的神經網路。分析Jax 的核心特性,如 vmap 的自動向量化和 grad 的自動微分,可以發現它們極大地簡化了批次處理和梯度計算的過程,有效提升了訓練效率。然而,JAX 的函式語言程式設計正規化和純函式要求,對習慣於物件導向程式設計的開發者來說,初期可能存在一定的學習曲線。整合 JAX 至現有專案時,需要仔細考量程式碼風格的轉換和團隊成員的技術儲備。對於追求極致效能且願意投入學習成本的團隊,JAX 提供了優異的硬體加速和程式碼最佳化潛力,尤其在處理大規模資料集和複雜模型時,其優勢將更加顯著。玄貓認為,JAX 作為新一代深度學習框架,展現了強大的技術優勢,值得關注效能的核心繫統採用。隨著社群的持續發展和工具鏈的日漸完善,我們預見 JAX 的應用門檻將大幅降低,並在未來深度學習領域扮演更重要的角色。