JAX 是 Google 和 NVIDIA 聯合開發的高性能數值計算庫,這兩年 JAX 生態快速發展,周邊工具鏈也日益完善了。如果你用過 NumPy 或 PyTorch,但還沒接觸過 JAX,這篇文章能幫助你快速上手。
圍繞 JAX 已經涌現出一批好用的庫:Flax 用來搭神經網絡,Optax 處理梯度和優化,Equinox 提供類似 PyTorch 的接口,Haiku 則是簡潔的函數式 API,Jraph 用於圖神經網絡,RLax 是強化學習工具庫,Chex 提供測試和調試工具,Orbax 負責模型檢查點和持久化。
純函數是硬需求
JAX 對函數有個基本要求:必須是純函數。這意味着函數不能有副作用,對同樣的輸入必須總是返回同樣的輸出。
這個約束來自函數式編程範式。JAX 內部做各種變換(編譯、自動微分等)依賴純函數的特性,用不純的函數可能導致錯誤或靜默失敗,結果完全不對。
# 純函數,沒問題
def pure_addition(a, b):
return a + b
# 不純的函數,JAX 不接受
counter = 0
def impure_addition(a, b):
global counter
counter += 1
return a + b
JAX NumPy 與原生 NumPy
JAX 提供了類 NumPy 的接口,核心優勢在於能自動高效地在 CPU、GPU 甚至 TPU 上運行,支持本地或分佈式執行。這套能力來自 XLA(Accelerated Linear Algebra) 編譯器,它把 JAX 代碼翻譯成針對不同硬件優化的機器碼。
NumPy 默認只在 CPU 上跑,JAX NumPy 則不同。用法上兩者很相似,這也是 JAX 容易上手的原因。
# JAX 也差不多
import jax.numpy as jnp
print(jnp.sqrt(4))# NumPy 的寫法
import numpy as np
print(np.sqrt(4))
# JAX 也差不多
import jax.numpy as jnp
print(jnp.sqrt(4))
常見的操作兩者看起來基本一樣:
import numpy as np
import jax.numpy as jnp
# 創建數組
np_a = np.array([1.0, 2.0, 3.0])
jnp_a = jnp.array([1.0, 2.0, 3.0])
# 元素級操作
print(np_a + 2)
print(jnp_a + 2)
# 廣播
np_b = np.array([[1, 2, 3]])
jnp_b = jnp.array([[1, 2, 3]])
print(np_b + np.arange(3))
print(jnp_b + jnp.arange(3))
# 求和
print(np.sum(np_a))
print(jnp.sum(jnp_a))
# 平均值
print(np.mean(np_a))
print(jnp.mean(jnp_a))
# 點積
print(np.dot(np_a, np_a))
print(jnp.dot(jnp_a, jnp_a))
但有個重要差異需要注意:
JAX 數組是不可變的,對數組的修改操作會返回新數組而不是改變原數組。
NumPy 數組則可以直接修改:
import numpy as np
x = np.array([1, 2, 3])
x[0] = 10 # 直接修改,沒問題
JAX 這邊就不行了:
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
x[0] = 10 # 報錯
但是JAX 提供了專門的 API 來處理這種情況,通過返回一個新數組的方式實現"修改":
z=x.at[idx].set(y)
完整的例子:
x = jnp.array([1, 2, 3])
y = x.at[0].set(10)
print(y) # [10, 2, 3]
print(x) # [1, 2, 3](沒變)
JIT 編譯加速
即時編譯(JIT)是 JAX 裏一個核心特性,通過 XLA 把 Python/JAX 代碼編譯成優化後的機器碼。
直接用 Python 解釋器跑函數會很慢。加上
@jit
裝飾器後,函數會被編譯成快速的原生代碼:
from jax import jit
# 不編譯的版本
def square(x):
return x * x
# 編譯過的版本
@jit
def jit_square(x):
return x * x
jit_square
快好幾個數量級。函數首次調用時,JIT 引擎會:
- 追蹤函數邏輯,構建計算圖
- 把圖編譯成優化的 XLA 代碼
- 緩存編譯結果
- 後續調用直接用緩存的版本
自動微分
JAX 的 grad 模塊能自動計算函數的導數。
import jax.numpy as jnp
from jax import grad
# 定義函數:f(x) = x² + 2x + 2
def f(x):
return x**2 + 2 * x + 2
# 計算導數
df_dx = grad(f)
# 在 x = 2.0 處求值
print(df_dx(2.0)) # 6.0
隨機數處理
NumPy 用全局隨機狀態生成隨機數。每次調用
np.random.random()
時,NumPy 會更新隱藏的內部狀態:
import numpy as np
np.random.random()
# 0.9539264374520571
JAX 的做法完全不同。作為純函數庫,它不能維護全局狀態,所以要求顯式傳入一個偽隨機數生成器(PRNG)密鑰。每次生成隨機數前要先分割密鑰:
from jax import random
# 初始化密鑰
key = random.PRNGKey(0)
# 每次生成前分割
key, subkey = random.split(key)
# 從正態分佈採樣
x = random.normal(subkey, ())
print(x) # -2.4424558
# 從均勻分佈採樣
key, subkey = random.split(key)
u = random.uniform(subkey, (), minval=0.0, maxval=1.0)
print(u) # 0.104290366
一個常見的坑:同一個密鑰生成的隨機數始終相同。
# 用同一個 subkey,結果重複
x = random.normal(subkey, ())
print(x) # -2.4424558
x = random.normal(subkey, ())
print(x) # -2.4424558(還是這個值)
所以要記住總是用新密鑰。
向量化:vmap
vmap 自動把函數轉換成能處理批量數據的版本。邏輯上就像循環遍歷每個樣本,但執行效率遠高於 Python 循環。
import jax.numpy as jnp
from jax import vmap
def f(x):
return x * x + 1
arr = jnp.array([1., 2., 3., 4.])
# Python 循環(慢)
outputs_loop = jnp.array([f(x) for x in arr])
# vmap 版本(快)
f_vectorized = vmap(f)
outputs_vmap = f_vectorized(arr)
並行化:pmap
pmap 不同於 vmap。vmap 在單個設備上做批處理,pmap 把計算分散到多個設備(GPU/TPU 核心),每個設備處理輸入的一部分。
VMAP:單設備批處理向量化
PMAP:跨多設備並行執行
import jax.numpy as jnp
from jax import pmap
# 查看可用設備
print(jax.devices()) # [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)]
def f(x):
return x * x + 1
arr = jnp.array([1., 2., 3., 4.])
# pmap 會把數組分配到不同設備
ys = pmap(f)(arr)
PyTrees
PyTree 在 JAX 裏是個常見的概念:任何嵌套的 Python 容器(列表、字典、元組等)加上基本類型的組合。JAX 裏用它來組織模型參數、優化器狀態、梯度等。
import jax.numpy as jnp
from jax import tree_util as tu
# 構建 PyTree
pytree = {
"a": jnp.array([1, 2]),
"b": [jnp.array([3, 4]), 5]
}
# 獲取所有葉子節點
leaves = tu.tree_leaves(pytree)
# 對每個葉子應用函數
doubled = tu.tree_map(lambda x: x * 2, pytree)
Optax:梯度處理和優化
Optax 是 JAX 生態裏的優化庫。它包含損失函數、優化器、梯度變換、學習率調度等一套工具。
損失函數:
logits = jnp.array([[2.0, -1.0]])
labels_onehot = jnp.array([[1.0, 0.0]])
labels_int = jnp.array([0])
# Softmax 交叉熵(獨熱編碼)
loss_softmax_onehot = optax.softmax_cross_entropy(logits, labels_onehot).mean()
# Softmax 交叉熵(整數標籤)
loss_softmax_int = optax.softmax_cross_entropy_with_integer_labels(logits, labels_int).mean()
# 二元交叉熵
loss_bce = optax.sigmoid_binary_cross_entropy(logits, labels_onehot).mean()
# L2 損失
loss_l2 = optax.l2_loss(jnp.array([1., 2.]), jnp.array([0., 1.])).mean()
# Huber 損失
loss_huber = optax.huber_loss(jnp.array([1.,2.]), jnp.array([0.,1.])).mean()
優化器:
# SGD
opt_sgd = optax.sgd(learning_rate=1e-2)
# SGD with momentum
opt_momentum = optax.sgd(learning_rate=1e-2, momentum=0.9)
# RMSProp
opt_rmsprop = optax.rmsprop(1e-3)
# Adafactor
opt_adafactor = optax.adafactor(learning_rate=1e-3)
# Adam
opt_adam = optax.adam(1e-3)
# AdamW
opt_adamw = optax.adamw(1e-3, weight_decay=1e-4)
梯度變換:
# 梯度裁剪
tx_clip = optax.clip(1.0)
# 全局梯度範數裁剪
tx_clip_global = optax.clip_by_global_norm(1.0)
# 權重衰減(L2)
tx_weight_decay = optax.add_decayed_weights(1e-4)
# 添加梯度噪聲
tx_noise = optax.add_noise(0.01)
學習率調度:
# 指數衰減
lr_exp = optax.exponential_decay(init_value=1e-3, transition_steps=1000, decay_rate=0.99)
# 餘弦衰減
lr_cos = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=10_000)
# 線性預熱
lr_linear = optax.linear_schedule(init_value=0.0, end_value=1e-3, transition_steps=500)
更新步驟:
# 計算梯度
loss, grads = jax.value_and_grad(loss_fn)(params)
# 生成優化器更新
updates, opt_state = optimizer.update(grads, opt_state)
# 應用更新
params = optax.apply_updates(params, updates)
鏈式組合:
# 把多個操作鏈起來
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # 梯度裁剪
optax.add_decayed_weights(1e-4), # 權重衰減
optax.adam(1e-3) # Adam 優化
)
Flax 與神經網絡
JAX 本身只是數值計算庫,Flax 在其基礎上提供了神經網絡定義和訓練的高級 API。Flax 代碼風格接近 PyTorch,如果你用過 PyTorch 會很快上手。
Flax 提供了豐富的層和操作。基礎層 包括全連接層
Dense
、卷積
Conv
、嵌入
Embed
、多頭注意力
MultiHeadDotProductAttention
等:
flax.linen.Dense(features=128)
flax.linen.Conv(features=64, kernel_size=(3, 3))
flax.linen.Embed(num_embeddings=10000, features=256)
flax.linen.MultiHeadDotProductAttention(num_heads=8)
flax.linen.SelfAttention(num_heads=8)
歸一化 支持多種方式:
flax.linen.BatchNorm()
flax.linen.LayerNorm()
flax.linen.GroupNorm(num_groups=32)
flax.linen.RMSNorm()
激活和 Dropout:
flax.linen.relu(x)
flax.linen.gelu(x)
flax.linen.sigmoid(x)
flax.linen.tanh(x)
flax.linen.Dropout(rate=0.1)
池化:
flax.linen.avg_pool(x, window_shape=(2,2), strides=(2,2))
flax.linen.max_pool(x, window_shape=(2,2), strides=(2,2))
循環層:
flax.linen.LSTMCell()
flax.linen.GRUCell()
flax.linen.OptimizedLSTMCell()
下面是一個簡單的多層感知機(MLP)例子:
import jax
import jax.numpy as jnp
from flax import linen as nn
class MLP(nn.Module):
features: list
@nn.compact
def __call__(self, x):
for f in self.features[:-1]:
x = nn.Dense(f)(x)
x = nn.relu(x)
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([32, 16, 10])
key = jax.random.PRNGKey(0)
# 輸入:batch_size=1, 特徵數=4
x = jnp.ones((1, 4))
# 初始化參數
params = model.init(key, x)
# 前向傳播
y = model.apply(params, x)
print("Input:", x)
# Input: [[1. 1. 1. 1.]]
print("Input shape:", x.shape)
# Input shape: (1, 4)
print("Output:", y)
# Output: [[ 0.51415515 0.36979797 0.6212194 -0.74496573 -0.8318489 0.6590691 0.89224255 0.00737424 0.33062232 0.34577468]]
print("Output shape:", y.shape)
# Output shape: (1, 10)
Flax 用
@nn.compact
裝飾器,讓你在
__call__
方法裏直接定義層。參數是獨立於模型對象存儲的,需要通過
init
方法顯式初始化,然後在
apply
方法中使用。
總結
JAX 的出現解決了一個長期存在的問題:如何讓 Python 科學計算既保持靈活性,又能獲得接近 C/CUDA 的性能。
不過 JAX 的學習曲線確實比 PyTorch 陡。純函數的約束、不可變數組的特性、顯式密鑰管理等細節起初會有些彆扭。但一旦習慣會發現它帶來的優雅和靈活性。
https://avoid.overfit.cn/post/a16194fdc3ea450f858515d7cb3d49c4
作者:Ashish Bamania