MindSpore ResNet 模型训练:基于图像分类的实战教程
MindSpore ResNet 模型训练:基于图像分类的实战教程/n/n本文将提供一个使用 MindSpore 框架训练 ResNet 模型进行图像分类的详细教程,包含数据集准备、模型构建、训练和评估等步骤,并针对常见错误进行了分析和解决方案,帮助您快速上手使用 MindSpore 进行深度学习任务。/n/n### 1. 准备工作/n/n首先,确保您的环境已经安装了 MindSpore 框架,并准备了用于训练和测试的图像数据集。/n/n#### 1.1 安装 MindSpore/n/n您可以参考 MindSpore 官方文档https://www.mindspore.cn/安装 MindSpore 框架。/n/n#### 1.2 准备数据集/n/n本教程假设您已经准备好了图像数据集,并将其分为训练集和测试集,每个图像文件应包含与其类别对应的标签信息,例如:/n/n* 训练集:D:/pythonProject7/train//n* 测试集:D:/pythonProject7/test//n/n每个文件夹下应该包含对应类别的图像文件,例如 D:/pythonProject7/train/1/1-1.jpg,其中 1 代表类别标签,1-1.jpg 代表图像文件名。/n/n### 2. 模型构建/n/n我们将使用 ResNet 模型作为图像分类模型,ResNet 是一个深度卷积神经网络,具有很强的图像特征提取能力。/n/npython/nimport numpy as np/nimport mindspore.dataset as ds/nimport os/nimport cv2/nimport mindspore/nimport mindspore.nn as nn/nfrom mindspore import Tensor/nfrom mindspore.common.initializer import Normal/nfrom mindspore import context/nfrom mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor/nfrom mindspore.train import Model/nfrom mindspore.nn.metrics import Accuracy/nfrom mindspore.ops.operations import TensorAdd/nfrom scipy.integrate._ivp.radau import P/n/nnp.random.seed(58)/n/n/nclass BasicBlock(nn.Cell):/n def __init__(self, in_channels, out_channels, stride=1, downsample=None):/n super(BasicBlock, self).__init__()/n self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad',has_bias=False)/n self.bn1 = nn.BatchNorm2d(out_channels)/n self.relu = nn.ReLU()/n self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)/n self.bn2 = nn.BatchNorm2d(out_channels)/n self.downsample = downsample/n self.add = TensorAdd()/n/n def construct(self, x):/n identity = x/n/n out = self.conv1(x)/n out = self.bn1(out)/n out = self.relu(out)/n/n out = self.conv2(out)/n out = self.bn2(out)/n/n if self.downsample is not None:/n identity = self.downsample(x)/n/n out = self.add(out, identity)/n out = self.relu(out)/n/n return out/n/nclass ResNet(nn.Cell):/n def __init__(self, block, layers, num_classes=10):/n super(ResNet, self).__init__()/n self.in_channels = 64/n/n self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)/n self.bn1 = nn.BatchNorm2d(64)/n self.relu = nn.ReLU()/n self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')/n self.layer1 = self.make_layer(block, 64, layers[0])/n self.layer2 = self.make_layer(block, 128, layers[1], stride=2)/n self.layer3 = self.make_layer(block, 256, layers[2], stride=2)/n self.layer4 = self.make_layer(block, 512, layers[3], stride=2)/n self.avgpool = nn.AvgPool2d(kernel_size=10, stride=1)/n self.flatten = nn.Flatten()/n self.fc = nn.Dense(512, num_classes)/n/n def make_layer(self, block, out_channels, blocks, stride=1):/n downsample = None/n if stride != 1 or self.in_channels != out_channels:/n downsample = nn.SequentialCell([/n nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),/n nn.BatchNorm2d(out_channels)/n ])/n/n layers = []/n layers.append(block(self.in_channels, out_channels, stride, downsample))/n self.in_channels = out_channels/n for _ in range(1, blocks):/n layers.append(block(out_channels, out_channels))/n/n return nn.SequentialCell(layers)/n/n def construct(self, x):/n x = self.conv1(x)/n x = self.bn1(x)/n x = self.relu(x)/n x = self.maxpool(x)/n/n x = self.layer1(x)/n x = self.layer2(x)/n x = self.layer3(x)/n x = self.layer4(x)/n/n x = self.avgpool(x)/n x = self.flatten(x)/n x = self.fc(x)/n/n return x/n/n/nclass TrainDatasetGenerator:/n def __init__(self, file_path):/n self.file_path = file_path/n self.img_names = os.listdir(file_path)/n/n def __getitem__(self, index):/n data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))/n # 将标签信息改为提取文件名中第一个下划线前的数字/n label = self.img_names[index].split('_')[0]/n # 将标签转换为整数/n try:/n label = int(label)/n except ValueError:/n # 如果标签无法转换为整数,则打印错误信息并跳过该文件/n print(f'Error: Invalid label format for file {self.img_names[index]}. Skipping...')/n return None, None/n data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)/n data = cv2.resize(data, (224, 224))/n data = data.transpose().astype(np.float32) / 255./n return data, label/n/n def __len__(self):/n return len(self.img_names)/n/n/ndef train_resnet():/n context.set_context(mode=context.GRAPH_MODE, device_target='CPU')/n train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train')/n ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)/n ds_train = ds_train.shuffle(buffer_size=10)/n ds_train = ds_train.batch(batch_size=4, drop_remainder=True)/n valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test')/n ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)/n ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)/n network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)/n net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')/n net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)/n time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())/n config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)/n config_ckpt_path = 'D:/pythonProject7/ckpt/'/n ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)/n/n model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})/n epoch_size = 10/n print('============== Starting Training =============')/n model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])/n/n acc = model.eval(ds_valid)/n print('============== {} ============='.format(acc))/n epoch_size = 10/n print('============== Starting Training =============')/n model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])/n/n acc = model.eval(ds_valid)/n print('============== {} ============='.format(acc))/n epoch_size = 10/n print('============== Starting Training =============')/n model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])/n/n acc = model.eval(ds_valid)/n print('============== {} ============='.format(acc))/n/nif __name__ == '__main__':/n train_resnet()/n/n/n### 3. 训练与评估/n/n在模型构建完成后,我们可以开始进行模型训练和评估。/n/n#### 3.1 训练过程/n/n1. 定义训练数据集和测试数据集:/n/npython/ntrain_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train')/nds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)/n ds_train = ds_train.shuffle(buffer_size=10)/n ds_train = ds_train.batch(batch_size=4, drop_remainder=True)/n/nvalid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test')/nds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)/n ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)/n/n/n2. 构建模型、损失函数和优化器:/n/npython/nnetwork = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)/nnet_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')/nnet_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)/n/n/n3. 设置训练参数和回调函数:/n/npython/ntime_cb = TimeMonitor(data_size=ds_train.get_dataset_size())/nconfig_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)/nconfig_ckpt_path = 'D:/pythonProject7/ckpt/'/nckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)/n/n/n4. 训练模型:/n/npython/nmodel = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})/nepoch_size = 10/nprint('============== Starting Training =============')/nmodel.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])/n/n/n#### 3.2 评估过程/n/n在训练完成后,可以使用测试数据集对模型进行评估:/n/npython/nacc = model.eval(ds_valid)/nprint('============== {} ============='.format(acc))/n/n/n### 4. 错误分析及解决方案/n/n在训练过程中,可能会遇到一些错误,例如数据读取错误、模型构建错误等。/n/n#### 4.1 数据读取错误/n/n本教程中,最常见的错误是数据读取错误,例如无法将标签信息解析为整数,导致训练过程中断。/n/n错误信息:/n/n/n[ERROR] MD(13744,153,?):2023-4-4 10:19:35 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:217] InterruptMaster] Task is terminated with err msg(more detail in info level log):Exception thrown from PyFunc. ValueError: invalid literal for int() with base 10: '1-6.jpg'/n/n/n原因:/n/n错误信息提示在读取文件名 1-6.jpg 时无法将其解析为整数,可能是由于文件名格式不符合规范,导致无法提取正确的标签信息。/n/n解决方案:/n/n1. 检查数据集中的所有文件,确保文件名都符合规范,并确保文件名中包含正确的类别标签信息。/n2. 在数据读取过程中添加异常处理机制,防止因格式错误导致的训练过程中断:/n/npython/ntry:/n label = int(label)/nexcept ValueError:/n print(f'Error: Invalid label format for file {self.img_names[index]}. Skipping...')/n return None, None/n/n/n#### 4.2 模型构建错误/n/n模型构建错误通常是由于模型架构设计错误、参数设置错误等导致的。/n/n解决方案:/n/n1. 仔细检查模型代码,确保模型架构设计正确,参数设置合理。/n2. 参考 MindSpore 官方文档,了解模型构建的规范和常用技巧,确保代码符合规范。/n3. 使用 MindSpore 模型库中的预训练模型,可以减少模型构建过程中的错误。/n/n### 5. 总结/n/n本文提供了一个使用 MindSpore 框架训练 ResNet 模型进行图像分类的详细教程,包含数据集准备、模型构建、训练和评估等步骤,并针对常见错误进行了分析和解决方案,帮助您快速上手使用 MindSpore 进行深度学习任务。/n/n在实际应用中,您可以根据自己的需要调整数据集、模型架构、训练参数等,并进行相应的优化。/n/n希望本文能对您有所帮助。/n
原文地址: https://www.cveoy.top/t/topic/mQSc 著作权归作者所有。请勿转载和采集!