在前一篇文章中,我們探討了如何使用 Apache Airflow 將一個深度學習專案的基本流程(預處理、訓練)串聯起來。然而,一個生產級別的機器學習運維 (MLOps) 管線,不僅需要自動化執行,更需要包含嚴謹的模型評估環節,以及靈活、穩健的工作流程設計。本文將深入探討如何為我們的 Airflow DAG 加入模型評估任務,並介紹兩種關鍵的管線最佳化技巧:動態 DAG 生成模型檢查點 (Checkpointing)

第一部分:實現模型評估任務

一個模型訓練完成後,必須經過客觀的評估,我們才能判斷其效能是否達到上線標準。

1. 編寫評估函式

我們需要一個 Python 函式,它能載入已訓練好的模型和測試資料集,執行預測,並計算出一個或多個評估指標。對於迴歸任務(如股價預測),除了計算損失 (Loss) 外,R-squared (決定係數) 是一個更具解釋性的指標。

# in tasks/evaluation.py
import torch
import numpy as np

def evaluate_model(model_path: str, test_data_path: str):
    """
    載入模型與測試資料,計算並輸出評估指標。
    """
    model = ... # 從 model_path 載入模型
    test_loader = ... # 從 test_data_path 載入並準備資料
    criterion = torch.nn.MSELoss()
    
    model.eval() # 將模型設為評估模式
    with torch.no_grad(): # 在此區塊中不計算梯度
        predictions = []
        targets = []
        for inputs, labels in test_loader:
            outputs = model(inputs)
            predictions.extend(outputs.squeeze().tolist())
            targets.extend(labels.tolist())

    # 計算 R-squared
    targets_tensor = torch.tensor(targets)
    predictions_tensor = torch.tensor(predictions)
    ss_res = torch.sum((targets_tensor - predictions_tensor) ** 2)
    ss_tot = torch.sum((targets_tensor - torch.mean(targets_tensor)) ** 2)
    r_squared = 1 - (ss_res / ss_tot)
    
    print(f"模型評估完成 - R-squared: {r_squared:.4f}")
    # 在真實場景中,可以將指標寫入資料庫或日誌系統
caption="圖表一:模型評估任務循序圖。此循序圖詳細展示了 `evaluate_model` 任務的內部執行流程,從載入資料到計算指標。"
alt="一個展示模型評估任務流程的循序圖。Airflow Worker 呼叫函式,函式從儲存系統載入模型和資料,使用 PyTorch 執行預測,計算 R-squared 指標,最後由 Worker 記錄指標。"
@startuml
!theme _none_
skinparam dpi auto
skinparam defaultFontName "Microsoft JhengHei UI"
skinparam minClassWidth 100
skinparam defaultFontSize 14
title 模型評估任務循序圖

participant "Airflow Worker" as Worker
participant "evaluate_model()" as Func
participant "儲存系統" as Storage
participant "PyTorch" as Torch

Worker -> Func : 呼叫函式
Func -> Storage : 載入模型檔案
Func -> Storage : 載入測試資料
Func -> Torch : 執行預測 (model.eval(), torch.no_grad())
Torch --> Func : 回傳預測結果
Func -> Func : 計算評估指標 (R-squared)
Func -> Worker : 列印/記錄指標
@enduml

2. 將評估任務加入 DAG

現在,我們可以將這個評估函式加入到前一篇文章的 diabetes_prediction_pipeline DAG 中,使其成為訓練任務的下游。

# in dags/diabetes_prediction_dag.py
# ... (省略 DAG 和前兩個 task 的定義) ...

    task_evaluate = PythonOperator(
        task_id='evaluate_model',
        python_callable=evaluate_model,
        op_kwargs={
            'test_input_path': '/path/to/test.parquet',
            'model_input_path': '/path/to/diabetes_model.h5',
        }
    )

    # 設定新的依賴關係
    task_preprocess >> task_train >> task_evaluate

第二部分:工作流程最佳化技巧

對於複雜的 MLOps 場景,靜態的 DAG 定義可能不夠靈活。以下介紹兩種常用的最佳化技巧。

1. 動態 DAG 生成

在進行模型實驗時,我們可能需要針對不同的超參數組合(如學習率、批次大小)執行相同的訓練流程。與其手動複製多個 DAG 檔案,不如使用一個 Python 腳本來動態生成它們。

方法: 將超參數設定儲存在一個外部檔案(如 config.yaml)中,然後在 DAG 檔案中讀取該設定,並使用迴圈來動態建立多個 DAG 或多個任務。

範例 config.yaml:

experiments:
  - name: "low_lr"
    learning_rate: 0.001
  - name: "high_lr"
    learning_rate: 0.01

範例 dynamic_training_dag.py:

import yaml
from airflow import DAG
# ...

with open('/path/to/config.yaml', 'r') as file:
    config = yaml.safe_load(file)

for exp in config['experiments']:
    dag_id = f"training_experiment_{exp['name']}"
    
    with DAG(dag_id=dag_id, ...) as dag:
        train_task = PythonOperator(
            task_id=f"train_{exp['name']}",
            python_callable=train_model,
            op_kwargs={'learning_rate': exp['learning_rate']}
        )

這種模式極大地提升了進行大規模機器學習實驗的效率。

2. 模型檢查點 (Checkpointing)

深度學習模型的訓練過程可能非常耗時,從數小時到數天不等。如果訓練過程中發生任何中斷(如機器故障),所有的訓練進度都會遺失。模型檢查點是一種在訓練過程中定期儲存模型狀態的機制,是保障長時間訓練任務穩健性的關鍵。

方法: 在 train_model 函式的訓練迴圈中,定期(例如每 N 個 epoch)儲存模型和最佳化器的狀態。

# 在 train_model 函式中
def train_model(...):
    # ...
    for epoch in range(num_epochs):
        # ... (執行一個 epoch 的訓練) ...
        
        # 每 10 個 epoch 儲存一次檢查點
        if (epoch + 1) % 10 == 0:
            checkpoint_path = f"{model_output_path}_epoch_{epoch+1}.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, checkpoint_path)
            print(f"已儲存檢查點至 {checkpoint_path}")
    # ...

這樣,即使任務中斷,我們也可以從最近的一個檢查點繼續訓練,而無需從頭開始,節省了大量的時間和計算資源。

透過將嚴謹的模型評估流程與動態、穩健的管線設計相結合,我們可以利用 Airflow 打造出真正適用於生產環境的 MLOps 解決方案。