MindSpore ResNet模型训练与最佳模型参数查找

本文介绍使用MindSpore框架训练ResNet模型,并提供查找最佳模型参数文件的示例代码。

模型训练

以下代码展示了使用MindSpore框架训练ResNet模型的示例代码:

import numpy as np
import mindspore.dataset as ds
import os
import cv2
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.ops.operations import TensorAdd
from scipy.integrate._ivp.radau import P
from mindspore import Model # 承载网络结构
from mindspore.nn.metrics import Accuracy # 测试模型用

np.random.seed(58)


class BasicBlock(nn.Cell):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad',has_bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.add = TensorAdd()

    def construct(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add(out, identity)
        out = self.relu(out)

        return out

class ResNet(nn.Cell):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        self.layer1 = self.make_layer(block, 64, layers[0])
        self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(512, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.SequentialCell([
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
                nn.BatchNorm2d(out_channels)
            ])

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.SequentialCell(layers)

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x


class TrainDatasetGenerator:
    def __init__(self, file_path):
        self.file_path = file_path
        self.img_names = os.listdir(file_path)

    def __getitem__(self, index):
        data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
        label = self.img_names[index].split('_')[0]
        label = int(label)
        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
        data = cv2.resize(data, (224, 224))
        data = data.transpose().astype(np.float32) / 255.
        return data, label

    def __len__(self):
        return len(self.img_names)


def train_resnet():
    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
    train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train1')
    ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
    ds_train = ds_train.shuffle(buffer_size=10)
    ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
    valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test1')
    ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
    ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
    network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)
    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
    config_ckpt_path = 'D:/pythonProject7/ckpt/'
    ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)

    model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    acc = model.eval(ds_valid)
    print('============== {} ============='.format(acc))
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    acc = model.eval(ds_valid)
    print('============== {} ============='.format(acc))
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    acc = model.eval(ds_valid)
    print('============== {} ============='.format(acc))

if __name__ == '__main__':
    train_resnet()

查找最佳模型参数文件

在训练完成后,需要找到最佳的模型参数文件。可以使用ModelCheckpoint回调函数来保存训练过程中的模型参数文件,并且可以指定保存的文件名格式。在训练过程中,ModelCheckpoint会在每个指定的保存步数保存一次模型参数文件,同时也可以设置保留的最大文件数,超过最大文件数后会自动删除旧的文件。

以下代码展示了使用ModelCheckpoint回调函数保存模型参数文件的示例代码:

from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

# 定义保存模型参数文件的路径和文件名格式
config_ckpt_path = 'ckpt/'
config_prefix = 'checkpoint_resnet'

# 定义ModelCheckpoint回调函数
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix=config_prefix, directory=config_ckpt_path, config=config_ck)

# 在训练模型时添加回调函数
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

在上述代码中,config_ckpt_path变量指定了保存模型参数文件的路径,config_prefix变量指定了文件名的前缀,ModelCheckpoint回调函数会在此基础上自动生成文件名。save_checkpoint_steps参数指定了保存模型参数文件的步数间隔,keep_checkpoint_max参数指定了最大保留的模型参数文件数目。

在训练完成后,可以在指定的路径下找到保存的模型参数文件,文件名类似于“checkpoint_resnet-10_10.ckpt”,其中“10_10”表示训练的第10个epoch,保存的步数间隔为10。

如何找到最佳模型参数文件?

通常,最佳模型参数文件对应着验证集上的最高准确率。可以通过评估每个模型参数文件在验证集上的性能,并选择准确率最高的模型参数文件作为最佳模型参数文件。

以下是评估模型参数文件在验证集上的性能的示例代码:

# 加载模型参数文件
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
model.load_checkpoint('checkpoint_resnet-10_10.ckpt')

# 评估模型性能
acc = model.eval(ds_valid)
print('Accuracy: ', acc)

通过评估每个模型参数文件在验证集上的性能,并选择准确率最高的模型参数文件作为最佳模型参数文件,即可找到训练后的最佳模型参数文件。

注意事项:

  • 在训练过程中,可以根据实际情况调整save_checkpoint_stepskeep_checkpoint_max参数,以控制模型参数文件的保存频率和数量。
  • 模型参数文件的大小会随着模型的复杂性和训练数据的规模而变化,建议将模型参数文件保存到存储空间充足的路径下。
  • 为了避免重复评估,建议在训练过程中记录每个模型参数文件在验证集上的性能,并保存到一个文件或数据库中,方便后续查找最佳模型参数文件。

希望本文能够帮助您更好地理解MindSpore ResNet模型训练过程中的最佳模型参数查找方法。

MindSpore ResNet模型训练与最佳模型参数查找

原文地址: http://www.cveoy.top/t/topic/jqBK 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录