import numpy as np
import mindspore.dataset as ds
import os
import cv2
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.initializer import Normal
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
np.random.seed(58)


class LeNet5(nn.Cell):
    'Lenet network

    Args:
        num_class (int): Number of classes. Default: 10.
        num_channel (int): Number of channels. Default: 1.

    Returns:
        Tensor, output tensor
    Examples:
        >>> LeNet(num_class=10)

    '
    def __init__(self, num_class=10, num_channel=3, include_top=True):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.include_top = include_top
        #如果include_top为True,则在模型的最后添加全连接层,
        #其中包括三个Dense层,分别是784->120、120->84、84->num_class
        # 每个Dense层采用正态分布初始化权重。同时,还定义了一个Flatten层,用于将输入数据压平。
        if self.include_top:#include_top是一个布尔值
            #如果include_top为True,则模型的顶部将包括全局平均池化层和分类器,用于进行分类
            #如果include_top为False,则模型的顶部将被删除,并且可以添加自定义的分类器。
            self.flatten = nn.Flatten()
            self.fc1 = nn.Dense(784, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))


    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        if not self.include_top:
            return x
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

#生成训练数据集,从指定路径读取图像文件,并以元组形式返回图像数据和标签
class TrainDatasetGenerator:
    def __init__(self, file_path):
        self.file_path = file_path
        self.img_names = os.listdir(file_path)

    def __getitem__(self, index):
        data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
        label = int(self.img_names[index][0])-1#图像对应的类别
        #label = np.array([label])
        data = data.transpose().astype(np.float32) / 255.
        #data = np.expand_dims(data, axis=0)
        #data = Tensor(data)
        #label = Tensor(label)
        return data, label

    def __len__(self):
        return len(self.img_names)


def train_lenet():
    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
    train_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/dataset')
    ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
    ds_train = ds_train.shuffle(buffer_size=10)
    ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
    valid_dataset_generator = TrainDatasetGenerator('D:/code/machine vision course/digit-mindspore/test')
    ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
    ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
    network = LeNet5(num_class=7)
    net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
    time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
    config_ck = CheckpointConfig(save_checkpoint_steps=10,
                                 keep_checkpoint_max=10)
    config_ckpt_path = 'D:/code/machine vision course/digit-mindspore/ckpt/'
    ckpoint_cb = ModelCheckpoint(prefix='checkpoint_lenet', directory=config_ckpt_path, config=config_ck)


    model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    acc = model.eval(ds_valid)
    print('============== {} ============='.format(acc))
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    acc = model.eval(ds_valid)
    print('============== {} ============='.format(acc))
    epoch_size = 10
    print('============== Starting Training =============')
    model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])

    acc = model.eval(ds_valid)
    print('============== {} ============='.format(acc))

if __name__ == '__main__':
    train_lenet()

Adding Pretrained Model and Generating CKPT Files

To add a pretrained model and generate ckpt files, follow these steps:

  1. Define a Pretrained Model:

    • Create a class for your pretrained model (e.g., MyPretrainedModel).
    • Implement the model's architecture and load the pretrained weights.
  2. Load Weights into LeNet5:

    • After defining your LeNet5 model, load the pretrained weights using load_state_dict().
  3. Initialize with Pretrained Weights:

    • Optionally, load the pretrained ckpt file using load_checkpoint() to initialize the LeNet5 model during training.
  4. Create Checkpoint Callbacks:

    • Use ModelCheckpoint callbacks during training to save the model's weights at regular intervals.
  5. Save Final Weights:

    • After training is complete, use save_checkpoint() to save the final trained weights.

Example (Pseudocode):

from mindspore.train.serialization import load_checkpoint, save_checkpoint

# Define Pretrained Model
class MyPretrainedModel(nn.Cell):
    # ... (Implement model architecture and load weights) ...

# Load pretrained weights into LeNet5
pretrained_model = MyPretrainedModel()
network = LeNet5(num_class=7)
network.load_state_dict(pretrained_model.state_dict())

# ... (Rest of your training code) ...

# Load pretrained weights from ckpt file (optional)
ckpt_path = 'pretrained_model.ckpt'
if os.path.exists(ckpt_path):
    load_checkpoint(ckpt_path, network)

# Define checkpoint callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
ckpt_cb = ModelCheckpoint(prefix='lenet', directory='./ckpt', config=config_ck)

# Train with checkpoints
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpt_cb, LossMonitor()])

# Save final weights
save_checkpoint(network, 'lenet_final.ckpt')

Key Points

  • Replace MyPretrainedModel with your actual pretrained model class.
  • Ensure the pretrained weights are compatible with the LeNet5 architecture.
  • The specific implementation for loading and saving weights might vary depending on your pretrained model framework.
  • Adjust the checkpoint configuration (CheckpointConfig) as needed to control how many checkpoints are saved and how long they are kept.
MindSpore LeNet5 Implementation for Image Classification

原文地址: https://www.cveoy.top/t/topic/mQE0 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录