大型語言模型的預訓練是現代自然語言處理領域最重要的技術突破之一。這種技術使模型能夠從大規模無標籤文字資料中學習語言的內在結構和語義知識,為後續的各種下游任務奠定堅實基礎。預訓練過程涉及資料準備、模型架構設計、損失函式計算和訓練策略最佳化等多個關鍵環節。本文將深入探討這些技術細節,透過完整的程式碼範例展示如何使用 PyTorch 實作大型語言模型的預訓練流程。

大型語言模型預訓練的基本概念

大型語言模型的預訓練本質上是一個自監督學習過程。模型透過預測序列中的下一個詞彙來學習語言模式,這種任務被稱為語言模型建模或因果語言模型建模。與傳統的監督學習不同,預訓練不需要人工標註的標籤,而是直接利用文字資料本身作為訓練信號。

在預訓練過程中,模型接收一段文字作為輸入,並嘗試預測每個位置的下一個詞彙。透過比較預測結果與實際詞彙,模型計算損失並調整參數。經過大量資料的訓練後,模型能夠學習到豐富的語言知識,包括語法結構、語義關係、世界知識甚至推理能力。

預訓練大型語言模型需要巨大的計算資源。以 Llama 2 為例,其訓練成本高達約 69 萬美元,使用了數千個 GPU 運算數週才完成。這種龐大的資源需求凸顯了最佳化訓練流程和提高效率的重要性。

以下是大型語言模型預訓練流程的架構圖:

@startuml
!define PLANTUML_FORMAT svg
!theme _none_

skinparam dpi auto
skinparam shadowing false
skinparam linetype ortho
skinparam roundcorner 5
skinparam defaultFontName "Microsoft JhengHei UI"
skinparam defaultFontSize 16
skinparam minClassWidth 100

rectangle "LLM 預訓練流程" {
    rectangle "原始文字資料" as raw
    rectangle "分詞與編碼" as tokenize
    rectangle "資料分割" as split
    rectangle "批次載入" as load
    rectangle "模型前向傳播" as forward
    rectangle "損失計算" as loss
    rectangle "反向傳播" as backward
    rectangle "參數更新" as update
}

raw --> tokenize : 文字處理
tokenize --> split : token 序列
split --> load : 訓練/驗證集
load --> forward : 批次資料
forward --> loss : logits
loss --> backward : 交叉熵
backward --> update : 梯度
update --> forward : 迭代

@enduml

資料準備與分割策略

高品質的資料準備是預訓練成功的基礎。在開始訓練之前,我們需要將原始文字資料轉換為模型能夠處理的格式,並將其分割為訓練集和驗證集。驗證集用於監控模型在訓練過程中的泛化能力,幫助我們識別過擬合問題。

資料分割的比例通常為 90% 用於訓練,10% 用於驗證。這種分割確保了模型有足夠的資料進行學習,同時保留了獨立的資料集來評估模型表現。對於大規模預訓練,驗證集的大小可能會更小,因為訓練資料已經足夠多樣化。

以下是完整的資料準備和分割系統實作:

# 大型語言模型資料準備系統
# 實作文字載入、分詞、分割和批次處理功能

import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Optional, Dict
import tiktoken
from dataclasses import dataclass

@dataclass
class DataConfig:
    """
    資料配置參數

    儲存資料處理相關的各種參數設定
    """
    train_ratio: float = 0.9        # 訓練集比例
    batch_size: int = 4             # 批次大小
    max_length: int = 256           # 最大序列長度
    stride: int = 128               # 滑動視窗步長
    shuffle: bool = True            # 是否打亂訓練資料
    num_workers: int = 0            # 資料載入工作執行緒數

class TextDataset(Dataset):
    """
    文字資料集類別

    將分詞後的文字資料封裝為 PyTorch Dataset,
    支援滑動視窗方式建立輸入輸出對
    """

    def __init__(
        self,
        token_ids: List[int],
        max_length: int,
        stride: int
    ):
        """
        初始化文字資料集

        Args:
            token_ids: 分詞後的 token ID 列表
            max_length: 每個樣本的最大序列長度
            stride: 滑動視窗步長
        """
        self.token_ids = token_ids
        self.max_length = max_length
        self.stride = stride

        # 計算樣本數量
        # 每個樣本包含 max_length 個輸入 token
        # 和 max_length 個目標 token(向右偏移一位)
        self.num_samples = max(0, (len(token_ids) - max_length) // stride)

    def __len__(self) -> int:
        """回傳資料集大小"""
        return self.num_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        取得指定索引的樣本

        Args:
            idx: 樣本索引

        Returns:
            輸入 tensor 和目標 tensor 的元組
        """
        # 計算起始位置
        start_pos = idx * self.stride

        # 建立輸入序列(長度為 max_length)
        input_ids = self.token_ids[start_pos:start_pos + self.max_length]

        # 建立目標序列(向右偏移一位)
        target_ids = self.token_ids[start_pos + 1:start_pos + self.max_length + 1]

        # 轉換為 tensor
        input_tensor = torch.tensor(input_ids, dtype=torch.long)
        target_tensor = torch.tensor(target_ids, dtype=torch.long)

        return input_tensor, target_tensor

class LLMDataPreparer:
    """
    大型語言模型資料準備器

    負責文字載入、分詞、分割和建立資料載入器
    """

    def __init__(self, tokenizer_name: str = "gpt2"):
        """
        初始化資料準備器

        Args:
            tokenizer_name: 分詞器名稱
        """
        # 載入分詞器
        # tiktoken 是 OpenAI 開發的高效分詞器
        self.tokenizer = tiktoken.get_encoding(tokenizer_name)
        print(f"已載入分詞器: {tokenizer_name}")
        print(f"詞彙表大小: {self.tokenizer.n_vocab}")

    def load_text(self, file_path: str) -> str:
        """
        載入文字檔案

        Args:
            file_path: 文字檔案路徑

        Returns:
            文字內容字串
        """
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()

        # 顯示資料統計
        print(f"載入檔案: {file_path}")
        print(f"字元數: {len(text)}")
        print(f"預估 token 數: {len(self.tokenizer.encode(text))}")

        return text

    def tokenize(self, text: str) -> List[int]:
        """
        將文字轉換為 token ID

        Args:
            text: 輸入文字

        Returns:
            token ID 列表
        """
        token_ids = self.tokenizer.encode(text)
        return token_ids

    def decode(self, token_ids: List[int]) -> str:
        """
        將 token ID 轉換回文字

        Args:
            token_ids: token ID 列表

        Returns:
            解碼後的文字
        """
        return self.tokenizer.decode(token_ids)

    def split_data(
        self,
        token_ids: List[int],
        train_ratio: float = 0.9
    ) -> Tuple[List[int], List[int]]:
        """
        分割資料為訓練集和驗證集

        Args:
            token_ids: 完整的 token ID 列表
            train_ratio: 訓練集比例

        Returns:
            訓練集和驗證集 token ID 的元組
        """
        # 計算分割點
        split_idx = int(len(token_ids) * train_ratio)

        # 分割資料
        train_ids = token_ids[:split_idx]
        val_ids = token_ids[split_idx:]

        print(f"訓練集 token 數: {len(train_ids)}")
        print(f"驗證集 token 數: {len(val_ids)}")

        return train_ids, val_ids

    def create_dataloaders(
        self,
        text: str,
        config: DataConfig
    ) -> Tuple[DataLoader, DataLoader]:
        """
        建立訓練和驗證資料載入器

        Args:
            text: 原始文字
            config: 資料配置參數

        Returns:
            訓練和驗證資料載入器的元組
        """
        # 分詞
        token_ids = self.tokenize(text)

        # 分割資料
        train_ids, val_ids = self.split_data(token_ids, config.train_ratio)

        # 建立資料集
        train_dataset = TextDataset(
            train_ids,
            config.max_length,
            config.stride
        )
        val_dataset = TextDataset(
            val_ids,
            config.max_length,
            config.stride
        )

        print(f"訓練樣本數: {len(train_dataset)}")
        print(f"驗證樣本數: {len(val_dataset)}")

        # 建立資料載入器
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=config.shuffle,
            num_workers=config.num_workers,
            drop_last=True  # 丟棄不完整的最後一個批次
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=config.batch_size,
            shuffle=False,  # 驗證集不需要打亂
            num_workers=config.num_workers,
            drop_last=False
        )

        print(f"訓練批次數: {len(train_loader)}")
        print(f"驗證批次數: {len(val_loader)}")

        return train_loader, val_loader

    def prepare_from_file(
        self,
        file_path: str,
        config: DataConfig
    ) -> Tuple[DataLoader, DataLoader]:
        """
        從檔案準備資料載入器

        這是一個便利方法,整合載入和處理流程

        Args:
            file_path: 文字檔案路徑
            config: 資料配置參數

        Returns:
            訓練和驗證資料載入器的元組
        """
        # 載入文字
        text = self.load_text(file_path)

        # 建立資料載入器
        return self.create_dataloaders(text, config)

# 使用範例
if __name__ == "__main__":
    # 建立資料準備器
    preparer = LLMDataPreparer()

    # 設定資料配置
    config = DataConfig(
        train_ratio=0.9,
        batch_size=4,
        max_length=256,
        stride=128
    )

    # 準備範例文字
    sample_text = """
    Large language models have revolutionized natural language processing.
    They learn from vast amounts of text data without explicit labels.
    The pretraining process involves predicting the next token in a sequence.
    This self-supervised learning approach enables models to acquire rich
    linguistic knowledge and world understanding.
    """ * 100  # 重複以產生足夠的訓練資料

    # 建立資料載入器
    train_loader, val_loader = preparer.create_dataloaders(sample_text, config)

    # 檢查批次資料
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print(f"\n批次 {batch_idx}:")
        print(f"  輸入形狀: {inputs.shape}")
        print(f"  目標形狀: {targets.shape}")

        # 只顯示前兩個批次
        if batch_idx >= 1:
            break

    # 顯示解碼後的文字
    print("\n解碼範例:")
    sample_input = inputs[0].tolist()
    decoded = preparer.decode(sample_input[:50])
    print(f"  {decoded}...")

這個資料準備系統實作了從原始文字到可用於訓練的批次資料的完整流程。滑動視窗機制確保了文字資料被充分利用,每個位置都會作為多個樣本的一部分參與訓練。這種設計提高了資料效率,特別適合處理有限的訓練資料。

交叉熵損失與困惑度計算

損失函式是訓練大型語言模型的核心組件,它量化了模型預測與實際目標之間的差距。對於語言模型,我們通常使用交叉熵損失函式,它衡量預測機率分布與真實分布之間的差異。

交叉熵損失的計算過程涉及幾個步驟。首先,模型輸出每個位置的 logits,這是對所有可能詞彙的未正規化分數。然後,透過 softmax 函式將 logits 轉換為機率分布。最後,計算真實詞彙的負對數機率並取平均。

困惑度是交叉熵損失的指數形式,它提供了更直觀的解釋。困惑度可以理解為模型在預測下一個詞彙時面臨的「有效選擇數」。例如,困惑度為 100 表示模型平均需要在 100 個等可能的候選中進行選擇。困惑度越低,表示模型的預測越確定,性能越好。

以下是損失計算系統的完整實作:

# 大型語言模型損失計算系統
# 實作交叉熵損失和困惑度計算

import torch
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

@dataclass
class LossStatistics:
    """
    損失統計資料

    儲存訓練過程中的各種損失指標
    """
    loss: float                 # 交叉熵損失
    perplexity: float          # 困惑度
    num_tokens: int            # 處理的 token 數
    num_correct: int           # 正確預測數(用於 accuracy)

class LossCalculator:
    """
    損失計算器

    提供各種損失計算和評估功能
    """

    def __init__(self, vocab_size: int, pad_token_id: Optional[int] = None):
        """
        初始化損失計算器

        Args:
            vocab_size: 詞彙表大小
            pad_token_id: 填充 token 的 ID(用於遮罩)
        """
        self.vocab_size = vocab_size
        self.pad_token_id = pad_token_id

    def compute_batch_loss(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        reduction: str = "mean"
    ) -> torch.Tensor:
        """
        計算單一批次的交叉熵損失

        Args:
            logits: 模型輸出的 logits
                   形狀: (batch_size, seq_length, vocab_size)
            targets: 目標 token ID
                    形狀: (batch_size, seq_length)
            reduction: 損失縮減方式("mean", "sum", "none")

        Returns:
            計算後的損失 tensor
        """
        # 取得維度
        batch_size, seq_length, vocab_size = logits.shape

        # 將 logits 扁平化以符合 cross_entropy 的輸入格式
        # 從 (batch_size, seq_length, vocab_size)
        # 變為 (batch_size * seq_length, vocab_size)
        logits_flat = logits.view(-1, vocab_size)

        # 將 targets 扁平化
        # 從 (batch_size, seq_length) 變為 (batch_size * seq_length)
        targets_flat = targets.view(-1)

        # 計算交叉熵損失
        # cross_entropy 內部會自動套用 log_softmax
        loss = F.cross_entropy(
            logits_flat,
            targets_flat,
            reduction=reduction,
            ignore_index=self.pad_token_id if self.pad_token_id is not None else -100
        )

        return loss

    def compute_perplexity(self, loss: torch.Tensor) -> torch.Tensor:
        """
        計算困惑度

        困惑度 = exp(交叉熵損失)

        Args:
            loss: 交叉熵損失

        Returns:
            困惑度值
        """
        return torch.exp(loss)

    def compute_token_probabilities(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor
    ) -> torch.Tensor:
        """
        計算目標 token 的機率

        Args:
            logits: 模型輸出的 logits
            targets: 目標 token ID

        Returns:
            每個位置目標 token 的機率
        """
        # 套用 softmax 取得機率分布
        probabilities = F.softmax(logits, dim=-1)

        # 取得目標 token 的機率
        # 使用 gather 從機率分布中選取目標 token 的機率
        batch_size, seq_length, _ = probabilities.shape
        target_probs = probabilities.gather(
            dim=-1,
            index=targets.unsqueeze(-1)
        ).squeeze(-1)

        return target_probs

    def compute_accuracy(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor
    ) -> float:
        """
        計算預測準確率

        Args:
            logits: 模型輸出的 logits
            targets: 目標 token ID

        Returns:
            準確率(0 到 1 之間)
        """
        # 取得預測結果(機率最高的 token)
        predictions = logits.argmax(dim=-1)

        # 計算正確預測數
        correct = (predictions == targets).sum().item()
        total = targets.numel()

        return correct / total

    def compute_detailed_statistics(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor
    ) -> LossStatistics:
        """
        計算詳細的損失統計

        Args:
            logits: 模型輸出的 logits
            targets: 目標 token ID

        Returns:
            LossStatistics 物件
        """
        # 計算損失
        loss = self.compute_batch_loss(logits, targets)

        # 計算困惑度
        perplexity = self.compute_perplexity(loss)

        # 計算準確率
        predictions = logits.argmax(dim=-1)
        num_correct = (predictions == targets).sum().item()
        num_tokens = targets.numel()

        return LossStatistics(
            loss=loss.item(),
            perplexity=perplexity.item(),
            num_tokens=num_tokens,
            num_correct=num_correct
        )

class DataLoaderLossEvaluator:
    """
    資料載入器損失評估器

    在整個資料集上評估模型損失
    """

    def __init__(self, loss_calculator: LossCalculator):
        """
        初始化評估器

        Args:
            loss_calculator: 損失計算器實例
        """
        self.loss_calculator = loss_calculator

    def evaluate(
        self,
        model: torch.nn.Module,
        data_loader: torch.utils.data.DataLoader,
        device: torch.device,
        num_batches: Optional[int] = None
    ) -> Dict[str, float]:
        """
        在資料載入器上評估模型

        Args:
            model: 要評估的模型
            data_loader: 資料載入器
            device: 計算裝置
            num_batches: 要評估的批次數(None 表示全部)

        Returns:
            包含各種指標的字典
        """
        # 設定模型為評估模式
        model.eval()

        # 初始化累積變數
        total_loss = 0.0
        total_tokens = 0
        total_correct = 0

        # 決定要評估的批次數
        if num_batches is None:
            num_batches = len(data_loader)
        else:
            num_batches = min(num_batches, len(data_loader))

        # 如果資料載入器為空,回傳 NaN
        if len(data_loader) == 0:
            return {
                "loss": float("nan"),
                "perplexity": float("nan"),
                "accuracy": float("nan")
            }

        # 禁用梯度計算以提高效率
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(data_loader):
                if batch_idx >= num_batches:
                    break

                # 將資料移到指定裝置
                inputs = inputs.to(device)
                targets = targets.to(device)

                # 前向傳播
                logits = model(inputs)

                # 計算損失
                batch_loss = self.loss_calculator.compute_batch_loss(
                    logits, targets, reduction="sum"
                )

                # 計算正確預測數
                predictions = logits.argmax(dim=-1)
                batch_correct = (predictions == targets).sum().item()

                # 累積統計
                total_loss += batch_loss.item()
                total_tokens += targets.numel()
                total_correct += batch_correct

        # 計算平均指標
        avg_loss = total_loss / total_tokens
        perplexity = torch.exp(torch.tensor(avg_loss)).item()
        accuracy = total_correct / total_tokens

        # 恢復模型為訓練模式
        model.train()

        return {
            "loss": avg_loss,
            "perplexity": perplexity,
            "accuracy": accuracy,
            "num_tokens": total_tokens
        }

def demonstrate_loss_calculation():
    """
    展示損失計算的完整流程
    """
    # 設定參數
    batch_size = 2
    seq_length = 5
    vocab_size = 1000

    # 建立模擬資料
    # logits: 模型對每個位置每個詞彙的分數
    logits = torch.randn(batch_size, seq_length, vocab_size)

    # targets: 每個位置的正確 token ID
    targets = torch.randint(0, vocab_size, (batch_size, seq_length))

    print("輸入資料形狀:")
    print(f"  Logits: {logits.shape}")
    print(f"  Targets: {targets.shape}")

    # 建立損失計算器
    calculator = LossCalculator(vocab_size)

    # 計算基本損失
    loss = calculator.compute_batch_loss(logits, targets)
    print(f"\n交叉熵損失: {loss.item():.4f}")

    # 計算困惑度
    perplexity = calculator.compute_perplexity(loss)
    print(f"困惑度: {perplexity.item():.4f}")

    # 計算目標 token 機率
    target_probs = calculator.compute_token_probabilities(logits, targets)
    print(f"\n目標 token 機率:")
    print(f"  形狀: {target_probs.shape}")
    print(f"  第一個樣本: {target_probs[0].tolist()}")

    # 計算準確率
    accuracy = calculator.compute_accuracy(logits, targets)
    print(f"\n預測準確率: {accuracy:.4f}")

    # 計算詳細統計
    stats = calculator.compute_detailed_statistics(logits, targets)
    print(f"\n詳細統計:")
    print(f"  損失: {stats.loss:.4f}")
    print(f"  困惑度: {stats.perplexity:.4f}")
    print(f"  Token 數: {stats.num_tokens}")
    print(f"  正確預測數: {stats.num_correct}")

    # 展示損失計算的數學原理
    print("\n\n損失計算數學原理展示:")
    print("=" * 50)

    # 選取一個位置進行詳細說明
    sample_logits = logits[0, 0]  # 第一個樣本,第一個位置
    target_id = targets[0, 0].item()

    # 步驟 1: 計算 softmax 機率
    probs = F.softmax(sample_logits, dim=0)
    target_prob = probs[target_id].item()
    print(f"\n步驟 1 - Softmax 機率:")
    print(f"  目標 token ID: {target_id}")
    print(f"  目標 token 機率: {target_prob:.6f}")

    # 步驟 2: 計算負對數機率
    neg_log_prob = -torch.log(probs[target_id]).item()
    print(f"\n步驟 2 - 負對數機率:")
    print(f"  -log(p) = {neg_log_prob:.4f}")

    # 步驟 3: 驗證與 cross_entropy 結果一致
    ce_loss = F.cross_entropy(
        sample_logits.unsqueeze(0),
        torch.tensor([target_id])
    )
    print(f"\n步驟 3 - 驗證:")
    print(f"  手動計算: {neg_log_prob:.4f}")
    print(f"  cross_entropy: {ce_loss.item():.4f}")

# 執行展示
if __name__ == "__main__":
    demonstrate_loss_calculation()

這個損失計算系統展示了交叉熵損失的完整計算流程。理解損失函式的數學原理對於除錯和最佳化模型至關重要。困惑度作為損失的指數形式,提供了更直觀的評估指標。

完整訓練迴圈實作

訓練迴圈是將所有組件整合在一起的核心部分。一個典型的訓練迴圈包含多個 epoch 的迭代,每個 epoch 會遍歷整個訓練資料集。在每個批次中,模型執行前向傳播計算預測,然後計算損失,接著執行反向傳播計算梯度,最後更新模型參數。

有效的訓練迴圈還需要包含監控和評估功能。定期在驗證集上評估模型可以幫助我們追蹤訓練進度並識別過擬合問題。生成樣本文字則提供了模型能力的直觀展示,讓我們能夠觀察模型學習到的語言模式。

以下是完整的訓練系統實作:

# 大型語言模型訓練系統
# 實作完整的訓練迴圈和監控功能

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from typing import List, Dict, Tuple, Optional, Callable
from dataclasses import dataclass, field
import time
from datetime import timedelta

@dataclass
class TrainingConfig:
    """
    訓練配置參數

    儲存訓練過程的各種超參數
    """
    num_epochs: int = 10            # 訓練 epoch 數
    learning_rate: float = 5e-4     # 學習率
    weight_decay: float = 0.1       # 權重衰減
    eval_freq: int = 100            # 評估頻率(步數)
    eval_iters: int = 10            # 評估時的批次數
    generate_freq: int = 500        # 生成樣本頻率
    max_grad_norm: float = 1.0      # 梯度裁剪閾值
    warmup_steps: int = 100         # 學習率預熱步數

@dataclass
class TrainingState:
    """
    訓練狀態

    追蹤訓練過程中的各種狀態變數
    """
    epoch: int = 0                          # 當前 epoch
    global_step: int = 0                    # 全域步數
    tokens_seen: int = 0                    # 已處理的 token 數
    train_losses: List[float] = field(default_factory=list)
    val_losses: List[float] = field(default_factory=list)
    learning_rates: List[float] = field(default_factory=list)
    best_val_loss: float = float("inf")     # 最佳驗證損失

class LRScheduler:
    """
    學習率排程器

    實作學習率預熱和餘弦退火
    """

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_steps: int,
        total_steps: int,
        min_lr: float = 1e-6
    ):
        """
        初始化學習率排程器

        Args:
            optimizer: 最佳化器
            warmup_steps: 預熱步數
            total_steps: 總步數
            min_lr: 最小學習率
        """
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        self.base_lr = optimizer.param_groups[0]["lr"]
        self.current_step = 0

    def step(self):
        """更新學習率"""
        self.current_step += 1

        if self.current_step < self.warmup_steps:
            # 線性預熱
            lr = self.base_lr * self.current_step / self.warmup_steps
        else:
            # 餘弦退火
            progress = (self.current_step - self.warmup_steps) / \
                      max(1, self.total_steps - self.warmup_steps)
            lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * \
                 (1 + torch.cos(torch.tensor(progress * 3.14159)).item())

        # 更新最佳化器的學習率
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

        return lr

class TextGenerator:
    """
    文字生成器

    使用訓練中的模型生成樣本文字
    """

    def __init__(self, tokenizer, device: torch.device):
        """
        初始化文字生成器

        Args:
            tokenizer: 分詞器
            device: 計算裝置
        """
        self.tokenizer = tokenizer
        self.device = device

    def generate(
        self,
        model: nn.Module,
        start_text: str,
        max_new_tokens: int = 50,
        temperature: float = 1.0,
        top_k: int = 50
    ) -> str:
        """
        生成文字

        Args:
            model: 語言模型
            start_text: 起始文字
            max_new_tokens: 要生成的最大 token 數
            temperature: 取樣溫度
            top_k: top-k 取樣參數

        Returns:
            生成的文字
        """
        model.eval()

        # 將起始文字編碼為 token
        input_ids = self.tokenizer.encode(start_text)
        input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)

        with torch.no_grad():
            for _ in range(max_new_tokens):
                # 取得模型輸出
                # 只使用最後一個位置的 logits
                logits = model(input_tensor)
                next_token_logits = logits[:, -1, :]

                # 套用溫度
                next_token_logits = next_token_logits / temperature

                # Top-k 過濾
                if top_k > 0:
                    indices_to_remove = next_token_logits < torch.topk(
                        next_token_logits, top_k
                    )[0][..., -1, None]
                    next_token_logits[indices_to_remove] = float("-inf")

                # 取樣下一個 token
                probs = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                # 將新 token 加入序列
                input_tensor = torch.cat([input_tensor, next_token], dim=1)

        # 解碼生成的序列
        output_ids = input_tensor[0].tolist()
        generated_text = self.tokenizer.decode(output_ids)

        model.train()
        return generated_text

class LLMTrainer:
    """
    大型語言模型訓練器

    負責完整的訓練流程管理
    """

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: TrainingConfig,
        device: torch.device,
        tokenizer=None
    ):
        """
        初始化訓練器

        Args:
            model: 要訓練的模型
            train_loader: 訓練資料載入器
            val_loader: 驗證資料載入器
            config: 訓練配置
            device: 計算裝置
            tokenizer: 分詞器(用於文字生成)
        """
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device
        self.tokenizer = tokenizer

        # 建立最佳化器
        self.optimizer = AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # 計算總步數
        total_steps = len(train_loader) * config.num_epochs

        # 建立學習率排程器
        self.lr_scheduler = LRScheduler(
            self.optimizer,
            config.warmup_steps,
            total_steps
        )

        # 建立文字生成器
        if tokenizer is not None:
            self.text_generator = TextGenerator(tokenizer, device)
        else:
            self.text_generator = None

        # 初始化訓練狀態
        self.state = TrainingState()

    def compute_batch_loss(
        self,
        inputs: torch.Tensor,
        targets: torch.Tensor
    ) -> torch.Tensor:
        """
        計算單一批次的損失

        Args:
            inputs: 輸入 tensor
            targets: 目標 tensor

        Returns:
            損失值
        """
        # 移動資料到裝置
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)

        # 前向傳播
        logits = self.model(inputs)

        # 計算交叉熵損失
        # 扁平化 logits 和 targets
        batch_size, seq_length, vocab_size = logits.shape
        loss = torch.nn.functional.cross_entropy(
            logits.view(-1, vocab_size),
            targets.view(-1)
        )

        return loss

    def evaluate(self, num_batches: Optional[int] = None) -> Dict[str, float]:
        """
        在驗證集上評估模型

        Args:
            num_batches: 要評估的批次數

        Returns:
            評估結果字典
        """
        self.model.eval()

        if num_batches is None:
            num_batches = len(self.val_loader)
        else:
            num_batches = min(num_batches, len(self.val_loader))

        total_loss = 0.0
        total_batches = 0

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.val_loader):
                if batch_idx >= num_batches:
                    break

                loss = self.compute_batch_loss(inputs, targets)
                total_loss += loss.item()
                total_batches += 1

        avg_loss = total_loss / total_batches if total_batches > 0 else float("nan")
        perplexity = torch.exp(torch.tensor(avg_loss)).item()

        self.model.train()

        return {
            "loss": avg_loss,
            "perplexity": perplexity
        }

    def train_step(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
        """
        執行單一訓練步驟

        Args:
            inputs: 輸入 tensor
            targets: 目標 tensor

        Returns:
            損失值
        """
        # 清除梯度
        self.optimizer.zero_grad()

        # 計算損失
        loss = self.compute_batch_loss(inputs, targets)

        # 反向傳播
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.config.max_grad_norm
        )

        # 更新參數
        self.optimizer.step()

        # 更新學習率
        current_lr = self.lr_scheduler.step()
        self.state.learning_rates.append(current_lr)

        return loss.item()

    def train(self, start_context: str = "The"):
        """
        執行完整的訓練流程

        Args:
            start_context: 用於生成樣本的起始文字
        """
        print("開始訓練...")
        print(f"配置: {self.config}")
        print(f"訓練批次數: {len(self.train_loader)}")
        print(f"驗證批次數: {len(self.val_loader)}")
        print()

        start_time = time.time()

        for epoch in range(self.config.num_epochs):
            self.state.epoch = epoch
            self.model.train()

            epoch_loss = 0.0
            epoch_steps = 0

            for batch_idx, (inputs, targets) in enumerate(self.train_loader):
                # 執行訓練步驟
                loss = self.train_step(inputs, targets)

                # 更新狀態
                self.state.global_step += 1
                self.state.tokens_seen += inputs.numel()
                epoch_loss += loss
                epoch_steps += 1

                # 定期評估
                if self.state.global_step % self.config.eval_freq == 0:
                    # 計算訓練損失(當前 epoch 的平均)
                    train_loss = epoch_loss / epoch_steps

                    # 計算驗證損失
                    val_metrics = self.evaluate(self.config.eval_iters)
                    val_loss = val_metrics["loss"]

                    # 記錄損失
                    self.state.train_losses.append(train_loss)
                    self.state.val_losses.append(val_loss)

                    # 更新最佳驗證損失
                    if val_loss < self.state.best_val_loss:
                        self.state.best_val_loss = val_loss

                    # 輸出訓練進度
                    elapsed = time.time() - start_time
                    print(
                        f"Epoch {epoch + 1}/{self.config.num_epochs} | "
                        f"Step {self.state.global_step} | "
                        f"Train Loss: {train_loss:.4f} | "
                        f"Val Loss: {val_loss:.4f} | "
                        f"Val PPL: {val_metrics['perplexity']:.2f} | "
                        f"LR: {self.state.learning_rates[-1]:.6f} | "
                        f"Time: {timedelta(seconds=int(elapsed))}"
                    )

                # 定期生成樣本
                if (self.state.global_step % self.config.generate_freq == 0 and
                    self.text_generator is not None):
                    generated = self.text_generator.generate(
                        self.model,
                        start_context,
                        max_new_tokens=50
                    )
                    print(f"\n生成樣本:\n{generated}\n")

        # 訓練完成
        total_time = time.time() - start_time
        print(f"\n訓練完成!")
        print(f"總時間: {timedelta(seconds=int(total_time))}")
        print(f"最佳驗證損失: {self.state.best_val_loss:.4f}")
        print(f"總處理 token 數: {self.state.tokens_seen}")

        return self.state

    def save_checkpoint(self, path: str):
        """
        儲存訓練檢查點

        Args:
            path: 儲存路徑
        """
        checkpoint = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "training_state": {
                "epoch": self.state.epoch,
                "global_step": self.state.global_step,
                "tokens_seen": self.state.tokens_seen,
                "best_val_loss": self.state.best_val_loss
            },
            "config": self.config
        }
        torch.save(checkpoint, path)
        print(f"檢查點已儲存至: {path}")

    def load_checkpoint(self, path: str):
        """
        載入訓練檢查點

        Args:
            path: 檢查點路徑
        """
        checkpoint = torch.load(path, map_location=self.device)

        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        state_dict = checkpoint["training_state"]
        self.state.epoch = state_dict["epoch"]
        self.state.global_step = state_dict["global_step"]
        self.state.tokens_seen = state_dict["tokens_seen"]
        self.state.best_val_loss = state_dict["best_val_loss"]

        print(f"檢查點已載入自: {path}")
        print(f"從 Epoch {self.state.epoch + 1}, Step {self.state.global_step} 繼續")

# 使用範例
if __name__ == "__main__":
    import tiktoken

    # 這是一個展示用的簡化模型
    # 實際應用中應使用完整的 Transformer 模型
    class SimpleLM(nn.Module):
        def __init__(self, vocab_size, embed_dim, num_layers):
            super().__init__()
            self.embedding = nn.Embedding(vocab_size, embed_dim)
            self.layers = nn.ModuleList([
                nn.TransformerEncoderLayer(
                    d_model=embed_dim,
                    nhead=4,
                    dim_feedforward=embed_dim * 4,
                    batch_first=True
                )
                for _ in range(num_layers)
            ])
            self.output = nn.Linear(embed_dim, vocab_size)

        def forward(self, x):
            x = self.embedding(x)
            for layer in self.layers:
                x = layer(x)
            return self.output(x)

    # 設定參數
    vocab_size = 50257  # GPT-2 詞彙表大小
    embed_dim = 256
    num_layers = 2

    # 建立模型
    model = SimpleLM(vocab_size, embed_dim, num_layers)
    print(f"模型參數數: {sum(p.numel() for p in model.parameters()):,}")

    # 建立模擬資料載入器
    # 實際應用中應使用真實的文字資料
    class DummyDataset(torch.utils.data.Dataset):
        def __init__(self, num_samples, seq_length, vocab_size):
            self.num_samples = num_samples
            self.seq_length = seq_length
            self.vocab_size = vocab_size

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            inputs = torch.randint(0, self.vocab_size, (self.seq_length,))
            targets = torch.randint(0, self.vocab_size, (self.seq_length,))
            return inputs, targets

    train_dataset = DummyDataset(1000, 128, vocab_size)
    val_dataset = DummyDataset(100, 128, vocab_size)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

    # 訓練配置
    config = TrainingConfig(
        num_epochs=2,
        learning_rate=1e-4,
        eval_freq=50,
        generate_freq=100
    )

    # 建立訓練器
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    trainer = LLMTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config,
        device=device,
        tokenizer=None  # 可以傳入 tiktoken 分詞器以啟用文字生成
    )

    # 執行訓練
    state = trainer.train()

    # 儲存檢查點
    trainer.save_checkpoint("model_checkpoint.pt")

這個訓練系統實作了大型語言模型訓練的所有關鍵組件。學習率排程器實作了預熱和餘弦退火策略,這對於穩定訓練非常重要。文字生成器讓我們能夠在訓練過程中觀察模型的生成能力。檢查點功能確保了訓練可以被中斷和恢復。

以下是訓練迴圈的詳細流程圖:

@startuml
!define PLANTUML_FORMAT svg
!theme _none_

skinparam dpi auto
skinparam shadowing false
skinparam linetype ortho
skinparam roundcorner 5
skinparam defaultFontName "Microsoft JhengHei UI"
skinparam defaultFontSize 16
skinparam minClassWidth 100

start

:初始化模型和最佳化器;

repeat
    :開始新的 Epoch;

    repeat
        :取得批次資料;
        :清除梯度;
        :前向傳播;
        :計算損失;
        :反向傳播;
        :梯度裁剪;
        :更新參數;
        :更新學習率;

        if (達到評估頻率?) then (是)
            :評估驗證集;
            :記錄損失;
        endif

        if (達到生成頻率?) then (是)
            :生成樣本文字;
        endif

    repeat while (還有批次?) is (是)

repeat while (還有 Epoch?) is (是)

:儲存最終模型;

stop

@enduml

模型權重管理與遷移學習

訓練完成後,妥善管理模型權重對於後續使用和分享至關重要。PyTorch 提供了靈活的權重儲存和載入機制,支援儲存完整檢查點或僅儲存模型參數。對於預訓練模型,我們通常只需要儲存模型的 state_dict,這樣可以減少檔案大小並提高載入效率。

遷移學習是預訓練模型的主要應用方式之一。透過在大規模資料上預訓練,模型學習到通用的語言表示。然後,我們可以在特定任務的較小資料集上進行微調,讓模型適應特定的應用場景。這種方法大幅降低了訓練成本,同時保持了優異的性能。

以下是模型權重管理的完整實作:

# 模型權重管理系統
# 實作模型儲存、載入和遷移學習功能

import torch
import torch.nn as nn
from typing import Dict, Optional, List
from pathlib import Path
import json

class ModelManager:
    """
    模型權重管理器

    負責模型的儲存、載入和版本管理
    """

    def __init__(self, save_dir: str = "checkpoints"):
        """
        初始化模型管理器

        Args:
            save_dir: 儲存目錄
        """
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def save_model(
        self,
        model: nn.Module,
        name: str,
        config: Optional[Dict] = None,
        metadata: Optional[Dict] = None
    ):
        """
        儲存模型權重

        Args:
            model: 要儲存的模型
            name: 模型名稱
            config: 模型配置
            metadata: 額外的元資料
        """
        # 建立儲存路徑
        model_path = self.save_dir / f"{name}.pt"
        config_path = self.save_dir / f"{name}_config.json"

        # 儲存模型權重
        torch.save(model.state_dict(), model_path)
        print(f"模型權重已儲存至: {model_path}")

        # 儲存配置和元資料
        if config is not None or metadata is not None:
            save_data = {}
            if config is not None:
                save_data["config"] = config
            if metadata is not None:
                save_data["metadata"] = metadata

            with open(config_path, "w", encoding="utf-8") as f:
                json.dump(save_data, f, indent=2, ensure_ascii=False)
            print(f"配置已儲存至: {config_path}")

    def load_model(
        self,
        model: nn.Module,
        name: str,
        strict: bool = True
    ) -> nn.Module:
        """
        載入模型權重

        Args:
            model: 要載入權重的模型
            name: 模型名稱
            strict: 是否嚴格匹配所有鍵

        Returns:
            載入權重後的模型
        """
        model_path = self.save_dir / f"{name}.pt"

        # 載入權重
        state_dict = torch.load(model_path, map_location="cpu")
        model.load_state_dict(state_dict, strict=strict)

        print(f"已從 {model_path} 載入模型權重")

        return model

    def load_config(self, name: str) -> Dict:
        """
        載入模型配置

        Args:
            name: 模型名稱

        Returns:
            配置字典
        """
        config_path = self.save_dir / f"{name}_config.json"

        with open(config_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        return data

    def load_pretrained_weights(
        self,
        model: nn.Module,
        pretrained_path: str,
        exclude_layers: Optional[List[str]] = None,
        freeze_loaded: bool = False
    ) -> nn.Module:
        """
        載入預訓練權重用於遷移學習

        這個方法支援部分載入,可以排除某些層或
        處理架構不完全匹配的情況

        Args:
            model: 目標模型
            pretrained_path: 預訓練權重路徑
            exclude_layers: 要排除的層名稱列表
            freeze_loaded: 是否凍結載入的權重

        Returns:
            載入權重後的模型
        """
        # 載入預訓練權重
        pretrained_dict = torch.load(pretrained_path, map_location="cpu")

        # 取得目標模型的 state_dict
        model_dict = model.state_dict()

        # 過濾要載入的權重
        loaded_keys = []
        skipped_keys = []

        for key, value in pretrained_dict.items():
            # 檢查是否在排除列表中
            if exclude_layers is not None:
                if any(exc in key for exc in exclude_layers):
                    skipped_keys.append(key)
                    continue

            # 檢查鍵是否存在且形狀匹配
            if key in model_dict:
                if value.shape == model_dict[key].shape:
                    model_dict[key] = value
                    loaded_keys.append(key)
                else:
                    skipped_keys.append(f"{key} (形狀不匹配)")
            else:
                skipped_keys.append(f"{key} (不存在)")

        # 載入權重
        model.load_state_dict(model_dict)

        # 輸出載入結果
        print(f"已載入 {len(loaded_keys)} 個權重")
        if skipped_keys:
            print(f"已跳過 {len(skipped_keys)} 個權重:")
            for key in skipped_keys[:5]:  # 只顯示前 5 個
                print(f"  - {key}")
            if len(skipped_keys) > 5:
                print(f"  ... 還有 {len(skipped_keys) - 5} 個")

        # 凍結載入的權重
        if freeze_loaded:
            for name, param in model.named_parameters():
                if name in loaded_keys:
                    param.requires_grad = False
            print("已凍結載入的權重")

        return model

class FineTuner:
    """
    微調訓練器

    專門用於在預訓練模型上進行微調
    """

    def __init__(
        self,
        model: nn.Module,
        device: torch.device,
        learning_rate: float = 1e-5
    ):
        """
        初始化微調訓練器

        Args:
            model: 預訓練模型
            device: 計算裝置
            learning_rate: 微調學習率
        """
        self.model = model.to(device)
        self.device = device

        # 微調時使用較小的學習率
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate
        )

    def prepare_for_finetuning(
        self,
        freeze_embeddings: bool = True,
        freeze_layers: Optional[int] = None
    ):
        """
        準備模型進行微調

        Args:
            freeze_embeddings: 是否凍結詞嵌入層
            freeze_layers: 要凍結的 Transformer 層數(從底部開始)
        """
        # 凍結詞嵌入
        if freeze_embeddings:
            for name, param in self.model.named_parameters():
                if "embedding" in name.lower():
                    param.requires_grad = False
                    print(f"已凍結: {name}")

        # 凍結指定數量的層
        if freeze_layers is not None and freeze_layers > 0:
            layer_count = 0
            for name, param in self.model.named_parameters():
                if "layer" in name.lower():
                    # 提取層編號
                    try:
                        layer_num = int(name.split(".")[1])
                        if layer_num < freeze_layers:
                            param.requires_grad = False
                            print(f"已凍結: {name}")
                    except (ValueError, IndexError):
                        pass

        # 統計可訓練參數
        trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.model.parameters())
        print(f"\n可訓練參數: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")

    def unfreeze_all(self):
        """解凍所有參數"""
        for param in self.model.parameters():
            param.requires_grad = True
        print("已解凍所有參數")

    def get_parameter_groups(
        self,
        lr_decay: float = 0.9
    ) -> List[Dict]:
        """
        取得分層學習率的參數群組

        較深的層使用較大的學習率,因為它們需要
        更多的調整來適應新任務

        Args:
            lr_decay: 每層學習率衰減係數

        Returns:
            參數群組列表
        """
        base_lr = self.optimizer.param_groups[0]["lr"]
        groups = []

        # 按層分組
        layer_params = {}
        other_params = []

        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue

            # 嘗試提取層編號
            try:
                if "layer" in name.lower():
                    layer_num = int(name.split(".")[1])
                    if layer_num not in layer_params:
                        layer_params[layer_num] = []
                    layer_params[layer_num].append(param)
                else:
                    other_params.append(param)
            except (ValueError, IndexError):
                other_params.append(param)

        # 建立參數群組
        # 較淺的層使用較小的學習率
        num_layers = max(layer_params.keys()) + 1 if layer_params else 0

        for layer_num in sorted(layer_params.keys()):
            decay_factor = lr_decay ** (num_layers - layer_num - 1)
            groups.append({
                "params": layer_params[layer_num],
                "lr": base_lr * decay_factor
            })

        # 其他參數使用基礎學習率
        if other_params:
            groups.append({
                "params": other_params,
                "lr": base_lr
            })

        return groups

# 使用範例
if __name__ == "__main__":
    # 建立簡單模型
    class SimpleModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.embedding = nn.Embedding(1000, 128)
            self.layer0 = nn.Linear(128, 128)
            self.layer1 = nn.Linear(128, 128)
            self.output = nn.Linear(128, 1000)

        def forward(self, x):
            x = self.embedding(x)
            x = self.layer0(x)
            x = self.layer1(x)
            return self.output(x)

    # 建立模型管理器
    manager = ModelManager("./test_checkpoints")

    # 建立並儲存模型
    model = SimpleModel()
    config = {"vocab_size": 1000, "embed_dim": 128}
    metadata = {"version": "1.0", "trained_on": "sample_data"}

    manager.save_model(model, "test_model", config, metadata)

    # 載入模型
    new_model = SimpleModel()
    manager.load_model(new_model, "test_model")

    # 載入配置
    loaded_config = manager.load_config("test_model")
    print(f"載入的配置: {loaded_config}")

    # 準備微調
    device = torch.device("cpu")
    fine_tuner = FineTuner(new_model, device, learning_rate=1e-5)
    fine_tuner.prepare_for_finetuning(freeze_embeddings=True, freeze_layers=1)

    print("\n微調準備完成!")

這個模型權重管理系統提供了完整的模型持久化和遷移學習支援。分層學習率策略讓我們能夠在微調時更精細地控制不同層的學習速度,這對於保留預訓練知識並適應新任務非常有效。

預訓練的挑戰與最佳實踐

大型語言模型的預訓練面臨許多挑戰,包括資料品質、訓練穩定性、計算資源和環境影響等。理解這些挑戰並採用適當的策略對於成功的預訓練至關重要。

資料品質直接影響模型的能力和偏見。預訓練資料應該經過仔細的清理和過濾,移除重複內容、低品質文字和有害內容。同時,資料的多樣性也很重要,模型需要接觸各種主題、風格和來源的文字才能學習到全面的語言知識。

訓練穩定性是另一個重要考量。大型模型容易出現梯度爆炸或消失、損失震盪等問題。適當的初始化、學習率排程、梯度裁剪和正則化技術都有助於穩定訓練。混合精度訓練可以在保持訓練穩定性的同時減少記憶體使用和加速計算。

計算資源的有效利用對於降低成本至關重要。資料並行和模型並行技術讓我們能夠在多個 GPU 或節點上分散訓練負載。梯度累積允許在有限的 GPU 記憶體下模擬更大的批次大小。Flash Attention 等效率最佳化技術可以顯著加速 Transformer 模型的訓練。

總結而言,大型語言模型的預訓練是一個複雜但回報豐厚的過程。透過理解基本概念、掌握關鍵技術並採用最佳實踐,我們能夠訓練出強大的語言模型,為各種下游應用奠定堅實基礎。隨著技術的持續進步,預訓練方法將變得更加高效和可及,讓更多研究者和開發者能夠參與這一令人興奮的領域。