MindSpore ResNet 模型训练与评估: 使用 GeneratorDataset 和 ModelCheckpoint
import os
import cv2
import numpy as np
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context, Tensor, Model
from mindspore.common.initializer import Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.ops.operations import TensorAdd
from mindspore.train.serialization import load_checkpoint
np.random.seed(58)
class BasicBlock(nn.Cell):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
self.add = TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.make_layer(block, 64, layers[0])
self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
self.flatten = nn.Flatten()
self.fc = nn.Dense(512, num_classes)
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels:
downsample = nn.SequentialCell([
nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
nn.BatchNorm2d(out_channels)
])
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
class TrainDatasetGenerator:
def __init__(self, file_path):
self.file_path = file_path
self.img_names = os.listdir(file_path)
def __getitem__(self, index):
data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
label = self.img_names[index].split('_')[0]
label = int(label)
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
data = cv2.resize(data, (224, 224))
data = data.transpose().astype(np.float32) / 255.
return data, label
def __len__(self):
return len(self.img_names)
def train_resnet():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train1')
ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
ds_train = ds_train.shuffle(buffer_size=10)
ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test1')
ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
config_ckpt_path = 'D:/pythonProject7/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training ===============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
ckpt_files = os.listdir(config_ckpt_path)
best_acc = 0
best_ckpt_file = None
for ckpt_file in ckpt_files:
if ckpt_file.endswith('.ckpt'):
param_dict = load_checkpoint(os.path.join(config_ckpt_path, ckpt_file))
model.load_state_dict(param_dict)
acc = model.eval(ds_valid)
if acc > best_acc:
best_acc = acc
best_ckpt_file = ckpt_file
print('Best ckpt file: {}'.format(best_ckpt_file))
if __name__ == '__main__':
train_resnet()
原文地址: https://www.cveoy.top/t/topic/jqx2 著作权归作者所有。请勿转载和采集!