多頭注意力機制允許模型從不同角度理解輸入資料,捕捉更豐富的語義資訊。本文從單頭注意力出發,逐步解釋如何建構多頭注意力模組,並使用 PyTorch 框架提供具體的程式碼實作。其中,MultiHeadAttentionWrapper 類別以模組化的方式封裝了多個單頭注意力例項,而 MultiHeadAttention 類別則整合了權重計算和注意力計算,提供更精簡的實作。程式碼中使用張量操作如 .view().transpose() 進行資料的重塑和維度調整,並透過矩陣乘法高效地計算注意力權重和上下文向量。最後,文章也討論瞭如何調整注意力頭數、改進計算效率以及結合其他技術來最佳化多頭注意力機制。

多頭注意力機制的實作與擴充套件

在前面的章節中,我們探討了因果注意力(causal attention)的概念和實作,並在神經網路中實作了該機制。接下來,我們將進一步擴充套件這一概念,實作多頭注意力(multi-head attention)模組,該模組能夠平行實作多個因果注意力機制。

從單頭注意力擴充套件到多頭注意力

我們的最終目標是將先前實作的因果注意力類別擴充套件到多個頭部,這也就是所謂的多頭注意力。多頭注意力是指將注意力機制分成多個「頭」,每個頭獨立運作。在這種情況下,單一的因果注意力模組可以被視為單頭注意力,其中只有一組注意力權重按順序處理輸入。

堆積疊多個單頭注意力層

在實際操作中,實作多頭注意力涉及建立多個自注意力機制(self-attention mechanism)的例項(如圖3.18所示),每個例項都有自己的權重,然後合併它們的輸出。使用多個自注意力機制的例項可能會帶來較高的計算強度,但對於像根據Transformer的大語言模型(LLM)這樣的複雜模式識別至關重要。

多頭注意力模組的結構

圖3.24展示了多頭注意力模組的結構,該模組由多個單頭注意力模組堆積疊而成。如前所述,多頭注意力背後的主要思想是透過不同的學習線性投影(即將輸入資料(如注意力機制中的查詢、鍵和值向量)乘以權重矩陣)多次(平行)執行注意力機制。

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)
        ])

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

多頭注意力的運作範例

例如,如果我們使用具有兩個注意力頭(num_heads=2)的MultiHeadAttentionWrapper類別,並將CausalAttention的輸出維度設為2,那麼我們將獲得一個四維的上下文向量(d_out*num_heads=4),如圖3.25所示。

torch.manual_seed(123)
context_length = batch.shape[1]  # 這是 token 的數量
d_in, d_out = 3, 2

內容解密:

  1. MultiHeadAttentionWrapper 類別初始化:在初始化函式中,我們建立了一個 nn.ModuleList,其中包含了 num_headsCausalAttention 例項。這使得我們能夠平行執行多個注意力機制。
  2. forward 方法:在 forward 方法中,我們將輸入 x 分別傳遞給每個 CausalAttention 頭,然後將所有頭的輸出沿著最後一個維度(dim=-1)進行拼接,得到最終的輸出。
  3. 多頭注意力的優點:透過使用多個注意力頭,我們的模型能夠捕捉到輸入資料的不同方面的資訊,從而提高模型的表達能力和泛化能力。

3.6 擴充套件單頭注意力機制至多頭注意力機制

在前面的章節中,我們實作了一個單頭注意力機制(single-head attention)。現在,我們將擴充套件這個概念到多頭注意力機制(multi-head attention)。

3.6.1 使用多頭注意力包裝器實作多頭注意力

首先,我們實作了一個 MultiHeadAttentionWrapper 類別,將多個單頭注意力模組結合起來。這個類別的實作如下:

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

輸出結果如下:

tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],
[-0.5874, 0.0058, 0.5891, 0.3257],
[-0.6300, -0.0632, 0.6202, 0.3860],
[-0.5675, -0.0843, 0.5478, 0.3589],
[-0.5526, -0.0981, 0.5321, 0.3428],
[-0.5299, -0.1081, 0.5077, 0.3493]],
[[-0.4519, 0.2216, 0.4772, 0.1063],
[-0.5874, 0.0058, 0.5891, 0.3257],
[-0.6300, -0.0632, 0.6202, 0.3860],
[-0.5675, -0.0843, 0.5478, 0.3589],
[-0.5526, -0.0981, 0.5321, 0.3428],
[-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])

內容解密:

  1. 輸出結果解析:輸出結果是一個張量(tensor),代表上下文向量(context vectors)。第一維度代表輸入文字的數量(本例中為2),第二維度代表每個輸入文字中的 token 數量(本例中為6),第三維度代表每個 token 的嵌入維度(本例中為4)。
  2. 形狀分析context_vecs.shape 的輸出結果表明,輸出張量的形狀為 (2, 6, 4),對應於批次大小、序列長度和嵌入維度。
  3. 多頭注意力的作用:多頭注意力機制允許多個注意力頭平行處理輸入序列,從而捕捉不同的語義資訊。

練習3.2:傳回二維嵌入向量

修改 MultiHeadAttentionWrapper 的輸入引數,使得輸出上下文向量的維度變為2,而不是4,同時保持 num_heads=2 的設定。

解題思路:

  • 無需修改類別實作,只需更改輸入引數。
  • 調整 d_out 的值,使其能夠被 num_heads 整除,並且最終輸出維度為2。

3.6.2 使用權重分割實作多頭注意力

除了使用 MultiHeadAttentionWrapper 外,我們還可以透過合併 CausalAttentionMultiHeadAttentionWrapper 的功能,實作一個更高效的多頭注意力類別:MultiHeadAttention

MultiHeadAttention類別實作

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

內容解密:

  1. MultiHeadAttention類別概述:該類別整合了多頭注意力機制,透過對輸入進行線性變換並分割成多個頭來計算注意力。
  2. 張量操作解析
    • 使用 .view() 方法將查詢、鍵和值的張量重塑,以表示多個頭。
    • 使用 .transpose() 方法調整張量的維度順序,以便進行後續的矩陣乘法運算。
  3. 注意力權重計算:透過查詢和鍵的點積計算注意力分數,並使用掩碼和 softmax 函式得到注意力權重。
  4. 上下文向量計算:將注意力權重與值相乘,得到上下文向量,並透過線性層進行最終的投影變換。

圖示說明

此圖示展示了 MultiHeadAttention 的內部運作機制:

@startuml
skinparam backgroundColor #FEFEFE
skinparam defaultTextAlignment center
skinparam rectangleBackgroundColor #F5F5F5
skinparam rectangleBorderColor #333333
skinparam arrowColor #333333

title 圖示說明

rectangle "線性變換" as node1
rectangle "重塑與轉置" as node2
rectangle "計算注意力" as node3
rectangle "與值相乘" as node4
rectangle "合併與投影" as node5

node1 --> node2
node2 --> node3
node3 --> node4
node4 --> node5

@enduml

圖示解析:

  • 輸入張量經過線性變換得到查詢、鍵和值張量。
  • 重塑與轉置操作將張量調整為適合多頭注意力的形狀。
  • 計算注意力步驟透過查詢和鍵的互動得到注意力權重。
  • 上下文向量由注意力權重和值計算而得,最終經過合併和投影得到輸出。

多頭注意力機制的實作與最佳化

在深度學習的自然語言處理領域中,注意力機制(Attention Mechanism)扮演著至關重要的角色,特別是在 Transformer 架構的模型中。本篇文章將探討多頭注意力(Multi-Head Attention)的實作細節及其最佳化方法。

多頭注意力的基本概念

多頭注意力是注意力機制的一種擴充套件,它允許模型同時關注輸入序列的不同部分,從而捕捉更豐富的上下文資訊。這種機制透過將輸入序列對映到多個不同的表示空間(或稱為“頭”),並在每個空間中計算注意力權重,最終結合所有頭的輸出來得到最終的表示。

多頭注意力的實作

在 PyTorch 中實作多頭注意力涉及以下幾個關鍵步驟:

  1. 初始化權重矩陣:為每個頭初始化獨立的權重矩陣,用於將輸入序列轉換為查詢(Query)、鍵(Key)和值(Value)向量。
  2. 計算查詢、鍵和值向量:透過矩陣乘法將輸入序列與權重矩陣相乘,得到查詢、鍵和值向量。
  3. 轉置和重塑張量:將得到的向量轉置和重塑,以便進行批次矩陣乘法。
  4. 計算注意力權重:透過查詢和鍵向量的點積計算注意力權重,並進行縮放和 softmax 操作。
  5. 計算上下文向量:將注意力權重與值向量相乘,得到上下文向量。
  6. 合併多頭輸出:將所有頭的上下文向量合併,並透過一個輸出投影層得到最終的表示。

程式碼範例

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads):
        super().__init__()
        # 初始化權重矩陣
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)
        self.out_proj = nn.Linear(d_out, d_out)
        # 其他初始化...

    def forward(self, x):
        # 計算查詢、鍵和值向量
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        # 轉置和重塑張量
        queries = queries.view(-1, self.num_heads, self.context_length, self.head_dim)
        keys = keys.view(-1, self.num_heads, self.context_length, self.head_dim)
        values = values.view(-1, self.num_heads, self.context_length, self.head_dim)
        
        # 計算注意力權重和上下文向量
        attention_weights = torch.matmul(queries, keys.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attention_weights = nn.functional.softmax(attention_weights, dim=-1)
        context_vectors = torch.matmul(attention_weights, values)
        
        # 合併多頭輸出
        context_vectors = context_vectors.view(-1, self.context_length, self.d_out)
        output = self.out_proj(context_vectors)
        return output

# 初始化模型並進行前向傳播
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs.shape)

內容解密:

  1. 多頭注意力的實作:上述程式碼展示瞭如何在 PyTorch 中實作多頭注意力機制。主要步驟包括初始化權重矩陣、計算查詢/鍵/值向量、轉置和重塑張量、計算注意力權重和上下文向量,以及合併多頭輸出。
  2. 批次矩陣乘法:在計算注意力權重時,使用了批次矩陣乘法,這使得模型能夠高效地處理多個頭和多個輸入序列。
  3. 輸出投影層:在合併多頭輸出後,使用了一個輸出投影層來得到最終的表示。這一步驟是可選的,但在許多大語言模型中被廣泛採用。

最佳化和應用

多頭注意力機制由於其能夠捕捉輸入序列中複雜的依賴關係,因此在許多自然語言處理任務中取得了巨大的成功。為了進一步最佳化多頭注意力的效能,可以考慮以下幾個方向:

  • 動態調整注意力頭的數量:根據輸入序列的複雜度和任務需求動態調整注意力頭的數量,可以提高模型的效率和效能。
  • 改進注意力機制的計算效率:研究更高效的注意力機制計算方法,如稀疏注意力機制,可以減少模型的計算開銷。
  • 結合其他技術提升模型效能:將多頭注意力與其他先進技術(如相對位置編碼、層歸一化等)相結合,可以進一步提升模型的效能。