本文詳細介紹瞭如何利用JAX及其神經網絡庫Haiku,從零開始構建並訓練一個完整的Transformer模型。內容涵蓋自注意力機制、線性層、歸一化層、嵌入層的實現,以及如何結合Optax優化器構建訓練循環,為理解和使用JAX進行深度學習開發提供了實用指南。

使用JAX從零構建Transformer模型全流程解析

在本教程中,我們將探討如何使用JAX開發神經網絡。而Transformer模型無疑是一個絕佳的選擇。隨着JAX日益流行,越來越多的開發團隊開始嘗試並將其納入項目。儘管它尚未達到Tensorflow或PyTorch的成熟度,但它為構建和訓練深度學習模型提供了一些強大的特性。

為了紮實理解JAX的基礎知識,建議先閲讀我之前的相關文章。完整代碼可在我們的GitHub倉庫中找到。

許多人在開始使用JAX時面臨的常見問題是框架的選擇。某機構的團隊似乎非常忙碌,已經在JAX之上發佈了大量框架。以下是最著名的一些框架列表:

  • Haiku:Haiku是用於深度學習的首選框架,被許多某中心和某機構的內部團隊使用。它為機器學習研究提供了簡單、可組合的抽象,以及現成的模塊和層。
  • Optax:Optax是一個梯度處理和優化庫,包含開箱即用的優化器和相關數學運算。
  • RLax:RLax是一個強化學習框架,包含許多RL子組件和操作。
  • Chex:Chex是一個用於測試和調試JAX代碼的實用程序庫。
  • Jraph:Jraph是JAX中的圖神經網絡庫。
  • Flax:Flax是另一個神經網絡庫,提供各種現成的模塊、優化器和實用程序。它很可能最接近我們理想中的一體化JAX框架。
  • Objax:Objax是第三個ML庫,專注於面向對象編程和代碼可讀性。同樣,它包含了最流行的模塊、激活函數、損失函數、優化器以及一些預訓練模型。
  • Trax:Trax是一個端到端的深度學習庫,專注於Transformer模型。
  • JAXline:JAXline是一個監督學習庫,用於分佈式JAX訓練和評估。
  • ACME:ACME是另一個強化學習研究框架。
  • JAX-MD:JAX-MD是一個處理分子動力學的專業框架。
  • Jaxchem:JAXChem是另一個強調化學建模的專業庫。

當然,問題是我該選擇哪一個?
老實説,我也不確定。
但如果我是你,並且想學習JAX,我會從最流行的開始。Haiku和Flax似乎在某中心和某機構內部被大量使用,並且擁有最活躍的GitHub社區。在本文中,我將從第一個開始,看看後續是否需要其他框架。

那麼,你準備好用JAX和Haiku構建一個Transformer了嗎?順便説一下,我假設你對Transformer有紮實的理解。如果沒有,請參考我們關於注意力和Transformer的文章。

讓我們從自注意力塊開始。

自注意力塊

首先,我們需要導入JAX和Haiku:

import jax
import jax.numpy as jnp
import haiku as hk
Import numpy as np

幸運的是,Haiku有一個內置的MultiHeadAttention塊,可以擴展以構建掩碼自注意力塊。我們的塊接收查詢、鍵、值以及掩碼,並返回一個JAX數組作為輸出。你可以看到代碼與標準的PyTorch或Tensorflow代碼非常相似。我們所做的就是使用np.trill()(該函數將數組中對角線第k個元素以上的所有元素置零)構建因果掩碼,與我們的掩碼相乘,然後將所有內容傳遞給hk.MultiHeadAttention模塊。

class SelfAttention(hk.MultiHeadAttention):
    """應用了因果掩碼的自注意力。"""
    def __call__(
            self,
            query: jnp.ndarray,
            key: Optional[jnp.ndarray] = None,
            value: Optional[jnp.ndarray] = None,
            mask: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        key = key if key is not None else query
        value = value if value is not None else query

        seq_len = query.shape[1]
        causal_mask = np.tril(np.ones((seq_len, seq_len)))
        mask = mask * causal_mask if mask is not None else causal_mask

        return super().__call__(query, key, value, mask)

這段代碼片段允許我介紹Haiku的第一個關鍵原則。所有模塊都應該是hk.Module的子類。這意味着它們應該實現__init____call__方法,以及其他任何方法。從某種意義上説,它與PyTorch模塊的架構相同,我們在那裏實現__init__和一個forward方法。

為了更清楚地説明這一點,讓我們構建一個簡單的2層多層感知機作為hk.Module,它將在下面的Transformer中方便地使用。

線性層

一個簡單的2層MLP看起來像這樣。再一次,你可以注意到它看起來多麼熟悉。

class DenseBlock(hk.Module):
    """一個2層的MLP"""
    def __init__(self,
                 init_scale: float,
                 widening_factor: int = 4,
                 name: Optional[str] = None):
        super().__init__(name=name)
        self._init_scale = init_scale
        self._widening_factor = widening_factor

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        hiddens = x.shape[-1]
        initializer = hk.initializers.VarianceScaling(self._init_scale)
        x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
        x = jax.nn.gelu(x)
        return hk.Linear(hiddens, w_init=initializer)(x)

這裏需要注意幾點:

  • Haiku在hk.initializers下為我們提供了一組權重初始化器,我們可以在這裏找到最常見的方法。
  • 它還有內置的許多流行層和模塊,例如hk.Linear。完整列表,請查看官方文檔。
  • 不提供激活函數,因為JAX已經有一個名為jax.nn的子包,我們可以在那裏找到relusoftmax等激活函數。

歸一化層

層歸一化是Transformer架構的另一個組成部分,我們也可以在Haiku的公共模塊中找到。

def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
    """使用默認設置對x應用唯一的LayerNorm。"""
    return hk.LayerNorm(axis=-1,
                        create_scale=True,
                        create_offset=True,
                        name=name)(x)

Transformer

現在是重點。下面你可以看到一個非常簡化的Transformer,它使用了我們預定義的模塊。在__init__中,我們定義了基本變量,如層數、注意力頭數和dropout率。在__call__中,我們使用for循環組合了一系列塊。

如你所見,每個塊包括:

  • 一個歸一化層
  • 一個自注意力塊
  • 兩個dropout層
  • 兩個歸一化層
  • 兩個跳躍連接 (h = h + h_attnh = h + h_dense)
  • 一個2層的密集塊

最後,我們還添加了最終的歸一化層。

class Transformer(hk.Module):
    """一個Transformer堆棧。"""
    def __init__(self,
                 num_heads: int,
                 num_layers: int,
                 dropout_rate: float,
                 name: Optional[str] = None):
        super().__init__(name=name)
        self._num_layers = num_layers
        self._num_heads = num_heads
        self._dropout_rate = dropout_rate

    def __call__(self,
                 h: jnp.ndarray,
                 mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """連接transformer。
        Args:
          h: 輸入, [B, T, H].
          mask: 填充掩碼, [B, T].
          is_training: 是否處於訓練模式。
        Returns:
          形狀為[B, T, H]的數組。
        """

        init_scale = 2. / self._num_layers
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = SelfAttention(
                num_heads=self._num_heads,
                key_size=64,
                w_init_scale=init_scale,
                name=f'h{i}_attn')(h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='ln_f')

        return h

我想現在你已經意識到,用JAX構建神經網絡非常簡單。

嵌入層

為了完整起見,我們也加入嵌入層。需要知道的是,Haiku也提供了一個嵌入層,它將從我們的輸入句子中創建標記。然後將標記添加到位置嵌入中,產生最終的輸入。

def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :
    tokens = data['obs']
    input_mask = jnp.greater(tokens, 0)
    seq_length = tokens.shape[1]

    # 嵌入輸入標記和位置。
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
    token_embs = token_embedding_map(tokens)
    positional_embeddings = hk.get_parameter(
        'pos_embs', [seq_length, d_model], init=embed_init)
    input_embeddings = token_embs + positional_embeddings
    return input_embeddings, input_mask

hk.get_parameter(param_name, ...)用於訪問模塊的可訓練參數。但你可能會問,為什麼不直接使用像在PyTorch中那樣的對象屬性。這就是Haiku的第二個關鍵原則發揮作用的地方。我們使用這個API,以便可以使用hk.transform將代碼轉換為純函數。這理解起來並不簡單,但我會盡量讓它清晰明瞭。

為什麼需要純函數?

JAX的強大之處在於其函數變換:用vmap向量化函數的能力,用pmap自動並行化,用jit即時編譯。這裏需要注意的是,為了變換一個函數,它必須是純的。

純函數是具有以下屬性的函數:

  • 對於相同的參數,函數返回值是相同的(不隨局部靜態變量、非局部變量、可變引用參數或輸入流的變化而變化)。
  • 函數應用沒有副作用(不改變局部靜態變量、非局部變量、可變引用參數或輸入/輸出流)。

來源:O'Reily的Scala純函數

這實際上意味着一個純函數總是會:

  • 如果使用相同的輸入調用,則返回相同的結果
  • 所有輸入數據都通過函數參數傳遞,所有結果都通過函數結果輸出

Haiku提供了一個名為hk.transform的函數轉換,它將具有面向對象、功能上“不純”的模塊的函數轉換為可以與JAX一起使用的純函數。為了在實踐中看到這一點,讓我們繼續訓練我們的Transformer模型。

前向傳播

一個典型的前向傳播包括:

  • 獲取輸入並計算輸入嵌入
  • 通過Transformer的塊運行
  • 返回輸出

上述步驟可以很容易地用JAX組合如下:

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
                     num_layers: int, dropout_rate: float):
    """創建模型的前向傳播。"""

    def forward_fn(data: Mapping[str, jnp.ndarray],
                   is_training: bool = True) -> jnp.ndarray:
        """前向傳播。"""
        input_embeddings, input_mask = embeddings(data, vocab_size)

        # 在輸入上運行transformer。
        transformer = Transformer(
            num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
        output_embeddings = transformer(input_embeddings, input_mask, is_training)

        # 反向嵌入(未綁定)。
        return hk.Linear(vocab_size)(output_embeddings)

    return forward_fn

雖然代碼很簡單,但其結構可能看起來有點奇怪。實際的前向傳播是通過forward_fn函數執行的。然而,我們用build_forward_fn函數包裝了這個函數,並返回forward_fn。這是怎麼回事?

接下來,我們將需要使用hk.transformforward_fn函數轉換為純函數,以便我們可以利用自動微分、並行化等。

這將通過以下方式完成:

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
                                  num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)

這就是為什麼我們不是簡單地定義一個函數,而是包裝並返回函數本身,或者更準確地説,是一個可調用對象。然後可以將這個可調用對象傳遞給hk.transform併成為一個純函數。如果清楚了這一點,讓我們繼續看損失函數。

損失函數

損失函數是我們熟知的交叉熵函數,不同之處在於我們也考慮了掩碼。同樣,JAX提供了one_hotlog_softmax功能。

def lm_loss_fn(forward_fn,
               vocab_size: int,
               params,
               rng,
               data: Mapping[str, jnp.ndarray],
               is_training: bool = True) -> jnp.ndarray:
    """計算數據相對於參數的損失。"""
    logits = forward_fn(params, rng, data, is_training)
    targets = jax.nn.one_hot(data['target'], vocab_size)
    assert logits.shape == targets.shape

    mask = jnp.greater(data['obs'], 0)
    loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
    loss = jnp.sum(loss * mask) / jnp.sum(mask)

    return loss

如果你還在堅持,喝一口咖啡,因為從現在開始事情會變得嚴肅起來。是時候構建我們的訓練循環了。

訓練循環

因為Jax和Haiku都沒有內置優化功能,所以我們將使用另一個名為Optax的框架。如開頭所述,Optax是用於梯度處理的包。

首先,關於Optax需要了解的一些事項:

  • Optax的關鍵變換是GradientTransformation。該變換由兩個函數定義,__init____update____init__初始化狀態,__update__根據狀態和參數的當前值轉換梯度:
state = init(params)
grads, state = update(grads, state, params=None)

在看代碼之前,還需要了解Python內置的functools.partial函數。functools包處理高階函數和可調用對象的操作。

如果一個函數包含其他函數作為參數或返回一個函數作為輸出,則稱為高階函數。

partial(也可以用作註解)返回一個基於原始函數的新函數,但具有更少或固定的參數。例如,如果f將兩個值x,y相乘,則partial將創建一個新函數,其中x將被固定為2:

from functools import partial
def f(x,y):
        return x * y

# 創建一個乘以2的新函數(x將被固定為2)
g = partial(f,2)

print(g(4))
#返回 8

在這個簡短插曲之後,讓我們繼續。為了簡化我們的主函數,我們將把梯度更新提取到它自己的類中。

首先,GradientUpdater接受模型、損失函數和優化器。

  • 模型將是通過hk.transform轉換的純forward_fn函數
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
                              num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
  • 損失函數將是具有固定forward_fnvocab_sizepartial的結果
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
  • 優化器是一系列將按順序運行的優化變換(可以使用optax.chain組合操作)
optimizer = optax.chain(
        optax.clip_by_global_norm(grad_clip_value),
        optax.adam(learning_rate, b1=0.9, b2=0.99))

梯度更新器將初始化如下:

updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

並將如下所示:

class GradientUpdater:
    """圍繞 init_fn/update_fn 對的無狀態抽象。
    這從訓練循環中提取了一些常見的樣板代碼。
    """

    def __init__(self, net_init, loss_fn,
                 optimizer: optax.GradientTransformation):
        self._net_init = net_init
        self._loss_fn = loss_fn
        self._opt = optimizer

    @functools.partial(jax.jit, static_argnums=0)
    def init(self, master_rng, data):
        """初始化更新器的狀態。"""
        out_rng, init_rng = jax.random.split(master_rng)
        params = self._net_init(init_rng, data)
        opt_state = self._opt.init(params)
        out = dict(
            step=np.array(0),
            rng=out_rng,
            opt_state=opt_state,
            params=params,
        )
        return out

    @functools.partial(jax.jit, static_argnums=0)
    def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
        """使用一些數據更新狀態並返回指標。"""
        rng, new_rng = jax.random.split(state['rng'])
        params = state['params']
        loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)

        updates, opt_state = self._opt.update(g, state['opt_state'])
        params = optax.apply_updates(params, updates)

        new_state = {
            'step': state['step'] + 1,
            'rng': new_rng,
            'opt_state': opt_state,
            'params': params,
        }

        metrics = {
            'step': state['step'],
            'loss': loss,
        }
        return new_state, metrics

__init__中,我們使用self._opt.init(params)初始化優化器,並聲明優化狀態。狀態將是一個包含以下內容的字典:

  • 當前步驟
  • 優化器狀態
  • 可訓練參數
  • (一個隨機生成器密鑰,用於傳遞給jax.random.split)

update函數將更新優化器的狀態以及可訓練參數。最後,它將返回新狀態。

updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)

這裏還有兩件事需要注意:

  • jax.value_and_grad()是一個特殊的函數,它返回一個帶有梯度的可微函數。
  • __init____update__都用@functools.partial(jax.jit, static_argnums=0)註解,這將觸發即時編譯器並在運行時將其編譯為XLA。請注意,如果我們沒有將forward_fn轉換為純函數,這是不可能的。

最後,我們準備構建整個訓練循環,它結合了迄今為止提到的所有想法和代碼。

def main():
    # 創建數據集。
    train_dataset, vocab_size = load(batch_size,
                                     sequence_length)
    # 設置模型、損失和更新器。
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
                                  num_layers, dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(
        optax.clip_by_global_norm(grad_clip_value),
        optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

    # 初始化參數。
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)

注意我們是如何整合GradientUpdate的。只需要兩行代碼:

state = updater.init(rng, data)
state, metrics = updater.update(state, data)

就是這樣。我希望現在你對JAX及其功能有了更清晰的理解。