卷積神經網絡算法(CNN)是一種專門用來處理具有網格結構數據(如圖像、視頻、時間序列等)的深度神經網絡。
它通過模仿人類視覺皮層的工作機制,通過局部感受野、權重共享和池化等設計,極大地降低了模型的複雜度,並有效提取了數據的空間層次特徵。
卷積神經網絡算法在計算機視覺領域,如圖像識別、目標檢測、圖像分割中取得了巨大成功。
核心原理
在 CNN 出現之前,我們處理圖像通常使用全連接網絡。但處理高分辨率圖片時,全連接網絡存在兩個致命問題。
- 參數量爆炸:一張 的彩色圖片,如果第一個隱藏層有 1000 個節點,參數量就高達 30 億個,根本無法訓練。
- 丟失空間信息:全連接網絡需要把圖片 “拉平” 成一維向量,丟失了像素之間的空間結構關係。
CNN 通過局部連接、權值共享和池化三大核心機制,完美解決了這些痛點。
局部連接
與全連接網絡中每個神經元連接輸入層所有神經元不同,CNN 的卷積層中的神經元只與輸入數據的局部區域相連。
這使得網絡可以專注於學習局部的特徵,例如圖像中的邊緣、紋理等。
權重共享
在同一個卷積層中,用於掃描整個輸入的卷積核的權重在整個輸入上是共享的。
這意味着如果一個卷積核學會了識別 “垂直邊緣”,它就能在圖像的任何位置識別垂直邊緣。這極大地減少了模型參數。
池化
通過池化操作降低特徵圖的分辨率,減少計算量,並賦予模型一定的“平移不變性”(即物體在圖中稍微挪動位置,依然能被識別)。
核心架構
一個典型的 CNN 由卷積層、激活層、池化層和全連接層組成。
1.卷積層
卷積層是 CNN 的核心層。它通過卷積操作來提取輸入數據的局部特徵。
卷積操作本質上是通過一組可學習的卷積核(也稱濾波器)在輸入圖像上滑動,計算局部區域的加權和,生成特徵圖(局部特徵)。
關鍵概念
- 卷積核:一個小的權重矩陣(如 或 ),每個卷積核負責學習一種不同的特徵,如邊緣、紋理、角點等。
- 步長:卷積核每次滑動的距離。步長越大,輸出的特徵圖尺寸越小。
- 填充:在輸入邊緣補零,用於控制輸出特徵圖的大小,防止信息丟失。
2.激活層
由於卷積是線性運算,為了讓模型能擬合複雜的非線性函數,必須引入激活函數。
最常用的激活函數是 ReLU。
它計算速度極快,且在正區間解決了梯度消失問題,使深層網絡的訓練成為可能。
其數學公式為
3.池化層
池化(也叫下采樣)的主要目的是降低特徵圖的維度,減少計算量,防止過擬合。
常見的操作有最大池化或平均池化
- 最大池化:取窗口內的最大值,它能保留最顯著的特徵。
- 平均池化:取局部區域內的平均值作為輸出,用於平滑特徵。
池化層使得網絡對物體在圖像中的微小位置移動不敏感。即使貓在圖中挪動了幾個像素,經過池化後的特徵圖依然能捕捉到相似的信息。
4. 全連接層
在經過多層卷積和池化後,網絡提取到了高維的抽象特徵。
此時,我們將這些特徵圖“壓平”,輸入到全連接層中進行分類或迴歸任務。
訓練過程
CNN 的訓練過程與所有神經網絡一樣,依賴於反向傳播算法和梯度下降算法。
- 前向傳播
數據從輸入層進入,經過每一層卷積、激活、池化、全連接層,最終在輸出層得到預測值(比如:這張圖是貓的概率為 90%)。 - 計算損失
衡量預測值與真實標籤的差距。
分類任務通常使用交叉熵損失。 - 反向傳播
根據鏈式法則,計算損失函數對網絡中所有參數的梯度。 - 參數更新
使用隨機梯度下降(SGD)及其變體(如 Adam, RMSProp),根據計算出的梯度更新權重。
其中 是學習率。通過成千上萬次的迭代,模型逐漸“學會”瞭如何識別特徵。
案例分享
下面是一個使用卷積神經網絡算法實現 MNIST 手寫數字識別的完整示例代碼。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
# ===============================
# 1. 設備配置
# ===============================
device = torch.device("cuda"if torch.cuda.is_available() else"cpu")
print("Using device:", device)
# ===============================
# 2. 數據加載與預處理
# ===============================
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(
root="./data",
train=True,
download=True,
transform=transform
)
test_dataset = datasets.MNIST(
root="./data",
train=False,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=64,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=True
)
# ===============================
# 3. 定義 CNN 網絡
# ===============================
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 5 * 5, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = self.pool(x)
x = torch.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = CNN().to(device)
# ===============================
# 4. 損失函數與優化器
# ===============================
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# ===============================
# 5. 模型訓練
# ===============================
epochs = 3
for epoch in range(epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss:.4f}")
# ===============================
# 6. 模型預測
# ===============================
model.eval()
images_list = []
true_labels = []
pred_labels = []
with torch.no_grad():
for i, (images, labels) in enumerate(test_loader):
images = images.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
images_list.append(images.cpu().numpy())
true_labels.append(labels.item())
pred_labels.append(predicted.item())
if i >= 9:
break
# ===============================
# 7. 預測結果可視化
# ===============================
plt.figure(figsize=(12, 4))
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(images_list[i].squeeze(), cmap="gray")
plt.title(f"True: {true_labels[i]} | Pred: {pred_labels[i]}")
plt.axis("off")
plt.show()