PyTorch 实现 GAN 对抗神经网络教程
GAN(Generative Adversarial Network)是一种生成式模型,由两个神经网络组成:生成器和判别器。生成器的目标是生成与真实数据相似的虚假数据,而判别器的目标是区分真实数据和虚假数据。两个网络不断进行对抗,最终生成器会生成更加逼真的虚假数据。
下面是一个简单的 PyTorch 实现 GAN 的示例:
首先,需要导入必要的库:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
接着,定义生成器和判别器的网络结构:
class Generator(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, 128)
self.fc2 = nn.Linear(128, 256)
self.fc3 = nn.Linear(256, output_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.tanh(x)
return x
class Discriminator(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc1 = nn.Linear(input_size, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.sigmoid(x)
return x
然后,定义训练过程:
def train(generator, discriminator, dataloader, num_epochs=50, lr=0.0002, beta1=0.5, beta2=0.999):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)
generator_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
criterion = nn.BCELoss()
for epoch in range(num_epochs):
for i, real_data in enumerate(dataloader):
real_data = real_data[0].to(device)
batch_size = real_data.size(0)
# Train discriminator
discriminator.zero_grad()
real_labels = torch.ones(batch_size, 1).to(device)
real_output = discriminator(real_data)
real_loss = criterion(real_output, real_labels)
noise = torch.randn(batch_size, 100).to(device)
fake_data = generator(noise)
fake_labels = torch.zeros(batch_size, 1).to(device)
fake_output = discriminator(fake_data.detach())
fake_loss = criterion(fake_output, fake_labels)
total_loss = real_loss + fake_loss
total_loss.backward()
discriminator_optimizer.step()
# Train generator
generator.zero_grad()
fake_labels = torch.ones(batch_size, 1).to(device)
fake_output = discriminator(fake_data)
generator_loss = criterion(fake_output, fake_labels)
generator_loss.backward()
generator_optimizer.step()
if i % 100 == 0:
print('[%d/%d][%d/%d] Discriminator Loss: %.4f Generator Loss: %.4f'
% (epoch, num_epochs, i, len(dataloader), total_loss.item(), generator_loss.item()))
# Save generated images
with torch.no_grad():
noise = torch.randn(64, 100).to(device)
fake_images = generator(noise).detach().cpu()
save_image(fake_images, 'generated_images/epoch{}.png'.format(epoch + 1), normalize=True)
最后,加载数据集并开始训练:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
generator = Generator(input_size=100, output_size=28 * 28)
discriminator = Discriminator(input_size=28 * 28)
train(generator, discriminator, dataloader, num_epochs=50)
以上就是一个简单的 PyTorch 实现 GAN 的示例。需要注意的是,GAN 模型的训练过程比较复杂,需要耐心调试和优化模型。
原文地址: https://www.cveoy.top/t/topic/nf7j 著作权归作者所有。请勿转载和采集!