gUNet模型代码详解:每一行代码的注释
定义一个卷积层
class ConvLayer(nn.Module): '定义一个卷积层' def init(self, net_depth, dim, kernel_size=3, gate_act=nn.Sigmoid): '初始化卷积层' super().init() self.dim = dim '网络深度' self.net_depth = net_depth '卷积核大小' self.kernel_size = kernel_size # 两个卷积层 self.Wv = nn.Sequential( '第一个卷积层,将输入通道数映射到输出通道数' nn.Conv2d(dim, dim, 1), '第二个卷积层,使用分组卷积,卷积核大小为kernel_size,padding为kernel_size//2,padding_mode为reflect' nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim, padding_mode='reflect') ) self.Wg = nn.Sequential( '第一个卷积层,将输入通道数映射到输出通道数' nn.Conv2d(dim, dim, 1), # 激活函数 '使用指定的激活函数,如果gate_act为Sigmoid或Tanh,则直接使用gate_act(),否则使用gate_act(inplace=True)' gate_act() if gate_act in [nn.Sigmoid, nn.Tanh] else gate_act(inplace=True) ) # 一个卷积层 self.proj = nn.Conv2d(dim, dim, 1) # 初始化权重 self.apply(self._init_weights)
def _init_weights(self, m):
'初始化卷积层的权重'
if isinstance(m, nn.Conv2d):
# 计算gain
gain = (8 * self.net_depth) ** (-1/4)
# 计算fan_in和fan_out
fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
# 计算标准差
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
# 初始化权重
trunc_normal_(m.weight, std=std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, X):
'卷积层的前向传播'
# 输出
out = self.Wv(X) * self.Wg(X)
out = self.proj(out)
return out
定义一个基本块
class BasicBlock(nn.Module): '定义一个基本块,包含一个归一化层和一个卷积层' def init(self, net_depth, dim, kernel_size=3, conv_layer=ConvLayer, norm_layer=nn.BatchNorm2d, gate_act=nn.Sigmoid): super().init() # 归一化层 self.norm = norm_layer(dim) # 卷积层 self.conv = conv_layer(net_depth, dim, kernel_size, gate_act) def forward(self, x): '基本块的前向传播' identity = x x = self.norm(x) x = self.conv(x) x = identity + x return x
定义一个基本层
class BasicLayer(nn.Module): '定义一个基本层,包含多个基本块' def init(self, net_depth, dim, depth, kernel_size=3, conv_layer=ConvLayer, norm_layer=nn.BatchNorm2d, gate_act=nn.Sigmoid): super().init() self.dim = dim self.depth = depth # 建立基本块 self.blocks = nn.ModuleList([ '建立一个包含depth个基本块的列表' BasicBlock(net_depth, dim, kernel_size, conv_layer, norm_layer, gate_act) for i in range(depth)])
def forward(self, x):
'基本层的前向传播'
for blk in self.blocks:
x = blk(x)
return x
定义一个patch embedding层
class PatchEmbed(nn.Module): '定义一个patch embedding层,将输入图像划分为patch,并将其映射到一个更高维度的特征空间' def init(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None): super().init() self.in_chans = in_chans self.embed_dim = embed_dim if kernel_size is None: kernel_size = patch_size # 卷积层 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size, padding=(kernel_size-patch_size+1)//2, padding_mode='reflect')
def forward(self, x):
'patch embedding层的前向传播'
x = self.proj(x)
return x
定义一个patch unembedding层
class PatchUnEmbed(nn.Module): '定义一个patch unembedding层,将高维度的特征空间映射回原始的图像空间' def init(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None): super().init() self.out_chans = out_chans self.embed_dim = embed_dim if kernel_size is None: kernel_size = 1 # 两个卷积层 self.proj = nn.Sequential( '第一个卷积层,将输入通道数映射到输出通道数' nn.Conv2d(embed_dim, out_chans*patch_size**2, kernel_size=kernel_size, padding=kernel_size//2, padding_mode='reflect'), '第二个卷积层,使用PixelShuffle,将特征图上采样到原始图像大小' nn.PixelShuffle(patch_size) )
def forward(self, x):
'patch unembedding层的前向传播'
x = self.proj(x)
return x
SK融合层
class SKFusion(nn.Module): '定义一个SK融合层,用于融合来自不同尺度的特征图' def init(self, dim, height=2, reduction=8): super(SKFusion, self).init() self.height = height d = max(int(dim/reduction), 4) # MLP网络 self.mlp = nn.Sequential( '全局平均池化层,将特征图缩减为1x1大小' nn.AdaptiveAvgPool2d(1), '第一个卷积层,将输入通道数映射到d个通道' nn.Conv2d(dim, d, 1, bias=False), 'ReLU激活函数' nn.ReLU(True), '第二个卷积层,将输入通道数映射到dimheight个通道' nn.Conv2d(d, dimheight, 1, bias=False) ) # softmax函数 self.softmax = nn.Softmax(dim=1)
def forward(self, in_feats):
'SK融合层的前向传播'
B, C, H, W = in_feats[0].shape
in_feats = torch.cat(in_feats, dim=1)
in_feats = in_feats.view(B, self.height, C, H, W)
feats_sum = torch.sum(in_feats, dim=1)
attn = self.mlp(feats_sum)
attn = self.softmax(attn.view(B, self.height, C, 1, 1))
out = torch.sum(in_feats*attn, dim=1)
return out
定义一个gUNet模型
class gUNet(nn.Module): '定义一个gUNet模型,用于图像去噪' def init(self, kernel_size=5, base_dim=32, depths=[4, 4, 4, 4, 4, 4, 4], conv_layer=ConvLayer, norm_layer=nn.BatchNorm2d, gate_act=nn.Sigmoid, fusion_layer=SKFusion): super(gUNet, self).init() # setting assert len(depths) % 2 == 1 stage_num = len(depths) half_num = stage_num // 2 net_depth = sum(depths) embed_dims = [2i*base_dim for i in range(half_num)] embed_dims = embed_dims + [2half_num*base_dim] + embed_dims[::-1]
self.patch_size = 2 ** (stage_num // 2)
self.stage_num = stage_num
self.half_num = half_num
# input convolution
# 输入卷积层
self.inconv = PatchEmbed(patch_size=1, in_chans=3, embed_dim=embed_dims[0], kernel_size=3)
# backbone
# 基础层列表
self.layers = nn.ModuleList()
# patch embedding层列表
self.downs = nn.ModuleList()
# patch unembedding层列表
self.ups = nn.ModuleList()
# 卷积层列表
self.skips = nn.ModuleList()
# SK融合层列表
self.fusions = nn.ModuleList()
for i in range(self.stage_num):
# 建立一层基本层
self.layers.append(BasicLayer(dim=embed_dims[i], depth=depths[i], net_depth=net_depth, kernel_size=kernel_size,
conv_layer=conv_layer, norm_layer=norm_layer, gate_act=gate_act))
for i in range(self.half_num):
# 建立patch embedding层
self.downs.append(PatchEmbed(patch_size=2, in_chans=embed_dims[i], embed_dim=embed_dims[i+1]))
# 建立patch unembedding层
self.ups.append(PatchUnEmbed(patch_size=2, out_chans=embed_dims[i], embed_dim=embed_dims[i+1]))
# 建立卷积层
self.skips.append(nn.Conv2d(embed_dims[i], embed_dims[i], 1))
# 建立SK融合层
self.fusions.append(fusion_layer(embed_dims[i]))
# output convolution
# 输出卷积层
self.outconv = PatchUnEmbed(patch_size=1, out_chans=3, embed_dim=embed_dims[-1], kernel_size=3)
def forward(self, x):
'gUNet模型的前向传播'
# 输入卷积层
feat = self.inconv(x)
skips = []
# 前半段
for i in range(self.half_num):
feat = self.layers[i](feat)
skips.append(self.skips[i](feat))
feat = self.downs[i](feat)
# 后半段
for i in range(self.half_num-1, -1, -1):
feat = self.ups[i](feat)
feat = self.fusions[i]([feat, skips[i]])
feat = self.layers[self.stage_num-i-1](feat)
x = self.outconv(feat) + x
return x
all = ['gUNet', 'gunet_t', 'gunet_s', 'gunet_b', 'gunet_d']
Normalization batch size of 16~32 may be good
def gunet_t(): # 4 cards 2080Ti return gUNet(kernel_size=5, base_dim=24, depths=[2, 2, 2, 4, 2, 2, 2], conv_layer=ConvLayer, norm_layer=nn.BatchNorm2d, gate_act=nn.Sigmoid, fusion_layer=SKFusion)
def gunet_s(): # 4 cards 3090 return gUNet(kernel_size=5, base_dim=24, depths=[4, 4, 4, 8, 4, 4, 4], conv_layer=ConvLayer, norm_layer=nn.BatchNorm2d, gate_act=nn.Sigmoid, fusion_layer=SKFusion)
def gunet_b(): # 4 cards 3090 return gUNet(kernel_size=5, base_dim=24, depths=[8, 8, 8, 16, 8, 8, 8], conv_layer=ConvLayer, norm_layer=nn.BatchNorm2d, gate_act=nn.Sigmoid, fusion_layer=SKFusion)
def gunet_d(): # 4 cards 3090 return gUNet(kernel_size=5, base_dim=24, depths=[16, 16, 16, 32, 16, 16, 16], conv_layer=ConvLayer, norm_layer=nn.BatchNorm2d, gate_act=nn.Sigmoid, fusion_layer=SKFusion)
原文地址: https://www.cveoy.top/t/topic/nwXA 著作权归作者所有。请勿转载和采集!