import torch
import torch.nn as nn
import torch.nn.functional as F
from gmlp import SpatialGatingUnit, gMLPBlock, gMLP


class CNN(nn.Module):
    def __init__(self):  # 初始化方法
        super(CNN, self).__init__()  # 初始化方法
        # 第一层卷积层,输入为 (1,10000,12)
        self.conv1 =nn.Conv2d(  # 二维卷积层
            in_channels=1,  # input height   # 输入通道数
            out_channels=600,  # n_filters     # 输出通道数(卷积核数量)
            kernel_size=(200, 3),  # filter size  # 卷积核大小
            stride=(200, 3),  # filter movement/step   # 卷积核移动步

        )

        self.gmlp = gMLP(d_model=600, d_ffn=1200, seq_len=200, num_layers=6)
        #  gmlp(data).shape

        self.flatten = nn.Flatten()

        self.linear = nn.Linear(in_features=120000, out_features=64)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=0.4)  # 将张量展平

        self.out = nn.Linear(64, 6)  # 全连接层,输出6个类别


# 定义一个名为CNN的类,继承自PyTorch的nn.Module

    def forward(self, x):  # 前向传播方法
        x = self.conv1(x)  # 通过第一个卷积层序列
        # 四维转三维
        N, C, H, W = x.shape
        x = x.view(N, C, H * W).permute(0, 2, 1)
        x = self.gmlp(x)
        x = self.flatten(x)  # 换顺序了 先展平 在输入线性层中
        x = self.linear(x)
        feature = x  # 提取特征
        x = self.relu(x)
        x = self.dropout(x)
        output = self.out(x)  # 通过全连接层得到输出
        return feature, output  # 返回特征和输出

model = CNN()

This code defines a CNN model using PyTorch. The model consists of a convolutional layer, followed by a gMLP block, flattening, and a linear layer for output. The gMLP is used to introduce non-linearity and capture complex relationships between features. The model can be trained on your dataset for multi-class classification.

PyTorch CNN with gMLP for Multi-class Classification

原文地址: https://www.cveoy.top/t/topic/pcT4 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录