MindSpore ResNet-34 模型训练实现代码
import numpy as np import mindspore.dataset as ds import cv2 import mindspore.nn as nn import os from mindspore import context, ops from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train import Model from mindspore.nn.metrics import Accuracy np.random.seed(58)
class ResidualBlock(nn.Cell): expansion = 1 def init(self, in_channels, out_channels, stride=1, downsample=None): super(ResidualBlock, self).init() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, pad_mode='same', group=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, pad_mode='same', group=1) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample self.stride = stride
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Cell): def init(self, block, layers, num_classes=34): super(ResNet, self).init() self.in_channels = 64
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, pad_mode='valid', group=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
self.layer1 = self.make_layer(block, 64, layers[0])
self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='valid')
self.fc = nn.Dense(512 * block.expansion, num_classes)
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if (stride != 1) or (self.in_channels != out_channels * block.expansion):
downsample = nn.SequentialCell([
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels * block.expansion)
])
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = ops.Reshape()(x, (ops.Shape()(x)[0], -1))
x = self.fc(x)
return x
class TrainDatasetGenerator: def init(self, file_path): self.file_path = file_path self.img_names = os.listdir(file_path)
def __getitem__(self, index=0):
data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
label = int(self.img_names[index].split('-')[0])
data = cv2.cvtColor(data,cv2.COLOR_BGR2RGB)
data = cv2.resize(data,(224,224))
# label = np.array([label])
data = data.transpose().astype(np.float32) / 255.
# data = np.expand_dims(data, axis=0)
# data = Tensor(data)
# label = Tensor(label)
return data, label
#data = np.expand_dims(data, axis=0)
#data = Tensor(data)
#label = Tensor(label)
return data, label
def __len__(self):
return len(self.img_names)
def train_resnet(): context.set_context(mode=context.GRAPH_MODE, device_target='CPU') train_dataset_generator = TrainDatasetGenerator('D:/pythonproject2/digital_mindspore/dataset') ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True) ds_train = ds_train.shuffle(buffer_size=10) ds_train = ds_train.batch(batch_size=4, drop_remainder=True) valid_dataset_generator = TrainDatasetGenerator('D:/pythonproject2/test1') ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True) ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True) network = ResNet(ResidualBlock,[2,2,2,2]) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10) config_ckpt_path = 'D:/pythonproject2/ckpt/' ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
if name == 'main': train_resnet()
原文地址: https://www.cveoy.top/t/topic/mSJN 著作权归作者所有。请勿转载和采集!