修改如下代码使其filters=64 128 256 5121024输入数据的第二通道由64到128到256到512再到1024:import torchimport torchnn as nnimport torchnnfunctional as Fclass Single_level_densenetnnModule def __init__selffilters num_conv =
修改后的代码如下所示:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Single_level_densenet(nn.Module):
def __init__(self, filters, num_conv=4):
super(Single_level_densenet, self).__init__()
self.num_conv = num_conv
self.conv_list = nn.ModuleList()
self.bn_list = nn.ModuleList()
for i in range(self.num_conv):
self.conv_list.append(nn.Conv2d(filters, filters, 3, padding=1))
self.bn_list.append(nn.BatchNorm2d(filters))
def forward(self, x):
outs = []
outs.append(x)
for i in range(self.num_conv):
temp_out = self.conv_list[i](outs[i])
if i > 0:
for j in range(i):
temp_out += outs[j]
outs.append(F.relu(self.bn_list[i](temp_out)))
out_final = outs[-1]
del outs
return out_final
class Down_sample(nn.Module):
def __init__(self, kernel_size=2, stride=2):
super(Down_sample, self).__init__()
self.down_sample_layer = nn.MaxPool2d(kernel_size, stride)
def forward(self, x):
y = self.down_sample_layer(x)
return y, x
class Upsample_n_Concat(nn.Module):
def __init__(self, filters):
super(Upsample_n_Concat, self).__init__()
self.upsample_layer = nn.ConvTranspose2d(filters, filters, 4, padding=1, stride=2)
self.conv = nn.Conv2d(2 * filters, filters, 3, padding=1)
self.bn = nn.BatchNorm2d(filters)
def forward(self, x, y):
x = self.upsample_layer(x)
x = torch.cat([x, y], dim=1)
x = F.relu(self.bn(self.conv(x)))
return x
class Dense_Unet(nn.Module):
def __init__(self, in_chan, out_chan, filters, num_conv=4):
super(Dense_Unet, self).__init__()
self.conv1 = nn.Conv2d(in_chan, filters, 1)
self.d1 = Single_level_densenet(filters, num_conv)
self.down1 = Down_sample()
self.d2 = Single_level_densenet(filters * 2, num_conv)
self.down2 = Down_sample()
self.d3 = Single_level_densenet(filters * 4, num_conv)
self.down3 = Down_sample()
self.d4 = Single_level_densenet(filters * 8, num_conv)
self.down4 = Down_sample()
self.bottom = Single_level_densenet(filters * 16, num_conv)
self.up4 = Upsample_n_Concat(filters * 8)
self.u4 = Single_level_densenet(filters * 8, num_conv)
self.up3 = Upsample_n_Concat(filters * 4)
self.u3 = Single_level_densenet(filters * 4, num_conv)
self.up2 = Upsample_n_Concat(filters * 2)
self.u2 = Single_level_densenet(filters * 2, num_conv)
self.up1 = Upsample_n_Concat(filters)
self.u1 = Single_level_densenet(filters, num_conv)
self.outconv = nn.Conv2d(filters, out_chan, 1)
def forward(self, x):
bsz = x.shape[0]
x = self.conv1(x)
print(x.shape)
x, y1 = self.down1(self.d1(x))
print(x.shape)
x, y2 = self.down2(self.d2(x))
print(x.shape)
x, y3 = self.down3(self.d3(x))
x, y4 = self.down4(self.d4(x))
x = self.bottom(x)
x = self.u4(self.up4(x, y4))
x = self.u3(self.up3(x, y3))
x = self.u2(self.up2(x, y2))
x = self.u1(self.up1(x, y1))
x1 = self.outconv(x)
return x1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dense_unet = Dense_Unet(1, 14, [64, 128, 256, 512, 1024]).to(device)
input_1 = torch.rand(1, 1, 256, 256).to(device)
output = dense_unet(input_1)
主要修改的地方包括:
- 在
Dense_Unet类中,修改了self.d2、self.d3、self.d4、self.bottom、self.up4、self.u4、self.up3、self.u3、self.up2、self.u2、self.up1、self.u1等层的输入通道数,使其与filters列表中的对应元素一致。 - 在实例化
Dense_Unet时,将filters参数修改为[64, 128, 256, 512, 1024],以满足输入数据的第二通道由64到128到256到512再到1024的要求
原文地址: https://www.cveoy.top/t/topic/iToe 著作权归作者所有。请勿转载和采集!