CNN-gMLP模型解析:相比普通CNN,性能提升的关键是什么?
CNN-gMLP模型解析:相比普通CNN,性能提升的关键是什么?
本文将介绍一种结合了卷积神经网络 (CNN) 和 gMLP 模块的模型——CNN-gMLP,并解释它为什么比普通CNN模型性能更优。
CNN模型回顾
首先,我们回顾一下普通的CNN模型代码:pythonimport torch.nn as nn
class CNN(nn.Module): def init(self): super(CNN, self).init() self.conv1 = nn.Sequential( # input shape (1,10000,12) nn.Conv2d( in_channels=1, # input height out_channels=5, # n_filters kernel_size=(200, 3), # filter size stride=(50, 1), # filter movement/step padding=1, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2, padding=1), ) self.conv2 = nn.Sequential( # input shape (5,99,7) nn.Conv2d(5, 10, (20, 2), (4, 1), 1), # output shape nn.ReLU(), # activation nn.MaxPool2d(kernel_size=2), # output shape (10,10,4) ) self.out = nn.Linear(10 * 10 * 4, 6) # fully connected layer, output 6 classes
def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) feature = x output = self.out(x) return feature, output
CNN-gMLP模型详解
下面是CNN-gMLP模型的代码:pythonimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom gmlp import SpatialGatingUnit ,gMLPBlock,gMLP
class CNN(nn.Module): def init(self): super(CNN, self).init() # 第一层卷积层,输入为 (1,10000,12) self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=5, kernel_size=(200, 3), stride=(50, 1), padding=1, ), nn.ReLU(), nn.MaxPool2d(kernel_size=2, padding=1), ) # 第二层卷积层,输入为 (5,99,7) self.conv2 = nn.Sequential( nn.Conv2d(5, 10, (20, 2), (4, 1), 1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), )
self.gmlp = gMLP(d_model=10, d_ffn=20, seq_len=40, num_layers=6) self.flatten = nn.Flatten()
self.linear = nn.Linear(in_features=400, out_features=64) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(p=0.4) self.out = nn.Linear(64, 6)
def forward(self, x): x = self.conv1(x) x = self.conv2(x) # x = x.view(x.size(0), -1) N, C, H, W = x.shape x = x.view(N, C, H * W).permute(0, 2, 1) x = self.gmlp(x) x = self.flatten(x) x = self.linear(x) feature = x x = self.relu(x) x = self.dropout(x) output = self.out(x) return feature, output
CNN-gMLP优势分析
相比于普通的CNN模型,CNN-gMLP模型主要通过引入gMLP模块带来了以下优势:
-
处理长序列: gMLP模块可以有效处理长序列数据,避免了普通CNN模型在处理长序列时容易出现的信息丢失问题。
-
捕捉全局信息: gMLP模块采用全局自注意力机制,能够捕捉到序列中的全局信息,而普通CNN模型只能通过局部感受野获取信息。
-
参数效率: gMLP模块使用全局池化操作,可以有效减少参数数量,提高模型的参数效率。
-
防止过拟合: CNN-gMLP模型中引入了Dropout层,可以有效防止模型过拟合,提高模型的泛化能力。
总结
总而言之,CNN-gMLP模型通过引入gMLP模块,解决了普通CNN模型在处理长序列数据时存在的一些问题,从而提升了模型性能。 gMLP模块的全局信息捕捉能力、参数效率和防止过拟合的能力,都为CNN-gMLP模型的性能提升做出了重要贡献
原文地址: https://www.cveoy.top/t/topic/ftOf 著作权归作者所有。请勿转载和采集!