MindSpore LeNet5 Implementation for Image Classification
import numpy as np
import mindspore.dataset as ds
import os
import cv2
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
np.random.seed(58)
class LeNet5(nn.Cell):
'Lenet network
Args:
num_class (int): Number of classes. Default: 10.
num_channel (int): Number of channels. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
'
def __init__(self, num_class=10, num_channel=3, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.include_top = include_top
#如果include_top为True,则在模型的最后添加全连接层,
#其中包括三个Dense层,分别是784->120、120->84、84->num_class
# 每个Dense层采用正态分布初始化权重。同时,还定义了一个Flatten层,用于将输入数据压平。
if self.include_top:#include_top是一个布尔值
#如果include_top为True,则模型的顶部将包括全局平均池化层和分类器,用于进行分类
#如果include_top为False,则模型的顶部将被删除,并且可以添加自定义的分类器。
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(784, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
#生成训练数据集,从指定路径读取图像文件,并以元组形式返回图像数据和标签
class TrainDatasetGenerator:
def __init__(self, file_path):
self.file_path = file_path
self.img_names = os.listdir(file_path)
def __getitem__(self, index):
data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
label = int(self.img_names[index][0])-1#图像对应的类别
#label = np.array([label])
data = data.transpose().astype(np.float32) / 255.
#data = np.expand_dims(data, axis=0)
#data = Tensor(data)
#label = Tensor(label)
return data, label
def __len__(self):
return len(self.img_names)
def train_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
train_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/dataset')
ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
ds_train = ds_train.shuffle(buffer_size=10)
ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
valid_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/test')
ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
network = LeNet5(num_class=7)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10,
keep_checkpoint_max=10)
config_ckpt_path = 'D:/code/machine vision course/digit-mindspore/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_lenet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
if __name__ == '__main__':
train_lenet()
Adding Pretrained Model and Generating CKPT Files
To add a pretrained model and generate ckpt files, follow these steps:
-
Define a Pretrained Model:
- Create a class for your pretrained model (e.g.,
MyPretrainedModel). - Implement the model's architecture and load the pretrained weights.
- Create a class for your pretrained model (e.g.,
-
Load Weights into LeNet5:
- After defining your
LeNet5model, load the pretrained weights usingload_state_dict().
- After defining your
-
Initialize with Pretrained Weights:
- Optionally, load the pretrained ckpt file using
load_checkpoint()to initialize theLeNet5model during training.
- Optionally, load the pretrained ckpt file using
-
Create Checkpoint Callbacks:
- Use
ModelCheckpointcallbacks during training to save the model's weights at regular intervals.
- Use
-
Save Final Weights:
- After training is complete, use
save_checkpoint()to save the final trained weights.
- After training is complete, use
Example (Pseudocode):
from mindspore.train.serialization import load_checkpoint, save_checkpoint
# Define Pretrained Model
class MyPretrainedModel(nn.Cell):
# ... (Implement model architecture and load weights) ...
# Load pretrained weights into LeNet5
pretrained_model = MyPretrainedModel()
network = LeNet5(num_class=7)
network.load_state_dict(pretrained_model.state_dict())
# ... (Rest of your training code) ...
# Load pretrained weights from ckpt file (optional)
ckpt_path = 'pretrained_model.ckpt'
if os.path.exists(ckpt_path):
load_checkpoint(ckpt_path, network)
# Define checkpoint callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
ckpt_cb = ModelCheckpoint(prefix='lenet', directory='./ckpt', config=config_ck)
# Train with checkpoints
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpt_cb, LossMonitor()])
# Save final weights
save_checkpoint(network, 'lenet_final.ckpt')
Key Points
- Replace
MyPretrainedModelwith your actual pretrained model class. - Ensure the pretrained weights are compatible with the
LeNet5architecture. - The specific implementation for loading and saving weights might vary depending on your pretrained model framework.
- Adjust the checkpoint configuration (
CheckpointConfig) as needed to control how many checkpoints are saved and how long they are kept.
原文地址: https://www.cveoy.top/t/topic/mQE0 著作权归作者所有。请勿转载和采集!