ResNet 模型训练:MindSpore 实现
import os
import cv2
import numpy as np
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context, Tensor, Model
from mindspore.common.initializer import Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.ops.operations import TensorAdd
from mindspore.train.serialization import load_checkpoint, load_param_into_net
np.random.seed(58)
class BasicBlock(nn.Cell):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
self.add = TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
def __init__(self, block, layers, num_classes=100):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.make_layer(block, 64, layers[0])
self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
self.flatten = nn.Flatten()
self.fc = nn.Dense(512, num_classes) # 修改输出大小为num_classes
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels:
downsample = nn.SequentialCell([
nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
nn.BatchNorm2d(out_channels)
])
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
class TrainDatasetGenerator:
def __init__(self, file_path):
self.file_path = file_path
self.img_names = os.listdir(file_path)
def __getitem__(self, index):
data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
label = self.img_names[index].split('_')[0]
label = int(label)
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
data = cv2.resize(data, (224, 224))
data = data.transpose().astype(np.float32) / 255.
return data, label
def __len__(self):
return len(self.img_names)
def train_resnet():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train1')
ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
ds_train = ds_train.shuffle(buffer_size=10)
ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test1')
ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
config_ckpt_path = 'D:/pythonProject7/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
ckpt_files = os.listdir(config_ckpt_path)
best_acc = {'Accuracy': 0}
best_ckpt_file = None
for ckpt_file in ckpt_files:
if ckpt_file.endswith('.ckpt'):
param_dict = load_checkpoint(os.path.join(config_ckpt_path, ckpt_file))
load_param_into_net(network, param_dict)
acc = model.eval(ds_valid)
if acc['Accuracy'] > best_acc['Accuracy']:
best_acc = acc
best_ckpt_file = ckpt_file
print('Best ckpt file: {}'.format(best_ckpt_file))
if __name__ == '__main__':
train_resnet()
代码解释:
-
导入必要库: 导入 MindSpore 相关的库,例如
mindspore.dataset、mindspore.nn等,以及其他常用的库,例如os、cv2等。 -
定义 BasicBlock: 定义 ResNet 网络中的基本块,该块由两个卷积层、两个批归一化层和一个 ReLU 激活函数组成。
-
定义 ResNet: 定义 ResNet 网络,该网络由多个 BasicBlock 组成,并包含初始卷积层、最大池化层、平均池化层和全连接层。
-
定义 TrainDatasetGenerator: 定义训练数据生成器,该生成器读取图像文件并将其转换为模型输入格式。
-
训练 ResNet 模型: 定义训练函数,该函数加载训练数据,构建模型,设置优化器和损失函数,并使用回调函数监控训练过程和保存模型。
-
评估模型: 使用验证数据集评估模型性能并找到最佳模型。
注意:
- 代码中的
num_classes参数应根据实际任务中类别的数量进行修改。 - 训练数据路径和保存模型路径应根据实际情况进行修改。
- 本代码示例使用 CPU 进行训练,如果需要使用 GPU,请修改
context.set_context()中的device_target参数。
优化建议:
- 可以使用更大的训练数据和更复杂的模型来提高模型性能。
- 可以使用不同的优化器和超参数来寻找最佳模型。
- 可以使用数据增强技术来提高模型的泛化能力。
相关资源:
- MindSpore 文档: https://www.mindspore.cn/
- ResNet 模型: https://arxiv.org/abs/1512.03385
希望本文可以帮助您使用 MindSpore 训练 ResNet 模型。
原文地址: http://www.cveoy.top/t/topic/jqAj 著作权归作者所有。请勿转载和采集!