用pytorch写一个自动编码器的代码
下面是一个简单的自动编码器的PyTorch代码:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 超参数设置
num_epochs = 10
batch_size = 128
learning_rate = 0.001
# 加载数据集并进行数据预处理
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
# 定义自动编码器的类
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 2))
# 解码器
self.decoder = nn.Sequential(
nn.Linear(2, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Tanh())
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 实例化自动编码器
autoencoder = Autoencoder()
# 定义优化器和损失函数
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=learning_rate)
# 训练自动编码器
for epoch in range(num_epochs):
for data in trainloader:
img, _ = data
img = img.view(img.size(0), -1)
optimizer.zero_grad()
outputs = autoencoder(img)
loss = criterion(outputs, img)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 测试自动编码器
with torch.no_grad():
for data in testloader:
img, _ = data
img = img.view(img.size(0), -1)
outputs = autoencoder(img)
break
# 显示原始图像和重建图像
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))
for images, row in zip([img, outputs], axes):
for img, ax in zip(images, row):
ax.imshow(img.view(28, 28), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
这个自动编码器使用了MNIST数据集,并且将图像从784维压缩到2维,再从2维重建为784维。在训练完成后,代码还会显示一些原始图像和重建图像,以便您可以看到自动编码器的效果。
原文地址: https://www.cveoy.top/t/topic/bzpV 著作权归作者所有。请勿转载和采集!