下圖展示了ViT的完整架構:從輸入圖像分割成patches,到Transformer編碼器處理,最後通過分類頭輸出結果。整個流程清晰明瞭,接下來我們一步步來實現。
1 環境搭建和數據準備
1.1 環境配置
首先確保本地裝好了Python和MindSpore。這個教程建議用GPU跑,CPU會慢得讓人懷疑人生。
數據集用的是ImageNet的子集,第一次運行會自動下載:
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
下載完後數據結構是這樣的:
.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
├── infer/
└── val/
1.2 數據預處理
數據預處理這塊比較標準,主要是resize、隨機裁剪、歸一化這些操作:
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [
transforms.RandomCropDecodeResize(size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
這裏的mean和std是ImageNet的標準值,乘以255是因為MindSpore的數據格式。
2 ViT模型原理解析
2.1 Transformer的核心:Self-Attention
要理解ViT,得先搞懂Transformer的核心機制——Self-Attention。簡單來説,就是讓模型學會關注輸入序列中不同位置之間的關係。
Self-Attention的計算過程:
- 輸入向量通過三個不同的線性變換得到Q(Query)、K(Key)、V(Value)
- 計算Q和K的點積,得到注意力權重
- 用這些權重對V進行加權求和
數學公式是這樣的:
然後計算注意力分數:
經過Softmax歸一化後,得到最終輸出:
上圖詳細展示了Self-Attention的計算過程:從輸入序列X通過線性變換得到Q、K、V矩陣,然後計算注意力分數,經過Softmax得到權重,最後加權求和得到輸出。這個機制讓模型能夠動態地關注輸入序列中的不同部分。
2.2 Multi-Head Attention實現
多頭注意力就是把輸入分成多個"頭",每個頭獨立計算注意力,最後拼接起來。這樣能讓模型從不同角度理解輸入:
from mindspore import nn, ops
class Attention(nn.Cell):
def __init__(self,
dim: int,
num_heads: int = 8,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = ms.Tensor(head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3)
self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
self.out = nn.Dense(dim, dim)
self.out_drop = nn.Dropout(p=1.0-keep_prob)
self.attn_matmul_v = ops.BatchMatMul()
self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax(axis=-1)
def construct(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = ops.unstack(qkv, axis=0)
attn = self.q_matmul_k(q, k)
attn = ops.mul(attn, self.scale)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
out = self.attn_matmul_v(attn, v)
out = ops.transpose(out, (0, 2, 1, 3))
out = ops.reshape(out, (b, n, c))
out = self.out(out)
out = self.out_drop(out)
return out
這段代碼的關鍵在於:
qkv = self.qkv(x)一次性生成Q、K、V三個矩陣- reshape和transpose操作把數據重新組織成多頭的形式
- 最後把多個頭的結果拼接回去
2.3 Feed Forward和殘差連接
除了注意力機制,Transformer還需要Feed Forward網絡和殘差連接:
from typing import Optional
class FeedForward(nn.Cell):
def __init__(self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
activation: nn.Cell = nn.GELU,
keep_prob: float = 1.0):
super(FeedForward, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.dense1 = nn.Dense(in_features, hidden_features)
self.activation = activation()
self.dense2 = nn.Dense(hidden_features, out_features)
self.dropout = nn.Dropout(p=1.0-keep_prob)
def construct(self, x):
x = self.dense1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.dense2(x)
x = self.dropout(x)
return x
class ResidualCell(nn.Cell):
def __init__(self, cell):
super(ResidualCell, self).__init__()
self.cell = cell
def construct(self, x):
return self.cell(x) + x
殘差連接很簡單,就是把輸入直接加到輸出上,這樣能避免深層網絡的梯度消失問題。
2.4 TransformerEncoder的完整實現
把注意力機制、Feed Forward和殘差連接組合起來,就是TransformerEncoder:
class TransformerEncoder(nn.Cell):
def __init__(self,
dim: int,
num_layers: int,
num_heads: int,
mlp_dim: int,
keep_prob: float = 1.,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: nn.Cell = nn.LayerNorm):
super(TransformerEncoder, self).__init__()
layers = []
for _ in range(num_layers):
normalization1 = norm((dim,))
normalization2 = norm((dim,))
attention = Attention(dim=dim,
num_heads=num_heads,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob)
feedforward = FeedForward(in_features=dim,
hidden_features=mlp_dim,
activation=activation,
keep_prob=keep_prob)
layers.append(
nn.SequentialCell([
ResidualCell(nn.SequentialCell([normalization1, attention])),
ResidualCell(nn.SequentialCell([normalization2, feedforward]))
])
)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
return self.layers(x)
這裏有個細節:ViT把LayerNorm放在了注意力和Feed Forward之前,這和標準Transformer不太一樣,但實驗證明這樣效果更好。
3 ViT的關鍵創新:圖像轉序列
上圖展示了ViT處理圖像的完整流程:從原始圖像分割成patches,經過embedding轉換,添加位置編碼和CLS token,通過Transformer編碼器處理,最後提取CLS token進行分類預測。
3.1 Patch Embedding
ViT最巧妙的地方就是把圖像轉換成序列。具體做法是把圖像切成一個個小塊(patch),然後把每個patch拉成一維向量:
class PatchEmbedding(nn.Cell):
MIN_NUM_PATCHES = 4
def __init__(self,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 768,
input_channels: int = 3):
super(PatchEmbedding, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.conv = nn.Conv2d(input_channels, embed_dim,
kernel_size=patch_size, stride=patch_size, has_bias=True)
def construct(self, x):
x = self.conv(x)
b, c, h, w = x.shape
x = ops.reshape(x, (b, c, h * w))
x = ops.transpose(x, (0, 2, 1))
return x
這裏用卷積來實現patch切分,比手工切分更高效。對於224×224的圖像,用16×16的patch,最終得到14×14=196個patch。
3.2 位置編碼和分類token
圖像切成patch後,還需要加上位置信息和分類token:
# 在ViT類的__init__中
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
分類token借鑑了BERT的思路,在序列開頭加一個特殊token,最後用這個token的輸出來做分類。位置編碼則告訴模型每個patch在圖像中的位置。
3.3 完整的ViT模型
把所有組件組合起來,就是完整的ViT模型:
from mindspore.common.initializer import Normal, initializer
from mindspore import Parameter
def init(init_type, shape, dtype, name, requires_grad):
initial = initializer(init_type, shape, dtype).init_data()
return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):
def __init__(self,
image_size: int = 224,
input_channels: int = 3,
patch_size: int = 16,
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: Optional[nn.Cell] = nn.LayerNorm,
pool: str = 'cls') -> None:
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim,
input_channels=input_channels)
num_patches = self.patch_embedding.num_patches
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
self.pool = pool
self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
self.norm = norm((embed_dim,))
self.transformer = TransformerEncoder(dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
mlp_dim=mlp_dim,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob,
drop_path_keep_prob=drop_path_keep_prob,
activation=activation,
norm=norm)
self.dropout = nn.Dropout(p=1.0-keep_prob)
self.dense = nn.Dense(embed_dim, num_classes)
def construct(self, x):
x = self.patch_embedding(x)
cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
x = ops.concat((cls_tokens, x), axis=1)
x += self.pos_embedding
x = self.pos_dropout(x)
x = self.transformer(x)
x = self.norm(x)
x = x[:, 0] # 取分類token的輸出
if self.training:
x = self.dropout(x)
x = self.dense(x)
return x
整個流程就是:圖像 → patch embedding → 加上cls token和位置編碼 → Transformer編碼器 → 分類頭。
4 訓練和驗證實戰
4.1 訓練配置
訓練前需要設置損失函數、優化器等:
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# 超參數設置
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# 構建模型
network = ViT()
# 加載預訓練權重
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# 學習率調度
lr = nn.cosine_decay_lr(min_lr=float(0),
max_lr=0.00005,
total_step=epoch_size * step_size,
step_per_epoch=step_size,
decay_epoch=10)
# 優化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
這裏用了預訓練模型,所以學習率設得比較小。餘弦退火調度能讓訓練更穩定。
4.2 損失函數
用了帶標籤平滑的交叉熵損失:
class CrossEntropySmooth(LossBase):
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = ops.OneHot()
self.sparse = sparse
self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
標籤平滑能防止模型過擬合,提高泛化能力。
4.3 開始訓練
# 設置檢查點
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics={"acc"}, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics={"acc"}, amp_level="O0")
# 開始訓練
model.train(epoch_size,
dataset_train,
callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
dataset_sink_mode=False)
訓練過程中會看到這樣的輸出:
epoch: 1 step: 125, loss is 1.903618335723877
Train epoch time: 99857.517 ms, per step time: 798.860 ms
epoch: 2 step: 125, loss is 1.448015570640564
Train epoch time: 95555.111 ms, per step time: 764.441 ms
loss在逐漸下降,説明訓練正常進行。
上圖展示了ViT模型的訓練過程:左側是損失函數的下降趨勢,右側是準確率的提升曲線,下方表格總結了訓練配置和最終結果。可以看到模型在訓練過程中穩定收斂,最終達到了不錯的性能。
4.4 模型驗證
訓練完後驗證一下效果:
# 驗證數據預處理
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [
transforms.Decode(),
transforms.Resize(224 + 32),
transforms.CenterCrop(224),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# 評估指標
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics=eval_metrics, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt,
metrics=eval_metrics, amp_level="O0")
# 開始驗證
result = model.eval(dataset_val)
print(result)
結果顯示:
{'Top_1_Accuracy': 0.75, 'Top_5_Accuracy': 0.928}
Top-1準確率75%,Top-5準確率92.8%,效果還不錯。
5 推理測試
5.1 推理數據準備
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [
transforms.Decode(),
transforms.Resize([224, 224]),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,
input_columns=["image"],
num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
5.2 推理和結果可視化
import cv2
import numpy as np
from PIL import Image
from scipy import io
def index2label():
"""獲取ImageNet類別標籤"""
metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
meta = io.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
_, wnids, classes = list(zip(*meta))[:3]
clssname = [tuple(clss.split(', ')) for clss in classes]
wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
mapping = {}
for index, (_, class_name) in enumerate(wind2class_name):
mapping[index] = class_name[0]
return mapping
# 推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
image = image["image"]
image = ms.Tensor(image)
prob = model.predict(image)
label = np.argmax(prob.asnumpy(), axis=1)
mapping = index2label()
output = {int(label): mapping[int(label)]}
print(output)
推理結果:
{236: 'Doberman'}
模型正確識別出了杜賓犬,説明推理效果不錯。
6 總結和思考
6.1 ViT的優勢
通過這次實踐,感受到ViT的幾個優勢:
- 架構簡潔:相比CNN的複雜卷積層設計,ViT的架構更加統一和簡潔
- 可擴展性強:Transformer的並行計算能力讓模型可以輕鬆擴展到更大規模
- 遷移能力好:在大數據集上預訓練後,可以很好地遷移到下游任務
6.2 實踐中的坑
- 計算資源要求高:ViT對GPU內存要求比較大,batch size不能設太大
- 需要大量數據:相比CNN,ViT更依賴大規模預訓練數據
- 位置編碼很重要:去掉位置編碼後性能會明顯下降
6.3 代碼實現的亮點
MindSpore的實現有幾個不錯的地方:
- 模塊化設計:每個組件都封裝得很好,便於理解和修改
- 自動混合精度:通過amp_level參數可以輕鬆開啓混合精度訓練
- 靈活的數據處理:數據預處理管道設計得很靈活
整個跑通過程還是比較順利的,代碼質量不錯,註釋也比較清楚。對於想了解ViT原理和實現的同學來説,這個教程是個不錯的起點。
當然,要真正掌握ViT,還需要多讀論文,多做實驗。這次只是個開始,後面可以嘗試在自己的數據集上微調,或者實現一些ViT的變種模型。