深度學習模型訓練的核心在於利用最佳化器調整模型引數,使其逐步擬合訓練資料。本篇從最佳化器狀態初始化開始,逐步介紹如何取得引數更新、應用更新,並以 Optax 與 Flax 框架為例,演示構建訓練迴圈、計算指標、使用 ResNet 進行影像分類別等實務技巧。同時也討論瞭如何提升訓練效率、評估模型效能,以及現代電腦視覺技術的發展趨勢。
最佳化器狀態與模型引數更新
在深度學習中,最佳化器(optimizer)扮演著調整模型引數以最小化損失函式的關鍵角色。以下是最佳化器狀態和模型引數更新的過程:
初始化最佳化器狀態
首先,我們需要初始化最佳化器的狀態。這通常涉及設定最佳化器的超引數,例如學習率、動量等。假設我們使用的是 optax 最佳化器函式庫,則可以使用 optimizer.init 方法來初始化最佳化器狀態。
import optax
# 定義模型引數
params =...
# 初始化最佳化器狀態
opt_state = optax.init(optimizer)
取得引數更新
接下來,我們需要根據模型引數的梯度來計算引數更新。這通常涉及呼叫最佳化器的 update 方法,傳入梯度和當前的最佳化器狀態。
# 取得梯度
grads = jax.grad(loss)(params)
# 取得引數更新和新的最佳化器狀態
updates, opt_state = optimizer.update(grads, opt_state)
應用引數更新
最後,我們需要將引數更新應用到模型引數上。這通常涉及呼叫 optax.apply_updates 方法,傳入當前的模型引數和引數更新。
# 應用引數更新
params = optax.apply_updates(params, updates)
內容解密
上述過程展示瞭如何使用最佳化器來更新模型引數。首先,我們需要初始化最佳化器狀態,然後根據模型引數的梯度來計算引數更新,最後將引數更新應用到模型引數上。
圖表翻譯
以下是上述過程的視覺化表示:
flowchart TD
A[初始化最佳化器狀態] --> B[取得梯度]
B --> C[取得引數更新]
C --> D[應用引數更新]
在這個流程圖中,首先我們初始化最佳化器狀態,然後根據模型引數的梯度來計算引數更新,最後將引數更新應用到模型引數上。這個過程不斷迭代,直到模型收斂或達到停止條件。
使用 Optax 最佳化器的步驟
在使用 Optax 進行神經網路最佳化時,需要按照特定的步驟進行。以下是使用 Optax 最佳化器的詳細步驟:
1. 建立最佳化器物件
首先,需要建立一個最佳化器物件,例如 Adam 最佳化器。這可以透過 optax.adam(learning_rate) 函式實作,該函式傳回一個最佳化器物件。
2. 初始化最佳化器狀態
接下來,需要初始化最佳化器狀態,例如動量向量。這可以透過 optimizer.init(params) 函式實作,該函式傳回最佳化器的初始狀態。
3. 計算梯度
在模型更新迴圈中,需要計算損失函式的梯度。這可以透過 jax.grad(loss_function)(params, x, y) 函式實作,該函式傳回損失函式對於模型引數的梯度。
4. 更新模型引數
最後,需要使用最佳化器更新模型引數。這可以透過 optimizer.update(opt_state, grads, params) 函式實作,該函式傳回更新後的模型引數和最佳化器狀態。
內容解密:
上述步驟可以透過以下程式碼實作:
import optax
import jax
# 建立最佳化器物件
optimizer = optax.adam(learning_rate=0.001)
# 初始化最佳化器狀態
opt_state = optimizer.init(params)
# 計算梯度
grads = jax.grad(loss_function)(params, x, y)
# 更新模型引數
opt_state, params = optimizer.update(opt_state, grads, params)
圖表翻譯:
以下是使用 Optax 最佳化器的流程圖:
flowchart TD
A[建立最佳化器物件] --> B[初始化最佳化器狀態]
B --> C[計算梯度]
C --> D[更新模型引數]
D --> E[傳回更新後的模型引數和最佳化器狀態]
這個流程圖展示了使用 Optax 最佳化器的步驟,從建立最佳化器物件到更新模型引數。
使用 Optax 和 Flax 進行神經網路訓練
Optax 是一個根據 JAX 的最佳化器函式庫,提供了多種最佳化器和工具來進行神經網路訓練。Flax 是一個根據 JAX 的神經網路函式庫,提供了簡單易用的 API 來進行神經網路建模和訓練。在本文中,我們將介紹如何使用 Optax 和 Flax 進行神經網路訓練。
Optax 最佳化器
Optax 提供了多種最佳化器,包括 SGD、Adam、RMSProp 等。這些最佳化器都實作了 GradientTransformation 介面,可以用來更新模型引數。Optax 也提供了工具來組合不同的最佳化器和轉換,例如 chain() 和 multi_transform()。
Flax TrainState
Flax 提供了 TrainState 類別來簡化使用 Optax 最佳化器的過程。TrainState 類別包含了模型引數、最佳化器狀態和其他訓練相關的狀態。使用 TrainState 類別,可以簡化訓練迴圈的實作。
訓練迴圈
訓練迴圈是神經網路訓練的核心部分。使用 Optax 和 Flax,可以簡化訓練迴圈的實作。以下是使用 Optax 和 Flax 進行神經網路訓練的基本步驟:
- 建立模型和最佳化器
- 初始化模型引數和最佳化器狀態
- 計算損失函式和梯度
- 更新模型引數
以下是使用 Optax 和 Flax 進行神經網路訓練的範例程式碼:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
# 定義模型
class MyModel(nn.Module):
def __call__(self, x):
#...
# 建立模型和最佳化器
model = MyModel()
optimizer = optax.sgd(learning_rate=0.01)
# 初始化模型引數和最佳化器狀態
state = train_state.TrainState.create(
apply_fn=model.apply,
params=model.init(jax.random.PRNGKey(0), jnp.ones((1, 10))),
tx=optimizer
)
# 訓練迴圈
for epoch in range(10):
# 計算損失函式和梯度
loss, grads = jax.value_and_grad(model.apply)(state.params, jnp.ones((1, 10)))
# 更新模型引數
state = state.apply_gradients(grads)
在這個範例中,我們定義了一個簡單的模型 MyModel,並建立了一個 Optax 最佳化器 optimizer。我們初始化模型引數和最佳化器狀態,然後進入訓練迴圈。在每個 epoch 中,我們計算損失函式和梯度,然後更新模型引數。
訓練迴圈的實作
在深度學習中,訓練迴圈是模型學習的核心部分。以下是如何實作一個基本的訓練迴圈:
單步訓練
def update(train_state, x, y):
"""
單步訓練函式
"""
# 定義損失函式
def loss(params, images, targets):
"""
分類別交叉熵損失函式
"""
logits = train_state.apply_fn(params, images)
log_preds = logits - jax.nn.logsumexp(logits)
return -jnp.mean(targets * log_preds)
# 計算損失值和梯度
loss_value, grads = jax.value_and_grad(loss)(train_state.params, x, y)
# 更新模型引數
train_state = train_state.apply_gradients(grads=grads)
return train_state, loss_value
多步訓練迴圈
for epoch in range(num_epochs):
start_time = time.time()
losses = []
for x, y in train_data:
# 對輸入資料進行reshape
x = jnp.reshape(x, (len(x), NUM_PIXELS))
# 對標籤進行one-hot編碼
y = jax.nn.one_hot(y, NUM_LABELS)
# 執行單步訓練
train_state, loss_value = update(train_state, x, y)
# 記錄損失值
losses.append(loss_value)
# 計算平均損失值
avg_loss = jnp.mean(losses)
# 顯示訓練進度
print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Time: {time.time() - start_time:.2f} seconds")
圖表翻譯:
flowchart TD
A[開始] --> B[讀取資料]
B --> C[reshape輸入資料]
C --> D[one-hot編碼標籤]
D --> E[執行單步訓練]
E --> F[記錄損失值]
F --> G[計算平均損失值]
G --> H[顯示訓練進度]
圖表翻譯:
上述Mermaid圖表展示了訓練迴圈的流程。首先,讀取資料,然後對輸入資料進行reshape和one-hot編碼標籤。接下來,執行單步訓練,記錄損失值,計算平均損失值,最後顯示訓練進度。
內容解密:
在上述程式碼中,update函式實作了單步訓練,包括計算損失值和梯度,更新模型引數。多步訓練迴圈則對整個資料集進行了遍歷,執行了多次單步訓練,並記錄了損失值。最後,計算了平均損失值,並顯示了訓練進度。這些步驟都是深度學習中模型訓練的核心部分。
執行高階神經網路訓練
在深度學習中,訓練一個神經網路模型涉及多個步驟,包括定義模型架構、選擇最佳化器、計算梯度以及更新模型引數。以下是使用高階API實作的MNIST手寫字型識別任務的訓練過程。
建立訓練狀態
首先,我們需要建立一個訓練狀態(state),這包括初始化模型引數、最佳化器狀態等。這一步驟通常涉及到選擇合適的最佳化器,例如隨機梯度下降(SGD)最佳化器。
import optax
# 建立一個SGD最佳化器
optimizer = optax.sgd(learning_rate=0.01)
# 初始化模型引數和最佳化器狀態
params, opt_state = initialize_model_and_optimizer()
應用模型和計算梯度
接下來,我們需要將輸入資料(在這裡是MNIST影像)應用到模型中,並計算梯度。這一步驟對於模型的更新至關重要,因為它告訴我們模型引數應該如何調整以最小化損失函式。
# 定義模型
def model(params, x):
# 在這裡實作你的神經網路模型
pass
# 計算梯度
grads = jax.grad(loss_fn)(params, x)
更新模型引數和最佳化器狀態
使用計算出的梯度,我們可以更新模型引數和最佳化器狀態。這通常涉及到最佳化器的update方法。
# 更新模型引數和最佳化器狀態
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
訓練迴圈
整個訓練過程通常是在一個迴圈中進行的,每次迭代都會更新一次模型引數和最佳化器狀態,並計算損失值以評估模型的效能。
for epoch in range(num_epochs):
for x, y in train_dataset:
# 前向傳播、計算損失和反向傳播
loss_value = loss_fn(params, x, y)
# 更新模型引數和最佳化器狀態
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
# 儲存損失值以便後續分析
losses.append(loss_value)
時間效率和效能評估
最後,評估訓練過程的時間效率和模型效能是非常重要的。這可以透過計算每個epoch的時間以及評估模型在驗證集上的效能來實作。
start_time = time.time()
for epoch in range(num_epochs):
# 訓練過程...
epoch_time = time.time() - start_time
print(f"Epoch {epoch+1}, Time: {epoch_time:.2f} seconds")
內容解密:
以上程式碼片段展示瞭如何使用高階API進行神經網路訓練的各個步驟,包括建立訓練狀態、應用模型、計算梯度、更新模型引數和最佳化器狀態,以及評估訓練過程的時間效率和模型效能。這些步驟對於構建和訓練一個高效能的神經網路模型至關重要。
圖表翻譯:
graph LR
A[初始化模型和最佳化器] --> B[應用模型和計算梯度]
B --> C[更新模型引數和最佳化器狀態]
C --> D[評估模型效能]
D --> E[儲存損失值和時間資訊]
此圖表展示了神經網路訓練過程中的主要步驟,從初始化模型和最佳化器開始,到應用模型、計算梯度、更新引數、評估效能,最後儲存相關資訊以便分析。
使用CLU函式庫計算指標
在Flax中,計算指標可以使用CLU函式庫來實作。CLU函式庫提供了一個功能性指標計算介面,稱為Metric,它依賴於指標累積中間值,然後使用這些中間值計算最終指標值。這個過程可以分為三個步驟:
- 計算區域性批次指標:對於每個批次,計算區域性批次指標從模型輸出。模型輸出是一個包含唯一鍵的字典,每個鍵都有一個特定的含義(例如「loss」、「logits」和「labels」)。每個指標都依賴於至少一個這樣的模型輸出。
Metric介面提供了from_output()函式來指定用於指標計算的模型輸出名稱。 - 聚合中間指標:使用
merge()函式聚合不同批次的區域性或中間指標。 - 計算最終指標:使用
compute()函式從聚合的中間值計算最終指標。
新增指標到訓練狀態
要新增指標到訓練狀態,首先需要宣告一個資料類別來儲存指標、損失和準確率。然後,建立一個子類別繼承自TrainState,並包含指標。最後,例項化新的TrainState類別,並初始化所有相關欄位,包括指標。
from flax.training import train_state
from clu import metrics
import flax
import optax
@flax.struct.dataclass
class Metrics:
accuracy: metrics.Accuracy
loss: metrics.Average.from_output('loss')
class TrainState(train_state.TrainState):
metrics: Metrics
state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.sgd(learning_rate=0.01, momentum=0.9),
metrics=Metrics.empty()
)
修改訓練迴圈以包含指標計算
接下來,需要修改訓練迴圈以包含指標計算。可以定義一個函式來計算所有指標,並在每個訓練批次上呼叫這個函式。內部使用Optax函式計算交叉熵損失,並使用CLU函式庫中的現有指標計算準確率。如果需要其他指標,如精確度或召回率,可能需要獨立實作,但這並不困難。
@jax.jit
def compute_metrics(state, x, y):
# 計算指標的實作
pass
測試集指標計算
要計算測試集指標,可以克隆一個具有空指標的TrainState,並以與訓練過程相同的方式計算所有指標。
test_state = state.replace(metrics=Metrics.empty())
# 計算測試集指標的實作
pass
透過這些步驟,可以使用CLU函式庫和Flax來實作指標計算,並將其整合到訓練迴圈中。
深度學習模型訓練流程
在深度學習中,模型的訓練是一個迭代的過程,涉及到多個步驟,包括前向傳播、損失計算、反向傳播和引數更新。以下是模型訓練流程的詳細步驟:
1. 前向傳播
首先,模型接收輸入資料 x,並將其傳遞給模型的前向傳播函式,計算輸出 logits。這一步驟可以使用以下程式碼實作:
logits = state.apply_fn(state.params, x)
其中,state.apply_fn 是模型的前向傳播函式,state.params 是模型的引數,x 是輸入資料。
2. 損失計算
接下來,模型計算損失函式,使用 softmax 交叉熵損失函式,計算輸出 logits 和真實標籤 y 之間的差異。這一步驟可以使用以下程式碼實作:
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
其中,optax.softmax_cross_entropy_with_integer_labels 是 softmax 交叉熵損失函式,logits 是模型的輸出,y 是真實標籤。
3. 指標更新
模型計算指標更新,使用 state.metrics.single_from_model_output 函式,計算輸出 logits 和真實標籤 y 之間的差異。這一步驟可以使用以下程式碼實作:
metric_updates = state.metrics.single_from_model_output(logits=logits, labels=y, loss=loss)
其中,state.metrics.single_from_model_output 是指標更新函式,logits 是模型的輸出,y 是真實標籤,loss 是損失函式。
4. 指標合併
模型合併指標更新,使用 state.metrics.merge 函式,合併指標更新和現有的指標。這一步驟可以使用以下程式碼實作:
metrics = state.metrics.merge(metric_updates)
其中,state.metrics.merge 是指標合併函式,metric_updates 是指標更新。
5. 狀態更新
模型更新狀態,使用 state.replace 函式,更新狀態中的指標。這一步驟可以使用以下程式碼實作:
state = state.replace(metrics=metrics)
其中,state.replace 是狀態更新函式,metrics 是更新後的指標。
6. 訓練迴圈
模型進行訓練迴圈,使用 for 迴圈,迭代多次,直到達到指定的 epoch 數量。這一步驟可以使用以下程式碼實作:
for epoch in range(num_epochs):
#...
其中,num_epochs 是指定的 epoch 數量。
7. 測試
模型進行測試,使用 test_state 狀態,計算測試資料的指標。這一步驟可以使用以下程式碼實作:
test_state = state
for x, y in test_data:
#...
其中,test_state 是測試狀態,test_data 是測試資料。
內容解密:
以上程式碼實作了深度學習模型的訓練流程,包括前向傳播、損失計算、反向傳播和引數更新。模型使用 softmax 交叉熵損失函式計算損失,並使用 optax 函式庫進行最佳化。模型的指標更新和合併使用 state.metrics 函式進行。最終,模型進行測試,計算測試資料的指標。
圖表翻譯:
以下是模型訓練流程的 Mermaid 圖表:
graph LR
A[輸入資料] --> B[前向傳播]
B --> C[損失計算]
C --> D[指標更新]
D --> E[指標合併]
E --> F[狀態更新]
F --> G[訓練迴圈]
G --> H[測試]
這個圖表展示了模型訓練流程的各個步驟,包括前向傳播、損失計算、指標更新、指標合併、狀態更新、訓練迴圈和測試。
訓練狀態指標新增
為了評估模型的效能,我們需要新增指標(metrics)到訓練狀態(TrainState)中。這個過程涉及初始化指標集合、計算指標值以及更新訓練狀態。
初始化指標
指標集合初始化為空,等待計算出來的指標值被新增進去。這一步驟確保了指標的準確性和可靠性。
計算指標
計算指標的函式負責進行所有指標的計算,並將新的指標更新到訓練狀態中。這個過程涉及多個步驟,包括取得模型的輸出、使用softmax損失函式等。
取得模型輸出
模型的輸出是計算指標的基礎。透過模型的輸出,可以計算出各種指標,例如準確率、損失值等。
使用softmax損失函式
Optax函式庫中的softmax損失函式被用於計算模型的損失值。這個函式對於評估模型的效能至關重要。
計算中間指標
模型輸出字典被提供給計算中間指標的函式,以便計算出所有必要的指標。這些指標對於評估模型的效能和調整模型引數非常重要。
內容解密:
import optax
def calculate_metrics(train_state, model_output):
# 初始化指標集合
metrics = {}
# 取得模型輸出
output = model_output
# 使用softmax損失函式
loss_fn = optax.softmax_cross_entropy
# 計算損失值
loss = loss_fn(output)
# 更新指標集合
metrics['loss'] = loss
# 更新訓練狀態
train_state.metrics = metrics
return train_state
圖表翻譯:
flowchart TD
A[初始化指標] --> B[計算指標]
B --> C[取得模型輸出]
C --> D[使用softmax損失函式]
D --> E[計算中間指標]
E --> F[更新訓練狀態]
這個過程確保了模型的效能被準確地評估和跟蹤,為模型的調整和最佳化提供了基礎。
影像分類別使用ResNet
在深度學習中,影像分類別是一個基本任務,旨在根據影像的視覺特徵將其分類別到預先定義的類別中。ResNet(Residual Network)是一種成功的卷積神經網路架構,特別適合於影像分類別任務。
ResNet架構
ResNet的核心思想是引入殘差連線(residual connection),使得網路可以學習到更深層次的特徵表示。這種架構允許網路更容易地最佳化,並且可以達到更高的準確率。
殘差塊
ResNet中的殘差塊是其核心組成部分。每個殘差塊包含兩個卷積層和一個殘差連線。殘差連線允許網路學習到更抽象的特徵,並且可以減少梯度消失問題。
flowchart TD
A[輸入] --> B[卷積層1]
B --> C[啟用函式]
C --> D[卷積層2]
D --> E[啟用函式]
E --> F[殘差連線]
F --> G[輸出]
網路架構
ResNet的網路架構通常由多個殘差塊堆積疊而成。每個殘差塊都會學習到不同的特徵,並且可以根據需要堆積疊多個殘差塊以達到所需的深度。
flowchart TD
A[輸入] --> B[殘差塊1]
B --> C[殘差塊2]
C --> D[殘差塊3]
D --> E[全連線層]
E --> F[輸出]
訓練過程
在訓練ResNet時,我們需要更新網路的引數以最小化損失函式。這個過程通常涉及多個epoch,每個epoch都會對整個訓練集進行一次迭代。
更新引數
在每個epoch中,我們會計算損失函式並更新網路的引數以最小化損失。
for epoch in range(num_epochs):
for x, y in train_loader:
# 前向傳播
output = model(x)
loss = criterion(output, y)
# 後向傳播
optimizer.zero_grad()
loss.backward()
optimizer.step()
評估模型
在每個epoch結束後,我們會評估模型在驗證集上的效能,以確保模型沒有過度擬合。
for epoch in range(num_epochs):
# 訓練模型
for x, y in train_loader:
#...
# 評估模型
model.eval()
with torch.no_grad():
total_correct = 0
for x, y in val_loader:
output = model(x)
_, predicted = torch.max(output, 1)
total_correct += (predicted == y).sum().item()
accuracy = total_correct / len(val_loader.dataset)
print(f'Epoch {epoch+1}, Accuracy: {accuracy:.4f}')
圖表翻譯:
上述Mermaid圖表展示了ResNet的架構,包括殘差塊和網路架構。圖表中,每個盒子代表一個層或一個模組,箭頭代表了資料的流動方向。這種架構允許ResNet學習到更抽象的特徵,並且可以減少梯度消失問題。
影像分類別使用ResNet
在之前的章節中,我們使用Flax訓練了一個簡單的神經網路,但該模型仍然是一個非常基礎的多層感知器(MLP),距離目前的最佳實踐還有很大差距。下一節將實作更先進的殘差網路(ResNet)進行影像分類別。
11.2 影像分類別使用ResNet
您可能知道,使用MLP進行影像處理任務通常不是最佳選擇,因為存在更專業和高效的解決方案,如卷積神經網路(CNN)。CNN在多年來不斷演進,通常人們使用殘差網路(ResNets)。在這裡,我們將使用第9章中的Dogs vs. Cats資料集,並實作一個相對簡單的ResNet,這可能是此任務的合適選擇。
現代電腦視覺
典型的影像分類別問題的實際解決方案可能涉及CNN。CNN是一種適合用於影像的神經網路,具有卓越的性質,如平移等變性和學習區域性特徵的能力。CNN長期被認為是最好的神經網路型別,適合於影像處理任務。但是在近幾年,情況發生了變化。
從技術架構視角來看,本文深入探討了深度學習模型訓練的關鍵環節,涵蓋了最佳化器狀態管理、引數更新策略、訓練迴圈架構以及指標計算方法。分析比較了MLP和CNN等不同網路架構在影像分類別任務中的適用性,並闡述了ResNet的核心思想——殘差連線如何提升模型效能。然而,僅僅依靠ResNet架構並不能保證最佳效能,模型的超引數調整、資料增強策略以及訓練過程的監控同樣至關重要。展望未來,根據Transformer的模型在電腦視覺領域的應用日益廣泛,Vision Transformer等架構展現出巨大的潛力,可能在特定任務上超越CNN。對於追求高效能的影像分類別任務,技術團隊應密切關注Transformer模型的發展趨勢,並積極探索其與ResNet等CNN架構的融合方案。玄貓認為,深度學習模型的訓練是一個持續最佳化的過程,需要不斷探索新的架構和訓練策略,才能在特定任務上達到最佳效能。