在深度學習和海量資料處理中,矩陣運算的效率至關重要。JAX 提供了 pmap()vmap() 等函式,可以有效地平行化和向量化矩陣運算,顯著提升計算速度。然而,隨著矩陣維度和運算複雜度的提升,管理軸索引和程式碼維護變得困難。為此,JAX 引入了 xmap() 和命名軸程式設計,簡化張量操作,並結合硬體資源,進一步提升運算效能。本文將深入探討這些技術的應用,並結合實際案例,展示如何在 JAX 中高效地進行矩陣運算。

高效矩陣運算:平行化和向量化

在深度學習和大資料處理中,高效的矩陣運算至關重要。為了加速這些運算,平行化和向量化是兩種常用的技術。在本文中,我們將探討如何使用 pmap()vmap() 這兩種方法來平行化和向量化矩陣運算。

平行化:pmap()

pmap() 是一個用於平行化運算的函式,它可以將任務分配到多個裝置上執行。例如,如果我們有 8 個 TPU(Tensor Processing Unit),我們可以使用 pmap() 將矩陣運算平行化到這些裝置上。

import jax
import jax.numpy as jnp

# 定義一個矩陣
x = jnp.array([[1, 2], [3, 4]])

# 平行化到 8 個 TPU
result = jax.pmap(lambda x: x @ x, devices=jax.devices()[:8])(x)

在這個例子中,pmap() 將矩陣 x 的乘法運算平行化到 8 個 TPU 上。這樣可以大大加速矩陣運算的速度。

向量化:vmap()

vmap() 是一個用於向量化運算的函式,它可以將一個函式應用到多個輸入上。例如,如果我們有兩個向量 v1v2,我們可以使用 vmap() 將它們的點積運算向量化。

import jax
import jax.numpy as jnp

# 定義兩個向量
v1 = jnp.array([1, 2, 3])
v2 = jnp.array([4, 5, 6])

# 向量化點積運算
result = jax.vmap(lambda v1, v2: v1 @ v2)(v1, v2)

在這個例子中,vmap() 將點積運算向量化到每個元素上。這樣可以大大加速向量運算的速度。

分割和重塑

在某些情況下,我們需要將一個大矩陣分割成小塊,以便於平行化和向量化。例如,如果我們有一個大矩陣 x,我們可以使用 split() 函式將它分割成小塊。

import jax
import jax.numpy as jnp

# 定義一個大矩陣
x = jnp.array([[1, 2], [3, 4]])

# 分割成小塊
x_split = jnp.split(x, 2)

在這個例子中,split() 函式將矩陣 x 分割成兩個小塊。然後,我們可以使用 pmap()vmap() 對這些小塊進行平行化和向量化。

內容解密:

在上面的例子中,我們使用了 pmap()vmap() 這兩種方法來平行化和向量化矩陣運算。pmap() 函式可以將任務分配到多個裝置上執行,而 vmap() 函式可以將一個函式應用到多個輸入上。透過使用這些方法,我們可以大大加速矩陣運算的速度。

圖表翻譯:

  graph LR
    A[矩陣運算] --> B[平行化]
    B --> C[向量化]
    C --> D[分割]
    D --> E[重塑]
    E --> F[結果]

在這個圖表中,我們展示瞭如何使用 pmap()vmap() 這兩種方法來平行化和向量化矩陣運算。首先,我們將矩陣運算平行化到多個裝置上,然後我們將點積運算向量化到每個元素上。最後,我們將結果重塑成原始矩陣的形狀。

資料聚合與重塑

在進行資料分析或機器學習任務時,經常需要對資料進行聚合和重塑,以便更好地理解和利用資料。在這個過程中,瞭解不同資料結構之間的轉換至關重要。

資料聚合

資料聚合是指將多個資料點合併成一個單一的統計量或值。這種技術在資料分析中非常常見,例如計算均值、標準差等。然而,在某些情況下,尤其是在處理大型資料集時,直接進行資料聚合可能會導致資訊的損失。

資料重塑

另一方面,資料重塑是指將資料從一個結構轉換成另一個結構,以便更好地適應特定的分析或模型需求。這種轉換可以是從寬表到長表、從長表到寬表,或者是其他形式的轉換。

具體例項

考慮以下幾個具體的資料結構:

  • v1sp: (3, 8, 1.25M)
  • v2sp: (3, 8, 1.25M)
  • v1s: (3, 10M)
  • v2s: (3, 10M)
  • dot: (8, 1.25M)
  • v1sp chunk: (3, 1.25M)
  • v2sp chunk:

這些資料結構代表了不同維度和大小的資料集。例如,v1spv2sp 都有 3 個特徵、8 個樣本和 1.25M 的觀測值,而 v1sv2s 則有 3 個特徵和 10M 的觀測值。

資料轉換

在進行資料分析時,可能需要將這些資料結構之間進行轉換。例如,將 v1sp 轉換成 dot 格式,或者將 v1s 轉換成 v1sp chunk 格式。這些轉換可以透過各種方法實作,例如使用 Pandas 的 pivot_table 函式或 NumPy 的陣列操作。

內容解密:

import pandas as pd
import numpy as np

# 假設 v1sp 是一個 Pandas DataFrame
v1sp = pd.DataFrame(np.random.rand(3, 8, 1250000))

# 將 v1sp 轉換成 dot 格式
dot = v1sp.stack().reset_index()
dot.columns = ['feature', 'sample', 'value']

# 將 v1s 轉換成 v1sp chunk 格式
v1s = pd.DataFrame(np.random.rand(3, 10000000))
v1sp_chunk = v1s.chunk(1250000)

圖表翻譯:

  flowchart TD
    A[v1sp] --> B[轉換]
    B --> C[dot]
    A --> D[轉換]
    D --> E[v1sp chunk]
    E --> F[分析]

在這個圖表中,我們展示瞭如何將 v1sp 轉換成 dot 格式和 v1sp chunk 格式。這些轉換使我們能夠更好地分析和利用資料。

使用xmap()和命名軸程式設計簡化張量運算

在進行大規模的張量運算時,尤其是在多個裝置上進行平行計算時,使用pmap()vmap()函式可能會導致程式碼複雜化和難以維護。這是因為這些函式需要手動管理張量的軸索引,這使得程式碼不僅難以閱讀和理解,而且還容易出現錯誤。

問題所在

當使用pmap()vmap()時,我們需要在每個函式呼叫中指定in_axes引數,以便於管理張量的軸索引。然而,這種方法存在著明顯的缺陷:當張量的佈局發生變化時,我們需要重新計算每個函式的軸索引,這使得程式碼非常脆弱和難以維護。

解決方案:xmap()

為瞭解決這個問題,JAX提供了一個新的函式xmap()”,它允許我們使用命名軸程式設計的方式來簡化張量運算。xmap()`函式可以自動管理張量的軸索引,從而使得程式碼更加簡潔和易於維護。

範例:使用xmap()重構程式碼

下面是使用xmap()重構的程式碼範例:

import jax
from jax import random
import jax.numpy as jnp

# 定義一個隨機數生成器
rng_key = random.PRNGKey(42)

# 定義一個dot積函式
def dot(v1, v2):
    return jnp.vdot(v1, v2)

# 生成兩個隨機向量
vs = random.normal(rng_key, shape=(20_000_000, 3))
v1s = vs[:10_000_000, :].T

# 使用xmap()進行平行計算
result = jax.xmap(dot, in_axes=(0, 0), out_axes=0)(v1s, v1s)

在這個範例中,我們使用xmap()函式來進行平行計算。xmap()函式自動管理了張量的軸索引,從而使得程式碼更加簡潔和易於維護。

內容解密:

  • jax.xmap()函式用於進行平行計算。
  • in_axes=(0, 0)引數指定了輸入張量的軸索引。
  • out_axes=0引數指定了輸出張量的軸索引。
  • dot函式是進行dot積的函式。
  • v1sv1s是輸入張量。

圖表翻譯:

  flowchart TD
    A[開始] --> B[定義dot積函式]
    B --> C[生成隨機向量]
    C --> D[使用xmap()進行平行計算]
    D --> E[輸出結果]

圖表翻譯:

  • 圖表描述了使用xmap()進行平行計算的流程。
  • 開始節點代表了程式的開始。
  • 定義dot積函式節點代表了定義dot積函式的步驟。
  • 生成隨機向量節點代表了生成隨機向量的步驟。
  • 使用xmap()進行平行計算節點代表了使用xmap()進行平行計算的步驟。
  • 輸出結果節點代表了輸出結果的步驟。

使用命名張量和 xmap() 重新實作程式碼

首先,我們需要了解原始程式碼的目標:對兩組向量進行點積運算,並利用 JAX 的 pmapvmap 函式來實作平行化。原始程式碼使用了 jax.pmapjax.vmap 來分別對向量群和向量本身進行對映。

步驟 1:定義命名張量

import jax
import jax.numpy as jnp

# 定義原始張量
v1s = jnp.random.rand(3, 10_000_000)
v2s = v1s.T

# 命名張量
v1s_named = jnp.array(v1s, dtype=jnp.float32)
v2s_named = jnp.array(v2s, dtype=jnp.float32)

步驟 2:重塑張量以便於平行化

# 重塑張量以便於平行化
v1sp = v1s_named.reshape((v1s_named.shape[0], 8, v1s_named.shape[1]//8))
v2sp = v2s_named.reshape((v2s_named.shape[0], 8, v2s_named.shape[1]//8))

步驟 3:定義點積函式

# 定義點積函式
def dot(x, y):
    return jnp.sum(x * y, axis=1)

步驟 4:使用 xmap() 實作平行化

# 使用 xmap() 實作平行化
from jax.experimental import maps

xmap = maps.xmap

# 定義平行化維度
in_axes = (1, 1)

# 使用 xmap() 對點積函式進行平行化
dot_parallel = xmap(dot, in_axes=in_axes, out_axes=(0,), axis_resources={'vector': 'groups'})

# 執行平行化點積運算
x_pmap = dot_parallel(v1sp, v2sp)

步驟 5:重塑結果

# 重塑結果
x_pmap = x_pmap.reshape((x_pmap.shape[0]*x_pmap.shape[1]))

結果分析

最終結果 x_pmap 的形狀為 (10_000_000,),表示我們成功地對兩組向量進行了點積運算,並利用 JAX 的 xmap() 函式實作了平行化。

使用 xmap() 進行向量化運算

在進行向量化運算時,我們可以使用 xmap() 函式來取代巢狀的 pmap()vmap() 函式。xmap() 可以對計算進行向量化,就像 vmap() 一樣,並且它不會執行任何平行化操作;程式碼在單個裝置上執行。

目前,我們的範例與使用巢狀 pmap()vmap() 的程式碼不等同。稍後,我們將在下一節中新增平行化部分。

命名軸的新增和移除

當使用 xmap() 時,新增和移除命名軸的過程如圖 D.3 所示。請忽略新增命名軸後張量的視覺大小增加;這只是為了在盒子中容納更多文字。張量的大小(以元素數為單位)保持不變。

內容解密:

在這個範例中,我們使用 xmap() 來向量化計算。xmap() 函式可以在指定的維度上進行向量化運算。在這種情況下,我們是在 “device” 和 “batch” 維度上進行向量化。

import jax
import jax.numpy as jnp

# 定義一個簡單的函式
def dot(x, y):
    return jnp.dot(x, y)

# 使用 xmap() 進行向量化運算
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])
result = jax.xmap(dot, in_axes=(0, 0))(x, y)

圖表翻譯:

此圖示了使用 xmap() 時新增和移除命名軸的過程。左側的張量沒有命名軸,而右側的張量增加了 “device” 和 “batch” 的命名軸。透過新增命名軸,我們可以更容易地控制張量的維度和運算。

  flowchart TD
    A[原始張量] --> B[新增命名軸]
    B --> C[具有命名軸的張量]
    C --> D[移除命名軸]
    D --> E[原始張量]

在這個過程中,xmap() 函式自動處理了張量的維度和命名軸,讓我們可以專注於計算本身。這使得程式碼更容易閱讀和維護。

高效矩陣運算:vmap()函式的應用

在深度學習和矩陣運算中,高效的向量化操作對於提升計算速度和降低記憶體使用至關重要。PyTorch中的vmap()函式提供了一種高效的方式來實作向量化操作,尤其是在需要對多個輸入執行相同操作的情況下。

vmap()函式的基本概念

vmap()函式是一種高階別的向量化工具,它可以將一個函式應用到多個輸入上,而不需要顯式地使用迴圈或列表推導。這使得程式碼更加簡潔和易於維護。

兩個vmap()-like向量化的例子

下面,我們將展示兩個使用vmap()函式進行向量化的例子。

示例1:基本向量化

假設我們有兩個張量v1sv2s,分別具有形狀(3, 10M)。我們想要計算這兩個張量的點積,並將結果 reshape 成 (3, 8, 1.25M) 的形狀。

import torch

# 定義輸入張量
v1s = torch.randn(3, 10**6)
v2s = torch.randn(3, 10**6)

# 使用vmap()函式進行向量化
def dot_product(v1, v2):
    return torch.dot(v1, v2)

# 對每個batch進行點積運算
result = torch.vmap(dot_product)(v1s, v2s)

# 對結果進行reshape
result_reshaped = result.view(3, 8, int(1.25*10**6))

示例2:高階別向量化

在這個示例中,我們將展示如何使用vmap()函式來實作更複雜的向量化操作。假設我們有兩個張量v1spv2sp,分別具有形狀(3, 8, 1.25M)。我們想要計算這兩個張量的元素-wise 乘積,並將結果 reshape 成 (3, 8, 1.25M) 的形狀。

# 定義輸入張量
v1sp = torch.randn(3, 8, int(1.25*10**6))
v2sp = torch.randn(3, 8, int(1.25*10**6))

# 使用vmap()函式進行向量化
def element_wise_multiply(v1, v2):
    return v1 * v2

# 對每個batch進行元素-wise 乘積運算
result = torch.vmap(element_wise_multiply)(v1sp, v2sp)

# 對結果進行reshape(如果需要)
result_reshaped = result.view(3, 8, int(1.25*10**6))

使用xmap()和命名軸進行資料處理

在進行大型轉置陣列的資料處理時,使用xmap()函式可以有效地對映函式到指定的維度上。以下是使用xmap()和命名軸進行資料處理的示例:

步驟1:匯入必要的模組

首先,需要匯入jax.experimental.maps模組中的xmap()函式。

from jax.experimental.maps import xmap

步驟2:定義輸入資料

定義兩個大型轉置陣列v1s和v2s,分別具有(3, 10000000)的形狀。

vs = random.normal(rng_key, shape=(20_000_000,3))
v1s = vs[:10_000_000,:].T
v2s = vs[10_000_000:,:].T

步驟3:重塑輸入資料

將v1s和v2s重塑為(3, 8, 1250000)的形狀,以便進行後續的資料處理。

v1sp = v1s.reshape((v1s.shape[0], 8, v1s.shape[1]//8))
v2sp = v2s.reshape((v2s.shape[0], 8, v2s.shape[1]//8))

步驟4:使用xmap()進行資料處理

使用xmap()函式對映dot()函式到指定的維度上,具體地說,就是將dot()函式對映到"device"和"batch"維度上。

f = xmap(dot,
         in_axes=(
             {1:'device', 2:'batch'},
             {1:'device', 2:'batch'}
         ),
         out_axes=['device', 'batch',...]
        )

內容解密:

在上述程式碼中,xmap()函式的in_axes引數指定了dot()函式的輸入維度,out_axes引數指定了輸出的維度。透過這種方式,可以有效地控制資料處理的流程和維度。

圖表翻譯:

以下是使用Mermaid語法繪製的xmap()函式的流程圖:

  flowchart TD
    A[輸入資料] --> B[重塑資料]
    B --> C[使用xmap()進行資料處理]
    C --> D[輸出結果]

在這個流程圖中,輸入資料首先被重塑為適合的形狀,然後使用xmap()函式進行資料處理,最終得到輸出結果。

使用 xmap() 進行張量運算的優點

在進行張量運算時,使用 xmap() 函式可以帶來多個優點。首先,xmap() 只需要一個函式轉換,而不是像 vmap()pmap() 一樣需要兩個。這使得程式碼更加簡潔和易於維護。

其次,使用 xmap() 可以避免手動追蹤不同張量維度的索引。只需要在 in_axes 引數中提供要轉換的軸,然後讓被包裝的函式處理其餘的位置軸。對於傳回的張量,out_axes 引數可以將指定的軸轉換回位置軸。

在前面的例子中,我們使用了 xmap() 來對映指定的維度。這個例子展示瞭如何使用單一的 xmap() 呼叫來替代兩個巢狀的 vmap() 呼叫。這段程式碼不使用平行化,因此我想強調它與使用兩個 vmap() 呼叫的等價性,而不是 pmap()vmap()

Einstein 總和(Einsum)

Einstein 總和(或 einsum)是一個有用的函式,用於表達點積、外積、矩陣-向量和矩陣-矩陣乘法。它是對多維度產品的一般化。

例如,矩陣乘法可以使用 einsum() 函式以以下方式表達:

import numpy as np

# 定義兩個矩陣
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])

# 使用 einsum() 進行矩陣乘法
C = np.einsum('ik,kj->ij', A, B)

print(C)

這將輸出:

[[19 22]
 [43 50]]

這等同於使用 @ 運算子或 np.matmul() 函式進行矩陣乘法。

使用 xmap() 進行 Einstein 總和

xmap() 也可以用於 Einstein 總和。透過指定輸入和輸出的軸,xmap() 可以將 einsum() 函式應用於多維度陣列。

import jax
import jax.numpy as jnp

# 定義兩個矩陣
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

# 使用 xmap() 進行 Einstein 總和
C = jax.xmap(lambda a, b: jnp.einsum('ik,kj->ij', a, b), in_axes=(0, 0), out_axes=0)(A, B)

print(C)

這將輸出相同的結果,如上所述。

透過使用 xmap() 和 einsum(),您可以簡單地表達複雜的張量運算,並利用 JAX 的自動向量化和平行化功能。

瞭解Einsum和Named-Axis程式設計

Einsum是一種強大的張量運算工具,根據愛因斯坦的總和符號。它提供了一種簡潔的方式來表示許多張量操作,使用了一個簡單的特定域語言(DSL)。Einsum可以接受多個輸入張量和一個特殊的格式字串,例如"ij,jk->ik",其中左邊的部分對應於輸入引數,右邊的部分對應於輸出。格式字串標籤張量維度,對於所有輸入張量和輸出,以每個字母對應於一個維度。

Einsum的工作原理

Einsum對於共同出現在輸入和輸出的維度(自由索引,例如i和k)建立外層迴圈,對於其他維度(總和索引,例如j)建立內層迴圈並進行總和。例如,字串"ij,jk->ik"表示兩個外層迴圈:一個沿著第一個張量的第一維度(i),另一個沿著第二個張量的第二維度(k)。對於第二個張量的第一維度和第二個張量的第二維度(均為j),它執行元素-wise乘積並總和。

Named-Axis程式設計

Named-Axis程式設計是一種強大的程式設計正規化,允許您使用有意義的軸名稱來操作張量。JAX中的xmap()函式是Named-Axis程式設計的一個實作,它將有名軸視為一等公民。使用xmap(),您可以將函式應用於具有不同有名軸的輸入,並獲得具有廣播有名軸的結果。

廣播規則

當二元操作應用於具有不同有名軸的引數時,這些軸將被廣播。例如,如果一個運算元具有有名軸"a",另一個運算元具有有名軸"b",則二元操作(例如加法)的結果將具有兩個軸"a"和"b"。所有具有相同名稱的軸預計在廣播操作中具有相同的形狀或可廣播的形狀。

示例

以下示例展示瞭如何使用xmap()將函式應用於具有不同有名形狀的輸入,並獲得具有廣播有名軸的結果。

import numpy as np
import jax
import jax.numpy as jnp
from jax import random

# 定義一個函式,將兩個引數相加
def add(a, b):
    return a + b

# 建立兩個隨機陣列
rng_key = random.PRNGKey(0)
image = random.normal(rng_key, shape=(480, 640, 3))
filters = random.normal(rng_key, shape=(5, 3, 3))

# 使用xmap()將add函式應用於image和filters
result = jax.xmap(add, in_axes=(['x', 'y', 'c'], ['kx', 'ky', 'c']), out_axes=['x', 'y', 'kx', 'ky', 'c'])(image, filters)

print(result.shape)

這個示例展示瞭如何使用xmap()add()函式應用於具有不同有名形狀的imagefilters陣列,並獲得具有廣播有名軸的結果。結果陣列的形狀是(480, 640, 5, 3, 3),其中包含了兩個輸入陣列的所有維度。

影像濾波器應用

在影像處理中,濾波器是一種重要的工具,能夠對影像進行各種變換和增強。以下是如何使用Python和相關函式庫來實作影像濾波器的應用。

生成隨機RGB影像

首先,我們需要生成一張隨機的RGB影像。這張影像的大小為640x480畫素。

import numpy as np

# 生成隨機RGB影像
image = np.random.rand(480, 640, 3)

生成濾波器

接下來,我們需要生成五個3x3的濾波器矩陣。

# 生成五個3x3的濾波器矩陣
filters = np.random.rand(5, 3, 3)

定義濾波器應用函式

現在,我們需要定義一個函式,該函式能夠將一個濾波器應用到影像的一個通道上。

from scipy.signal import convolve2d

def apply_filter(channel, kernel):
    return convolve2d(channel, kernel, mode="same")

對多個通道和濾波器進行對映

為了將多個濾波器應用到影像的多個通道上,我們可以使用xmap函式對通道和濾波器維度進行對映。

import xarray as xr

apply_filters_to_image = xr.apply_ufunc(
    apply_filter,
    image,
    filters,
    input_core_dims=[["channel"], ["filter"]],
    output_core_dims=[["filter", "channel"]],
    vectorize=True,
)

執行濾波器應用

最後,我們可以執行濾波器應用並檢視結果。

res = apply_filters_to_image
print(res.shape)

這將輸出(5, 480, 640, 3),表示濾波器應用後的結果是一個五維陣列,每個維度分別代表濾波器、影像高度、影像寬度和顏色通道。

圖表翻譯:

以下是Mermaid圖表,用於視覺化描述濾波器應用過程:

  flowchart TD
    A[生成隨機RGB影像] --> B[生成濾波器]
    B --> C[定義濾波器應用函式]
    C --> D[對多個通道和濾波器進行對映]
    D --> E[執行濾波器應用]
    E --> F[輸出結果]

這個圖表展示了從生成隨機RGB影像到執行濾波器應用並輸出結果的整個過程。

命名軸的應用:過濾器維度和通道維度

在進行多維陣列運算時,命名軸(named axes)可以提供更大的靈活性和可讀性。例如,在應用一組過濾器(filters)時,我們可以將過濾器維度放在輸出陣列的第一個位置,通道維度放在最後一個位置。這樣,輸出的陣列就具有明確的維度名稱(filters, h, w, channels)。

實驗性平行化

我們可以將二維卷積函式應用於具有不同命名軸的引數上。二維卷積函式作用於單個影像和單個核。為了增加影像通道和分離過濾器核的維度,我們在函式的第一個引數上增加了一個命名維度,第二個引數上增加了另一個命名維度。結果陣列具有這兩個維度,每個維度都被廣播到另一張張量上。對於輸出值,我們選擇了這些維度的特定位置:過濾器維度成為第一個維度,顏色通道維度成為最後一個維度。因此,結果張量可以被視為五個顏色影像堆積疊在一起。每個影像由三個顏色通道組成,結果是將單獨的矩陣過濾器應用於原始影像。

降維操作

命名軸應該等同於位置軸,因此您可以在命名軸上進行降維操作。目前,只有JAX NumPy介面中的少數函式支援命名軸,包括jnp.sum()jnp.max()jnp.min()

降維範例

以下範例中,我們有一個具有命名軸"row"和"col"的二維矩陣。然後,我們執行一個函式,該函式在"row"軸上執行降維操作,計算對應元素的總和(使用jnp.sum())。另一個軸被保留,結果是我們獲得每列的總和(因為行被還原)。

import jax.numpy as jnp
from jax import xmap

# 定義一個具有命名軸的二維矩陣
C = jnp.array([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

# 定義一個在"row"軸上執行降維操作的函式
f = xmap(
    lambda x: jnp.sum(x, axis=['row']),
    in_axes=['row', 'col'],
    out_axes=['col']
)

# 執行函式
result = f(C)
print(result)  # Output: Array([12, 15, 18], dtype=int32)

集體操作

所有在pmapped函式內部工作的集體操作也可以與命名軸一起使用。我們可以重寫我們的全域陣列正規化函式。

# 定義一個二維陣列
arr = jnp.array(range(8)).reshape(2, 4)
print(arr)
# Output: Array([[0, 1, 2, 3],
#                 [4, 5, 6, 7]], dtype=int32)

透過這些範例,我們可以看到命名軸如何提高多維陣列運算的靈活性和可讀性,並使得複雜的運算變得更加直觀和易於理解。

使用 XMap 和命名軸程式設計計算矩陣行和

在這個例子中,我們將使用 JAX 的 xmap 和命名軸程式設計來計算矩陣的行和。首先,我們需要了解 xmap 的功能,它可以對多維陣列進行對映操作。

命名軸宣告

在進行計算之前,我們需要宣告命名軸。這裡我們宣告兩個命名軸:'rows''cols',分別代表矩陣的行和列。

計算行和

接下來,我們使用 jax.pmap 來計算矩陣的行和。jax.pmap 可以對陣列進行平行操作,從而提高計算效率。具體來說,我們使用 jax.lax.psum 來計算沿著指定軸的和。

import jax
import jax.numpy as jnp

# 宣告命名軸
axis_name = ('rows', 'cols')

# 定義計算函式
def calculate_row_sum(x):
    return x / jax.lax.psum(x, axis_name=axis_name)

# 使用 xmap 進行計算
n_pmap = jax.pmap(calculate_row_sum, axis_name='cols')

生成小矩陣

為了演示這個過程,我們生成一個小矩陣。

# 生成小矩陣
matrix = jnp.array([[1, 2, 3], [4, 5, 6]])

執行計算

現在,我們可以使用 n_pmap 來計算矩陣的行和。

# 執行計算
result = n_pmap(matrix)

內容解密:

在上面的程式碼中,jax.pmapjax.lax.psum 是關鍵函式。jax.pmap 用於對陣列進行平行操作,而 jax.lax.psum 則用於計算沿著指定軸的和。透過使用命名軸, 我們可以方便地指定計算的軸向,從而簡化了矩陣運算的過程。

圖表翻譯:

  flowchart TD
    A[輸入矩陣] --> B[宣告命名軸]
    B --> C[定義計算函式]
    C --> D[使用 xmap 進行計算]
    D --> E[生成結果]

這個流程圖描述了從輸入矩陣到生成結果的整個過程,包括宣告命名軸、定義計算函式、使用 xmap 進行計算等步驟。

使用 xmap() 來簡化巢狀 pmap() 處理

在前面的例子中,我們使用 pmap() 來實作巢狀平行處理。然而,使用 xmap() 可以更簡單地實作相同的功能。

首先,我們定義一個簡單的陣列 arr

import jax.numpy as jnp

arr = jnp.array(range(8)).reshape(2, 4)
print(arr)

輸出:

Array([[0, 1, 2, 3],
       [4, 5, 6, 7]], dtype=int32)

接下來,我們使用 xmap() 來實作巢狀平行處理:

n_xmap = xmap(
    lambda x: x / jax.lax.psum(x, axis_name=('rows', 'cols')),
    in_axes=['rows', 'cols'],
    out_axes=['rows', 'cols']
)
result = jnp.sum(n_xmap(arr))
print(result)

輸出:

Array(1., dtype=float32)

這個結果與使用 pmap() 的結果相同,但是使用 xmap() 的程式碼更加簡潔。

平行化和硬體網格

雖然 xmap() 可以簡化巢狀平行處理,但是它目前還沒有使用平行裝置。它類別似於兩個 vmap() 呼叫在單個裝置上執行。

要平行化計算,你需要使用資源軸。資源軸是一種控制 xmap() 評估計算的方式。

每個由 xmap() 引入的軸都分配給一個或多個資源軸。資源軸來自硬體網格,一個 n 維陣列具有命名軸的裝置。

從技術上講,硬體網格是一個兩部分物件:

  1. 一個 n 維陣列的 JAX 裝置物件。
  2. 這些是你用於建立硬體網格的相同物件。

從效能最佳化視角來看,JAX 提供的 pmap()vmap() 以及 xmap() 為矩陣運算的平行化和向量化提供了強大的工具。透過 pmap(),我們可以有效地利用多個 TPU 進行平行計算,大幅提升運算速度。vmap() 則允許我們對向量和矩陣進行向量化操作,避免顯式迴圈,簡化程式碼並提升效率。然而,管理軸索引和處理巢狀對映的複雜性一直是開發者面臨的挑戰。xmap() 的出現,結合命名軸程式設計,為解決這個問題提供了優雅的方案。它不僅簡化了程式碼,提高了可讀性和可維護性,更重要的是,它允許我們更精確地控制資料在不同維度上的對映和運算,從而實作更精細的效能調校。對於追求極致效能的深度學習和大型矩陣運算任務而言,xmap() 結合命名軸的程式設計正規化無疑是未來的主流方向,值得深入研究和應用。技術團隊應著重於掌握命名軸的設計原則和 xmap() 的使用方法,才能最大程度地釋放 JAX 的效能潛力,並構建更高效、更易維護的大規模矩陣運算應用。