Rust 在機器學習領域的應用日益增長,其高效能和安全性使其成為資料處理的理想選擇。本文示範如何使用 Rust 生成模擬資料,並將其輸出為 CSV 格式,方便與其他工具和平台整合。同時,我們也將探討如何使用 TOML 組態檔案管理實驗引數,提升程式碼的可維護性和實驗的可重複性。更進一步,我們將使用 Plotters 函式庫視覺化資料,並使用 K-Means 演算法進行叢集分析,展示 Rust 在機器學習工作流程中的實用性。透過這些技術,開發者可以更有效率地進行資料處理、模型訓練和結果分析。

人工智慧與機器學習中的資料生成與組態管理

在開發機器學習應用程式的過程中,資料生成是一個重要的步驟。本篇文章將介紹如何使用 Rust 語言生成模擬資料,並將其輸出為 CSV 格式。此外,我們還將探討如何將組態引數從硬編碼轉移到組態檔案中,以提高實驗的靈活性和可重現性。

生成模擬資料

首先,我們需要生成模擬資料。在本例中,我們將生成一些模擬的貓的高度和長度資料。我們使用 Array2 來表示這些資料,並使用 generate_data 函式來生成模擬資料。

程式碼範例:生成模擬資料

use ndarray::Array2;

fn generate_data(centroids: &Array2<f64>, samples_per_centroid: usize, noise: f64) -> Array2<f64> {
    // 生成模擬資料的實作細節
    // ...
}

內容解密:

  1. generate_data 函式接受三個引數:centroidssamples_per_centroidnoise
  2. centroids 是一個 Array2<f64>,代表不同類別的中心點。
  3. samples_per_centroid 是一個 usize,代表每個中心點要生成的樣本數量。
  4. noise 是一個 f64,代表生成資料時的雜訊程度。
  5. 函式傳回一個 Array2<f64>,代表生成的模擬資料。

輸出 CSV 檔案

生成模擬資料後,我們需要將其輸出為 CSV 檔案。我們使用 csv crate 來實作這一步驟。

程式碼範例:輸出 CSV 檔案

use csv::Writer;
use std::io;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 生成模擬資料
    let centroids = Array2::from_shape_vec((3, 2), CENTROIDS.to_vec())?;
    let samples = generate_data(&centroids, SAMPLES_PER_CENTROID, NOISE)?;

    // 輸出 CSV 檔案
    let mut writer = Writer::from_writer(io::stdout());
    writer.write_record(&["height", "length"])?;
    for sample in samples.rows() {
        let mut sample_iter = sample.into_iter();
        writer.serialize((sample_iter.next().unwrap(), sample_iter.next().unwrap()))?;
    }
    Ok(())
}

內容解密:

  1. 我們使用 csv::Writer 來寫入 CSV 檔案。
  2. 首先,我們寫入 CSV 的標頭行,包含 “height” 和 “length” 兩個欄位。
  3. 然後,我們遍歷生成的模擬資料,並將每個樣本序列化為 CSV 行。
  4. 最後,我們將 CSV 資料輸出到標準輸出。

將組態引數移到組態檔案中

在開發機器學習應用程式時,我們經常需要嘗試不同的組態引數。將組態引數硬編碼在程式碼中會使得實驗過程變得繁瑣。因此,我們將組態引數移到組態檔案中,以提高實驗的靈活性。

組態檔案範例:config/generate.toml

centroids = [
    22.5, 40.5, # persian
    38.0, 50.0, # British shorthair
    25.5, 48.0, # Ragdoll
]
noise = 1.8
samples_per_centroid = 2000

程式碼範例:讀取組態檔案

use serde::Deserialize;
use toml;

#[derive(Deserialize)]
struct Config {
    centroids: [f64; 6],
    noise: f64,
    samples_per_centroid: usize,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 讀取組態檔案
    let config_str = std::fs::read_to_string("config/generate.toml")?;
    let config: Config = toml::from_str(&config_str)?;

    // 使用組態檔案中的引數生成模擬資料
    // ...
}

內容解密:

  1. 我們定義了一個 Config 結構體,用於反序列化組態檔案中的引數。
  2. 使用 toml crate 將組態檔案讀入為字串,並將其反序列化為 Config 結構體。
  3. 然後,我們可以使用 Config 結構體中的引數來生成模擬資料。

組態檔案與命令列引數的整合應用

在開發機器學習模型時,組態檔案的管理與命令列引數的靈活運用至關重要。本章節將介紹如何結合TOML組態檔案與Clap函式庫來實作動態組態引數,並透過Plotters函式庫進行資料視覺化。

使用TOML組態檔案

首先,我們定義一個Config結構體來反序列化TOML組態檔案:

#[derive(Deserialize)]
struct Config {
    // 組態欄位定義
    centroids: Vec<f64>,
    samples_per_centroid: usize,
    noise: f64,
}

fn main() -> Result<(), Box<dyn Error>> {
    let toml_config_str = read_to_string("config/generate.toml")?;
    let config: Config = toml::from_str(&toml_config_str)?;
    // 使用組態引數
    let centroids = Array2::from_shape_vec((3, 2), config.centroids.to_vec())?;
    let samples = generate_data(&centroids, config.samples_per_centroid, config.noise)?;
    // ...
}

內容解密:

  1. Config結構體使用Deserialize衍生宏來實作TOML反序列化。
  2. main函式讀取TOML組態檔案並將其解析為Config例項。
  3. 使用組態引數取代硬編碼的常數,使程式更具彈性。

動態設定組態檔案路徑

為了提高程式的靈活性,我們使用Clap函式庫來解析命令列引數:

use clap::Parser;

#[derive(Parser)]
struct Args {
    #[arg(short = 'c', long = "config-file")]
    /// 組態檔案路徑
    config_file_path: std::path::PathBuf,
}

fn main() -> Result<(), Box<dyn Error>> {
    let args = Args::parse();
    let toml_config_str = read_to_string(args.config_file_path)?;
    // ...
}

內容解密:

  1. 使用Clap的Parser衍生宏來定義命令列引數結構體Args
  2. config_file_path欄位對應到--config-file命令列選項。
  3. main函式中解析命令列引數並讀取指定的組態檔案。

資料視覺化

使用Plotters函式庫來視覺化生成的資料:

use plotters::prelude::*;

fn main() -> Result<(), Box<dyn Error>> {
    let mut x: Vec<f64> = Vec::new();
    let mut y: Vec<f64> = Vec::new();
    
    // 從標準輸入讀取CSV資料
    let mut reader = csv::Reader::from_reader(io::stdin());
    for result in reader.records() {
        let record = result?;
        x.push(record[0].parse()?);
        y.push(record[1].parse()?);
    }
    
    // 建立繪圖區域
    let root_drawing_area = BitMapBackend::new("plot.png", (900, 600)).into_drawing_area();
    root_drawing_area.fill(&WHITE)?;
    
    // 建立圖表並繪製資料
    let mut chart = ChartBuilder::on(&root_drawing_area)
        .build_cartesian_2d(15.0..45.0, 30.0..55.0)?;
    chart.configure_mesh().disable_mesh().draw()?;
    chart.draw_series(
        x.into_iter()
            .zip(y)
            .map(|point| Cross::new(point, 3, Into::<ShapeStyle>::into(&BLUE).stroke_width(2))),
    )?;
    
    Ok(())
}

內容解密:

  1. 從標準輸入讀取CSV資料並儲存在xy向量中。
  2. 使用Plotters建立一個PNG圖片作為繪圖區域。
  3. 建立一個2D笛卡爾座標系圖表並繪製資料點。
  4. 將圖表儲存為PNG檔案。

命令列操作示範

  1. 生成資料:cargo run --bin generate -- --config-file config/generate.toml > training_data.csv
  2. 繪製資料:cat training_data.csv | cargo run --bin plot

這種方法結合了組態檔案的靈活性、命令列引數的動態性以及資料視覺化的直觀性,為機器學習模型的開發和除錯提供了有力的支援。

K-Means 叢集分析在貓身體測量資料上的應用

在前面的章節中,我們已經成功地使用 Plotters 繪製了貓身體測量的散佈圖。從圖中可以明顯看出,這些資料點形成了三個叢集。接下來,我們將使用 K-Means 演算法對這些資料進行叢集分析。

設定 K-Means 模型

K-Means 是一種無監督學習演算法,用於將資料分成 K 個叢集。在我們的例子中,K 等於 3,因為從散佈圖中可以看出資料形成了三個叢集。K-Means 模型的實作位於 linfa_clustering 套件中。

首先,我們需要在 Cargo.toml 中加入必要的相依套件:

[dependencies]
linfa-clustering = "0.5.0"
linfa-nn = "0.5.0"
ndarray = "0.15.3"

K-Means 模型的訓練與預測

K-Means 模型的訓練與預測過程如下所示:

use std::error::Error;
use linfa::DatasetBase;
use linfa::traits::Fit;
use linfa::traits::Predict;
use linfa_clustering::KMeans;
use linfa_nn::distance::L2Dist;
use rand::thread_rng;

const CLUSTER_COUNT: usize = 3;

fn main() -> Result<(), Box<dyn Error>> {
    let samples = read_data_from_stdin()?;
    let training_data = DatasetBase::from(samples);
    let rng = thread_rng();
    let model = KMeans::params_with(CLUSTER_COUNT, rng, L2Dist)
        .max_n_iterations(200)
        .tolerance(1e-5)
        .fit(&training_data)?;

    let dataset = model.predict(training_data);
    let DatasetBase { records, targets, .. } = dataset;
    export_result_to_stdout(records, targets)?;

    Ok(())
}

程式碼解析:

  1. 讀取資料:使用 read_data_from_stdin 函式從標準輸入讀取 CSV 格式的資料。
  2. 轉換資料:將讀取到的資料轉換成 DatasetBase 格式,以便進行 K-Means 叢集分析。
  3. 初始化 K-Means 模型:使用 KMeans::params_with 初始化 K-Means 模型,設定叢集數量為 3,並指定亂數產生器和距離計算方法。
  4. 訓練模型:呼叫 fit 方法對模型進行訓練,傳入訓練資料。
  5. 進行預測:使用訓練好的模型對資料進行叢集分析,得到每個資料點所屬的叢集標籤。
  6. 輸出結果:將叢集分析的結果輸出到標準輸出。

資料讀取與輸出函式

read_data_from_stdinexport_result_to_stdout 函式的實作與之前章節中的實作類別似,分別用於讀取 CSV 資料和輸出叢集分析結果。

read_data_from_stdin 函式實作

use csv::ReaderBuilder;
use ndarray::Array2;

fn read_data_from_stdin() -> Result<Array2<f64>, Box<dyn Error>> {
    let mut reader = ReaderBuilder::new().from_reader(io::stdin());
    let mut records = Vec::new();
    for result in reader.records() {
        let record = result?;
        let x: f64 = record[0].parse()?;
        let y: f64 = record[1].parse()?;
        records.push(vec![x, y]);
    }
    let array = Array2::from_shape_vec((records.len(), 2), records.into_iter().flatten().collect())?;
    Ok(array)
}

程式碼解析:

  1. 建立 CSV 讀取器:使用 ReaderBuilder 建立一個 CSV 讀取器,從標準輸入讀取資料。
  2. 逐行讀取資料:使用 records 方法逐行讀取 CSV 資料,並將每行資料解析為 f64 型別的數值。
  3. 儲存資料:將解析好的資料儲存到一個向量中。
  4. 轉換為 Array2:將向量中的資料轉換為 Array2 格式,以便進行後續的叢集分析。

export_result_to_stdout 函式實作

use csv::WriterBuilder;

fn export_result_to_stdout(records: Array2<f64>, targets: Array1<usize>) -> Result<(), Box<dyn Error>> {
    let mut writer = WriterBuilder::new().from_writer(io::stdout());
    for i in 0..records.nrows() {
        writer.write_record(&[
            records.row(i)[0].to_string(),
            records.row(i)[1].to_string(),
            targets[i].to_string(),
        ])?;
    }
    writer.flush()?;
    Ok(())
}

程式碼解析:

  1. 建立 CSV 輸出器:使用 WriterBuilder 建立一個 CSV 輸出器,將結果輸出到標準輸出。
  2. 逐行輸出結果:使用 write_record 方法逐行輸出叢集分析的結果,包括原始資料和所屬叢集的標籤。
  3. 清空輸出緩衝區:呼叫 flush 方法清空輸出緩衝區,確保結果被正確輸出。