这段代码展示了如何在 TensorFlow 中使用 tf.keras.callbacks.ModelCheckpoint 回调函数实现断点续训,以在 Fashion MNIST 数据集上进行模型训练。

import tensorflow as tf
import os

fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/fashion.ckpt"  # 指定保存文件的地址
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,  # 给出模型保存路径
                                                 save_weights_only=True,  # 只保留模型参数
                                                 save_best_only=True)  # 只保留最优模型

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

代码解读:

  1. 导入库: 导入 tensorflowos 库。
  2. 加载数据: 加载 Fashion MNIST 数据集,并将数据进行归一化处理。
  3. 创建模型: 创建一个简单的 Sequential 模型,包含一个 Flatten 层和两个 Dense 层。
  4. 编译模型: 使用 'adam' 优化器,SparseCategoricalCrossentropy 损失函数和 'sparse_categorical_accuracy' 指标编译模型。
  5. 定义断点续训:
    • 设置模型保存路径 checkpoint_save_path
    • 使用 tf.keras.callbacks.ModelCheckpoint 回调函数 cp_callback,指定保存模型参数路径、是否只保存权重以及是否只保存最优模型。
  6. 开始训练: 使用 fit 函数开始训练模型,并将 cp_callback 添加到 callbacks 列表中。
  7. 加载模型: 在训练开始之前,判断是否有已保存的模型权重文件,如果有则加载权重文件继续训练。

断点续训原理:

  • ModelCheckpoint 回调函数在每个 epoch 结束时会检查验证集上的性能,如果性能有所提升,则会保存模型的权重到指定的路径中。
  • 在训练过程中,如果中断了训练,重新运行代码时,会通过判断是否存在已保存的模型权重文件来加载之前保存的模型。
  • 加载模型权重后,继续从上次保存的模型状态开始训练,实现断点续训的功能。

示例:

假设您已经训练了模型,并且保存了模型权重文件。在您重新运行代码时,代码会加载模型权重文件,并且从上次保存的模型状态继续训练。

断点续训的优点:

  • 可以节省训练时间,避免重复训练已经完成的 epoch。
  • 可以方便地恢复训练,避免由于意外中断导致的训练损失。

注意:

  • ModelCheckpoint 回调函数会保存模型权重,而不是整个模型。
  • 如果要保存整个模型,可以使用 tf.keras.models.save_model 函数。
  • 断点续训需要确保模型的定义和编译方式与之前保存的模型一致。
TensorFlow 断点续训:Fashion MNIST 模型训练

原文地址: https://www.cveoy.top/t/topic/cbEU 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录