博客 / 詳情

返回

JAX 核心特性詳解:純函數、JIT 編譯、自動微分等十大必知概念

JAX 是 Google 和 NVIDIA 聯合開發的高性能數值計算庫,這兩年 JAX 生態快速發展,周邊工具鏈也日益完善了。如果你用過 NumPy 或 PyTorch,但還沒接觸過 JAX,這篇文章能幫助你快速上手。

圍繞 JAX 已經涌現出一批好用的庫:Flax 用來搭神經網絡,Optax 處理梯度和優化,Equinox 提供類似 PyTorch 的接口,Haiku 則是簡潔的函數式 API,Jraph 用於圖神經網絡,RLax 是強化學習工具庫,Chex 提供測試和調試工具,Orbax 負責模型檢查點和持久化。

純函數是硬需求

JAX 對函數有個基本要求:必須是純函數。這意味着函數不能有副作用,對同樣的輸入必須總是返回同樣的輸出。

這個約束來自函數式編程範式。JAX 內部做各種變換(編譯、自動微分等)依賴純函數的特性,用不純的函數可能導致錯誤或靜默失敗,結果完全不對。

 # 純函數,沒問題
def pure_addition(a, b):  
  return a + b  
      
# 不純的函數,JAX 不接受
counter = 0  
def impure_addition(a, b):  
  global counter  
  counter += 1  
   return a + b

JAX NumPy 與原生 NumPy

JAX 提供了類 NumPy 的接口,核心優勢在於能自動高效地在 CPU、GPU 甚至 TPU 上運行,支持本地或分佈式執行。這套能力來自 XLA(Accelerated Linear Algebra) 編譯器,它把 JAX 代碼翻譯成針對不同硬件優化的機器碼。

NumPy 默認只在 CPU 上跑,JAX NumPy 則不同。用法上兩者很相似,這也是 JAX 容易上手的原因。

 
# JAX 也差不多
import jax.numpy as jnp  
      
print(jnp.sqrt(4))# NumPy 的寫法
import numpy as np  
      
print(np.sqrt(4))
# JAX 也差不多
import jax.numpy as jnp  
      
 print(jnp.sqrt(4))

常見的操作兩者看起來基本一樣:

 import numpy as np  
import jax.numpy as jnp  
      
# 創建數組
np_a = np.array([1.0, 2.0, 3.0])  
jnp_a = jnp.array([1.0, 2.0, 3.0])  
      
# 元素級操作
print(np_a + 2)  
print(jnp_a + 2)  
      
# 廣播
np_b = np.array([[1, 2, 3]])  
jnp_b = jnp.array([[1, 2, 3]])  
print(np_b + np.arange(3))  
print(jnp_b + jnp.arange(3))  
      
# 求和
print(np.sum(np_a))  
print(jnp.sum(jnp_a))  
      
# 平均值
print(np.mean(np_a))   
print(jnp.mean(jnp_a))  
      
# 點積
print(np.dot(np_a, np_a))   
 print(jnp.dot(jnp_a, jnp_a))

但有個重要差異需要注意:

JAX 數組是不可變的,對數組的修改操作會返回新數組而不是改變原數組。

NumPy 數組則可以直接修改:

 import numpy as np  
       
 x = np.array([1, 2, 3])  
 x[0] = 10  # 直接修改,沒問題

JAX 這邊就不行了:

 import jax.numpy as jnp  
       
 x = jnp.array([1, 2, 3])  
 x[0] = 10  # 報錯

但是JAX 提供了專門的 API 來處理這種情況,通過返回一個新數組的方式實現"修改":

 z=x.at[idx].set(y)

完整的例子:

 x = jnp.array([1, 2, 3])  
 y = x.at[0].set(10)  
       
 print(y)  # [10, 2, 3]  
 print(x)  # [1, 2, 3](沒變)

JIT 編譯加速

即時編譯(JIT)是 JAX 裏一個核心特性,通過 XLA 把 Python/JAX 代碼編譯成優化後的機器碼。

直接用 Python 解釋器跑函數會很慢。加上

@jit

裝飾器後,函數會被編譯成快速的原生代碼:

 from jax import jit  
      
# 不編譯的版本
def square(x):  
  return x * x   
      
# 編譯過的版本
@jit   
def jit_square(x):  
   return x * x
jit_square

快好幾個數量級。函數首次調用時,JIT 引擎會:

  1. 追蹤函數邏輯,構建計算圖
  2. 把圖編譯成優化的 XLA 代碼
  3. 緩存編譯結果
  4. 後續調用直接用緩存的版本

自動微分

JAX 的 grad 模塊能自動計算函數的導數。

 import jax.numpy as jnp  
from jax import grad  
      
# 定義函數:f(x) = x² + 2x + 2
def f(x):  
  return x**2 + 2 * x + 2  
      
# 計算導數
df_dx = grad(f)  
      
# 在 x = 2.0 處求值
 print(df_dx(2.0))  # 6.0

隨機數處理

NumPy 用全局隨機狀態生成隨機數。每次調用

np.random.random()

時,NumPy 會更新隱藏的內部狀態:

 import numpy as np  
       
 np.random.random()  
 # 0.9539264374520571

JAX 的做法完全不同。作為純函數庫,它不能維護全局狀態,所以要求顯式傳入一個偽隨機數生成器(PRNG)密鑰。每次生成隨機數前要先分割密鑰:

 from jax import random  
      
# 初始化密鑰
key = random.PRNGKey(0)  
      
# 每次生成前分割
key, subkey = random.split(key)  
      
# 從正態分佈採樣
x = random.normal(subkey, ())  
print(x)  # -2.4424558  
      
# 從均勻分佈採樣
key, subkey = random.split(key)  
u = random.uniform(subkey, (), minval=0.0, maxval=1.0)  
 print(u)  # 0.104290366

一個常見的坑:同一個密鑰生成的隨機數始終相同。

 # 用同一個 subkey,結果重複
 x = random.normal(subkey, ())  
 print(x)  # -2.4424558  
       
 x = random.normal(subkey, ())  
 print(x)  # -2.4424558(還是這個值)

所以要記住總是用新密鑰。

向量化:vmap

vmap 自動把函數轉換成能處理批量數據的版本。邏輯上就像循環遍歷每個樣本,但執行效率遠高於 Python 循環。

 import jax.numpy as jnp  
from jax import vmap  
      
def f(x):  
    return x * x + 1  
      
arr = jnp.array([1., 2., 3., 4.])  
      
# Python 循環(慢)
outputs_loop = jnp.array([f(x) for x in arr])  
      
# vmap 版本(快)
f_vectorized = vmap(f)  
 outputs_vmap = f_vectorized(arr)

並行化:pmap

pmap 不同於 vmap。vmap 在單個設備上做批處理,pmap 把計算分散到多個設備(GPU/TPU 核心),每個設備處理輸入的一部分。

VMAP:單設備批處理向量化

PMAP:跨多設備並行執行

 import jax.numpy as jnp  
from jax import pmap  
      
# 查看可用設備
print(jax.devices())  # [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)]  
      
def f(x):  
    return x * x + 1  
      
arr = jnp.array([1., 2., 3., 4.])  
      
# pmap 會把數組分配到不同設備
 ys = pmap(f)(arr)

PyTrees

PyTree 在 JAX 裏是個常見的概念:任何嵌套的 Python 容器(列表、字典、元組等)加上基本類型的組合。JAX 裏用它來組織模型參數、優化器狀態、梯度等。

 import jax.numpy as jnp  
from jax import tree_util as tu

# 構建 PyTree
pytree = {  
    "a": jnp.array([1, 2]),  
    "b": [jnp.array([3, 4]), 5]  
}  
      
# 獲取所有葉子節點
leaves = tu.tree_leaves(pytree)  
      
# 對每個葉子應用函數
 doubled = tu.tree_map(lambda x: x * 2, pytree)

Optax:梯度處理和優化

Optax 是 JAX 生態裏的優化庫。它包含損失函數、優化器、梯度變換、學習率調度等一套工具。

損失函數:

 logits = jnp.array([[2.0, -1.0]])  
labels_onehot = jnp.array([[1.0, 0.0]])  
labels_int = jnp.array([0])  
      
# Softmax 交叉熵(獨熱編碼)
loss_softmax_onehot = optax.softmax_cross_entropy(logits, labels_onehot).mean()  
      
# Softmax 交叉熵(整數標籤)
loss_softmax_int = optax.softmax_cross_entropy_with_integer_labels(logits, labels_int).mean()  
      
# 二元交叉熵
loss_bce = optax.sigmoid_binary_cross_entropy(logits, labels_onehot).mean()  
      
# L2 損失
loss_l2 = optax.l2_loss(jnp.array([1., 2.]), jnp.array([0., 1.])).mean()  
      
# Huber 損失
 loss_huber = optax.huber_loss(jnp.array([1.,2.]), jnp.array([0.,1.])).mean()

優化器:

 # SGD
opt_sgd = optax.sgd(learning_rate=1e-2)  
      
# SGD with momentum
opt_momentum = optax.sgd(learning_rate=1e-2, momentum=0.9)   
      
# RMSProp
opt_rmsprop = optax.rmsprop(1e-3)  
      
# Adafactor
opt_adafactor = optax.adafactor(learning_rate=1e-3)  
      
# Adam
opt_adam = optax.adam(1e-3)  
      
# AdamW
 opt_adamw = optax.adamw(1e-3, weight_decay=1e-4)

梯度變換:

 # 梯度裁剪
tx_clip = optax.clip(1.0)  
      
# 全局梯度範數裁剪
tx_clip_global = optax.clip_by_global_norm(1.0)  
      
# 權重衰減(L2)
tx_weight_decay = optax.add_decayed_weights(1e-4)  
      
# 添加梯度噪聲
 tx_noise = optax.add_noise(0.01)

學習率調度:

 # 指數衰減
 lr_exp = optax.exponential_decay(init_value=1e-3, transition_steps=1000, decay_rate=0.99)  
       
 # 餘弦衰減
 lr_cos = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=10_000)  
       
 # 線性預熱
 lr_linear = optax.linear_schedule(init_value=0.0, end_value=1e-3, transition_steps=500)

更新步驟:

 # 計算梯度
 loss, grads = jax.value_and_grad(loss_fn)(params)  
       
 # 生成優化器更新
 updates, opt_state = optimizer.update(grads, opt_state)  
       
 # 應用更新
 params = optax.apply_updates(params, updates)

鏈式組合:

 # 把多個操作鏈起來
 optimizer = optax.chain(  
     optax.clip_by_global_norm(1.0),  # 梯度裁剪
     optax.add_decayed_weights(1e-4),  # 權重衰減
     optax.adam(1e-3)  # Adam 優化
 )

Flax 與神經網絡

JAX 本身只是數值計算庫,Flax 在其基礎上提供了神經網絡定義和訓練的高級 API。Flax 代碼風格接近 PyTorch,如果你用過 PyTorch 會很快上手。

Flax 提供了豐富的層和操作。基礎層 包括全連接層

Dense

、卷積

Conv

、嵌入

Embed

、多頭注意力

MultiHeadDotProductAttention

等:

 flax.linen.Dense(features=128)  
 flax.linen.Conv(features=64, kernel_size=(3, 3))  
 flax.linen.Embed(num_embeddings=10000, features=256)  
 flax.linen.MultiHeadDotProductAttention(num_heads=8)  
 flax.linen.SelfAttention(num_heads=8)

歸一化 支持多種方式:

 flax.linen.BatchNorm()  
 flax.linen.LayerNorm()  
 flax.linen.GroupNorm(num_groups=32)  
 flax.linen.RMSNorm()

激活和 Dropout:

 flax.linen.relu(x)  
 flax.linen.gelu(x)  
 flax.linen.sigmoid(x)  
 flax.linen.tanh(x)  
 flax.linen.Dropout(rate=0.1)

池化:

 flax.linen.avg_pool(x, window_shape=(2,2), strides=(2,2))  
 flax.linen.max_pool(x, window_shape=(2,2), strides=(2,2))

循環層:

 flax.linen.LSTMCell()  
 flax.linen.GRUCell()  
 flax.linen.OptimizedLSTMCell()

下面是一個簡單的多層感知機(MLP)例子:

 import jax  
import jax.numpy as jnp  
from flax import linen as nn  
      
class MLP(nn.Module):  
    features: list  
          
    @nn.compact  
    def __call__(self, x):  
        for f in self.features[:-1]:  
            x = nn.Dense(f)(x)
            x = nn.relu(x)
              
        x = nn.Dense(self.features[-1])(x)  
        return x  
      
model = MLP([32, 16, 10])  
key = jax.random.PRNGKey(0)  
      
# 輸入:batch_size=1, 特徵數=4
x = jnp.ones((1, 4))  
      
# 初始化參數
params = model.init(key, x)  
      
# 前向傳播
y = model.apply(params, x)  
      
print("Input:", x)  
# Input: [[1. 1. 1. 1.]]  
      
print("Input shape:", x.shape)  
# Input shape: (1, 4)   
      
print("Output:", y)  
# Output: [[ 0.51415515  0.36979797  0.6212194  -0.74496573 -0.8318489   0.6590691 0.89224255  0.00737424  0.33062232  0.34577468]]  
      
print("Output shape:", y.shape)  
 # Output shape: (1, 10)

Flax 用

@nn.compact

裝飾器,讓你在

__call__

方法裏直接定義層。參數是獨立於模型對象存儲的,需要通過

init

方法顯式初始化,然後在

apply

方法中使用。

總結

JAX 的出現解決了一個長期存在的問題:如何讓 Python 科學計算既保持靈活性,又能獲得接近 C/CUDA 的性能。

不過 JAX 的學習曲線確實比 PyTorch 陡。純函數的約束、不可變數組的特性、顯式密鑰管理等細節起初會有些彆扭。但一旦習慣會發現它帶來的優雅和靈活性。

https://avoid.overfit.cn/post/a16194fdc3ea450f858515d7cb3d49c4

作者:Ashish Bamania

user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.