MindSpore LeNet-5 模型训练示例
以下示例代码使用 MindSpore 框架,基于 LeNet-5 模型实现数字识别任务的训练过程,包括数据加载、模型构建、训练和评估等步骤:
def train_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
train_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/dataset')
ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
ds_train = ds_train.shuffle(buffer_size=10)
ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
valid_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/test')
ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
network = LeNet5(num_class=7)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10,
keep_checkpoint_max=10)
config_ckpt_path = 'D:/code/machine vision course/digit-mindspore/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_lenet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
acc = model.eval(ds_valid)
print('============== {} ============='.format(acc))
注意: LeNet-5 模型并不适合用于人脸识别任务。人脸识别需要使用更加复杂的卷积神经网络模型,例如 ResNet、VGG 等。同时,数据集也需要使用人脸数据集,例如 LFW、CASIA 等。建议您了解相关的人脸识别模型和数据集,并进行相关的实现。
原文地址: https://www.cveoy.top/t/topic/mQeT 著作权归作者所有。请勿转载和采集!