很多人剛接觸JAX都會有點懵——參數為啥要單獨傳?隨機數還要自己管key?這跟PyTorch的畫風完全不一樣啊。
其實根本原因就一個:JAX是函數式編程而不是面向對象那套,想明白這點很多設計就都説得通了。
先説個核心區別
PyTorch裏,模型是個對象,權重藏在裏面,訓練的時候自己更新自己。這是典型的面向對象思路,狀態封裝在對象內部。
JAX的思路完全反過來。模型定義是模型定義,參數是參數,兩邊分得清清楚楚。函數本身不持有任何狀態,每次調用都把參數從外面傳進去。
這麼做的好處?JAX可以把你的函數當純數學表達式來處理。求導、編譯、並行,想怎麼折騰都行,因為函數裏沒有藏着掖着的東西,行為完全可預測。
代碼對比一下就明白了
PyTorch這麼寫:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = Model()
x = torch.randn(5, 10)
output = model(x)
權重在self.linear裏,模型自己管自己。
JAX配Flax是這樣:
import jax
import jax.numpy as jnp
from flax import linen as nn
class Model(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(1)(x)
model = Model()
key = jax.random.PRNGKey(0)
dummy = jnp.ones((1, 10))
params = model.init(key, dummy)['params']
x = jnp.ones((5, 10))
output = model.apply({'params': params}, x)
參數要先init出來,用的時候再apply進去。麻煩是麻煩了點,但參數流向一目瞭然,想做什麼騷操作都很方便。
隨機數那個key是怎麼回事
這個確實是JAX最讓新手頭疼的地方。不能直接random.normal()完事,非得帶個key:
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (3,))
原因還是那個——函數式編程不允許隱藏狀態。
普通框架的隨機數生成器內部維護一個種子狀態,每次調用偷偷改一下。JAX不幹這事。你得顯式給它一個key,它用完就扔,下次想生成隨機數再給個新的。
好處是隨機性完全可控可復現。jit編譯、多卡訓練、梯度計算,不管代碼怎麼變換,只要key一樣結果就一樣。調試的時候不會遇到那種"明明代碼沒改怎麼結果不一樣了"的玄學問題。
key不能複用,用之前要split
還有個規矩:同一個key只能用一次。要生成多個隨機數,得先split:
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
a = jax.random.normal(subkey)
key, subkey = jax.random.split(key)
b = jax.random.uniform(subkey)
每次split出來的subkey都是獨立的隨機源。這套機制在分佈式場景下特別香,不同機器拿不同的key,隨機性既獨立又可追溯。
合在一起看個完整例子
def forward(params, x):
w, b = params
return w * x + b
def init_params(key):
key_w, key_b = jax.random.split(key)
w = jax.random.normal(key_w)
b = jax.random.normal(key_b)
return w, b
key = jax.random.PRNGKey(0)
params = init_params(key)
x = jnp.array(2.0)
output = forward(params, x)
forward是純函數,輸入決定輸出,沒有副作用。隨機性在init_params裏一次性處理完。參數獨立存放,想存哪存哪。
這種代碼JAX處理起來特別順手——jit編譯、自動微分、vmap批處理、多卡並行,都是開箱即用。
什麼場景下JAX更合適
説實話JAX學習曲線是陡了點。但有些場景下它的優勢很明顯:做研究需要魔改模型結構的時候;物理仿真對數值精度和可復現性要求高的時候;大規模分佈式訓練不想被隱藏狀態坑的時候;想自己擼optimizer或者自定義layer的時候。
適應了這套顯式風格之後其實挺舒服的。參數在哪、隨機數哪來的、函數幹了啥,全都擺在明面上。沒有黑魔法,debug的時候心裏有底。
https://avoid.overfit.cn/post/52fcdfd1d8054dcbb31783ed0547850e
作者:Ali Nawaz