本文詳細介紹瞭如何利用JAX及其神經網絡庫Haiku,從零開始構建並訓練一個完整的Transformer模型。內容涵蓋自注意力機制、線性層、歸一化層、嵌入層的實現,以及如何結合Optax優化器構建訓練循環,為理解和使用JAX進行深度學習開發提供了實用指南。
使用JAX從零構建Transformer模型全流程解析
在本教程中,我們將探討如何使用JAX開發神經網絡。而Transformer模型無疑是一個絕佳的選擇。隨着JAX日益流行,越來越多的開發團隊開始嘗試並將其納入項目。儘管它尚未達到Tensorflow或PyTorch的成熟度,但它為構建和訓練深度學習模型提供了一些強大的特性。
為了紮實理解JAX的基礎知識,建議先閲讀我之前的相關文章。完整代碼可在我們的GitHub倉庫中找到。
許多人在開始使用JAX時面臨的常見問題是框架的選擇。某機構的團隊似乎非常忙碌,已經在JAX之上發佈了大量框架。以下是最著名的一些框架列表:
- Haiku:Haiku是用於深度學習的首選框架,被許多某中心和某機構的內部團隊使用。它為機器學習研究提供了簡單、可組合的抽象,以及現成的模塊和層。
- Optax:Optax是一個梯度處理和優化庫,包含開箱即用的優化器和相關數學運算。
- RLax:RLax是一個強化學習框架,包含許多RL子組件和操作。
- Chex:Chex是一個用於測試和調試JAX代碼的實用程序庫。
- Jraph:Jraph是JAX中的圖神經網絡庫。
- Flax:Flax是另一個神經網絡庫,提供各種現成的模塊、優化器和實用程序。它很可能最接近我們理想中的一體化JAX框架。
- Objax:Objax是第三個ML庫,專注於面向對象編程和代碼可讀性。同樣,它包含了最流行的模塊、激活函數、損失函數、優化器以及一些預訓練模型。
- Trax:Trax是一個端到端的深度學習庫,專注於Transformer模型。
- JAXline:JAXline是一個監督學習庫,用於分佈式JAX訓練和評估。
- ACME:ACME是另一個強化學習研究框架。
- JAX-MD:JAX-MD是一個處理分子動力學的專業框架。
- Jaxchem:JAXChem是另一個強調化學建模的專業庫。
當然,問題是我該選擇哪一個?
老實説,我也不確定。
但如果我是你,並且想學習JAX,我會從最流行的開始。Haiku和Flax似乎在某中心和某機構內部被大量使用,並且擁有最活躍的GitHub社區。在本文中,我將從第一個開始,看看後續是否需要其他框架。
那麼,你準備好用JAX和Haiku構建一個Transformer了嗎?順便説一下,我假設你對Transformer有紮實的理解。如果沒有,請參考我們關於注意力和Transformer的文章。
讓我們從自注意力塊開始。
自注意力塊
首先,我們需要導入JAX和Haiku:
import jax
import jax.numpy as jnp
import haiku as hk
Import numpy as np
幸運的是,Haiku有一個內置的MultiHeadAttention塊,可以擴展以構建掩碼自注意力塊。我們的塊接收查詢、鍵、值以及掩碼,並返回一個JAX數組作為輸出。你可以看到代碼與標準的PyTorch或Tensorflow代碼非常相似。我們所做的就是使用np.trill()(該函數將數組中對角線第k個元素以上的所有元素置零)構建因果掩碼,與我們的掩碼相乘,然後將所有內容傳遞給hk.MultiHeadAttention模塊。
class SelfAttention(hk.MultiHeadAttention):
"""應用了因果掩碼的自注意力。"""
def __call__(
self,
query: jnp.ndarray,
key: Optional[jnp.ndarray] = None,
value: Optional[jnp.ndarray] = None,
mask: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
key = key if key is not None else query
value = value if value is not None else query
seq_len = query.shape[1]
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(query, key, value, mask)
這段代碼片段允許我介紹Haiku的第一個關鍵原則。所有模塊都應該是hk.Module的子類。這意味着它們應該實現__init__和__call__方法,以及其他任何方法。從某種意義上説,它與PyTorch模塊的架構相同,我們在那裏實現__init__和一個forward方法。
為了更清楚地説明這一點,讓我們構建一個簡單的2層多層感知機作為hk.Module,它將在下面的Transformer中方便地使用。
線性層
一個簡單的2層MLP看起來像這樣。再一次,你可以注意到它看起來多麼熟悉。
class DenseBlock(hk.Module):
"""一個2層的MLP"""
def __init__(self,
init_scale: float,
widening_factor: int = 4,
name: Optional[str] = None):
super().__init__(name=name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
hiddens = x.shape[-1]
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
x = jax.nn.gelu(x)
return hk.Linear(hiddens, w_init=initializer)(x)
這裏需要注意幾點:
- Haiku在
hk.initializers下為我們提供了一組權重初始化器,我們可以在這裏找到最常見的方法。 - 它還有內置的許多流行層和模塊,例如
hk.Linear。完整列表,請查看官方文檔。 - 不提供激活函數,因為JAX已經有一個名為
jax.nn的子包,我們可以在那裏找到relu或softmax等激活函數。
歸一化層
層歸一化是Transformer架構的另一個組成部分,我們也可以在Haiku的公共模塊中找到。
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:
"""使用默認設置對x應用唯一的LayerNorm。"""
return hk.LayerNorm(axis=-1,
create_scale=True,
create_offset=True,
name=name)(x)
Transformer
現在是重點。下面你可以看到一個非常簡化的Transformer,它使用了我們預定義的模塊。在__init__中,我們定義了基本變量,如層數、注意力頭數和dropout率。在__call__中,我們使用for循環組合了一系列塊。
如你所見,每個塊包括:
- 一個歸一化層
- 一個自注意力塊
- 兩個dropout層
- 兩個歸一化層
- 兩個跳躍連接 (
h = h + h_attn和h = h + h_dense) - 一個2層的密集塊
最後,我們還添加了最終的歸一化層。
class Transformer(hk.Module):
"""一個Transformer堆棧。"""
def __init__(self,
num_heads: int,
num_layers: int,
dropout_rate: float,
name: Optional[str] = None):
super().__init__(name=name)
self._num_layers = num_layers
self._num_heads = num_heads
self._dropout_rate = dropout_rate
def __call__(self,
h: jnp.ndarray,
mask: Optional[jnp.ndarray],
is_training: bool) -> jnp.ndarray:
"""連接transformer。
Args:
h: 輸入, [B, T, H].
mask: 填充掩碼, [B, T].
is_training: 是否處於訓練模式。
Returns:
形狀為[B, T, H]的數組。
"""
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.
if mask is not None:
mask = mask[:, None, None, :]
for i in range(self._num_layers):
h_norm = layer_norm(h, name=f'h{i}_ln_1')
h_attn = SelfAttention(
num_heads=self._num_heads,
key_size=64,
w_init_scale=init_scale,
name=f'h{i}_attn')(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
h_norm = layer_norm(h, name=f'h{i}_ln_2')
h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = layer_norm(h, name='ln_f')
return h
我想現在你已經意識到,用JAX構建神經網絡非常簡單。
嵌入層
為了完整起見,我們也加入嵌入層。需要知道的是,Haiku也提供了一個嵌入層,它將從我們的輸入句子中創建標記。然後將標記添加到位置嵌入中,產生最終的輸入。
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :
tokens = data['obs']
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape[1]
# 嵌入輸入標記和位置。
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
'pos_embs', [seq_length, d_model], init=embed_init)
input_embeddings = token_embs + positional_embeddings
return input_embeddings, input_mask
hk.get_parameter(param_name, ...)用於訪問模塊的可訓練參數。但你可能會問,為什麼不直接使用像在PyTorch中那樣的對象屬性。這就是Haiku的第二個關鍵原則發揮作用的地方。我們使用這個API,以便可以使用hk.transform將代碼轉換為純函數。這理解起來並不簡單,但我會盡量讓它清晰明瞭。
為什麼需要純函數?
JAX的強大之處在於其函數變換:用vmap向量化函數的能力,用pmap自動並行化,用jit即時編譯。這裏需要注意的是,為了變換一個函數,它必須是純的。
純函數是具有以下屬性的函數:
- 對於相同的參數,函數返回值是相同的(不隨局部靜態變量、非局部變量、可變引用參數或輸入流的變化而變化)。
- 函數應用沒有副作用(不改變局部靜態變量、非局部變量、可變引用參數或輸入/輸出流)。
來源:O'Reily的Scala純函數
這實際上意味着一個純函數總是會:
- 如果使用相同的輸入調用,則返回相同的結果
- 所有輸入數據都通過函數參數傳遞,所有結果都通過函數結果輸出
Haiku提供了一個名為hk.transform的函數轉換,它將具有面向對象、功能上“不純”的模塊的函數轉換為可以與JAX一起使用的純函數。為了在實踐中看到這一點,讓我們繼續訓練我們的Transformer模型。
前向傳播
一個典型的前向傳播包括:
- 獲取輸入並計算輸入嵌入
- 通過Transformer的塊運行
- 返回輸出
上述步驟可以很容易地用JAX組合如下:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
num_layers: int, dropout_rate: float):
"""創建模型的前向傳播。"""
def forward_fn(data: Mapping[str, jnp.ndarray],
is_training: bool = True) -> jnp.ndarray:
"""前向傳播。"""
input_embeddings, input_mask = embeddings(data, vocab_size)
# 在輸入上運行transformer。
transformer = Transformer(
num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
# 反向嵌入(未綁定)。
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
雖然代碼很簡單,但其結構可能看起來有點奇怪。實際的前向傳播是通過forward_fn函數執行的。然而,我們用build_forward_fn函數包裝了這個函數,並返回forward_fn。這是怎麼回事?
接下來,我們將需要使用hk.transform將forward_fn函數轉換為純函數,以便我們可以利用自動微分、並行化等。
這將通過以下方式完成:
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
這就是為什麼我們不是簡單地定義一個函數,而是包裝並返回函數本身,或者更準確地説,是一個可調用對象。然後可以將這個可調用對象傳遞給hk.transform併成為一個純函數。如果清楚了這一點,讓我們繼續看損失函數。
損失函數
損失函數是我們熟知的交叉熵函數,不同之處在於我們也考慮了掩碼。同樣,JAX提供了one_hot和log_softmax功能。
def lm_loss_fn(forward_fn,
vocab_size: int,
params,
rng,
data: Mapping[str, jnp.ndarray],
is_training: bool = True) -> jnp.ndarray:
"""計算數據相對於參數的損失。"""
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data['target'], vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data['obs'], 0)
loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
如果你還在堅持,喝一口咖啡,因為從現在開始事情會變得嚴肅起來。是時候構建我們的訓練循環了。
訓練循環
因為Jax和Haiku都沒有內置優化功能,所以我們將使用另一個名為Optax的框架。如開頭所述,Optax是用於梯度處理的包。
首先,關於Optax需要了解的一些事項:
- Optax的關鍵變換是
GradientTransformation。該變換由兩個函數定義,__init__和__update__。__init__初始化狀態,__update__根據狀態和參數的當前值轉換梯度:
state = init(params)
grads, state = update(grads, state, params=None)
在看代碼之前,還需要了解Python內置的functools.partial函數。functools包處理高階函數和可調用對象的操作。
如果一個函數包含其他函數作為參數或返回一個函數作為輸出,則稱為高階函數。
partial(也可以用作註解)返回一個基於原始函數的新函數,但具有更少或固定的參數。例如,如果f將兩個值x,y相乘,則partial將創建一個新函數,其中x將被固定為2:
from functools import partial
def f(x,y):
return x * y
# 創建一個乘以2的新函數(x將被固定為2)
g = partial(f,2)
print(g(4))
#返回 8
在這個簡短插曲之後,讓我們繼續。為了簡化我們的主函數,我們將把梯度更新提取到它自己的類中。
首先,GradientUpdater接受模型、損失函數和優化器。
- 模型將是通過
hk.transform轉換的純forward_fn函數
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
- 損失函數將是具有固定
forward_fn和vocab_size的partial的結果
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
- 優化器是一系列將按順序運行的優化變換(可以使用
optax.chain組合操作)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
梯度更新器將初始化如下:
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
並將如下所示:
class GradientUpdater:
"""圍繞 init_fn/update_fn 對的無狀態抽象。
這從訓練循環中提取了一些常見的樣板代碼。
"""
def __init__(self, net_init, loss_fn,
optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, master_rng, data):
"""初始化更新器的狀態。"""
out_rng, init_rng = jax.random.split(master_rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(
step=np.array(0),
rng=out_rng,
opt_state=opt_state,
params=params,
)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
"""使用一些數據更新狀態並返回指標。"""
rng, new_rng = jax.random.split(state['rng'])
params = state['params']
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
new_state = {
'step': state['step'] + 1,
'rng': new_rng,
'opt_state': opt_state,
'params': params,
}
metrics = {
'step': state['step'],
'loss': loss,
}
return new_state, metrics
在__init__中,我們使用self._opt.init(params)初始化優化器,並聲明優化狀態。狀態將是一個包含以下內容的字典:
- 當前步驟
- 優化器狀態
- 可訓練參數
- (一個隨機生成器密鑰,用於傳遞給
jax.random.split)
update函數將更新優化器的狀態以及可訓練參數。最後,它將返回新狀態。
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
這裏還有兩件事需要注意:
jax.value_and_grad()是一個特殊的函數,它返回一個帶有梯度的可微函數。__init__和__update__都用@functools.partial(jax.jit, static_argnums=0)註解,這將觸發即時編譯器並在運行時將其編譯為XLA。請注意,如果我們沒有將forward_fn轉換為純函數,這是不可能的。
最後,我們準備構建整個訓練循環,它結合了迄今為止提到的所有想法和代碼。
def main():
# 創建數據集。
train_dataset, vocab_size = load(batch_size,
sequence_length)
# 設置模型、損失和更新器。
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
# 初始化參數。
logging.info('Initializing parameters...')
rng = jax.random.PRNGKey(428)
data = next(train_dataset)
state = updater.init(rng, data)
logging.info('Starting train loop...')
prev_time = time.time()
for step in range(MAX_STEPS):
data = next(train_dataset)
state, metrics = updater.update(state, data)
注意我們是如何整合GradientUpdate的。只需要兩行代碼:
state = updater.init(rng, data)
state, metrics = updater.update(state, data)
就是這樣。我希望現在你對JAX及其功能有了更清晰的理解。