import torch.nn as nn


def conv_bn(inp, oup, stride = 1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6()
    )
    
def conv_dw(inp, oup, stride = 1):
    return nn.Sequential(
        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.ReLU6(),

        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(),
    )

class MobileNetV1(nn.Module):
    def __init__(self):
        super(MobileNetV1, self).__init__()
        self.stage1 = nn.Sequential(
            # 160,160,3 -> 80,80,32
            conv_bn(3, 32, 2), 
            # 80,80,32 -> 80,80,64
            conv_dw(32, 64, 1), 

            # 80,80,64 -> 40,40,128
            conv_dw(64, 128, 2),
            conv_dw(128, 128, 1),

            # 40,40,128 -> 20,20,256
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
        )
        self.stage2 = nn.Sequential(
            # 20,20,256 -> 10,10,512
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
        )
        self.stage3 = nn.Sequential(
            # 10,10,512 -> 5,5,1024
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
        )

        self.avg = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(1024, 1000)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.avg(x)
        # x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

这段代码使用PyTorch实现了一个简化版的MobileNetV1模型。

模型结构

该模型包含三个主要阶段(stage1, stage2, stage3),每个阶段都包含多个深度可分离卷积块(conv_dw)。深度可分离卷积块是MobileNet系列模型的核心,它将标准卷积分解为深度卷积和逐点卷积,有效减少了模型参数和计算量,使得模型更加轻量级。

代码解析

  • conv_bn函数定义了一个包含卷积、批归一化和ReLU6激活函数的模块。
  • conv_dw函数定义了一个深度可分离卷积块,它包含两个卷积层:第一个是深度卷积,对每个输入通道分别进行卷积;第二个是逐点卷积,使用1x1卷积核将深度卷积的输出通道进行组合。
  • MobileNetV1类定义了整个模型结构,包括三个阶段的卷积块、自适应平均池化层和全连接层。
  • forward方法定义了模型的前向传播过程。

模型应用

这个简化的MobileNetV1模型可以用于图像分类任务。由于其轻量级的特性,它特别适用于移动设备或嵌入式系统等资源受限的场景。

进一步学习

PyTorch实现精简版MobileNetV1:高效的卷积神经网络模型

原文地址: https://www.cveoy.top/t/topic/mHE 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录