Rust 在機器學習領域的應用日益增長,其高效能和安全性使其成為資料處理的理想選擇。本文示範如何使用 Rust 生成模擬資料,並將其輸出為 CSV 格式,方便與其他工具和平台整合。同時,我們也將探討如何使用 TOML 組態檔案管理實驗引數,提升程式碼的可維護性和實驗的可重複性。更進一步,我們將使用 Plotters 函式庫視覺化資料,並使用 K-Means 演算法進行叢集分析,展示 Rust 在機器學習工作流程中的實用性。透過這些技術,開發者可以更有效率地進行資料處理、模型訓練和結果分析。
人工智慧與機器學習中的資料生成與組態管理
在開發機器學習應用程式的過程中,資料生成是一個重要的步驟。本篇文章將介紹如何使用 Rust 語言生成模擬資料,並將其輸出為 CSV 格式。此外,我們還將探討如何將組態引數從硬編碼轉移到組態檔案中,以提高實驗的靈活性和可重現性。
生成模擬資料
首先,我們需要生成模擬資料。在本例中,我們將生成一些模擬的貓的高度和長度資料。我們使用 Array2 來表示這些資料,並使用 generate_data 函式來生成模擬資料。
程式碼範例:生成模擬資料
use ndarray::Array2;
fn generate_data(centroids: &Array2<f64>, samples_per_centroid: usize, noise: f64) -> Array2<f64> {
// 生成模擬資料的實作細節
// ...
}
內容解密:
generate_data函式接受三個引數:centroids、samples_per_centroid和noise。centroids是一個Array2<f64>,代表不同類別的中心點。samples_per_centroid是一個usize,代表每個中心點要生成的樣本數量。noise是一個f64,代表生成資料時的雜訊程度。- 函式傳回一個
Array2<f64>,代表生成的模擬資料。
輸出 CSV 檔案
生成模擬資料後,我們需要將其輸出為 CSV 檔案。我們使用 csv crate 來實作這一步驟。
程式碼範例:輸出 CSV 檔案
use csv::Writer;
use std::io;
fn main() -> Result<(), Box<dyn std::error::Error>> {
// 生成模擬資料
let centroids = Array2::from_shape_vec((3, 2), CENTROIDS.to_vec())?;
let samples = generate_data(¢roids, SAMPLES_PER_CENTROID, NOISE)?;
// 輸出 CSV 檔案
let mut writer = Writer::from_writer(io::stdout());
writer.write_record(&["height", "length"])?;
for sample in samples.rows() {
let mut sample_iter = sample.into_iter();
writer.serialize((sample_iter.next().unwrap(), sample_iter.next().unwrap()))?;
}
Ok(())
}
內容解密:
- 我們使用
csv::Writer來寫入 CSV 檔案。 - 首先,我們寫入 CSV 的標頭行,包含 “height” 和 “length” 兩個欄位。
- 然後,我們遍歷生成的模擬資料,並將每個樣本序列化為 CSV 行。
- 最後,我們將 CSV 資料輸出到標準輸出。
將組態引數移到組態檔案中
在開發機器學習應用程式時,我們經常需要嘗試不同的組態引數。將組態引數硬編碼在程式碼中會使得實驗過程變得繁瑣。因此,我們將組態引數移到組態檔案中,以提高實驗的靈活性。
組態檔案範例:config/generate.toml
centroids = [
22.5, 40.5, # persian
38.0, 50.0, # British shorthair
25.5, 48.0, # Ragdoll
]
noise = 1.8
samples_per_centroid = 2000
程式碼範例:讀取組態檔案
use serde::Deserialize;
use toml;
#[derive(Deserialize)]
struct Config {
centroids: [f64; 6],
noise: f64,
samples_per_centroid: usize,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// 讀取組態檔案
let config_str = std::fs::read_to_string("config/generate.toml")?;
let config: Config = toml::from_str(&config_str)?;
// 使用組態檔案中的引數生成模擬資料
// ...
}
內容解密:
- 我們定義了一個
Config結構體,用於反序列化組態檔案中的引數。 - 使用
tomlcrate 將組態檔案讀入為字串,並將其反序列化為Config結構體。 - 然後,我們可以使用
Config結構體中的引數來生成模擬資料。
組態檔案與命令列引數的整合應用
在開發機器學習模型時,組態檔案的管理與命令列引數的靈活運用至關重要。本章節將介紹如何結合TOML組態檔案與Clap函式庫來實作動態組態引數,並透過Plotters函式庫進行資料視覺化。
使用TOML組態檔案
首先,我們定義一個Config結構體來反序列化TOML組態檔案:
#[derive(Deserialize)]
struct Config {
// 組態欄位定義
centroids: Vec<f64>,
samples_per_centroid: usize,
noise: f64,
}
fn main() -> Result<(), Box<dyn Error>> {
let toml_config_str = read_to_string("config/generate.toml")?;
let config: Config = toml::from_str(&toml_config_str)?;
// 使用組態引數
let centroids = Array2::from_shape_vec((3, 2), config.centroids.to_vec())?;
let samples = generate_data(¢roids, config.samples_per_centroid, config.noise)?;
// ...
}
內容解密:
Config結構體使用Deserialize衍生宏來實作TOML反序列化。main函式讀取TOML組態檔案並將其解析為Config例項。- 使用組態引數取代硬編碼的常數,使程式更具彈性。
動態設定組態檔案路徑
為了提高程式的靈活性,我們使用Clap函式庫來解析命令列引數:
use clap::Parser;
#[derive(Parser)]
struct Args {
#[arg(short = 'c', long = "config-file")]
/// 組態檔案路徑
config_file_path: std::path::PathBuf,
}
fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
let toml_config_str = read_to_string(args.config_file_path)?;
// ...
}
內容解密:
- 使用Clap的
Parser衍生宏來定義命令列引數結構體Args。 config_file_path欄位對應到--config-file命令列選項。- 在
main函式中解析命令列引數並讀取指定的組態檔案。
資料視覺化
使用Plotters函式庫來視覺化生成的資料:
use plotters::prelude::*;
fn main() -> Result<(), Box<dyn Error>> {
let mut x: Vec<f64> = Vec::new();
let mut y: Vec<f64> = Vec::new();
// 從標準輸入讀取CSV資料
let mut reader = csv::Reader::from_reader(io::stdin());
for result in reader.records() {
let record = result?;
x.push(record[0].parse()?);
y.push(record[1].parse()?);
}
// 建立繪圖區域
let root_drawing_area = BitMapBackend::new("plot.png", (900, 600)).into_drawing_area();
root_drawing_area.fill(&WHITE)?;
// 建立圖表並繪製資料
let mut chart = ChartBuilder::on(&root_drawing_area)
.build_cartesian_2d(15.0..45.0, 30.0..55.0)?;
chart.configure_mesh().disable_mesh().draw()?;
chart.draw_series(
x.into_iter()
.zip(y)
.map(|point| Cross::new(point, 3, Into::<ShapeStyle>::into(&BLUE).stroke_width(2))),
)?;
Ok(())
}
內容解密:
- 從標準輸入讀取CSV資料並儲存在
x和y向量中。 - 使用Plotters建立一個PNG圖片作為繪圖區域。
- 建立一個2D笛卡爾座標系圖表並繪製資料點。
- 將圖表儲存為PNG檔案。
命令列操作示範
- 生成資料:
cargo run --bin generate -- --config-file config/generate.toml > training_data.csv - 繪製資料:
cat training_data.csv | cargo run --bin plot
這種方法結合了組態檔案的靈活性、命令列引數的動態性以及資料視覺化的直觀性,為機器學習模型的開發和除錯提供了有力的支援。
K-Means 叢集分析在貓身體測量資料上的應用
在前面的章節中,我們已經成功地使用 Plotters 繪製了貓身體測量的散佈圖。從圖中可以明顯看出,這些資料點形成了三個叢集。接下來,我們將使用 K-Means 演算法對這些資料進行叢集分析。
設定 K-Means 模型
K-Means 是一種無監督學習演算法,用於將資料分成 K 個叢集。在我們的例子中,K 等於 3,因為從散佈圖中可以看出資料形成了三個叢集。K-Means 模型的實作位於 linfa_clustering 套件中。
首先,我們需要在 Cargo.toml 中加入必要的相依套件:
[dependencies]
linfa-clustering = "0.5.0"
linfa-nn = "0.5.0"
ndarray = "0.15.3"
K-Means 模型的訓練與預測
K-Means 模型的訓練與預測過程如下所示:
use std::error::Error;
use linfa::DatasetBase;
use linfa::traits::Fit;
use linfa::traits::Predict;
use linfa_clustering::KMeans;
use linfa_nn::distance::L2Dist;
use rand::thread_rng;
const CLUSTER_COUNT: usize = 3;
fn main() -> Result<(), Box<dyn Error>> {
let samples = read_data_from_stdin()?;
let training_data = DatasetBase::from(samples);
let rng = thread_rng();
let model = KMeans::params_with(CLUSTER_COUNT, rng, L2Dist)
.max_n_iterations(200)
.tolerance(1e-5)
.fit(&training_data)?;
let dataset = model.predict(training_data);
let DatasetBase { records, targets, .. } = dataset;
export_result_to_stdout(records, targets)?;
Ok(())
}
程式碼解析:
- 讀取資料:使用
read_data_from_stdin函式從標準輸入讀取 CSV 格式的資料。 - 轉換資料:將讀取到的資料轉換成
DatasetBase格式,以便進行 K-Means 叢集分析。 - 初始化 K-Means 模型:使用
KMeans::params_with初始化 K-Means 模型,設定叢集數量為 3,並指定亂數產生器和距離計算方法。 - 訓練模型:呼叫
fit方法對模型進行訓練,傳入訓練資料。 - 進行預測:使用訓練好的模型對資料進行叢集分析,得到每個資料點所屬的叢集標籤。
- 輸出結果:將叢集分析的結果輸出到標準輸出。
資料讀取與輸出函式
read_data_from_stdin 和 export_result_to_stdout 函式的實作與之前章節中的實作類別似,分別用於讀取 CSV 資料和輸出叢集分析結果。
read_data_from_stdin 函式實作
use csv::ReaderBuilder;
use ndarray::Array2;
fn read_data_from_stdin() -> Result<Array2<f64>, Box<dyn Error>> {
let mut reader = ReaderBuilder::new().from_reader(io::stdin());
let mut records = Vec::new();
for result in reader.records() {
let record = result?;
let x: f64 = record[0].parse()?;
let y: f64 = record[1].parse()?;
records.push(vec![x, y]);
}
let array = Array2::from_shape_vec((records.len(), 2), records.into_iter().flatten().collect())?;
Ok(array)
}
程式碼解析:
- 建立 CSV 讀取器:使用
ReaderBuilder建立一個 CSV 讀取器,從標準輸入讀取資料。 - 逐行讀取資料:使用
records方法逐行讀取 CSV 資料,並將每行資料解析為f64型別的數值。 - 儲存資料:將解析好的資料儲存到一個向量中。
- 轉換為 Array2:將向量中的資料轉換為
Array2格式,以便進行後續的叢集分析。
export_result_to_stdout 函式實作
use csv::WriterBuilder;
fn export_result_to_stdout(records: Array2<f64>, targets: Array1<usize>) -> Result<(), Box<dyn Error>> {
let mut writer = WriterBuilder::new().from_writer(io::stdout());
for i in 0..records.nrows() {
writer.write_record(&[
records.row(i)[0].to_string(),
records.row(i)[1].to_string(),
targets[i].to_string(),
])?;
}
writer.flush()?;
Ok(())
}
程式碼解析:
- 建立 CSV 輸出器:使用
WriterBuilder建立一個 CSV 輸出器,將結果輸出到標準輸出。 - 逐行輸出結果:使用
write_record方法逐行輸出叢集分析的結果,包括原始資料和所屬叢集的標籤。 - 清空輸出緩衝區:呼叫
flush方法清空輸出緩衝區,確保結果被正確輸出。