import os
import cv2
import numpy as np
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context, Tensor, Model
from mindspore.common.initializer import Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.ops.operations import TensorAdd
from mindspore.train.serialization import load_checkpoint, load_param_into_net

np.random.seed(58)

class BasicBlock(nn.Cell):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad', has_bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.add = TensorAdd()

    def construct(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add(out, identity)
        out = self.relu(out)

        return out

class ResNet(nn.Cell):
    def __init__(self, block, layers, num_classes=100):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        self.layer1 = self.make_layer(block, 64, layers[0])
        self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(512, num_classes)  # 修改输出大小为num_classes

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.SequentialCell([
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
                nn.BatchNorm2d(out_channels)
            ])

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.SequentialCell(layers)

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x

class TrainDatasetGenerator:
    def __init__(self, file_path):
        self.file_path = file_path
        self.img_names = os.listdir(file_path)

    def __getitem__(self, index):
        data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
        label = self.img_names[index].split('_')[0]
        label = int(label)
        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
        data = cv2.resize(data, (224, 224))
        data = data.transpose().astype(np.float32) / 255.
        return data, label

    def __len__(self):
        return len(self.img_names)

def train_resnet():
    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
    train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train1')
    ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
    ds_train = ds_train.shuffle(buffer_size=10)
    ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
    valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test1')
    ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
    ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)

    network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)

    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
    config_ckpt_path = 'D:/pythonProject7/ckpt/'
    ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)

    model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    ckpt_files = os.listdir(config_ckpt_path)
    best_acc = {'Accuracy': 0}
    best_ckpt_file = None
    for ckpt_file in ckpt_files:
        if ckpt_file.endswith('.ckpt'):
            param_dict = load_checkpoint(os.path.join(config_ckpt_path, ckpt_file))
            load_param_into_net(network, param_dict)
            acc = model.eval(ds_valid)
            if acc['Accuracy'] > best_acc['Accuracy']:
                best_acc = acc
                best_ckpt_file = ckpt_file

    print('Best ckpt file: {}'.format(best_ckpt_file))


if __name__ == '__main__':
    train_resnet()

代码解释:

  1. 导入必要库: 导入 MindSpore 相关的库,例如 mindspore.datasetmindspore.nn 等,以及其他常用的库,例如 oscv2 等。

  2. 定义 BasicBlock: 定义 ResNet 网络中的基本块,该块由两个卷积层、两个批归一化层和一个 ReLU 激活函数组成。

  3. 定义 ResNet: 定义 ResNet 网络,该网络由多个 BasicBlock 组成,并包含初始卷积层、最大池化层、平均池化层和全连接层。

  4. 定义 TrainDatasetGenerator: 定义训练数据生成器,该生成器读取图像文件并将其转换为模型输入格式。

  5. 训练 ResNet 模型: 定义训练函数,该函数加载训练数据,构建模型,设置优化器和损失函数,并使用回调函数监控训练过程和保存模型。

  6. 评估模型: 使用验证数据集评估模型性能并找到最佳模型。

注意:

  • 代码中的 num_classes 参数应根据实际任务中类别的数量进行修改。
  • 训练数据路径和保存模型路径应根据实际情况进行修改。
  • 本代码示例使用 CPU 进行训练,如果需要使用 GPU,请修改 context.set_context() 中的 device_target 参数。

优化建议:

  • 可以使用更大的训练数据和更复杂的模型来提高模型性能。
  • 可以使用不同的优化器和超参数来寻找最佳模型。
  • 可以使用数据增强技术来提高模型的泛化能力。

相关资源:

希望本文可以帮助您使用 MindSpore 训练 ResNet 模型。

ResNet 模型训练:MindSpore 实现

原文地址: http://www.cveoy.top/t/topic/jqAj 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录