import os
import cv2
import numpy as np
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context, Tensor, Model
from mindspore.common.initializer import Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.ops.operations import TensorAdd
from mindspore.train.serialization import load_checkpoint, load_param_into_net

np.random.seed(58)

class BasicBlock(nn.Cell):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad', has_bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.add = TensorAdd()

    def construct(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.add(out, identity)
        out = self.relu(out)

        return out

class ResNet(nn.Cell):
    def __init__(self, block, layers, num_classes=10):  # 修改输出大小为10
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        self.layer1 = self.make_layer(block, 64, layers[0])
        self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Dense(512, num_classes)  # 修改输出大小为num_classes

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.SequentialCell([
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
                nn.BatchNorm2d(out_channels)
            ])

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.SequentialCell(layers)

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x

class TrainDatasetGenerator:
    def __init__(self, file_path):
        self.file_path = file_path
        self.img_names = os.listdir(file_path)

    def __getitem__(self, index):
        data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
        label = self.img_names[index].split('_')[0]
        label = int(label)
        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
        data = cv2.resize(data, (224, 224))
        data = data.transpose().astype(np.float32) / 255.
        return data, label

    def __len__(self):
        return len(self.img_names)

def train_resnet():
    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
    train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train1')
    ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
    ds_train = ds_train.shuffle(buffer_size=10)
    ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
    valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test1')
    ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
    ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)

    network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10)
    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)

    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
    config_ckpt_path = 'D:/pythonProject7/ckpt/'
    ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)

    model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    ckpt_files = os.listdir(config_ckpt_path)
    best_acc = {'Accuracy': 0}
    best_ckpt_file = None
    for ckpt_file in ckpt_files:
        if ckpt_file.endswith('.ckpt'):
            param_dict = load_checkpoint(os.path.join(config_ckpt_path, ckpt_file))
            load_param_into_net(network, param_dict)
            acc = model.eval(ds_valid)
            if acc['Accuracy'] > best_acc['Accuracy']:
                best_acc = acc
                best_ckpt_file = ckpt_file

    print('Best ckpt file: {}'.format(best_ckpt_file))


if __name__ == '__main__':
    train_resnet()

代码说明:

  1. 数据预处理:

    • 使用 TrainDatasetGenerator 类读取图像数据并进行预处理,包括:
      • 读取图像文件
      • 将图像转换为 RGB 格式
      • 调整图像大小为 224x224
      • 将图像数据转换为 float32 类型并归一化到 0-1 之间
  2. 模型定义:

    • 定义 ResNet18 模型,最后一层全连接层输出大小为 10,与实际标签数量一致。
    • 使用 nn.SoftmaxCrossEntropyWithLogits 作为损失函数。
    • 使用 nn.Momentum 作为优化器。
  3. 训练和评估:

    • 使用 Model 类训练模型,并使用 ModelCheckpointLossMonitorTimeMonitor 等回调函数记录训练过程中的信息。
    • 评估模型性能,选择最佳的 checkpoint 文件。

代码修改说明:

  • ResNet 类中的 num_classes 参数修改为 10,以匹配实际标签数量。
  • ResNet 类中 fc 层的输出大小修改为 10。
  • TrainDatasetGenerator 类中 label 的提取方法修改为 label = self.img_names[index].split('_')[0],以提取图像文件名中的标签信息。
  • 将双引号改为单引号。

代码运行步骤:

  1. 将代码保存为 resnet_train.py 文件。
  2. 将训练数据和测试数据分别放到 D:/pythonProject7/train1D:/pythonProject7/test1 文件夹中。
  3. 运行代码:python resnet_train.py

注意:

  • 代码中使用的路径需根据实际情况进行修改。
  • 代码中使用的 CPU 环境,如果需要使用 GPU,请将 device_target 参数修改为 GPU
  • 代码中的训练参数可以根据实际情况进行调整。

代码运行结果:

代码运行后,会在 D:/pythonProject7/ckpt 文件夹中生成训练过程中保存的 checkpoint 文件。代码还会打印出最佳的 checkpoint 文件名和模型的评估结果。


原文地址: https://www.cveoy.top/t/topic/jqAH 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录