PyTorch GPU训练:确保模型参数和输入数据类型一致

在使用GPU进行PyTorch模型训练时,确保模型参数和输入数据的数据类型一致至关重要。如果不一致,可能会导致运行时错误或训练性能下降。以下是如何在代码中实现这一点:pythonimport torchfrom net import simpleconvfrom torchvision import transforms, datasetsimport torch.nn as nnimport torch.optim as optimfrom torch.optim import lr_schedulerfrom tensorboardX import SummaryWriterimport timeimport os

def train(train_loader, model, criterion, optimizer, device, len_train, batch_size): num_loss = 0.0 num_corrects = 0.0 model.train()

for i, (data, target) in enumerate(train_loader):        # 将数据和目标转换为目标设备和数据类型        data = data.to(device, dtype=torch.float)        target = target.to(device)

    optimizer.zero_grad()        outputs = model(data)        _, preds = torch.max(outputs, 1)        loss = criterion(outputs, target)        loss.backward()        optimizer.step()

    num_corrects += torch.sum(preds == target).item()        num_loss = num_loss + loss.item()

train_loss = num_loss / (len_train // batch_size + 1)    train_acc = num_corrects / len_train

return train_loss, train_acc

def val(val_loader, model, criterion, device, len_val, batch_size): num_loss = 0.0 num_corrects = 0.0 model.eval()

with torch.no_grad():        for i, (data, target) in enumerate(val_loader):            # 将数据和目标转换为目标设备和数据类型            data = data.to(device, dtype=torch.float)            target = target.to(device)

        outputs = model(data)            _, preds = torch.max(outputs.data, 1)            loss = criterion(outputs, target)

        num_corrects += torch.sum(preds == target).item()            num_loss = num_loss + loss.item()

val_loss = num_loss / (len_val // batch_size + 1)    val_acc = num_corrects / len_val

return val_loss, val_acc

batch_size = 100nclass = 13num_epochs = 364data_dir = './testimage'

使用 torch.device 获取设备if torch.cuda.is_available(): device = torch.device('cuda')else: device = torch.device('cpu')

将模型转换为目标设备和数据类型model = simpleconv(nclass).to(device, dtype=torch.float)

train_transforms = transforms.Compose([ transforms.Resize((64, 64)), transforms.RandomResizedCrop(48), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

val_transforms = transforms.Compose([ transforms.Resize((64, 64)), transforms.CenterCrop(48), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms)val_datasets = datasets.ImageFolder(os.path.join(data_dir, 'test'), val_transforms)len_train = len(train_datasets)len_val = len(val_datasets)

train_loaders = torch.utils.data.DataLoader( dataset=train_datasets, batch_size=batch_size, shuffle=True, num_workers=0)

val_loaders = torch.utils.data.DataLoader( dataset=val_datasets, batch_size=batch_size, shuffle=True, num_workers=0)

criterion = nn.CrossEntropyLoss()optimizer_ft = optim.SGD(model.parameters(), lr=0.1, momentum=0.8)exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=200, gamma=0.1)

writer = SummaryWriter('runs')best_acc = 0.0

for epoch in range(num_epochs):

start = time.time()

train_loss, train_acc = train(train_loaders, model, criterion, optimizer_ft, device, len_train, batch_size)

exp_lr_scheduler.step()    val_loss, val_acc = val(val_loaders, model, criterion, device, len_val, batch_size)

if val_acc > best_acc:        best_acc = val_acc        torch.save(model.state_dict(), './best.pth')    writer.add_scalar('trainloss', train_loss, epoch)    writer.add_scalar('trainacc', train_acc, epoch)    writer.add_scalar('valloss', val_loss, epoch)    writer.add_scalar('valacc', val_acc, epoch)    end = time.time()    print('[{}/{}]: train_loss:{:.3f}, train_acc:{:.3f},eval_loss:{:.3f}, eval_acc:{:.3f},  time:{:.3f}'.format(epoch+1,num_epochs,train_loss,train_acc,val_loss,val_acc,end-start))writer.close()

关键点:

  • 使用 to(device, dtype=torch.float) 将模型参数和输入数据都转换为 torch.float 类型,并将其移动到目标设备(CPU 或 GPU)。* 使用 torch.device 使用 torch.device('cuda')torch.device('cpu') 指定设备,而不是使用字符串。

通过遵循这些步骤,您可以确保模型参数和输入数据的数据类型一致,从而避免潜在的错误并提高训练性能。

PyTorch GPU训练:确保模型参数和输入数据类型一致

原文地址: http://www.cveoy.top/t/topic/ROC 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录