用pytorch的CBAM的代码实现
以下是使用PyTorch实现CBAM模块的示例代码:
import torch
import torch.nn as nn
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, in_planes, ratio=16, kernel_size=7):
super(CBAM, self).__init__()
self.channel_att = ChannelAttention(in_planes, ratio)
self.spatial_att = SpatialAttention(kernel_size)
def forward(self, x):
x = x * self.channel_att(x)
x = x * self.spatial_att(x)
return x
在上面的代码中,我们首先定义了一个称为ChannelAttention的类,它实现了通道注意力模块。该模块包括一个自适应平均池化层和一个自适应最大池化层,两个卷积层和一个sigmoid函数。在前向传递过程中,我们首先将输入张量分别传递到平均池化层和最大池化层中,然后将两个池化层的输出分别传递到两个卷积层中,并对其进行ReLU激活。然后将这两个卷积层的输出相加,并通过sigmoid函数进行输出。在这个过程中,我们使用的是in_planes和ratio两个超参数,其中in_planes表示输入张量的通道数,ratio表示通道注意力模块中降低通道数的因子。
接下来,我们定义了一个称为SpatialAttention的类,它实现了空间注意力模块。该模块包括一个2通道的卷积层和一个sigmoid函数。在前向传递过程中,我们首先将输入张量分别求出其平均值和最大值,并将这两个张量拼接起来。然后将这个拼接张量传递到2通道的卷积层中,并通过sigmoid函数进行输出。在这个过程中,我们使用的是kernel_size这个超参数,它表示卷积核的大小。
最后,我们定义了一个称为CBAM的类,它将通道注意力模块和空间注意力模块组合起来。在前向传递过程中,我们首先将输入张量分别传递到通道注意力模块和空间注意力模块中,并将它们的输出相乘作为最终输出。在这个过程中,我们使用的是in_planes、ratio和kernel_size这三个超参数
原文地址: https://www.cveoy.top/t/topic/cpFd 著作权归作者所有。请勿转载和采集!