下圖展示了ViT的完整架構:從輸入圖像分割成patches,到Transformer編碼器處理,最後通過分類頭輸出結果。整個流程清晰明瞭,接下來我們一步步來實現。

自研AI框架升思MindSpore數據變換:Transforms_Yeats_#華為

1 環境搭建和數據準備

1.1 環境配置

首先確保本地裝好了Python和MindSpore。這個教程建議用GPU跑,CPU會慢得讓人懷疑人生。

數據集用的是ImageNet的子集,第一次運行會自動下載:

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"

path = download(dataset_url, path, kind="zip", replace=True)

下載完後數據結構是這樣的:

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

1.2 數據預處理

數據預處理這塊比較標準,主要是resize、隨機裁剪、歸一化這些操作:

import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms

data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

這裏的mean和std是ImageNet的標準值,乘以255是因為MindSpore的數據格式。

2 ViT模型原理解析

2.1 Transformer的核心:Self-Attention

要理解ViT,得先搞懂Transformer的核心機制——Self-Attention。簡單來説,就是讓模型學會關注輸入序列中不同位置之間的關係。

Self-Attention的計算過程:

  1. 輸入向量通過三個不同的線性變換得到Q(Query)、K(Key)、V(Value)
  2. 計算Q和K的點積,得到注意力權重
  3. 用這些權重對V進行加權求和

數學公式是這樣的:
自研AI框架升思MindSpore數據變換:Transforms_Yeats_#華為_02

然後計算注意力分數:
自研AI框架升思MindSpore數據變換:Transforms_Yeats_#華為_03

經過Softmax歸一化後,得到最終輸出:
自研AI框架升思MindSpore數據變換:Transforms_Yeats_Self_04

自研AI框架升思MindSpore數據變換:Transforms_Yeats_#transformer_05

上圖詳細展示了Self-Attention的計算過程:從輸入序列X通過線性變換得到Q、K、V矩陣,然後計算注意力分數,經過Softmax得到權重,最後加權求和得到輸出。這個機制讓模型能夠動態地關注輸入序列中的不同部分。

2.2 Multi-Head Attention實現

多頭注意力就是把輸入分成多個"頭",每個頭獨立計算注意力,最後拼接起來。這樣能讓模型從不同角度理解輸入:

from mindspore import nn, ops

class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

這段代碼的關鍵在於:

  • qkv = self.qkv(x) 一次性生成Q、K、V三個矩陣
  • reshape和transpose操作把數據重新組織成多頭的形式
  • 最後把多個頭的結果拼接回去

2.3 Feed Forward和殘差連接

除了注意力機制,Transformer還需要Feed Forward網絡和殘差連接:

from typing import Optional

class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)

    def construct(self, x):
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        return x

class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        return self.cell(x) + x

殘差連接很簡單,就是把輸入直接加到輸出上,這樣能避免深層網絡的梯度消失問題。

2.4 TransformerEncoder的完整實現

把注意力機制、Feed Forward和殘差連接組合起來,就是TransformerEncoder:

class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []

        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        return self.layers(x)

這裏有個細節:ViT把LayerNorm放在了注意力和Feed Forward之前,這和標準Transformer不太一樣,但實驗證明這樣效果更好。

3 ViT的關鍵創新:圖像轉序列

自研AI框架升思MindSpore數據變換:Transforms_Yeats_#人工智能_06

上圖展示了ViT處理圖像的完整流程:從原始圖像分割成patches,經過embedding轉換,添加位置編碼和CLS token,通過Transformer編碼器處理,最後提取CLS token進行分類預測。

3.1 Patch Embedding

ViT最巧妙的地方就是把圖像轉換成序列。具體做法是把圖像切成一個個小塊(patch),然後把每個patch拉成一維向量:

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4

    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
        return x

這裏用卷積來實現patch切分,比手工切分更高效。對於224×224的圖像,用16×16的patch,最終得到14×14=196個patch。

3.2 位置編碼和分類token

圖像切成patch後,還需要加上位置信息和分類token:

# 在ViT類的__init__中
self.cls_token = init(init_type=Normal(sigma=1.0),
                      shape=(1, 1, embed_dim),
                      dtype=ms.float32,
                      name='cls',
                      requires_grad=True)

self.pos_embedding = init(init_type=Normal(sigma=1.0),
                          shape=(1, num_patches + 1, embed_dim),
                          dtype=ms.float32,
                          name='pos_embedding',
                          requires_grad=True)

分類token借鑑了BERT的思路,在序列開頭加一個特殊token,最後用這個token的輸出來做分類。位置編碼則告訴模型每個patch在圖像中的位置。

3.3 完整的ViT模型

把所有組件組合起來,就是完整的ViT模型:

from mindspore.common.initializer import Normal, initializer
from mindspore import Parameter

def init(init_type, shape, dtype, name, requires_grad):
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)

class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]  # 取分類token的輸出
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

整個流程就是:圖像 → patch embedding → 加上cls token和位置編碼 → Transformer編碼器 → 分類頭。

4 訓練和驗證實戰

4.1 訓練配置

訓練前需要設置損失函數、優化器等:

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train

# 超參數設置
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# 構建模型
network = ViT()

# 加載預訓練權重
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# 學習率調度
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)

# 優化器
network_opt = nn.Adam(network.trainable_params(), lr, momentum)

這裏用了預訓練模型,所以學習率設得比較小。餘弦退火調度能讓訓練更穩定。

4.2 損失函數

用了帶標籤平滑的交叉熵損失:

class CrossEntropySmooth(LossBase):
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss

network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

標籤平滑能防止模型過擬合,提高泛化能力。

4.3 開始訓練

# 設置檢查點
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# 初始化模型
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics={"acc"}, amp_level="O0")

# 開始訓練
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False)

訓練過程中會看到這樣的輸出:

epoch: 1 step: 125, loss is 1.903618335723877
Train epoch time: 99857.517 ms, per step time: 798.860 ms
epoch: 2 step: 125, loss is 1.448015570640564
Train epoch time: 95555.111 ms, per step time: 764.441 ms

loss在逐漸下降,説明訓練正常進行。

自研AI框架升思MindSpore數據變換:Transforms_Yeats_#華為_07

上圖展示了ViT模型的訓練過程:左側是損失函數的下降趨勢,右側是準確率的提升曲線,下方表格總結了訓練配置和最終結果。可以看到模型在訓練過程中穩定收斂,最終達到了不錯的性能。

4.4 模型驗證

訓練完後驗證一下效果:

# 驗證數據預處理
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)

trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)

# 評估指標
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, 
                       metrics=eval_metrics, amp_level="O0")

# 開始驗證
result = model.eval(dataset_val)
print(result)

結果顯示:

{'Top_1_Accuracy': 0.75, 'Top_5_Accuracy': 0.928}

Top-1準確率75%,Top-5準確率92.8%,效果還不錯。

5 推理測試

5.1 推理數據準備

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)

trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

5.2 推理和結果可視化

import cv2
import numpy as np
from PIL import Image
from scipy import io

def index2label():
    """獲取ImageNet類別標籤"""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
    
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
    
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
    
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping

# 推理
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)

推理結果:

{236: 'Doberman'}

自研AI框架升思MindSpore數據變換:Transforms_Yeats_#transformer_08

模型正確識別出了杜賓犬,説明推理效果不錯。

6 總結和思考

6.1 ViT的優勢

通過這次實踐,感受到ViT的幾個優勢:

  1. 架構簡潔:相比CNN的複雜卷積層設計,ViT的架構更加統一和簡潔
  2. 可擴展性強:Transformer的並行計算能力讓模型可以輕鬆擴展到更大規模
  3. 遷移能力好:在大數據集上預訓練後,可以很好地遷移到下游任務

6.2 實踐中的坑

  1. 計算資源要求高:ViT對GPU內存要求比較大,batch size不能設太大
  2. 需要大量數據:相比CNN,ViT更依賴大規模預訓練數據
  3. 位置編碼很重要:去掉位置編碼後性能會明顯下降

6.3 代碼實現的亮點

MindSpore的實現有幾個不錯的地方:

  1. 模塊化設計:每個組件都封裝得很好,便於理解和修改
  2. 自動混合精度:通過amp_level參數可以輕鬆開啓混合精度訓練
  3. 靈活的數據處理:數據預處理管道設計得很靈活

整個跑通過程還是比較順利的,代碼質量不錯,註釋也比較清楚。對於想了解ViT原理和實現的同學來説,這個教程是個不錯的起點。

當然,要真正掌握ViT,還需要多讀論文,多做實驗。這次只是個開始,後面可以嘗試在自己的數據集上微調,或者實現一些ViT的變種模型。