TensorFlow 断点续训:Fashion MNIST 模型训练
这段代码展示了如何在 TensorFlow 中使用 tf.keras.callbacks.ModelCheckpoint 回调函数实现断点续训,以在 Fashion MNIST 数据集上进行模型训练。
import tensorflow as tf
import os
fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/fashion.ckpt" # 指定保存文件的地址
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, # 给出模型保存路径
save_weights_only=True, # 只保留模型参数
save_best_only=True) # 只保留最优模型
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()
代码解读:
- 导入库: 导入
tensorflow和os库。 - 加载数据: 加载 Fashion MNIST 数据集,并将数据进行归一化处理。
- 创建模型: 创建一个简单的 Sequential 模型,包含一个 Flatten 层和两个 Dense 层。
- 编译模型: 使用 'adam' 优化器,SparseCategoricalCrossentropy 损失函数和 'sparse_categorical_accuracy' 指标编译模型。
- 定义断点续训:
- 设置模型保存路径
checkpoint_save_path。 - 使用
tf.keras.callbacks.ModelCheckpoint回调函数cp_callback,指定保存模型参数路径、是否只保存权重以及是否只保存最优模型。
- 设置模型保存路径
- 开始训练: 使用
fit函数开始训练模型,并将cp_callback添加到callbacks列表中。 - 加载模型: 在训练开始之前,判断是否有已保存的模型权重文件,如果有则加载权重文件继续训练。
断点续训原理:
ModelCheckpoint回调函数在每个 epoch 结束时会检查验证集上的性能,如果性能有所提升,则会保存模型的权重到指定的路径中。- 在训练过程中,如果中断了训练,重新运行代码时,会通过判断是否存在已保存的模型权重文件来加载之前保存的模型。
- 加载模型权重后,继续从上次保存的模型状态开始训练,实现断点续训的功能。
示例:
假设您已经训练了模型,并且保存了模型权重文件。在您重新运行代码时,代码会加载模型权重文件,并且从上次保存的模型状态继续训练。
断点续训的优点:
- 可以节省训练时间,避免重复训练已经完成的 epoch。
- 可以方便地恢复训练,避免由于意外中断导致的训练损失。
注意:
ModelCheckpoint回调函数会保存模型权重,而不是整个模型。- 如果要保存整个模型,可以使用
tf.keras.models.save_model函数。 - 断点续训需要确保模型的定义和编译方式与之前保存的模型一致。
原文地址: https://www.cveoy.top/t/topic/cbEU 著作权归作者所有。请勿转载和采集!