.tf 格式是 TensorFlow 的 SavedModel 格式的文件擴展名,它是 TensorFlow 推薦的標準模型保存格式。
什麼是 .tf 格式?
當你在 TensorFlow 中調用 model.save('my_model') 時,TensorFlow 會創建一個名為 my_model 的目錄,裏面包含:
my_model/
├── saved_model.pb # 模型架構和計算圖定義
├── variables/ # 模型權重
│ ├── variables.data-00000-of-00001
│ └── variables.index
└── assets/ # 輔助文件(可選)
└── fingerprint.pb # 模型指紋信息
這裏的 .pb 文件是 Protocol Buffer 格式,但整個保存的模型結構我們通常稱為 “.tf 格式” 或 “SavedModel 格式”。
.tf 格式 vs .h5 格式
|
特性
|
.tf 格式 (SavedModel)
|
.h5 格式
|
|
文件結構 |
目錄(多個文件)
|
單個文件
|
|
保存內容 |
模型架構 + 權重 + 計算圖 + 訓練配置
|
模型架構 + 權重
|
|
TensorFlow 版本 |
TF2.x 推薦
|
TF1.x 常用
|
|
加載方式 |
|
|
|
自定義對象 |
支持較好
|
需要額外配置
|
|
生產部署 |
推薦 |
不推薦
|
代碼示例
保存為 .tf 格式
import tensorflow as tf
# 創建並訓練一個簡單模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy')
# 保存為 .tf 格式(默認)
model.save('my_model') # 創建 my_model/ 目錄
# 或者明確指定格式
model.save('my_model', save_format='tf')
加載 .tf 格式模型
# 加載模型
loaded_model = tf.keras.models.load_model('my_model')
# 使用模型進行預測
predictions = loaded_model.predict(test_data)
查看 .tf 格式內容
import os
model_dir = 'my_model'
print("SavedModel 目錄內容:")
for root, dirs, files in os.walk(model_dir):
level = root.replace(model_dir, '').count(os.sep)
indent = ' ' * 2 * level
print(f'{indent}{os.path.basename(root)}/')
subindent = ' ' * 2 * (level + 1)
for file in files:
print(f'{subindent}{file}')
輸出示例:
SavedModel 目錄內容:
my_model/
saved_model.pb
variables/
variables.data-00000-of-00001
variables.index
assets/
fingerprint.pb
.tf 格式的優勢
1. 完整的模型信息
# 保存完整的模型信息
model.save('complete_model') # 保存:架構、權重、優化器狀態、訓練配置
# 之後可以繼續訓練
loaded_model = tf.keras.models.load_model('complete_model')
loaded_model.fit(more_data, more_labels) # 可以繼續訓練!
2. 支持自定義層和模型
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, units=32):
super().__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(shape=(input_shape[-1], self.units))
self.b = self.add_weight(shape=(self.units,))
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
# 使用自定義層
model = tf.keras.Sequential([
CustomLayer(64),
tf.keras.layers.Dense(10)
])
# .tf 格式可以正確保存和加載自定義層
model.save('custom_model')
loaded_model = tf.keras.models.load_model('custom_model')
3. 生產環境友好
# 使用 TensorFlow Serving 部署
import subprocess
# 啓動 TensorFlow Serving(示例命令)
serve_command = [
'tensorflow_model_server',
'--rest_api_port=8501',
'--model_name=my_model',
'--model_base_path=/path/to/my_model'
]
# .tf 格式可以直接用於生產環境部署
4. 跨平台兼容性
# 在 Python 中保存
model.save('my_model')
# 可以在其他環境中加載:
# - TensorFlow.js
# - TensorFlow Lite
# - TensorFlow Serving
# - 其他編程語言的 TensorFlow 綁定
實際應用場景
場景1:訓練後部署
# 訓練模型
model.fit(train_data, train_labels, epochs=10)
# 保存為生產就緒的格式
model.save('production_model')
# 之後在任何地方加載使用
production_model = tf.keras.models.load_model('production_model')
場景2:模型版本管理
import datetime
# 帶時間戳保存不同版本
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model.save(f'models/mnist_model_{timestamp}')
# 目錄結構:
# models/
# ├── mnist_model_20231201-143022/
# ├── mnist_model_20231202-093015/
# └── mnist_model_20231203-162345/
場景3:轉換為其他格式
# 從 .tf 格式轉換為 TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_saved_model('my_model')
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
總結
.tf 格式是:
- TensorFlow 2.x 的默認模型保存格式
- 一個目錄,包含多個文件,而不是單個文件
- 功能完整,支持保存模型架構、權重、優化器狀態等所有信息
- 生產就緒,可以直接用於 TensorFlow Serving 等生產環境
- 靈活性強,支持自定義層和複雜模型結構
對於新的 TensorFlow 項目,推薦始終使用 .tf 格式來保存模型,除非你有特定的理由需要使用 .h5 格式。