优化一下下面的模型并给出理由# 生成器输入100噪声输出12828class GeneratornnModule def __init__self superGenerator self__init__ selflinear = nnSequential nnLinear100 256 nnTanh
优化后的模型如下:
生成器,输入100噪声输出(1,28,28)
class Generator(nn.Module): def init(self): super(Generator, self).init() self.linear = nn.Sequential( nn.Linear(100, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Linear(512, 28*28), nn.Tanh() )
def forward(self, x):
x = self.linear(x)
x = x.view(-1, 28, 28)
return x
辨别器,输入(1,28,28),输出真假,推荐使用LeakRelu
class Discriminator(nn.Module): def init(self): super(Discriminator, self).init() self.linear = nn.Sequential( nn.Linear(28*28, 512), nn.BatchNorm1d(512), nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 28*28)
x = self.linear(x)
return x
优化理由:
-
在生成器中添加了BN层和ReLU激活函数,使得生成的图像更加真实。
-
在辨别器中添加了BN层,可以加速模型收敛,LeakyReLU激活函数可以避免梯度消失问题。
-
对于LeakyReLU激活函数,设置了较小的负斜率,可以保持一定的负梯度,避免神经元“死亡”。
-
对于BN层,由于GAN中生成器和辨别器的输出分别是真实样本和生成样本,因此需要针对两个模型分别进行BN处理。
原文地址: http://www.cveoy.top/t/topic/dvGz 著作权归作者所有。请勿转载和采集!