import os import cv2 import numpy as np import mindspore.dataset as ds import mindspore.nn as nn from mindspore import context, Tensor, Model from mindspore.common.initializer import Normal from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.nn.metrics import Accuracy from mindspore.ops.operations import TensorAdd from mindspore.train.serialization import load_checkpoint, load_param_into_net

np.random.seed(58)

class BasicBlock(nn.Cell): def init(self, in_channels, out_channels, stride=1, downsample=None): super(BasicBlock, self).init() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad', has_bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample self.add = TensorAdd()

def construct(self, x):
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        identity = self.downsample(x)

    out = self.add(out, identity)
    out = self.relu(out)

    return out

class ResNet(nn.Cell): def init(self, block, layers, num_classes=10): super(ResNet, self).init() self.in_channels = 64

    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU()
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
    self.layer1 = self.make_layer(block, 64, layers[0])
    self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
    self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
    self.flatten = nn.Flatten()
    self.fc = nn.Dense(512, num_classes)

def make_layer(self, block, out_channels, blocks, stride=1):
    downsample = None
    if stride != 1 or self.in_channels != out_channels:
        downsample = nn.SequentialCell([
            nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
            nn.BatchNorm2d(out_channels)
        ])

    layers = []
    layers.append(block(self.in_channels, out_channels, stride, downsample))
    self.in_channels = out_channels
    for _ in range(1, blocks):
        layers.append(block(out_channels, out_channels))

    return nn.SequentialCell(layers)

def construct(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = self.flatten(x)
    x = self.fc(x)

    return x

class TrainDatasetGenerator: def init(self, file_path): self.file_path = file_path self.img_names = os.listdir(file_path)

def __getitem__(self, index):
    data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
    label = self.img_names[index].split('_')[0]
    label = int(label)
    data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
    data = cv2.resize(data, (224, 224))
    data = data.transpose().astype(np.float32) / 255.
    return data, label

def __len__(self):
    return len(self.img_names)

def train_resnet(): context.set_context(mode=context.GRAPH_MODE, device_target='CPU') train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train1') ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True) ds_train = ds_train.shuffle(buffer_size=10) ds_train = ds_train.batch(batch_size=4, drop_remainder=True) valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test1') ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True) ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)

network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)

time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
config_ckpt_path = 'D:/pythonProject7/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)

model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

ckpt_files = os.listdir(config_ckpt_path)
best_acc = 0
best_ckpt_file = None
for ckpt_file in ckpt_files:
    if ckpt_file.endswith('.ckpt'):
        param_dict = load_checkpoint(os.path.join(config_ckpt_path, ckpt_file))
        load_param_into_net(network, param_dict)
        acc = model.eval(ds_valid)
        if acc > best_acc:
            best_acc = acc
            best_ckpt_file = ckpt_file

print('Best ckpt file: {}'.format(best_ckpt_file))

if name == 'main': train_resnet()

MindSpore ResNet 模型训练代码示例

原文地址: https://www.cveoy.top/t/topic/jqyD 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录