gMLP 模型架构:用于对比 CNN-gMLP 架构的基准模型
import torch
import torch.nn as nn
class SpatialGatingUnit(nn.Module):
def __init__(self, d_model, seq_len):
super(SpatialGatingUnit, self).__init__()
self.fc = nn.Linear(d_model, seq_len)
def forward(self, x):
return torch.sigmoid(self.fc(x)) * x
class gMLPBlock(nn.Module):
def __init__(self, d_model, d_ffn, seq_len):
super(gMLPBlock, self).__init__()
self.fc1 = nn.Linear(d_model, d_ffn)
self.fc2 = nn.Linear(d_ffn, d_model)
self.sgu = SpatialGatingUnit(d_model, seq_len)
def forward(self, x):
x = self.fc1(x)
x = self.sgu(x)
x = self.fc2(x)
return x
class gMLP(nn.Module):
def __init__(self, d_model, d_ffn, seq_len, num_layers):
super(gMLP, self).__init__()
self.gmlp_blocks = nn.ModuleList([gMLPBlock(d_model, d_ffn, seq_len) for _ in range(num_layers)])
def forward(self, x):
for block in self.gmlp_blocks:
x = block(x)
return x
您可以用此框架替换原始代码中的 gMLP 类。输入和输出形状将保持不变。
该 gMLP 模型架构可以作为 CNN-gMLP 架构的基准模型,以便比较它们的性能。gMLP 模型采用了一种新的注意力机制,在无需卷积层的情况下,仍然能够很好地提取特征。
原文地址: https://www.cveoy.top/t/topic/bKSA 著作权归作者所有。请勿转载和采集!