import numpy as npimport torch# 导入 PyTorch 内置的 mnist 数据from torchvisiondatasets import mnist#导入预处理模块import torchvisiontransforms as transformsfrom torchutilsdata import DataLoader#导入nn及优化器import torch
这段代码有几处错误:
- 在导入PyTorch库时,未导入torchvision库。
- 在使用matplotlib绘图时,缺少对plt.subplot()函数的缩进。
- plt.tight_layout()函数缩进不正确。
- plt.imshow()函数中,应该将example_data[i][0]改为example_data[i][0].numpy(),因为imshow()函数需要接受numpy数组作为输入。
以下是修改后的代码:
import numpy as np
import torch
import torchvision
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
train_batch_size = 64
test_batch_size = 128
learning_rate = 0.01
num_epoches = 20
lr = 0.01
momentum = 0.5
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
import matplotlib.pyplot as plt
%matplotlib inline
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0].numpy(), cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
``
原文地址: https://www.cveoy.top/t/topic/hLPH 著作权归作者所有。请勿转载和采集!