PyTorch CNN with gMLP for Multi-class Classification
import torch
import torch.nn as nn
import torch.nn.functional as F
from gmlp import SpatialGatingUnit, gMLPBlock, gMLP
class CNN(nn.Module):
def __init__(self): # 初始化方法
super(CNN, self).__init__() # 初始化方法
# 第一层卷积层,输入为 (1,10000,12)
self.conv1 =nn.Conv2d( # 二维卷积层
in_channels=1, # input height # 输入通道数
out_channels=600, # n_filters # 输出通道数(卷积核数量)
kernel_size=(200, 3), # filter size # 卷积核大小
stride=(200, 3), # filter movement/step # 卷积核移动步
)
self.gmlp = gMLP(d_model=600, d_ffn=1200, seq_len=200, num_layers=6)
# gmlp(data).shape
self.flatten = nn.Flatten()
self.linear = nn.Linear(in_features=120000, out_features=64)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=0.4) # 将张量展平
self.out = nn.Linear(64, 6) # 全连接层,输出6个类别
# 定义一个名为CNN的类,继承自PyTorch的nn.Module
def forward(self, x): # 前向传播方法
x = self.conv1(x) # 通过第一个卷积层序列
# 四维转三维
N, C, H, W = x.shape
x = x.view(N, C, H * W).permute(0, 2, 1)
x = self.gmlp(x)
x = self.flatten(x) # 换顺序了 先展平 在输入线性层中
x = self.linear(x)
feature = x # 提取特征
x = self.relu(x)
x = self.dropout(x)
output = self.out(x) # 通过全连接层得到输出
return feature, output # 返回特征和输出
model = CNN()
This code defines a CNN model using PyTorch. The model consists of a convolutional layer, followed by a gMLP block, flattening, and a linear layer for output. The gMLP is used to introduce non-linearity and capture complex relationships between features. The model can be trained on your dataset for multi-class classification.
原文地址: https://www.cveoy.top/t/topic/pcT4 著作权归作者所有。请勿转载和采集!