基于深度可分离卷积的高分辨率网络人脸关键点检测
------------------------------------------------------------------------------
Copyright (c) Microsoft
Licensed under the MIT License.
Create by Bin Xiao (Bin.Xiao@microsoft.com)
Modified by Tianheng Cheng(tianhengcheng@gmail.com), Yang Zhao
------------------------------------------------------------------------------
from future import absolute_import from future import division from future import print_function
import os import logging
import torch import torch.nn as nn import torch.nn.functional as F
BatchNorm2d = nn.BatchNorm2d BN_MOMENTUM = 0.01 logger = logging.getLogger(name)
def conv3x3(in_planes, out_planes, stride=1): '''3x3 深度可分离卷积''' return nn.Sequential( nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False), nn.BatchNorm2d(in_planes), nn.ReLU(inplace=True), nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_planes), nn.ReLU(inplace=True) )
class BasicBlock(nn.Module): expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module): expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = conv3x3(planes, planes, stride) # 使用深度可分离卷积替换
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = BatchNorm2d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
... (其余代码保持不变)
原文地址: https://www.cveoy.top/t/topic/fZwj 著作权归作者所有。请勿转载和采集!