.tf 格式是 TensorFlow 的 SavedModel 格式的文件擴展名,它是 TensorFlow 推薦的標準模型保存格式。

什麼是 .tf 格式?

當你在 TensorFlow 中調用 model.save('my_model') 時,TensorFlow 會創建一個名為 my_model 的目錄,裏面包含:

my_model/
├── saved_model.pb      # 模型架構和計算圖定義
├── variables/          # 模型權重
│   ├── variables.data-00000-of-00001
│   └── variables.index
└── assets/            # 輔助文件(可選)
└── fingerprint.pb     # 模型指紋信息

這裏的 .pb 文件是 Protocol Buffer 格式,但整個保存的模型結構我們通常稱為 “.tf 格式”“SavedModel 格式”

.tf 格式 vs .h5 格式

特性

.tf 格式 (SavedModel)

.h5 格式

文件結構

目錄(多個文件)

單個文件

保存內容

模型架構 + 權重 + 計算圖 + 訓練配置

模型架構 + 權重

TensorFlow 版本

TF2.x 推薦

TF1.x 常用

加載方式

tf.keras.models.load_model('model_dir')

tf.keras.models.load_model('model.h5')

自定義對象

支持較好

需要額外配置

生產部署

推薦

不推薦

代碼示例

保存為 .tf 格式

import tensorflow as tf

# 創建並訓練一個簡單模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy')

# 保存為 .tf 格式(默認)
model.save('my_model')  # 創建 my_model/ 目錄

# 或者明確指定格式
model.save('my_model', save_format='tf')

加載 .tf 格式模型

# 加載模型
loaded_model = tf.keras.models.load_model('my_model')

# 使用模型進行預測
predictions = loaded_model.predict(test_data)

查看 .tf 格式內容

import os

model_dir = 'my_model'
print("SavedModel 目錄內容:")
for root, dirs, files in os.walk(model_dir):
    level = root.replace(model_dir, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f'{indent}{os.path.basename(root)}/')
    subindent = ' ' * 2 * (level + 1)
    for file in files:
        print(f'{subindent}{file}')

輸出示例:

SavedModel 目錄內容:
my_model/
  saved_model.pb
  variables/
    variables.data-00000-of-00001
    variables.index
  assets/
  fingerprint.pb

.tf 格式的優勢

1. 完整的模型信息

# 保存完整的模型信息
model.save('complete_model')  # 保存:架構、權重、優化器狀態、訓練配置

# 之後可以繼續訓練
loaded_model = tf.keras.models.load_model('complete_model')
loaded_model.fit(more_data, more_labels)  # 可以繼續訓練!

2. 支持自定義層和模型

class CustomLayer(tf.keras.layers.Layer):
    def __init__(self, units=32):
        super().__init__()
        self.units = units
    
    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.units))
        self.b = self.add_weight(shape=(self.units,))
    
    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

# 使用自定義層
model = tf.keras.Sequential([
    CustomLayer(64),
    tf.keras.layers.Dense(10)
])

# .tf 格式可以正確保存和加載自定義層
model.save('custom_model')
loaded_model = tf.keras.models.load_model('custom_model')

3. 生產環境友好

# 使用 TensorFlow Serving 部署
import subprocess

# 啓動 TensorFlow Serving(示例命令)
serve_command = [
    'tensorflow_model_server',
    '--rest_api_port=8501',
    '--model_name=my_model',
    '--model_base_path=/path/to/my_model'
]

# .tf 格式可以直接用於生產環境部署

4. 跨平台兼容性

# 在 Python 中保存
model.save('my_model')

# 可以在其他環境中加載:
# - TensorFlow.js
# - TensorFlow Lite
# - TensorFlow Serving
# - 其他編程語言的 TensorFlow 綁定

實際應用場景

場景1:訓練後部署

# 訓練模型
model.fit(train_data, train_labels, epochs=10)

# 保存為生產就緒的格式
model.save('production_model')

# 之後在任何地方加載使用
production_model = tf.keras.models.load_model('production_model')

場景2:模型版本管理

import datetime

# 帶時間戳保存不同版本
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model.save(f'models/mnist_model_{timestamp}')

# 目錄結構:
# models/
#   ├── mnist_model_20231201-143022/
#   ├── mnist_model_20231202-093015/
#   └── mnist_model_20231203-162345/

場景3:轉換為其他格式

# 從 .tf 格式轉換為 TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_saved_model('my_model')
tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

總結

.tf 格式是:

  1. TensorFlow 2.x 的默認模型保存格式
  2. 一個目錄,包含多個文件,而不是單個文件
  3. 功能完整,支持保存模型架構、權重、優化器狀態等所有信息
  4. 生產就緒,可以直接用於 TensorFlow Serving 等生產環境
  5. 靈活性強,支持自定義層和複雜模型結構

對於新的 TensorFlow 項目,推薦始終使用 .tf 格式來保存模型,除非你有特定的理由需要使用 .h5 格式。