本文通過 MNIST 手寫數字識別案例,詳細講解了 TensorFlow/Keras 模型的搭建、訓練、評估和保存全流程,涵蓋了 Sequential 和函數式 API 兩種模型搭建方式,以及三種主流的模型保存方法。
目錄
- 一、前期準備:環境搭建與數據集加載
- 1. TensorFlow環境安裝
- 2. 加載MNIST數據集
- 3. 數據預處理
- 二、模型搭建:Sequential與函數式API兩種方式
- 1. Sequential序貫模型(入門首選)
- 2. 函數式API(進階必備)
- 三、模型編譯與訓練:核心參數解析
- 1. 模型編譯
- 2. 模型訓練
- 3. 訓練過程可視化
- 四、模型評估:測試集性能驗證
- 五、模型保存與加載:三種常用方式
- 1. 保存完整模型(HDF5格式)
- 2. 分別保存模型結構和權重
- 3. 保存為SavedModel格式(TensorFlow推薦)
- 六、常見問題與解決方案
在深度學習領域,TensorFlow作為谷歌開源的主流框架,憑藉其靈活性和生態完整性被廣泛應用,而Keras作為TensorFlow的高層API,以簡潔直觀的接口大幅降低了模型開發的門檻。本文將從環境搭建到實戰落地,手把手教你完成TensorFlow/Keras模型的搭建、訓練、評估與保存,全程基於MNIST手寫數字識別經典案例,適合深度學習入門者和進階開發者參考。
一、前期準備:環境搭建與數據集加載
在開始模型開發前,需完成TensorFlow環境配置和數據集準備,這是後續操作的基礎。
1. TensorFlow環境安裝
TensorFlow 2.x版本已將Keras集成到核心模塊中,無需單獨安裝Keras。建議使用Python 3.8~3.11版本(兼容性最佳),通過pip命令安裝:
# 國內使用清華鏡像源加速安裝
pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple
安裝完成後,可通過以下代碼驗證是否安裝成功,並檢查是否支持GPU加速(GPU加速能大幅提升訓練速度):
import tensorflow as tf
# 打印TensorFlow版本
print(tf.__version__)
# 檢查GPU是否可用
print("GPU是否可用:", tf.config.list_physical_devices('GPU'))
若輸出GPU設備信息,説明GPU加速配置成功;若僅顯示CPU,可參考TensorFlow官方文檔配置CUDA和cuDNN。
2. 加載MNIST數據集
MNIST是手寫數字識別數據集,包含60000張訓練圖片和10000張測試圖片,每張圖片為28×28的灰度圖,標籤為0~9的數字。Keras內置了該數據集的加載函數:
from tensorflow.keras.datasets import mnist
# 加載數據集,分為訓練集和測試集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 查看數據集形狀
print("訓練集圖片形狀:", x_train.shape) # (60000, 28, 28)
print("訓練集標籤形狀:", y_train.shape) # (60000,)
print("測試集圖片形狀:", x_test.shape) # (10000, 28, 28)
3. 數據預處理
原始數據需經過歸一化和維度調整,適配模型輸入要求:
# 歸一化:將像素值從0-255縮放到0-1,提升模型收斂速度
x_train = x_train / 255.0
x_test = x_test / 255.0
# 增加維度:將(60000,28,28)轉為(60000,28,28,1),適配卷積層的4維輸入(樣本數,高度,寬度,通道數)
x_train = tf.expand_dims(x_train, axis=-1)
x_test = tf.expand_dims(x_test, axis=-1)
# 查看預處理後的數據形狀
print("預處理後訓練集形狀:", x_train.shape) # (60000, 28, 28, 1)
二、模型搭建:Sequential與函數式API兩種方式
Keras提供了Sequential序貫模型和函數式API兩種搭建方式,前者適合簡單的線性堆疊模型,後者支持構建複雜的多輸入、多輸出模型。
1. Sequential序貫模型(入門首選)
Sequential模型通過add()方法逐層堆疊網絡層,適合結構簡單的深度學習模型。本文構建一個卷積神經網絡(CNN),用於手寫數字識別:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 初始化序貫模型
model = Sequential([
# 卷積層:32個3×3卷積核,激活函數ReLU,輸入形狀為(28,28,1)
Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
# 池化層:2×2最大池化,降低特徵維度
MaxPooling2D((2, 2)),
# 第二層卷積層
Conv2D(64, (3, 3), activation='relu'),
# 第二層池化層
MaxPooling2D((2, 2)),
# 展平層:將二維特徵圖轉為一維向量
Flatten(),
# 全連接層:128個神經元,激活函數ReLU
Dense(128, activation='relu'),
# 輸出層:10個神經元,激活函數Softmax,對應0-9的分類概率
Dense(10, activation='softmax')
])
# 查看模型結構
model.summary()
model.summary()會輸出模型的各層參數和總參數量,便於檢查網絡結構是否符合預期。
2. 函數式API(進階必備)
函數式API通過定義輸入張量和輸出張量,靈活構建複雜模型。以相同的CNN結構為例,函數式API的實現方式如下:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
# 定義輸入層,形狀為(28,28,1)
inputs = Input(shape=(28, 28, 1))
# 卷積層1
x = Conv2D(32, (3, 3), activation='relu')(inputs)
# 池化層1
x = MaxPooling2D((2, 2))(x)
# 卷積層2
x = Conv2D(64, (3, 3), activation='relu')(x)
# 池化層2
x = MaxPooling2D((2, 2))(x)
# 展平層
x = Flatten()(x)
# 全連接層
x = Dense(128, activation='relu')(x)
# 輸出層
outputs = Dense(10, activation='softmax')(x)
# 構建模型
model = Model(inputs=inputs, outputs=outputs)
# 查看模型結構
model.summary()
函數式API的優勢在於可自由組合網絡層,比如構建多分支網絡、殘差連接等,是複雜模型開發的主流方式。
三、模型編譯與訓練:核心參數解析
模型搭建完成後,需先編譯(配置訓練參數),再通過訓練數據迭代優化模型參數。
1. 模型編譯
compile()方法用於配置模型的優化器、損失函數和評估指標,這三個參數是訓練的核心:
model.compile(
optimizer='adam', # 優化器:Adam是自適應矩估計,收斂速度快且穩定
loss='sparse_categorical_crossentropy', # 損失函數:適用於整數標籤的多分類
metrics=['accuracy'] # 評估指標:準確率
)
參數説明:
- 優化器:除Adam外,還有SGD(隨機梯度下降)、RMSprop等,Adam是大多數場景的首選。
- 損失函數:若標籤為獨熱編碼(如[0,1,0]),使用
categorical_crossentropy;若為整數標籤(如1),使用sparse_categorical_crossentropy。 - 評估指標:常用
accuracy(準確率),也可自定義指標。
2. 模型訓練
fit()方法用於執行模型訓練,核心參數包括訓練數據、迭代次數、批次大小、驗證集等:
history = model.fit(
x_train, # 訓練特徵數據
y_train, # 訓練標籤
epochs=5, # 迭代次數:整個訓練集遍歷5次
batch_size=32, # 批次大小:每次迭代用32個樣本更新參數
validation_split=0.1 # 驗證集比例:從訓練集中劃分10%作為驗證集
)
訓練過程中會實時輸出每一輪的訓練損失、準確率和驗證損失、準確率,通過驗證集指標可判斷模型是否過擬合。history對象保存了訓練過程的指標變化,可用於後續可視化。
3. 訓練過程可視化
通過Matplotlib繪製訓練和驗證的準確率、損失曲線,直觀分析模型訓練狀態:
import matplotlib.pyplot as plt
# 提取訓練指標
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(5)
# 繪製準確率曲線
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
# 繪製損失曲線
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
若驗證準確率不再提升甚至下降,而訓練準確率持續上升,説明模型出現過擬合,可通過增加正則化、數據增強等方式解決。
四、模型評估:測試集性能驗證
訓練完成後,需用獨立的測試集評估模型的泛化能力,避免過擬合導致的評估偏差:
# 在測試集上評估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print("測試集損失:", test_loss)
print("測試集準確率:", test_acc)
對於MNIST數據集,本文的CNN模型測試準確率通常能達到98%以上,若準確率過低,可調整網絡結構(如增加捲積層)或訓練參數(如增加epochs)。
五、模型保存與加載:三種常用方式
訓練好的模型需要保存到本地,以便後續部署或繼續訓練。Keras提供了三種主流的保存方式,適用於不同場景。
1. 保存完整模型(HDF5格式)
將模型的結構、權重、編譯信息全部保存為.h5文件,加載後可直接使用,無需重新編譯:
# 保存完整模型
model.save('mnist_cnn.h5')
# 加載模型
from tensorflow.keras.models import load_model
loaded_model = load_model('mnist_cnn.h5')
# 用加載的模型評估測試集
loaded_test_loss, loaded_test_acc = loaded_model.evaluate(x_test, y_test)
print("加載模型後的測試準確率:", loaded_test_acc)
這種方式簡單易用,適合小型模型的保存和分享。
2. 分別保存模型結構和權重
將模型結構保存為JSON/YMAL文件,權重保存為.h5文件,便於單獨修改模型結構或更新權重:
# 保存模型結構為JSON文件
model_json = model.to_json()
with open('mnist_cnn_structure.json', 'w') as f:
f.write(model_json)
# 保存模型權重為.h5文件
model.save_weights('mnist_cnn_weights.h5')
# 加載模型結構
from tensorflow.keras.models import model_from_json
with open('mnist_cnn_structure.json', 'r') as f:
loaded_structure = f.read()
loaded_model_2 = model_from_json(loaded_structure)
# 加載模型權重
loaded_model_2.load_weights('mnist_cnn_weights.h5')
# 加載權重後需重新編譯模型
loaded_model_2.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# 評估加載的模型
loaded_test_loss_2, loaded_test_acc_2 = loaded_model_2.evaluate(x_test, y_test)
print("分別加載結構和權重後的測試準確率:", loaded_test_acc_2)
這種方式適合需要調整模型結構後重新加載權重的場景。
3. 保存為SavedModel格式(TensorFlow推薦)
SavedModel是TensorFlow的原生格式,支持跨平台部署(如TensorFlow Serving、TensorFlow Lite),也是TensorFlow 2.x的推薦保存方式:
# 保存為SavedModel格式
model.save('mnist_cnn_savedmodel')
# 加載SavedModel格式的模型
loaded_model_3 = tf.keras.models.load_model('mnist_cnn_savedmodel')
# 評估模型
loaded_test_loss_3, loaded_test_acc_3 = loaded_model_3.evaluate(x_test, y_test)
print("SavedModel加載後的測試準確率:", loaded_test_acc_3)
SavedModel格式包含模型的計算圖和權重,適合生產環境的部署和推理。
六、常見問題與解決方案
- 訓練過擬合:可通過增加數據增強(如
ImageDataGenerator)、添加Dropout層、L2正則化等方式緩解。 - 模型加載失敗:確保加載模型時的TensorFlow版本與保存時一致,HDF5文件路徑無中文或特殊字符。
- 訓練速度慢:檢查GPU是否正常啓用,減小批次大小或使用更輕量的網絡結構。
- 損失不收斂:調整學習率(如Adam優化器設置
learning_rate=0.001)、增加數據預處理步驟、檢查網絡層輸入輸出維度是否匹配。