大語言模型在自然語言處理領域展現強大能力,但應用於特定分類別任務時,需要進行微調以提升效能。本篇著重於實作層面,以程式碼示例說明如何調整模型以適應二元分類別任務。首先,我們需要理解模型的輸出張量結構,其中包含批次大小、輸入令牌數量和嵌入維度等資訊。接著,由於因果注意力機制的作用,我們可以提取最後一個輸出令牌的嵌入向量,它蘊含了前面所有令牌的上下文資訊,作為分類別依據。為了將其轉換為類別標籤,我們可以使用 softmax 函式電腦率分佈,並透過 argmax 函式取得最可能標籤。在評估模型效能時,我們可以計算訓練集、驗證集和測試集的準確率,並使用交叉熵損失函式作為最佳化目標。

微調大語言模型進行分類別任務

在前面的章節中,我們探討瞭如何對大語言模型(LLM)進行微調,以使其適應特定的分類別任務。在本章中,我們將探討如何實作這一目標。

瞭解模型的輸出

首先,讓我們瞭解模型的輸出。當我們將輸入傳遞給模型時,它會輸出一個張量,其中包含與輸入令牌相對應的嵌入向量。在我們的例子中,輸出張量的形狀是 [1, 4, 2],其中 1 表示批次大小,4 表示輸入令牌的數量,2 表示輸出嵌入的維度。

with torch.no_grad():
    outputs = model(inputs)
print("Outputs:\n", outputs)
print("Outputs dimensions:", outputs.shape)

輸出結果如下:

Outputs:
 tensor([[[-1.5854, 0.9904],
         [-3.7235, 7.4548],
         [-2.2661, 6.6049],
         [-3.5983, 3.9902]]])
Outputs dimensions: torch.Size([1, 4, 2])

內容解密:

  1. 輸出張量的第一維度代表批次大小,在本例中為 1,表示我們正在處理單個輸入序列。
  2. 第二維度代表輸入令牌的數量,本例中為 4,對應於輸入序列中的四個令牌。
  3. 第三維度代表輸出嵌入的維度,本例中為 2,這是因為我們修改了模型的輸出層,以使其適應二元分類別任務。

提取最後一個輸出令牌

在我們的分類別任務中,我們感興趣的是最後一個輸出令牌,因為它包含了前面所有令牌的資訊。我們可以使用以下程式碼提取最後一個輸出令牌:

print("Last output token:", outputs[:, -1, :])

輸出結果如下:

Last output token: tensor([[-3.5983, 3.9902]])

內容解密:

  1. outputs[:, -1, :] 這行程式碼提取了輸出張量中的最後一個令牌。
  2. -1 索引表示最後一個元素,因此 outputs[:, -1, :] 傳回了最後一個令牌的嵌入向量。

因果注意力機制的作用

因果注意力機制是一種特殊的注意力機制,它確保每個令牌只能關注其自身和之前的令牌,而不能關注之後的令牌。這種機制使得最後一個令牌能夠累積前面所有令牌的資訊,因此在我們的分類別任務中,我們重點關注最後一個令牌。

@startuml
skinparam backgroundColor #FEFEFE
skinparam defaultTextAlignment center
skinparam rectangleBackgroundColor #F5F5F5
skinparam rectangleBorderColor #333333
skinparam arrowColor #333333

title 因果注意力機制的作用

rectangle "因果注意力機制" as node1
rectangle "包含前面所有令牌的資訊" as node2

node1 --> node2

@enduml

此圖示說明瞭因果注意力機制如何使得最後一個令牌包含前面所有令牌的資訊,從而用於分類別任務。

將輸出轉換為類別標籤預測

為了將模型的輸出轉換為類別標籤預測,我們需要對輸出進行 softmax 處理,然後使用 argmax 函式取得最高機率的索引位置。

# 將輸出轉換為機率
probabilities = torch.softmax(outputs[:, -1, :], dim=1)
# 取得類別標籤預測
predicted_labels = torch.argmax(probabilities, dim=1)

內容解密:

  1. torch.softmax(outputs[:, -1, :], dim=1) 將最後一個令牌的輸出轉換為機率分佈。
  2. torch.argmax(probabilities, dim=1) 取得機率分佈中最高機率的索引位置,即類別標籤預測。

評估模型的效能

在微調模型之前,我們需要實作評估函式來計算模型的分類別損失和準確率。

def calculate_accuracy(logits, labels):
    probabilities = torch.softmax(logits, dim=1)
    predicted_labels = torch.argmax(probabilities, dim=1)
    accuracy = (predicted_labels == labels).sum().item() / len(labels)
    return accuracy

內容解密:

  1. 此函式計算模型的準確率,方法是比較預測標籤和真實標籤。
  2. torch.softmax(logits, dim=1) 將 logits 轉換為機率分佈。
  3. torch.argmax(probabilities, dim=1) 取得機率分佈中最高機率的索引位置,即預測標籤。

分類別任務中的模型微調

計算分類別準確率

在進行分類別任務時,評估模型的表現是非常重要的。分類別準確率是衡量模型預測正確率的指標。為了計算這個指標,我們需要對資料集中的所有樣本進行預測,並比較預測結果與真實標籤。

以下是一個具體的例子,展示瞭如何計算模型的最後一個輸出標記(token)的類別標籤:

print("Last output token:", outputs[:, -1, :])

假設輸出的張量(tensor)對應於最後一個標記的值為:

Last output token: tensor([[-3.5983, 3.9902]])

我們可以透過 softmax 函式將這些值轉換為機率,並獲得預測的類別標籤:

probas = torch.softmax(outputs[:, -1, :], dim=-1)
label = torch.argmax(probas)
print("Class label:", label.item())

由於最大的輸出值直接對應於最高的機率分數,因此這裡使用 softmax 函式是可選的。我們可以直接簡化程式碼而不使用 softmax:

logits = outputs[:, -1, :]
label = torch.argmax(logits)
print("Class label:", label.item())

內容解密:

  1. outputs[:, -1, :]: 這行程式碼提取了模型輸出的最後一個標記的 logits 值。
  2. torch.softmax(outputs[:, -1, :], dim=-1): 對最後一個標記的 logits 值應用 softmax 函式,將其轉換為機率分佈。
  3. torch.argmax(probas): 找到機率分佈中最大值的索引,即預測的類別標籤。
  4. 由於 logits 值已經代表了未歸一化的機率,因此可以直接使用 torch.argmax(logits) 獲得預測的類別標籤。

定義計算分類別準確率的函式

為了評估模型的表現,我們定義了一個名為 calc_accuracy_loader 的函式,用於計算資料載入器(data loader)中樣本的分類別準確率。

def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    model.eval()
    correct_predictions, num_examples = 0, 0
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            input_batch = input_batch.to(device)
            target_batch = target_batch.to(device)
            with torch.no_grad():
                logits = model(input_batch)[:, -1, :]
                predicted_labels = torch.argmax(logits, dim=-1)
                num_examples += predicted_labels.shape[0]
                correct_predictions += (predicted_labels == target_batch).sum().item()
        else:
            break
    return correct_predictions / num_examples

內容解密:

  1. model.eval(): 將模型設定為評估模式,關閉 dropout 和 BatchNorm 的訓練特定行為。
  2. correct_predictionsnum_examples: 用於跟蹤正確預測的數量和總樣本數。
  3. logits = model(input_batch)[:, -1, :]: 取得模型對輸入批次的輸出,並提取最後一個標記的 logits 值。
  4. predicted_labels = torch.argmax(logits, dim=-1): 計算預測的類別標籤。
  5. correct_predictions += (predicted_labels == target_batch).sum().item(): 累積正確預測的數量。

計算訓練、驗證和測試集的準確率

使用定義好的 calc_accuracy_loader 函式,我們可以計算模型在訓練集、驗證集和測試集上的準確率。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
torch.manual_seed(123)
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)
test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

輸出結果可能類別似如下:

Training accuracy: 46.25%
Validation accuracy: 45.00%
Test accuracy: 48.75%

這些結果表明模型的預測準確率接近隨機猜測(本例中為 50%)。為了提高預測準確率,我們需要對模型進行微調。

定義損失函式

為了最佳化模型的表現,我們需要定義一個損失函式。在分類別任務中,交叉熵損失(cross-entropy loss)是一個常用的選擇。

def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)
    logits = model(input_batch)[:, -1, :]
    loss = torch.nn.functional.cross_entropy(logits, target_batch)
    return loss

內容解密:

  1. logits = model(input_batch)[:, -1, :]: 取得模型對輸入批次的輸出,並提取最後一個標記的 logits 值。
  2. torch.nn.functional.cross_entropy(logits, target_batch): 計算 logits 和真實標籤之間的交叉熵損失。

計算資料載入器的損失

為了計算整個資料載入器的平均損失,我們定義了 calc_loss_loader 函式。

def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

內容解密:

  1. total_loss: 用於累積所有批次的損失總和。
  2. loss = calc_loss_batch(input_batch, target_batch, model, device): 計算每個批次的損失。
  3. return total_loss / num_batches: 傳回平均損失。

微調模型

在定義了損失函式之後,我們可以開始微調模型,以最小化訓練集上的損失,從而提高分類別準確率。