ResNet模型训练代码:解决矩阵乘法维度不匹配问题
ResNet模型训练代码:解决矩阵乘法维度不匹配问题
本文提供了一个ResNet模型的训练代码,并解决了在训练过程中出现的矩阵乘法维度不匹配问题。代码使用MindSpore框架,并详细介绍了模型结构、数据集生成、训练过程以及如何解决错误。
import os
import cv2
import numpy as np
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context, Tensor, Model
from mindspore.common.initializer import Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.metrics import Accuracy
from mindspore.ops.operations import TensorAdd
from mindspore.train.serialization import load_checkpoint, load_param_into_net
np.random.seed(58)
class BasicBlock(nn.Cell):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode='pad', has_bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
self.add = TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.make_layer(block, 64, layers[0])
self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
self.flatten = nn.Flatten()
self.fc = nn.Dense(512, num_classes) # 修改输出大小为10
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels:
downsample = nn.SequentialCell([
nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, has_bias=False),
nn.BatchNorm2d(out_channels)
])
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
class TrainDatasetGenerator:
def __init__(self, file_path):
self.file_path = file_path
self.img_names = os.listdir(file_path)
def __getitem__(self, index):
data = cv2.imread(os.path.join(self.file_path, self.img_names[index]))
label = self.img_names[index].split('_')[0]
label = int(label)
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
data = cv2.resize(data, (224, 224))
data = data.transpose().astype(np.float32) / 255.
return data, label
def __len__(self):
return len(self.img_names)
def train_resnet():
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
train_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/train1')
ds_train = ds.GeneratorDataset(train_dataset_generator, ['data', 'label'], shuffle=True)
ds_train = ds_train.shuffle(buffer_size=10)
ds_train = ds_train.batch(batch_size=4, drop_remainder=True)
valid_dataset_generator = TrainDatasetGenerator('D:/pythonProject7/test1')
ds_valid = ds.GeneratorDataset(valid_dataset_generator, ['data', 'label'], shuffle=True)
ds_valid = ds_valid.batch(batch_size=4, drop_remainder=True)
network = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.001, momentum=0.9)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=10)
config_ckpt_path = 'D:/pythonProject7/ckpt/'
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_resnet', directory=config_ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={'Accuracy': Accuracy()})
epoch_size = 10
print('============== Starting Training =============')
model.train(epoch_size, ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
ckpt_files = os.listdir(config_ckpt_path)
best_acc = {'Accuracy': 0}
best_ckpt_file = None
for ckpt_file in ckpt_files:
if ckpt_file.endswith('.ckpt'):
param_dict = load_checkpoint(os.path.join(config_ckpt_path, ckpt_file))
load_param_into_net(network, param_dict)
acc = model.eval(ds_valid)
if acc['Accuracy'] > best_acc['Accuracy']:
best_acc = acc
best_ckpt_file = ckpt_file
print('Best ckpt file: {}'.format(best_ckpt_file))
if __name__ == '__main__':
train_resnet()
错误原因分析
根据错误提示,可以看出是在进行矩阵乘法时,输入的维度不匹配导致的。具体来说,输入的x的shape为[4, 512],而fc层的权重参数y的shape为[100, 10],它们的第二个维度不一致,无法进行矩阵乘法。
解决方法
解决方法是修改ResNet类中的fc层的输出大小,将其改为10即可。修改后的代码如下:
class ResNet(nn.Cell):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', has_bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self.make_layer(block, 64, layers[0])
self.layer2 = self.make_layer(block, 128, layers[1], stride=2)
self.layer3 = self.make_layer(block, 256, layers[2], stride=2)
self.layer4 = self.make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
self.flatten = nn.Flatten()
self.fc = nn.Dense(512, num_classes) # 修改输出大小为10
# ... (其他代码)
通过修改fc层的输出大小,确保输入和权重参数的维度匹配,从而解决矩阵乘法维度不匹配的问题。
总结
本文提供了一个ResNet模型训练代码,并详细介绍了如何解决训练过程中出现的矩阵乘法维度不匹配问题。通过修改fc层的输出大小,可以确保输入和权重参数的维度匹配,从而使模型训练正常进行。希望本文可以帮助您理解ResNet模型的训练过程以及如何解决常见问题。
原文地址: https://www.cveoy.top/t/topic/jqA7 著作权归作者所有。请勿转载和采集!