博客 / 詳情

返回

JAX 訓練加速指南:8 個讓 TPU 滿跑的工程實戰習慣

TPU 訓練的真實效率往往取決於兩個核心要素:Shape 的穩定性算子的融合度

很多時候,JAX 任務之所以出現嚴重的性能瓶頸,並非算法本身設計有問題,而是忽視了 XLA 編譯器與底層硬件對“確定性”的極度偏好。基於大量實戰調優經驗,本文總結了八條能讓 JAX 訓練任務從“甚至跑不通”蜕變為“跑滿 TPU 算力”的工程經驗。

1、儘早鎖定 Shape

TPU 喜歡靜態 Shape,JAX 也是,所以動態 Shape 是性能殺手,它會觸發重新編譯(Recompile)。一旦發生重編譯,Step time 和內存佔用都會直接炸裂。所以解決方法也很簡單,選定幾個規範的尺寸,剩下的全填(Pad)滿。

全局 Batch Size 要能被 TPU 核心數整除,然後就是對於變長序列,別指望它原本多長就多長,把它 Pad 到幾個固定的“桶(Bucket)”裏,比如 128、256 或 512,這步工作最好在輸入(Input Pipeline)裏就做完。

Python層面的條件判斷儘量別依賴 Shape,真要分支邏輯,就老老實實讓

lax.cond

lax.switch

來接管。

     # Example: bucketing & padding (conceptual)  
    def pad_to_length(arr, L):  
        pad = L - arr.shape[0]  
        return jnp.pad(arr, ((0, pad), (0, 0)), mode='constant')  
      
    bucket_sizes = [128, 256, 512]  
    def bucket_len(n):   
        return next(b for b in bucket_sizes if n <= b)  
      
    def preprocess_batch(batch):  
        L = bucket_len(batch["tokens"].shape[1])  
        batch["tokens"] = pad_to_length(batch["tokens"], L)  
        batch["mask"]   = pad_to_length(batch["mask"], L)  
         return batch

每個 Step 餵給 TPU 的 Shape 只要是固定的,XLA 編譯器就不會找麻煩。

2、激活值默認用 bfloat16,主權重要 FP32

在 TPU 上

bfloat16

(bf16) 是個好東西,兼顧了速度、內存和數值穩定性。

工程上的常規操作是:激活(Activations)和梯度(Gradients)存成 bf16。但是,優化器狀態裏的權重必須保留一份 FP32 的“主副本”,不然跑久了數值就會漂移。所欲需要在模型邊界做類型轉換(Cast)的時候小心點。

     class MLP(nn.Module):  
        features: int  
        @nn.compact  
        def __call__(self, x):  
            x = x.astype(jnp.bfloat16)     # fast path on TPUs  
            x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)  
            x = nn.gelu(x)  
            x = nn.Dense(self.features, dtype=jnp.bfloat16)(x)  
            return x  
      
    # Optimizer state stays in FP32 (conceptual)  
    params_fp32 = params.astype(jnp.float32)  
    grads_bf16  = compute_grads_bf16(...)  
     updates_fp32 = opt.update(grads_bf16.astype(jnp.float32), opt_state, params_fp32)

3、pjit和命名網格:切分要明確,別靠猜

JAX 在 TPU 上最強的一點就是通過

pjit

實現了 GSPMD。你通過 PartitionSpecs 告訴它想要什麼切分方式,XLA 負責搞定如何在設備間搬運數據。

在 TPU 核心上建個命名網格(Mesh)。做數據並行(Data Parallelism)時,用

PartitionSpec('data', None)

這種模式。如果模型太大需要張量並行(Tensor Model Parallelism),就加個

'model'

軸。

     import numpy as np  
    import jax  
    import jax.numpy as jnp  
    from jax.sharding import Mesh, PartitionSpec as P  
    from jax.experimental import pjit  
      
    devices = np.array(jax.devices()).reshape(1, -1)  # 1 x N mesh  
    mesh = Mesh(devices, ('data',))  
      
    def loss_fn(params, batch):  
        logits = model_apply(params, batch['x'])  
        return cross_entropy(logits, batch['y'])  
      
    @pjit.pjit(  
        in_shardings=(P(None), P('data')),   # params replicated, batch sharded on 'data'  
        out_shardings=P(None),               # scalar loss replicated  
    )  
    def step(params, batch):  
        grads = jax.grad(loss_fn)(params, batch)  
        # aggregate grads across cores  
        grads = jax.tree.map(lambda g: jax.lax.pmean(g, axis_name='data'), grads)  
        return grads  
      
    with mesh:  
         grads = step(params, sharded_batch)

切分(Sharding)這事必須顯式。如果偷懶依賴自動推導,等到後期 debug 那些悄無聲息的跨設備數據傳輸時,絕對會很痛苦。

4、jit, vmap, scan 三件套

TPU 喜歡大塊頭的 Kernel,討厭成千上萬個細碎的小算子。訓練 Step 和任何中大型計算邏輯,必須用

jit

包起來。遇到 Python 循環,如果是時間步邏輯就換成

lax.scan

,如果是批次並行就用

vmap

把 Loss 計算、梯度計算和參數更新塞進同一個 jitted 函數裏,這樣編譯器才有機會把它們融合成一個大算子。

     import optax  
    import jax  
      
    optimizer = optax.adamw(3e-4)  
      
    def loss_and_grads(params, batch):  
        def _loss(p):  
            logits = model_apply(p, batch['x'])  
            return cross_entropy(logits, batch['y'])  
        loss, grads = jax.value_and_grad(_loss)(params)  
        return loss, grads  
      
    @jax.jit  
    def train_step(state, batch):  
        loss, grads = loss_and_grads(state.params, batch)  
        grads = jax.lax.pmean(grads, axis_name='data')  
        updates, new_opt_state = optimizer.update(grads, state.opt_state, state.params)  
        new_params = optax.apply_updates(state.params, updates)  
         return state.replace(params=new_params, opt_state=new_opt_state), loss

5、別讓輸入管道拖後腿

Host 到 Device 的數據傳輸一旦停頓,吞吐量就掉下來了,所以永遠別讓計算單元等數據。

tf.data

或者高效的 NumPy loader 配合 prefetch。數據預取到設備(Stage to device) 最好做雙重緩衝。全局 Batch 儘量大(當然要能被核心數整除),數據增強這種髒活累活在 Host 上一次性做完。

     # tf.data pipeline (conceptual)  
    ds = (tf.data.TFRecordDataset(files)  
          .map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)  
          .batch(global_batch_size, drop_remainder=True)  
          .prefetch(tf.data.AUTOTUNE))  
      
    # Convert to NumPy and prefetch onto devices  
    from flax.jax_utils import prefetch_to_device  
    it = prefetch_to_device(map(npify, ds.as_numpy_iterator()), size=2)  
      
    with mesh:  
        for step_i in range(num_steps):  
            batch = next(it)     # already sharded/prefetched  
             state, loss = train_step(state, batch)

6、PRNG要Fold 進 Step 和 Device ID

JAX 的 PRNG 是無狀態的,這意味如果不小心,很容易在不同 Step 或者不同設備上用了一樣的隨機數 Key。

每個 Step 都要 Split 一次絕對別複用。所以説為了保證獨立性必須把 Global StepDevice IndexFold 進去。數據增強/Dropout 的 Key 和參數初始化的 Key 得分開管理。

     def make_step_rng(rng, step):  
        step_key = jax.random.fold_in(rng, step)  
        dev_key  = jax.random.fold_in(step_key, jax.lax.axis_index('data'))  
        return jax.random.split(dev_key, 1)[0]  
      
    @jax.jit  
    def train_step(state, batch, base_rng):  
        rng = make_step_rng(base_rng, state.step)  
        logits = model_apply(state.params, batch['x'], rngs={'dropout': rng})  
         ...

7、Remat,智能 Checkpoint,梯度累積

TPU 內存看着大,模型一跑起來就不夠用。深層網絡可以直接用 Activation Checkpointing(

jax.checkpoint

nn.remat

),用計算換顯存。想跑大 Batch 但顯存不夠,就用梯度累積(Gradient Accumulation) 把它切成小的 micro-step。

存盤的時候,推薦用 Orbax 做異步、分片(Sharded)的 Checkpoint,穩。

     from flax import linen as nn  
      
    class DeepBlock(nn.Module):  
        @nn.compact  
        def __call__(self, x):  
            # recompute on backward to trim activation memory  
            f = nn.remat(lambda y: nn.gelu(nn.Dense(x.shape[-1])(y)))  
            return f(x)  
      
    # Gradient accumulation (conceptual)  
    @jax.jit  
    def accum_step(state, batch_slices):  
        def body(carry, micro):  
            state, grad_sum = carry  
            _, grads = loss_and_grads(state.params, micro)  
            return (state, jax.tree_util.tree_map(jnp.add, grad_sum, grads)), None  
        init_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params)  
        (state, grad_sum), _ = jax.lax.scan(body, (state, init_grads), batch_slices)  
        grads = jax.tree_map(lambda g: g / len(batch_slices), grad_sum)  
         ...

8、一定要跑 Profiler

把關鍵代碼段用 Profiler Annotations 包起來,看 Step Timeline。重點找 Host Waits、Recompiles 和那些沒融合好的細碎算子(Small op soup)。

穩態運行的時候,盯着 Tokens/sec 或者Images/sec,還有硬件利用率。

     from jax.experimental import host_callback as hcb  
    from jax import profiler  
      
    def tagged(name, fn, *a, **k):  
        profiler.annotate_function(name=name)  
        return fn(*a, **k)  
      
    @jax.jit  
    def train_step(state, batch):  
        profiler.annotate_function(name="train_step")  
        # do work...  
         return state, loss

一定要在鎖定 Shape 並且 JIT 完熱點路徑之後再做 Profile,不然全是噪音,根本看不到真正的瓶頸。

極簡 TPU 訓練示例

這基本包含了上面所有的內容

     # Pseudo-skeleton (Flax + JAX + TPU)  
    mesh = Mesh(np.array(jax.devices()).reshape(1, -1), ('data',))  
      
    @pjit.pjit(in_shardings=(P(None), P('data'), P(None)), out_shardings=(P(None), P(None)))  
    def train_step(state, batch, base_rng):  
        rng = jax.random.fold_in(base_rng, state.step)  
        rng = jax.random.fold_in(rng, jax.lax.axis_index('data'))  
        def loss_fn(p):  
            logits = model_apply(p, batch['x'].astype(jnp.bfloat16),  
                                 rngs={'dropout': rng})  
            return cross_entropy(logits, batch['y'])  
        loss, grads = jax.value_and_grad(loss_fn)(state.params)  
        grads = jax.tree_map(lambda g: jax.lax.pmean(g, 'data'), grads)  
        updates, opt_state = optimizer.update(grads, state.opt_state, state.params)  
        params = optax.apply_updates(state.params, updates)  
        return state.replace(params=params, opt_state=opt_state, step=state.step+1), loss  
      
    with mesh:  
        for step_i, batch in enumerate(prefetched_iterator):  
            state, loss = train_step(state, batch, base_rng)  
            if step_i % log_every == 0:  
                # Pull back just tiny scalars; keep big tensors on device  
                host_loss = jax.device_get(loss)  
                 print(f"[{step_i}] loss={host_loss:.4f}")

總結

TPU 需要的是 一致性:穩定的 Shape,融合的 Kernel,目的明確的切分,不掉鏈子的數據管道,把上面的這八件事做好,寫 JAX 訓練循環就非常順暢了。

https://avoid.overfit.cn/post/16b582a493ba4eca8333314859665dd2

作者:Modexa

user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.