MobileNetV2 微调训练 - 使用 PyTorch 进行图像分类
import os import sys import json
import torch import torch.nn as nn import torch.optim as optim from torchvision import transforms, datasets from tqdm import tqdm
from model_v2 import MobileNetV2
def main(): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print('using {} device.'.format(device))
batch_size = 16
epochs = 5
data_transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
data_root = os.path.abspath(os.path.join(os.getcwd(), '../..')) # get data root path
# data_root = os.path.abspath(os.path.join(os.getcwd(), '/datasets/NEU-DET'))
image_path = os.path.join(data_root, 'data_set', 'flower_data') # flower data set path
# image_path = os.path.join(data_root, 'data_set', 'NEU-DET', 'images') # flower data set path
assert os.path.exists(image_path), '{} path does not exist.'.format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'),
transform=data_transform['train'])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'),
transform=data_transform['val'])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print('using {} images for training, {} images for validation.'.format(train_num,
val_num))
# create model
net = MobileNetV2(num_classes=5)
# net = MobileNetV2(num_classes=6)
# load pretrain weights
# download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
model_weight_path = './mobilenet_v2.pth'
assert os.path.exists(model_weight_path), 'file {} dose not exist.'.format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location='cpu')
# delete classifier weights
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)
# freeze features weights
for param in net.features.parameters():
param.requires_grad = False
net.to(device)
# define loss function
loss_function = nn.CrossEntropyLoss()
# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)
best_acc = 0.0
save_path = './MobileNetV2.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = 'train epoch[{}/{}] loss:{:.3f}'.format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = 'valid epoch[{}/{}]'.format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
if name == 'main': main() 解释代码内容:该代码是一个用于训练MobileNetV2模型的脚本。它使用PyTorch库加载预训练的MobileNetV2模型,并在自定义的数据集上进行微调训练。
脚本首先设置了设备(GPU或CPU)和一些超参数,如批大小和训练轮数。然后定义了数据的预处理步骤,并加载训练和验证数据集。
接下来创建了MobileNetV2模型,并加载了预训练的权重。然后冻结了模型的特征层的权重,只训练分类器层的权重。
定义了损失函数和优化器,并开始训练。在每个epoch中,通过迭代训练数据集计算损失并进行梯度更新。然后在验证集上计算准确率。
最后,将训练得到的模型保存到文件中。
请注意,模型训练过程中使用了tqdm库来显示训练进度条。在每个epoch中,会打印出训练损失和验证准确率。
这段代码的主要目的是展示如何使用MobileNetV2模型进行图像分类任务的微调训练,并保存训练得到的模型。
原文地址: http://www.cveoy.top/t/topic/i3BH 著作权归作者所有。请勿转载和采集!