PyTorch实现精简版MobileNetV1:高效的卷积神经网络模型
import torch.nn as nn
def conv_bn(inp, oup, stride = 1):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6()
)
def conv_dw(inp, oup, stride = 1):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU6(),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(),
)
class MobileNetV1(nn.Module):
def __init__(self):
super(MobileNetV1, self).__init__()
self.stage1 = nn.Sequential(
# 160,160,3 -> 80,80,32
conv_bn(3, 32, 2),
# 80,80,32 -> 80,80,64
conv_dw(32, 64, 1),
# 80,80,64 -> 40,40,128
conv_dw(64, 128, 2),
conv_dw(128, 128, 1),
# 40,40,128 -> 20,20,256
conv_dw(128, 256, 2),
conv_dw(256, 256, 1),
)
self.stage2 = nn.Sequential(
# 20,20,256 -> 10,10,512
conv_dw(256, 512, 2),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
conv_dw(512, 512, 1),
)
self.stage3 = nn.Sequential(
# 10,10,512 -> 5,5,1024
conv_dw(512, 1024, 2),
conv_dw(1024, 1024, 1),
)
self.avg = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(1024, 1000)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.avg(x)
# x = self.model(x)
x = x.view(-1, 1024)
x = self.fc(x)
return x
这段代码使用PyTorch实现了一个简化版的MobileNetV1模型。
模型结构
该模型包含三个主要阶段(stage1, stage2, stage3),每个阶段都包含多个深度可分离卷积块(conv_dw)。深度可分离卷积块是MobileNet系列模型的核心,它将标准卷积分解为深度卷积和逐点卷积,有效减少了模型参数和计算量,使得模型更加轻量级。
代码解析
conv_bn函数定义了一个包含卷积、批归一化和ReLU6激活函数的模块。conv_dw函数定义了一个深度可分离卷积块,它包含两个卷积层:第一个是深度卷积,对每个输入通道分别进行卷积;第二个是逐点卷积,使用1x1卷积核将深度卷积的输出通道进行组合。MobileNetV1类定义了整个模型结构,包括三个阶段的卷积块、自适应平均池化层和全连接层。forward方法定义了模型的前向传播过程。
模型应用
这个简化的MobileNetV1模型可以用于图像分类任务。由于其轻量级的特性,它特别适用于移动设备或嵌入式系统等资源受限的场景。
进一步学习
- 阅读MobileNetV1论文:https://arxiv.org/abs/1704.04861
- 学习更多关于PyTorch的知识:https://pytorch.org/
原文地址: https://www.cveoy.top/t/topic/mHE 著作权归作者所有。请勿转载和采集!