解决ResNet中'Add'操作的广播问题
在ResNet模型中,我们经常会遇到在'Add'操作时发生广播错误。这是因为'Add'操作要求两个输入张量的形状必须满足广播规则,即两个张量的对应维度必须相等,或者其中一个维度的长度为1。
以下代码展示了一个典型的ResNet残差块,以及在进行'Add'操作时遇到的广播错误:
class ResidualBlock(nn.Cell):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, pad_mode='valid')
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, pad_mode='valid')
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
self.stride = stride
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
ValueError: For 'Add', x_shape and y_shape are supposed to broadcast, where broadcast means that x_shape[i] = 1 or -1 or y_shape[i] = 1 or -1 or x_shape[i] = y_shape[i], but now x_shape and y_shape can not broadcast, got i: -2, x_shape: [25, 64, 19, 19], y_shape: [25, 64, 23, 23].
我们可以尝试对输入进行padding,使得x_shape和y_shape能够broadcast。
class ResidualBlock(nn.Cell):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
self.stride = stride
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
# padding identity to match the output shape
_, _, h, w = out.shape
_, _, ih, iw = identity.shape
if h != ih or w != iw:
padding_h = (h - ih) // 2
padding_w = (w - iw) // 2
identity = F.pad(identity, ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)))
out += identity
out = self.relu(out)
return out
通过在Conv2d层中添加padding=1,我们可以确保输入和输出的形状一致,从而解决广播问题。同时,在'Add'操作之前,对identity进行padding,可以保证其形状与out的形状一致,从而避免出现广播错误。
需要注意的是,padding的具体参数需要根据实际情况进行调整,以确保最终的输出形状符合要求。
原文地址: https://www.cveoy.top/t/topic/mH6u 著作权归作者所有。请勿转载和采集!