PyTorch实现Dense-UNet:修改通道数,打造高效语义分割模型
import torch
import torch.nn as nn
import torch.nn.functional as F
class Single_level_densenet(nn.Module):
def __init__(self, filters, num_conv=4):
super(Single_level_densenet, self).__init__()
self.num_conv = num_conv
self.conv_list = nn.ModuleList()
self.bn_list = nn.ModuleList()
for i in range(self.num_conv):
self.conv_list.append(nn.Conv2d(filters, filters, 3, padding=1))
self.bn_list.append(nn.BatchNorm2d(filters))
def forward(self, x):
outs = []
outs.append(x)
for i in range(self.num_conv):
temp_out = self.conv_list[i](outs[i])
if i > 0:
for j in range(i):
temp_out += outs[j]
outs.append(F.relu(self.bn_list[i](temp_out)))
out_final = outs[-1]
del outs
return out_final
class Down_sample(nn.Module):
def __init__(self, kernel_size=2, stride=2):
super(Down_sample, self).__init__()
self.down_sample_layer = nn.MaxPool2d(kernel_size, stride)
def forward(self, x):
y = self.down_sample_layer(x)
return y, x
class Upsample_n_Concat(nn.Module):
def __init__(self, filters):
super(Upsample_n_Concat, self).__init__()
self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding=1, stride=2)
self.conv = nn.Conv2d(2 * filters, filters, 3, padding=1)
self.bn = nn.BatchNorm2d(filters)
def forward(self, x, y):
x = self.upsample_layer(x)
x = torch.cat([x, y], dim=1)
x = F.relu(self.bn(self.conv(x)))
return x
class Dense_Unet(nn.Module):
def __init__(self, in_chan, out_chan, filters, num_conv=4):
super(Dense_Unet, self).__init__()
self.conv1 = nn.Conv2d(in_chan, filters[0], 1)
self.d1 = Single_level_densenet(filters[0], num_conv)
self.down1 = Down_sample()
self.d2 = Single_level_densenet(filters[1], num_conv)
self.down2 = Down_sample()
self.d3 = Single_level_densenet(filters[2], num_conv)
self.down3 = Down_sample()
self.d4 = Single_level_densenet(filters[3], num_conv)
self.down4 = Down_sample()
self.bottom = Single_level_densenet(filters[4], num_conv)
self.up4 = Upsample_n_Concat(filters[3])
self.u4 = Single_level_densenet(filters[3], num_conv)
self.up3 = Upsample_n_Concat(filters[2])
self.u3 = Single_level_densenet(filters[2], num_conv)
self.up2 = Upsample_n_Concat(filters[1])
self.u2 = Single_level_densenet(filters[1], num_conv)
self.up1 = Upsample_n_Concat(filters[0])
self.u1 = Single_level_densenet(filters[0], num_conv)
self.outconv = nn.Conv2d(filters[0], out_chan, 1)
def forward(self, x):
bsz = x.shape[0]
x = self.conv1(x)
print(x.shape)
x, y1 = self.down1(self.d1(x))
print(x.shape)
x, y2 = self.down2(self.d2(x))
print(x.shape)
x, y3 = self.down3(self.d3(x))
x, y4 = self.down4(self.d4(x))
x = self.bottom(x)
x = self.u4(self.up4(x, y4))
x = self.u3(self.up3(x, y3))
x = self.u2(self.up2(x, y2))
x = self.u1(self.up1(x, y1))
x1 = self.outconv(x)
return x1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dense_unet = Dense_Unet(1, 14, [64, 128, 256, 512, 1024]).to(device)
input_1 = torch.rand(1, 1, 256, 256).to(device)
output = dense_unet(input_1)
主要修改的地方包括:
- 在
Dense_Unet类中,根据filters列表修改了self.d1到self.bottom和self.up4到self.u1等层的输入通道数,使其与filters列表中的对应元素一致。 - 在实例化
Dense_Unet时,将filters参数修改为[64, 128, 256, 512, 1024],以满足输入数据的第二通道依次为64到128到256到512再到1024的要求。
通过以上修改,我们成功地将 Dense-UNet 模型的输入通道数按照要求进行了调整。您可以根据自己的实际需求,灵活地修改 filters 列表中的元素,以构建适合您任务的模型。
原文地址: https://www.cveoy.top/t/topic/fRGG 著作权归作者所有。请勿转载和采集!