HighResolutionNet 代码优化:深度可分离卷积、残差连接和批量归一化
一种改进方式是使用深度可分离卷积代替传统的卷积操作,以减少模型参数和计算量。另外,可以使用残差连接和批量归一化来加速收敛和提高模型性能。
改进后的代码如下:
import torch.nn.functional as F
class HighResolutionNet(nn.Module):
def __init__(self, config, **kwargs):
self.inplanes = 64
extra = config.MODEL.EXTRA
super(HighResolutionNet, self).__init__()
# stem net
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
bias=False)
self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.sf = nn.Softmax(dim=1)
self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
self.stage2_cfg = extra['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition1 = self._make_transition_layer(
[256], num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
self.stage3_cfg = extra['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition2 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
self.stage4_cfg = extra['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion for i in range(len(num_channels))]
self.transition3 = self._make_transition_layer(
pre_stage_channels, num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True)
final_inp_channels = sum(pre_stage_channels)
self.head = nn.Sequential(
nn.Conv2d(
in_channels=final_inp_channels,
out_channels=final_inp_channels,
kernel_size=1,
stride=1,
padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0),
BatchNorm2d(final_inp_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(
in_channels=final_inp_channels,
out_channels=config.MODEL.NUM_JOINTS,
kernel_size=extra.FINAL_CONV_KERNEL,
stride=1,
padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0)
)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(inplanes, planes, stride, downsample))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(inplanes, planes))
return nn.Sequential(*layers)
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_pre = len(num_channels_pre_layer)
num_branches_cur = len(num_channels_cur_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i],
kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(num_channels_cur_layer[i], momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))) else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i+1-num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i-num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False),
nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True))) transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_stage(self, layer_config, num_inchannels,
multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
fuse_method = layer_config['FUSE_METHOD']
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(HighResolutionModule(
num_branches,
block,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
reset_multi_scale_output
))
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(DepthwiseSeparableConv, self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
out = self.bn(out)
return out
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output):
super(HighResolutionModule, self).__init__()
self.fuse_method = fuse_method
self.num_branches = num_branches
self.block = block
self.num_blocks = num_blocks
self.num_inchannels = num_inchannels
self.num_channels = num_channels
self.reset_multi_scale_output = reset_multi_scale_output
self.branches = self._make_branches(self.num_branches, self.block, self.num_blocks, self.num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=True)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
layers = []
layers.append(block(self.num_inchannels[branch_index],
num_channels,
stride))
self.num_inchannels[branch_index] = num_channels * block.expansion
for i in range(1, num_blocks):
layers.append(block(self.num_inchannels[branch_index],
num_channels))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
fuse_layers = []
for i in range(self.num_branches if self.fuse_method == 'SUM' else 1):
fuse_layer = []
for j in range(self.num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
DepthwiseSeparableConv(self.num_channels[j], self.num_channels[i], 3, 2, 1),
nn.BatchNorm2d(self.num_channels[i], momentum=BN_MOMENTUM)))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i-j):
inchannels = self.num_inchannels[j]
outchannels = self.num_channels[i] if k == i-j-1 else inchannels
conv3x3s.append(nn.Sequential(
DepthwiseSeparableConv(inchannels, outchannels, 3, 2, 1),
nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
if self.fuse_method == 'SUM':
y = sum(x)
elif self.fuse_method == 'CONCAT':
y = []
for i in range(self.num_branches):
if i == 0:
y.append(x[i])
elif i > 0 and i < self.num_branches - 1:
y.append(F.interpolate(x[i], size=x[0].shape[2:], mode='bilinear', align_corners=True))
else:
y.append(x[i])
y = torch.cat(y, 1)
else:
raise NotImplementedError
if self.fuse_layers is not None:
if self.reset_multi_scale_output:
self.multi_scale_output = []
for i in range(len(self.fuse_layers)):
y = self.fuse_layers[i][0](y)
for j in range(len(self.fuse_layers[i])-1):
if self.fuse_layers[i][j+1] is not None:
if not self.reset_multi_scale_output:
self.multi_scale_output[j] = self.fuse_layers[i][j+1](self.multi_scale_output[j])
else:
self.multi_scale_output.append(self.fuse_layers[i][j+1](x[j+1]))
if not self.reset_multi_scale_output:
return self.multi_scale_output + [y]
else:
return [y]
原文地址: https://www.cveoy.top/t/topic/n3qJ 著作权归作者所有。请勿转载和采集!