要在代码中添加 dropout 层,可以在模型定义的地方进行修改。下面是修改后的代码:

import torch
from net import simpleconv
from torchvision import transforms, datasets
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tensorboardX import SummaryWriter
import time
import os


def train(train_loader, model, criterion, optimizer, device, len_train, batch_size):
    num_loss = 0.0
    num_corrects = 0.0
    model.train()

    for i, (data, target) in enumerate(train_loader):
        data = data.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        outputs = model(data)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        num_corrects += torch.sum(preds == target).item()
        num_loss = num_loss + loss.item()

    train_loss = num_loss / (len_train // batch_size + 1)
    train_acc = num_corrects / len_train

    return train_loss, train_acc


def val(val_loader, model, criterion, device, len_val, batch_size):
    num_loss = 0.0
    num_corrects = 0.0
    model.eval()

    with torch.no_grad():
        for i, (data, target) in enumerate(val_loader):
            data = data.to(device)
            target = target.to(device)

            outputs = model(data)
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, target)

            num_corrects += torch.sum(preds == target).item()
            num_loss = num_loss + loss.item()

    val_loss = num_loss / (len_val // batch_size + 1)
    val_acc = num_corrects / len_val

    return val_loss, val_acc


batch_size = 100
nclass = 13
num_epochs = 364
data_dir = './testimage'

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
model = simpleconv(nclass, dropout_rate=0.5).to(device)  # 添加dropout层,设置dropout_rate为0.5

train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomResizedCrop(48),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

val_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.CenterCrop(48),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms)
val_datasets = datasets.ImageFolder(os.path.join(data_dir, 'test'), val_transforms)
len_train = len(train_datasets)
len_val = len(val_datasets)

train_loaders = torch.utils.data.DataLoader(
    dataset=train_datasets,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

val_loaders = torch.utils.data.DataLoader(
    dataset=val_datasets,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model.parameters(), lr=0.1, momentum=0.8)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=200, gamma=0.1)

writer = SummaryWriter("runs")
best_acc = 0.0

for epoch in range(num_epochs):

    start = time.time()

    train_loss, train_acc = train(train_loaders, model, criterion, optimizer_ft, device, len_train, batch_size)

    exp_lr_scheduler.step()
    val_loss, val_acc = val(val_loaders, model, criterion, device, len_val, batch_size)

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './best.pth')
    writer.add_scalar('trainloss', train_loss, epoch)
    writer.add_scalar('trainacc', train_acc, epoch)
    writer.add_scalar('valloss', val_loss, epoch)
    writer.add_scalar('valacc', val_acc, epoch)
    end = time.time()
    print('[{}/{}]: train_loss:{:.3f}, train_acc:{:.3f},eval_loss:{:.3f}, eval_acc:{:.3f},  time:{:.3f}'.format(epoch+1,num_epochs,train_loss,train_acc,val_loss,val_acc,end-start))
writer.close()

在模型定义的地方,将模型的初始化语句修改为model = simpleconv(nclass, dropout_rate=0.5).to(device),这样就添加了一个dropout层,设置了dropout_rate为0.5。可以根据需要调整dropout_rate的值。其他部分代码保持不变。

这样就在模型中添加了dropout层。

PyTorch 中添加 Dropout 层以防止过拟合

原文地址: https://www.cveoy.top/t/topic/RM1 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录