深度學習模型訓練的核心在於利用最佳化器調整模型引數,使其逐步擬合訓練資料。本篇從最佳化器狀態初始化開始,逐步介紹如何取得引數更新、應用更新,並以 Optax 與 Flax 框架為例,演示構建訓練迴圈、計算指標、使用 ResNet 進行影像分類別等實務技巧。同時也討論瞭如何提升訓練效率、評估模型效能,以及現代電腦視覺技術的發展趨勢。

最佳化器狀態與模型引數更新

在深度學習中,最佳化器(optimizer)扮演著調整模型引數以最小化損失函式的關鍵角色。以下是最佳化器狀態和模型引數更新的過程:

初始化最佳化器狀態

首先,我們需要初始化最佳化器的狀態。這通常涉及設定最佳化器的超引數,例如學習率、動量等。假設我們使用的是 optax 最佳化器函式庫,則可以使用 optimizer.init 方法來初始化最佳化器狀態。

import optax

# 定義模型引數
params =...

# 初始化最佳化器狀態
opt_state = optax.init(optimizer)

取得引數更新

接下來,我們需要根據模型引數的梯度來計算引數更新。這通常涉及呼叫最佳化器的 update 方法,傳入梯度和當前的最佳化器狀態。

# 取得梯度
grads = jax.grad(loss)(params)

# 取得引數更新和新的最佳化器狀態
updates, opt_state = optimizer.update(grads, opt_state)

應用引數更新

最後,我們需要將引數更新應用到模型引數上。這通常涉及呼叫 optax.apply_updates 方法,傳入當前的模型引數和引數更新。

# 應用引數更新
params = optax.apply_updates(params, updates)

內容解密

上述過程展示瞭如何使用最佳化器來更新模型引數。首先,我們需要初始化最佳化器狀態,然後根據模型引數的梯度來計算引數更新,最後將引數更新應用到模型引數上。

圖表翻譯

以下是上述過程的視覺化表示:

  flowchart TD
    A[初始化最佳化器狀態] --> B[取得梯度]
    B --> C[取得引數更新]
    C --> D[應用引數更新]

在這個流程圖中,首先我們初始化最佳化器狀態,然後根據模型引數的梯度來計算引數更新,最後將引數更新應用到模型引數上。這個過程不斷迭代,直到模型收斂或達到停止條件。

使用 Optax 最佳化器的步驟

在使用 Optax 進行神經網路最佳化時,需要按照特定的步驟進行。以下是使用 Optax 最佳化器的詳細步驟:

1. 建立最佳化器物件

首先,需要建立一個最佳化器物件,例如 Adam 最佳化器。這可以透過 optax.adam(learning_rate) 函式實作,該函式傳回一個最佳化器物件。

2. 初始化最佳化器狀態

接下來,需要初始化最佳化器狀態,例如動量向量。這可以透過 optimizer.init(params) 函式實作,該函式傳回最佳化器的初始狀態。

3. 計算梯度

在模型更新迴圈中,需要計算損失函式的梯度。這可以透過 jax.grad(loss_function)(params, x, y) 函式實作,該函式傳回損失函式對於模型引數的梯度。

4. 更新模型引數

最後,需要使用最佳化器更新模型引數。這可以透過 optimizer.update(opt_state, grads, params) 函式實作,該函式傳回更新後的模型引數和最佳化器狀態。

內容解密:

上述步驟可以透過以下程式碼實作:

import optax
import jax

# 建立最佳化器物件
optimizer = optax.adam(learning_rate=0.001)

# 初始化最佳化器狀態
opt_state = optimizer.init(params)

# 計算梯度
grads = jax.grad(loss_function)(params, x, y)

# 更新模型引數
opt_state, params = optimizer.update(opt_state, grads, params)

圖表翻譯:

以下是使用 Optax 最佳化器的流程圖:

  flowchart TD
    A[建立最佳化器物件] --> B[初始化最佳化器狀態]
    B --> C[計算梯度]
    C --> D[更新模型引數]
    D --> E[傳回更新後的模型引數和最佳化器狀態]

這個流程圖展示了使用 Optax 最佳化器的步驟,從建立最佳化器物件到更新模型引數。

使用 Optax 和 Flax 進行神經網路訓練

Optax 是一個根據 JAX 的最佳化器函式庫,提供了多種最佳化器和工具來進行神經網路訓練。Flax 是一個根據 JAX 的神經網路函式庫,提供了簡單易用的 API 來進行神經網路建模和訓練。在本文中,我們將介紹如何使用 Optax 和 Flax 進行神經網路訓練。

Optax 最佳化器

Optax 提供了多種最佳化器,包括 SGD、Adam、RMSProp 等。這些最佳化器都實作了 GradientTransformation 介面,可以用來更新模型引數。Optax 也提供了工具來組合不同的最佳化器和轉換,例如 chain()multi_transform()

Flax TrainState

Flax 提供了 TrainState 類別來簡化使用 Optax 最佳化器的過程。TrainState 類別包含了模型引數、最佳化器狀態和其他訓練相關的狀態。使用 TrainState 類別,可以簡化訓練迴圈的實作。

訓練迴圈

訓練迴圈是神經網路訓練的核心部分。使用 Optax 和 Flax,可以簡化訓練迴圈的實作。以下是使用 Optax 和 Flax 進行神經網路訓練的基本步驟:

  1. 建立模型和最佳化器
  2. 初始化模型引數和最佳化器狀態
  3. 計算損失函式和梯度
  4. 更新模型引數

以下是使用 Optax 和 Flax 進行神經網路訓練的範例程式碼:

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

# 定義模型
class MyModel(nn.Module):
    def __call__(self, x):
        #...

# 建立模型和最佳化器
model = MyModel()
optimizer = optax.sgd(learning_rate=0.01)

# 初始化模型引數和最佳化器狀態
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=model.init(jax.random.PRNGKey(0), jnp.ones((1, 10))),
    tx=optimizer
)

# 訓練迴圈
for epoch in range(10):
    # 計算損失函式和梯度
    loss, grads = jax.value_and_grad(model.apply)(state.params, jnp.ones((1, 10)))
    
    # 更新模型引數
    state = state.apply_gradients(grads)

在這個範例中,我們定義了一個簡單的模型 MyModel,並建立了一個 Optax 最佳化器 optimizer。我們初始化模型引數和最佳化器狀態,然後進入訓練迴圈。在每個 epoch 中,我們計算損失函式和梯度,然後更新模型引數。

訓練迴圈的實作

在深度學習中,訓練迴圈是模型學習的核心部分。以下是如何實作一個基本的訓練迴圈:

單步訓練

def update(train_state, x, y):
    """
    單步訓練函式
    """
    # 定義損失函式
    def loss(params, images, targets):
        """
        分類別交叉熵損失函式
        """
        logits = train_state.apply_fn(params, images)
        log_preds = logits - jax.nn.logsumexp(logits)
        return -jnp.mean(targets * log_preds)

    # 計算損失值和梯度
    loss_value, grads = jax.value_and_grad(loss)(train_state.params, x, y)
    
    # 更新模型引數
    train_state = train_state.apply_gradients(grads=grads)
    
    return train_state, loss_value

多步訓練迴圈

for epoch in range(num_epochs):
    start_time = time.time()
    losses = []
    
    for x, y in train_data:
        # 對輸入資料進行reshape
        x = jnp.reshape(x, (len(x), NUM_PIXELS))
        
        # 對標籤進行one-hot編碼
        y = jax.nn.one_hot(y, NUM_LABELS)
        
        # 執行單步訓練
        train_state, loss_value = update(train_state, x, y)
        
        # 記錄損失值
        losses.append(loss_value)
    
    # 計算平均損失值
    avg_loss = jnp.mean(losses)
    
    # 顯示訓練進度
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Time: {time.time() - start_time:.2f} seconds")

圖表翻譯:

  flowchart TD
    A[開始] --> B[讀取資料]
    B --> C[reshape輸入資料]
    C --> D[one-hot編碼標籤]
    D --> E[執行單步訓練]
    E --> F[記錄損失值]
    F --> G[計算平均損失值]
    G --> H[顯示訓練進度]

圖表翻譯:

上述Mermaid圖表展示了訓練迴圈的流程。首先,讀取資料,然後對輸入資料進行reshape和one-hot編碼標籤。接下來,執行單步訓練,記錄損失值,計算平均損失值,最後顯示訓練進度。

內容解密:

在上述程式碼中,update函式實作了單步訓練,包括計算損失值和梯度,更新模型引數。多步訓練迴圈則對整個資料集進行了遍歷,執行了多次單步訓練,並記錄了損失值。最後,計算了平均損失值,並顯示了訓練進度。這些步驟都是深度學習中模型訓練的核心部分。

執行高階神經網路訓練

在深度學習中,訓練一個神經網路模型涉及多個步驟,包括定義模型架構、選擇最佳化器、計算梯度以及更新模型引數。以下是使用高階API實作的MNIST手寫字型識別任務的訓練過程。

建立訓練狀態

首先,我們需要建立一個訓練狀態(state),這包括初始化模型引數、最佳化器狀態等。這一步驟通常涉及到選擇合適的最佳化器,例如隨機梯度下降(SGD)最佳化器。

import optax

# 建立一個SGD最佳化器
optimizer = optax.sgd(learning_rate=0.01)

# 初始化模型引數和最佳化器狀態
params, opt_state = initialize_model_and_optimizer()

應用模型和計算梯度

接下來,我們需要將輸入資料(在這裡是MNIST影像)應用到模型中,並計算梯度。這一步驟對於模型的更新至關重要,因為它告訴我們模型引數應該如何調整以最小化損失函式。

# 定義模型
def model(params, x):
    # 在這裡實作你的神經網路模型
    pass

# 計算梯度
grads = jax.grad(loss_fn)(params, x)

更新模型引數和最佳化器狀態

使用計算出的梯度,我們可以更新模型引數和最佳化器狀態。這通常涉及到最佳化器的update方法。

# 更新模型引數和最佳化器狀態
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

訓練迴圈

整個訓練過程通常是在一個迴圈中進行的,每次迭代都會更新一次模型引數和最佳化器狀態,並計算損失值以評估模型的效能。

for epoch in range(num_epochs):
    for x, y in train_dataset:
        # 前向傳播、計算損失和反向傳播
        loss_value = loss_fn(params, x, y)
        
        # 更新模型引數和最佳化器狀態
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        
        # 儲存損失值以便後續分析
        losses.append(loss_value)

時間效率和效能評估

最後,評估訓練過程的時間效率和模型效能是非常重要的。這可以透過計算每個epoch的時間以及評估模型在驗證集上的效能來實作。

start_time = time.time()
for epoch in range(num_epochs):
    # 訓練過程...
    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1}, Time: {epoch_time:.2f} seconds")

內容解密:

以上程式碼片段展示瞭如何使用高階API進行神經網路訓練的各個步驟,包括建立訓練狀態、應用模型、計算梯度、更新模型引數和最佳化器狀態,以及評估訓練過程的時間效率和模型效能。這些步驟對於構建和訓練一個高效能的神經網路模型至關重要。

圖表翻譯:

  graph LR
    A[初始化模型和最佳化器] --> B[應用模型和計算梯度]
    B --> C[更新模型引數和最佳化器狀態]
    C --> D[評估模型效能]
    D --> E[儲存損失值和時間資訊]

此圖表展示了神經網路訓練過程中的主要步驟,從初始化模型和最佳化器開始,到應用模型、計算梯度、更新引數、評估效能,最後儲存相關資訊以便分析。

使用CLU函式庫計算指標

在Flax中,計算指標可以使用CLU函式庫來實作。CLU函式庫提供了一個功能性指標計算介面,稱為Metric,它依賴於指標累積中間值,然後使用這些中間值計算最終指標值。這個過程可以分為三個步驟:

  1. 計算區域性批次指標:對於每個批次,計算區域性批次指標從模型輸出。模型輸出是一個包含唯一鍵的字典,每個鍵都有一個特定的含義(例如「loss」、「logits」和「labels」)。每個指標都依賴於至少一個這樣的模型輸出。Metric介面提供了from_output()函式來指定用於指標計算的模型輸出名稱。
  2. 聚合中間指標:使用merge()函式聚合不同批次的區域性或中間指標。
  3. 計算最終指標:使用compute()函式從聚合的中間值計算最終指標。

新增指標到訓練狀態

要新增指標到訓練狀態,首先需要宣告一個資料類別來儲存指標、損失和準確率。然後,建立一個子類別繼承自TrainState,並包含指標。最後,例項化新的TrainState類別,並初始化所有相關欄位,包括指標。

from flax.training import train_state
from clu import metrics
import flax
import optax

@flax.struct.dataclass
class Metrics:
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output('loss')

class TrainState(train_state.TrainState):
    metrics: Metrics

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.sgd(learning_rate=0.01, momentum=0.9),
    metrics=Metrics.empty()
)

修改訓練迴圈以包含指標計算

接下來,需要修改訓練迴圈以包含指標計算。可以定義一個函式來計算所有指標,並在每個訓練批次上呼叫這個函式。內部使用Optax函式計算交叉熵損失,並使用CLU函式庫中的現有指標計算準確率。如果需要其他指標,如精確度或召回率,可能需要獨立實作,但這並不困難。

@jax.jit
def compute_metrics(state, x, y):
    # 計算指標的實作
    pass

測試集指標計算

要計算測試集指標,可以克隆一個具有空指標的TrainState,並以與訓練過程相同的方式計算所有指標。

test_state = state.replace(metrics=Metrics.empty())
# 計算測試集指標的實作
pass

透過這些步驟,可以使用CLU函式庫和Flax來實作指標計算,並將其整合到訓練迴圈中。

深度學習模型訓練流程

在深度學習中,模型的訓練是一個迭代的過程,涉及到多個步驟,包括前向傳播、損失計算、反向傳播和引數更新。以下是模型訓練流程的詳細步驟:

1. 前向傳播

首先,模型接收輸入資料 x,並將其傳遞給模型的前向傳播函式,計算輸出 logits。這一步驟可以使用以下程式碼實作:

logits = state.apply_fn(state.params, x)

其中,state.apply_fn 是模型的前向傳播函式,state.params 是模型的引數,x 是輸入資料。

2. 損失計算

接下來,模型計算損失函式,使用 softmax 交叉熵損失函式,計算輸出 logits 和真實標籤 y 之間的差異。這一步驟可以使用以下程式碼實作:

loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()

其中,optax.softmax_cross_entropy_with_integer_labels 是 softmax 交叉熵損失函式,logits 是模型的輸出,y 是真實標籤。

3. 指標更新

模型計算指標更新,使用 state.metrics.single_from_model_output 函式,計算輸出 logits 和真實標籤 y 之間的差異。這一步驟可以使用以下程式碼實作:

metric_updates = state.metrics.single_from_model_output(logits=logits, labels=y, loss=loss)

其中,state.metrics.single_from_model_output 是指標更新函式,logits 是模型的輸出,y 是真實標籤,loss 是損失函式。

4. 指標合併

模型合併指標更新,使用 state.metrics.merge 函式,合併指標更新和現有的指標。這一步驟可以使用以下程式碼實作:

metrics = state.metrics.merge(metric_updates)

其中,state.metrics.merge 是指標合併函式,metric_updates 是指標更新。

5. 狀態更新

模型更新狀態,使用 state.replace 函式,更新狀態中的指標。這一步驟可以使用以下程式碼實作:

state = state.replace(metrics=metrics)

其中,state.replace 是狀態更新函式,metrics 是更新後的指標。

6. 訓練迴圈

模型進行訓練迴圈,使用 for 迴圈,迭代多次,直到達到指定的 epoch 數量。這一步驟可以使用以下程式碼實作:

for epoch in range(num_epochs):
    #...

其中,num_epochs 是指定的 epoch 數量。

7. 測試

模型進行測試,使用 test_state 狀態,計算測試資料的指標。這一步驟可以使用以下程式碼實作:

test_state = state
for x, y in test_data:
    #...

其中,test_state 是測試狀態,test_data 是測試資料。

內容解密:

以上程式碼實作了深度學習模型的訓練流程,包括前向傳播、損失計算、反向傳播和引數更新。模型使用 softmax 交叉熵損失函式計算損失,並使用 optax 函式庫進行最佳化。模型的指標更新和合併使用 state.metrics 函式進行。最終,模型進行測試,計算測試資料的指標。

圖表翻譯:

以下是模型訓練流程的 Mermaid 圖表:

  graph LR
    A[輸入資料] --> B[前向傳播]
    B --> C[損失計算]
    C --> D[指標更新]
    D --> E[指標合併]
    E --> F[狀態更新]
    F --> G[訓練迴圈]
    G --> H[測試]

這個圖表展示了模型訓練流程的各個步驟,包括前向傳播、損失計算、指標更新、指標合併、狀態更新、訓練迴圈和測試。

訓練狀態指標新增

為了評估模型的效能,我們需要新增指標(metrics)到訓練狀態(TrainState)中。這個過程涉及初始化指標集合、計算指標值以及更新訓練狀態。

初始化指標

指標集合初始化為空,等待計算出來的指標值被新增進去。這一步驟確保了指標的準確性和可靠性。

計算指標

計算指標的函式負責進行所有指標的計算,並將新的指標更新到訓練狀態中。這個過程涉及多個步驟,包括取得模型的輸出、使用softmax損失函式等。

取得模型輸出

模型的輸出是計算指標的基礎。透過模型的輸出,可以計算出各種指標,例如準確率、損失值等。

使用softmax損失函式

Optax函式庫中的softmax損失函式被用於計算模型的損失值。這個函式對於評估模型的效能至關重要。

計算中間指標

模型輸出字典被提供給計算中間指標的函式,以便計算出所有必要的指標。這些指標對於評估模型的效能和調整模型引數非常重要。

內容解密:

import optax

def calculate_metrics(train_state, model_output):
    # 初始化指標集合
    metrics = {}
    
    # 取得模型輸出
    output = model_output
    
    # 使用softmax損失函式
    loss_fn = optax.softmax_cross_entropy
    
    # 計算損失值
    loss = loss_fn(output)
    
    # 更新指標集合
    metrics['loss'] = loss
    
    # 更新訓練狀態
    train_state.metrics = metrics
    
    return train_state

圖表翻譯:

  flowchart TD
    A[初始化指標] --> B[計算指標]
    B --> C[取得模型輸出]
    C --> D[使用softmax損失函式]
    D --> E[計算中間指標]
    E --> F[更新訓練狀態]

這個過程確保了模型的效能被準確地評估和跟蹤,為模型的調整和最佳化提供了基礎。

影像分類別使用ResNet

在深度學習中,影像分類別是一個基本任務,旨在根據影像的視覺特徵將其分類別到預先定義的類別中。ResNet(Residual Network)是一種成功的卷積神經網路架構,特別適合於影像分類別任務。

ResNet架構

ResNet的核心思想是引入殘差連線(residual connection),使得網路可以學習到更深層次的特徵表示。這種架構允許網路更容易地最佳化,並且可以達到更高的準確率。

殘差塊

ResNet中的殘差塊是其核心組成部分。每個殘差塊包含兩個卷積層和一個殘差連線。殘差連線允許網路學習到更抽象的特徵,並且可以減少梯度消失問題。

  flowchart TD
    A[輸入] --> B[卷積層1]
    B --> C[啟用函式]
    C --> D[卷積層2]
    D --> E[啟用函式]
    E --> F[殘差連線]
    F --> G[輸出]

網路架構

ResNet的網路架構通常由多個殘差塊堆積疊而成。每個殘差塊都會學習到不同的特徵,並且可以根據需要堆積疊多個殘差塊以達到所需的深度。

  flowchart TD
    A[輸入] --> B[殘差塊1]
    B --> C[殘差塊2]
    C --> D[殘差塊3]
    D --> E[全連線層]
    E --> F[輸出]

訓練過程

在訓練ResNet時,我們需要更新網路的引數以最小化損失函式。這個過程通常涉及多個epoch,每個epoch都會對整個訓練集進行一次迭代。

更新引數

在每個epoch中,我們會計算損失函式並更新網路的引數以最小化損失。

for epoch in range(num_epochs):
    for x, y in train_loader:
        # 前向傳播
        output = model(x)
        loss = criterion(output, y)
        
        # 後向傳播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

評估模型

在每個epoch結束後,我們會評估模型在驗證集上的效能,以確保模型沒有過度擬合。

for epoch in range(num_epochs):
    # 訓練模型
    for x, y in train_loader:
        #...
    
    # 評估模型
    model.eval()
    with torch.no_grad():
        total_correct = 0
        for x, y in val_loader:
            output = model(x)
            _, predicted = torch.max(output, 1)
            total_correct += (predicted == y).sum().item()
    
    accuracy = total_correct / len(val_loader.dataset)
    print(f'Epoch {epoch+1}, Accuracy: {accuracy:.4f}')
圖表翻譯:

上述Mermaid圖表展示了ResNet的架構,包括殘差塊和網路架構。圖表中,每個盒子代表一個層或一個模組,箭頭代表了資料的流動方向。這種架構允許ResNet學習到更抽象的特徵,並且可以減少梯度消失問題。

影像分類別使用ResNet

在之前的章節中,我們使用Flax訓練了一個簡單的神經網路,但該模型仍然是一個非常基礎的多層感知器(MLP),距離目前的最佳實踐還有很大差距。下一節將實作更先進的殘差網路(ResNet)進行影像分類別。

11.2 影像分類別使用ResNet

您可能知道,使用MLP進行影像處理任務通常不是最佳選擇,因為存在更專業和高效的解決方案,如卷積神經網路(CNN)。CNN在多年來不斷演進,通常人們使用殘差網路(ResNets)。在這裡,我們將使用第9章中的Dogs vs. Cats資料集,並實作一個相對簡單的ResNet,這可能是此任務的合適選擇。

現代電腦視覺

典型的影像分類別問題的實際解決方案可能涉及CNN。CNN是一種適合用於影像的神經網路,具有卓越的性質,如平移等變性和學習區域性特徵的能力。CNN長期被認為是最好的神經網路型別,適合於影像處理任務。但是在近幾年,情況發生了變化。

從技術架構視角來看,本文深入探討了深度學習模型訓練的關鍵環節,涵蓋了最佳化器狀態管理、引數更新策略、訓練迴圈架構以及指標計算方法。分析比較了MLP和CNN等不同網路架構在影像分類別任務中的適用性,並闡述了ResNet的核心思想——殘差連線如何提升模型效能。然而,僅僅依靠ResNet架構並不能保證最佳效能,模型的超引數調整、資料增強策略以及訓練過程的監控同樣至關重要。展望未來,根據Transformer的模型在電腦視覺領域的應用日益廣泛,Vision Transformer等架構展現出巨大的潛力,可能在特定任務上超越CNN。對於追求高效能的影像分類別任務,技術團隊應密切關注Transformer模型的發展趨勢,並積極探索其與ResNet等CNN架構的融合方案。玄貓認為,深度學習模型的訓練是一個持續最佳化的過程,需要不斷探索新的架構和訓練策略,才能在特定任務上達到最佳效能。