实现WGAN-GP这个神经网络模型的代码
由于WGAN-GP是一种比较复杂的神经网络模型,需要较高的编程能力和深度学习知识,以下是一个基于PyTorch实现WGAN-GP的代码示例:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import numpy as np
import argparse
import os
import random
import time
import datetime
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
# 定义超参数
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--image_size', type=int, default=64, help='image size')
parser.add_argument('--z_dim', type=int, default=100, help='dimension of noise vector')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
parser.add_argument('--n_critic', type=int, default=5, help='number of training steps for discriminator per iter')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs')
parser.add_argument('--data_path', type=str, default='./data', help='data path')
parser.add_argument('--save_path', type=str, default='./output', help='save path')
parser.add_argument('--cuda', type=bool, default=True, help='use cuda or not')
opt = parser.parse_args()
# 定义数据预处理
transforms = transforms.Compose([
transforms.Resize(opt.image_size),
transforms.CenterCrop(opt.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
# 加载数据集
dataset = datasets.ImageFolder(root=opt.data_path, transform=transforms)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=4)
# 定义生成器
class Generator(nn.Module):
def __init__(self, z_dim=100):
super(Generator, self).__init__()
self.z_dim = z_dim
self.main = nn.Sequential(
nn.ConvTranspose2d(self.z_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
output = self.main(z.view(-1, self.z_dim, 1, 1))
return output
# 定义鉴别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, 1, 0, bias=False)
)
def forward(self, x):
output = self.main(x)
return output.view(-1, 1)
# 定义WGAN-GP模型
class WGAN_GP:
def __init__(self, opt):
self.opt = opt
self.generator = Generator(self.opt.z_dim)
self.discriminator = Discriminator()
self.generator.apply(weights_init)
self.discriminator.apply(weights_init)
self.optimizer_G = optim.Adam(self.generator.parameters(), lr=self.opt.lr, betas=(0.5, 0.999))
self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=self.opt.lr, betas=(0.5, 0.999))
def train(self):
print("Start Training...")
for epoch in range(self.opt.n_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
# 训练鉴别器
for _ in range(self.opt.n_critic):
self.discriminator.zero_grad()
real_imgs = real_imgs.cuda() if self.opt.cuda else real_imgs
real_validity = self.discriminator(real_imgs)
real_loss = -torch.mean(real_validity)
real_grad = torch.autograd.grad(outputs=real_validity, inputs=real_imgs,
grad_outputs=torch.ones(real_validity.size()).cuda() if self.opt.cuda else torch.ones(real_validity.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
real_grad_norm = real_grad.view(real_grad.size(0), -1).norm(2, dim=1)
real_grad_penalty = ((real_grad_norm - 1) ** 2).mean() * 10
z = torch.randn(real_imgs.size(0), self.opt.z_dim).cuda() if self.opt.cuda else torch.randn(real_imgs.size(0), self.opt.z_dim)
fake_imgs = self.generator(z)
fake_validity = self.discriminator(fake_imgs.detach())
fake_loss = torch.mean(fake_validity)
gradient_penalty = self.compute_gradient_penalty(real_imgs, fake_imgs)
d_loss = real_loss + fake_loss + real_grad_penalty + gradient_penalty
d_loss.backward()
self.optimizer_D.step()
# 训练生成器
self.generator.zero_grad()
z = torch.randn(real_imgs.size(0), self.opt.z_dim).cuda() if self.opt.cuda else torch.randn(real_imgs.size(0), self.opt.z_dim)
fake_imgs = self.generator(z)
fake_validity = self.discriminator(fake_imgs)
g_loss = -torch.mean(fake_validity)
g_loss.backward()
self.optimizer_G.step()
# 打印训练日志
if i % 100 == 0:
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"
% (epoch, self.opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
# 保存生成图片
if i % 1000 == 0:
fake_imgs = self.generator(fixed_z)
vutils.save_image(fake_imgs.data[:64], '%s/generated_%d.png' % (self.opt.save_path, i), normalize=True)
def compute_gradient_penalty(self, real_imgs, fake_imgs):
alpha = torch.rand(real_imgs.size(0), 1, 1, 1).cuda() if self.opt.cuda else torch.rand(real_imgs.size(0), 1, 1, 1)
interpolates = (alpha * real_imgs + ((1 - alpha) * fake_imgs)).requires_grad_(True)
D_interpolates = self.discriminator(interpolates)
fake_output = Variable(torch.Tensor(real_imgs.shape[0], 1).fill_(1.0).cuda(), requires_grad=False) if self.opt.cuda else Variable(torch.Tensor(real_imgs.shape[0], 1).fill_(1.0), requires_grad=False)
gradients = torch.autograd.grad(outputs=D_interpolates, inputs=interpolates,
grad_outputs=fake_output, create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients_norm = gradients.view(gradients.size(0), -1).norm(2, dim=1)
gradient_penalty = ((gradients_norm - 1) ** 2).mean() * 10
return gradient_penalty
# 定义参数初始化函数
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# 定义固定的噪声向量
fixed_z = torch.randn(64, opt.z_dim).cuda() if opt.cuda else torch.randn(64, opt.z_dim)
# 创建WGAN-GP模型并开始训练
model = WGAN_GP(opt)
if opt.cuda:
model.generator.cuda()
model.discriminator.cuda()
model.train()
原文地址: https://www.cveoy.top/t/topic/b4E1 著作权归作者所有。请勿转载和采集!