本文通過 MNIST 手寫數字識別案例,詳細講解了 TensorFlow/Keras 模型的搭建、訓練、評估和保存全流程,涵蓋了 Sequential 和函數式 API 兩種模型搭建方式,以及三種主流的模型保存方法。

目錄

  • 一、前期準備:環境搭建與數據集加載
  • 1. TensorFlow環境安裝
  • 2. 加載MNIST數據集
  • 3. 數據預處理
  • 二、模型搭建:Sequential與函數式API兩種方式
  • 1. Sequential序貫模型(入門首選)
  • 2. 函數式API(進階必備)
  • 三、模型編譯與訓練:核心參數解析
  • 1. 模型編譯
  • 2. 模型訓練
  • 3. 訓練過程可視化
  • 四、模型評估:測試集性能驗證
  • 五、模型保存與加載:三種常用方式
  • 1. 保存完整模型(HDF5格式)
  • 2. 分別保存模型結構和權重
  • 3. 保存為SavedModel格式(TensorFlow推薦)
  • 六、常見問題與解決方案


在深度學習領域,TensorFlow作為谷歌開源的主流框架,憑藉其靈活性和生態完整性被廣泛應用,而Keras作為TensorFlow的高層API,以簡潔直觀的接口大幅降低了模型開發的門檻。本文將從環境搭建到實戰落地,手把手教你完成TensorFlow/Keras模型的搭建、訓練、評估與保存,全程基於MNIST手寫數字識別經典案例,適合深度學習入門者和進階開發者參考。

一、前期準備:環境搭建與數據集加載

在開始模型開發前,需完成TensorFlow環境配置和數據集準備,這是後續操作的基礎。

1. TensorFlow環境安裝

TensorFlow 2.x版本已將Keras集成到核心模塊中,無需單獨安裝Keras。建議使用Python 3.8~3.11版本(兼容性最佳),通過pip命令安裝:

# 國內使用清華鏡像源加速安裝
pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple

安裝完成後,可通過以下代碼驗證是否安裝成功,並檢查是否支持GPU加速(GPU加速能大幅提升訓練速度):

import tensorflow as tf
# 打印TensorFlow版本
print(tf.__version__)
# 檢查GPU是否可用
print("GPU是否可用:", tf.config.list_physical_devices('GPU'))

若輸出GPU設備信息,説明GPU加速配置成功;若僅顯示CPU,可參考TensorFlow官方文檔配置CUDA和cuDNN。

2. 加載MNIST數據集

MNIST是手寫數字識別數據集,包含60000張訓練圖片和10000張測試圖片,每張圖片為28×28的灰度圖,標籤為0~9的數字。Keras內置了該數據集的加載函數:

from tensorflow.keras.datasets import mnist

# 加載數據集,分為訓練集和測試集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 查看數據集形狀
print("訓練集圖片形狀:", x_train.shape)  # (60000, 28, 28)
print("訓練集標籤形狀:", y_train.shape)  # (60000,)
print("測試集圖片形狀:", x_test.shape)    # (10000, 28, 28)

3. 數據預處理

原始數據需經過歸一化和維度調整,適配模型輸入要求:

# 歸一化:將像素值從0-255縮放到0-1,提升模型收斂速度
x_train = x_train / 255.0
x_test = x_test / 255.0

# 增加維度:將(60000,28,28)轉為(60000,28,28,1),適配卷積層的4維輸入(樣本數,高度,寬度,通道數)
x_train = tf.expand_dims(x_train, axis=-1)
x_test = tf.expand_dims(x_test, axis=-1)

# 查看預處理後的數據形狀
print("預處理後訓練集形狀:", x_train.shape)  # (60000, 28, 28, 1)

二、模型搭建:Sequential與函數式API兩種方式

Keras提供了Sequential序貫模型函數式API兩種搭建方式,前者適合簡單的線性堆疊模型,後者支持構建複雜的多輸入、多輸出模型。

1. Sequential序貫模型(入門首選)

Sequential模型通過add()方法逐層堆疊網絡層,適合結構簡單的深度學習模型。本文構建一個卷積神經網絡(CNN),用於手寫數字識別:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

# 初始化序貫模型
model = Sequential([
    # 卷積層:32個3×3卷積核,激活函數ReLU,輸入形狀為(28,28,1)
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    # 池化層:2×2最大池化,降低特徵維度
    MaxPooling2D((2, 2)),
    # 第二層卷積層
    Conv2D(64, (3, 3), activation='relu'),
    # 第二層池化層
    MaxPooling2D((2, 2)),
    # 展平層:將二維特徵圖轉為一維向量
    Flatten(),
    # 全連接層:128個神經元,激活函數ReLU
    Dense(128, activation='relu'),
    # 輸出層:10個神經元,激活函數Softmax,對應0-9的分類概率
    Dense(10, activation='softmax')
])

# 查看模型結構
model.summary()

model.summary()會輸出模型的各層參數和總參數量,便於檢查網絡結構是否符合預期。

2. 函數式API(進階必備)

函數式API通過定義輸入張量和輸出張量,靈活構建複雜模型。以相同的CNN結構為例,函數式API的實現方式如下:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

# 定義輸入層,形狀為(28,28,1)
inputs = Input(shape=(28, 28, 1))
# 卷積層1
x = Conv2D(32, (3, 3), activation='relu')(inputs)
# 池化層1
x = MaxPooling2D((2, 2))(x)
# 卷積層2
x = Conv2D(64, (3, 3), activation='relu')(x)
# 池化層2
x = MaxPooling2D((2, 2))(x)
# 展平層
x = Flatten()(x)
# 全連接層
x = Dense(128, activation='relu')(x)
# 輸出層
outputs = Dense(10, activation='softmax')(x)

# 構建模型
model = Model(inputs=inputs, outputs=outputs)

# 查看模型結構
model.summary()

函數式API的優勢在於可自由組合網絡層,比如構建多分支網絡、殘差連接等,是複雜模型開發的主流方式。

三、模型編譯與訓練:核心參數解析

模型搭建完成後,需先編譯(配置訓練參數),再通過訓練數據迭代優化模型參數。

1. 模型編譯

compile()方法用於配置模型的優化器、損失函數和評估指標,這三個參數是訓練的核心:

model.compile(
    optimizer='adam',  # 優化器:Adam是自適應矩估計,收斂速度快且穩定
    loss='sparse_categorical_crossentropy',  # 損失函數:適用於整數標籤的多分類
    metrics=['accuracy']  # 評估指標:準確率
)

參數説明

  • 優化器:除Adam外,還有SGD(隨機梯度下降)、RMSprop等,Adam是大多數場景的首選。
  • 損失函數:若標籤為獨熱編碼(如[0,1,0]),使用categorical_crossentropy;若為整數標籤(如1),使用sparse_categorical_crossentropy
  • 評估指標:常用accuracy(準確率),也可自定義指標。

2. 模型訓練

fit()方法用於執行模型訓練,核心參數包括訓練數據、迭代次數、批次大小、驗證集等:

history = model.fit(
    x_train,  # 訓練特徵數據
    y_train,  # 訓練標籤
    epochs=5,  # 迭代次數:整個訓練集遍歷5次
    batch_size=32,  # 批次大小:每次迭代用32個樣本更新參數
    validation_split=0.1  # 驗證集比例:從訓練集中劃分10%作為驗證集
)

訓練過程中會實時輸出每一輪的訓練損失、準確率和驗證損失、準確率,通過驗證集指標可判斷模型是否過擬合。history對象保存了訓練過程的指標變化,可用於後續可視化。

3. 訓練過程可視化

通過Matplotlib繪製訓練和驗證的準確率、損失曲線,直觀分析模型訓練狀態:

import matplotlib.pyplot as plt

# 提取訓練指標
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(5)

# 繪製準確率曲線
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

# 繪製損失曲線
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

若驗證準確率不再提升甚至下降,而訓練準確率持續上升,説明模型出現過擬合,可通過增加正則化、數據增強等方式解決。

四、模型評估:測試集性能驗證

訓練完成後,需用獨立的測試集評估模型的泛化能力,避免過擬合導致的評估偏差:

# 在測試集上評估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print("測試集損失:", test_loss)
print("測試集準確率:", test_acc)

對於MNIST數據集,本文的CNN模型測試準確率通常能達到98%以上,若準確率過低,可調整網絡結構(如增加捲積層)或訓練參數(如增加epochs)。

五、模型保存與加載:三種常用方式

訓練好的模型需要保存到本地,以便後續部署或繼續訓練。Keras提供了三種主流的保存方式,適用於不同場景。

1. 保存完整模型(HDF5格式)

將模型的結構、權重、編譯信息全部保存為.h5文件,加載後可直接使用,無需重新編譯:

# 保存完整模型
model.save('mnist_cnn.h5')

# 加載模型
from tensorflow.keras.models import load_model
loaded_model = load_model('mnist_cnn.h5')

# 用加載的模型評估測試集
loaded_test_loss, loaded_test_acc = loaded_model.evaluate(x_test, y_test)
print("加載模型後的測試準確率:", loaded_test_acc)

這種方式簡單易用,適合小型模型的保存和分享。

2. 分別保存模型結構和權重

將模型結構保存為JSON/YMAL文件,權重保存為.h5文件,便於單獨修改模型結構或更新權重:

# 保存模型結構為JSON文件
model_json = model.to_json()
with open('mnist_cnn_structure.json', 'w') as f:
    f.write(model_json)

# 保存模型權重為.h5文件
model.save_weights('mnist_cnn_weights.h5')

# 加載模型結構
from tensorflow.keras.models import model_from_json
with open('mnist_cnn_structure.json', 'r') as f:
    loaded_structure = f.read()
loaded_model_2 = model_from_json(loaded_structure)

# 加載模型權重
loaded_model_2.load_weights('mnist_cnn_weights.h5')

# 加載權重後需重新編譯模型
loaded_model_2.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 評估加載的模型
loaded_test_loss_2, loaded_test_acc_2 = loaded_model_2.evaluate(x_test, y_test)
print("分別加載結構和權重後的測試準確率:", loaded_test_acc_2)

這種方式適合需要調整模型結構後重新加載權重的場景。

3. 保存為SavedModel格式(TensorFlow推薦)

SavedModel是TensorFlow的原生格式,支持跨平台部署(如TensorFlow Serving、TensorFlow Lite),也是TensorFlow 2.x的推薦保存方式:

# 保存為SavedModel格式
model.save('mnist_cnn_savedmodel')

# 加載SavedModel格式的模型
loaded_model_3 = tf.keras.models.load_model('mnist_cnn_savedmodel')

# 評估模型
loaded_test_loss_3, loaded_test_acc_3 = loaded_model_3.evaluate(x_test, y_test)
print("SavedModel加載後的測試準確率:", loaded_test_acc_3)

SavedModel格式包含模型的計算圖和權重,適合生產環境的部署和推理。

六、常見問題與解決方案

  1. 訓練過擬合:可通過增加數據增強(如ImageDataGenerator)、添加Dropout層、L2正則化等方式緩解。
  2. 模型加載失敗:確保加載模型時的TensorFlow版本與保存時一致,HDF5文件路徑無中文或特殊字符。
  3. 訓練速度慢:檢查GPU是否正常啓用,減小批次大小或使用更輕量的網絡結構。
  4. 損失不收斂:調整學習率(如Adam優化器設置learning_rate=0.001)、增加數據預處理步驟、檢查網絡層輸入輸出維度是否匹配。