MindSpore LeNet5 数字识别模型训练
import numpy as np
import mindspore.dataset as ds
import os
import cv2
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
np.random.seed(58)
class LeNet5(nn.Cell):
'Lenet network
Args:
num_class (int): Number of classes. Default: 10.
num_channel (int): Number of channels. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
'
def __init__(self, num_class=10, num_channel=3, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.include_top = include_top
#如果include_top为True,则在模型的最后添加全连接层,
#其中包括三个Dense层,分别是784->120、120->84、84->num_class
# 每个Dense层采用正态分布初始化权重。同时,还定义了一个Flatten层,用于将输入数据压平。
if self.include_top:#include_top是一个布尔值
#如果include_top为True,则模型的顶部将包括全局平均池化层和分类器,用于进行分类
#如果include_top为False,则模型的顶部将被删除,并且可以添加自定义的分类器。
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(784, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
#生成训练数据集,从指定路径读取图像文件,并以元组形式返回图像数据和标签
class TrainDatasetGenerator:
def __init__(self, file_path):
self.file_path = file_path
self.img_names = os.listdir(file_path)
def __getitem__(self, index):
data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
label = int(self.img_names[index][0])-1#图像对应的类别
#label = np.array([label])
data = data.transpose().astype(np.float32) / 255.
#data = np.expand_dims(data, axis=0)
#data = Tensor(data)
#label = Tensor(label)
return data, label
def __len__(self):
return len(self.img_names)
def train_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
train_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/dataset')
ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
ds_train = ds_train.shuffle(buffer_size=10)
ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
valid_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/test')
ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
network = LeNet5(num_class=7)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10,
keep_checkpoint_max=10)
config_ckpt_path = 'D:/code/machine vision course/digit-mindspore/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_lenet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
if __name__ == '__main__':
train_lenet()
该代码使用 MindSpore 框架训练 LeNet5 神经网络模型进行数字识别。
代码功能:
- 定义 LeNet5 网络结构:
- 使用
mindspore.nn模块定义卷积层、池化层、激活函数和全连接层。 - 根据
include_top参数控制是否添加全连接层。
- 使用
- 创建训练和验证数据集:
- 使用
mindspore.dataset.GeneratorDataset从自定义的数据生成器中创建数据集。 - 定义
TrainDatasetGenerator类,从指定的文件夹读取图像数据和标签。
- 使用
- 训练模型:
- 使用
mindspore.train.Model类创建模型,并指定损失函数、优化器和评估指标。 - 使用
mindspore.train.callback模块设置回调函数,用于记录训练过程、保存模型和评估性能。
- 使用
- 评估模型:
- 使用
model.eval()方法评估模型在验证数据集上的准确率。
- 使用
注意:
- 代码中使用了
'D:/code/machine vision course/digit-mindspore/dataset'和'D:/code/machine vision course/digit-mindspore/test'路径,需要根据实际情况修改。 - 代码中
num_class=7表示要识别的数字类别数为 7,可以根据需要修改。 - 训练和评估过程中,需要保证数据集中包含所有数字类别,且每个类别的数据量足够。
- 可以根据需要调整训练参数,例如学习率、batch 大小等。
ckpt 文件:
训练过程中生成的 ckpt 文件保存在 'D:/code/machine vision course/digit-mindspore/ckpt/' 文件夹中。
如何加载 ckpt 文件:
from mindspore.train.serialization import load_checkpoint, load_param_into_net
network = LeNet5(num_class=7)
param_dict = load_checkpoint('D:/code/machine vision course/digit-mindspore/ckpt/checkpoint_lenet-1_10.ckpt')
load_param_into_net(network, param_dict)
可以使用 load_checkpoint 函数加载 ckpt 文件中的参数字典,并使用 load_param_into_net 函数将参数字典加载到网络中。
总结:
该代码展示了如何使用 MindSpore 框架训练 LeNet5 模型进行数字识别,并提供了一些重要参数和使用方法的解释。使用者可以根据自身需求进行调整和修改。
原文地址: https://www.cveoy.top/t/topic/mQFj 著作权归作者所有。请勿转载和采集!