ResNet 模型训练 - MindSpore 代码示例
import numpy as np
import mindspore.dataset as ds
import os
import cv2
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.ops.operations import TensorAdd
from scipy.integrate._ivp.radau import P
np.random.seed(58)
class BasicBlock(nn.Cell):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad',has_bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
self.add = TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.make_layer(block, 64, layers[0])
self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(kernel_size=10, stride=1)
self.flatten = nn.Flatten()
self.fc = nn.Dense(512, num_classes)
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels:
downsample = nn.SequentialCell([
nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
nn.BatchNorm2d(out_channels)
])
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
class TrainDatasetGenerator:
def __init__(self, file_path):
self.file_path = file_path
self.img_names = os.listdir(file_path)
def __getitem__(self, index):
data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
label = self.img_names[index].split('_')[0]
label = int(label)
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
data = cv2.resize(data, (224, 224))
data = data.transpose().astype(np.float32) / 255.
return data, label
def __len__(self):
return len(self.img_names)
def train_resnet():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train')
ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
ds_train = ds_train.shuffle(buffer_size=10)
ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test')
ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
config_ckpt_path = 'D:/pythonProject7/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
return model
if __name__ == '__main__':
model = train_resnet()
print(model.eval(ds_valid))
This code is a basic example of training a ResNet model using MindSpore. Here's what it does:
- Imports: Necessary libraries like
mindspore,numpy,cv2, etc. are imported. - Model Definition: The
BasicBlockclass defines a single building block of the ResNet architecture, and theResNetclass constructs the entire network using these blocks. - Dataset: The
TrainDatasetGeneratorclass defines a custom dataset generator that loads and processes images from specified directories. - Training: The
train_resnetfunction performs the training and evaluation.- Data Loading: It creates
ds_trainandds_validdatasets from the training and validation directories. - Model Initialization: The ResNet model is initialized with the specified architecture and loss function.
- Training Loop: The
model.trainmethod trains the model for a set number of epochs, andmodel.evalevaluates the trained model on the validation set.
- Data Loading: It creates
The code also includes checkpoints and a time monitor to track training progress.
This is a simplified example, and you can customize it further based on your specific needs.
This code defines a ResNet model, prepares a training dataset and a validation dataset, trains the ResNet model, and evaluates the model's performance on the validation dataset. It includes details like the definition of BasicBlock, the definition of ResNet, the custom dataset generator, and the training and evaluation steps. It also uses the MindSpore framework for building and training the ResNet model.
Remember to update the file paths for your training and validation datasets in the train_resnet function.
This code provides a starting point for using MindSpore for ResNet model training. You can expand this code by implementing more sophisticated data augmentation, adding more advanced learning rate schedulers, or integrating with other tools for visualization and analysis.
原文地址: https://www.cveoy.top/t/topic/mQ5w 著作权归作者所有。请勿转载和采集!