在ResNet模型中,我们经常会遇到在'Add'操作时发生广播错误。这是因为'Add'操作要求两个输入张量的形状必须满足广播规则,即两个张量的对应维度必须相等,或者其中一个维度的长度为1。

以下代码展示了一个典型的ResNet残差块,以及在进行'Add'操作时遇到的广播错误:

class ResidualBlock(nn.Cell):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,   pad_mode='valid')
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,   pad_mode='valid')
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def construct(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
ValueError: For 'Add', x_shape and y_shape are supposed to broadcast, where broadcast means that x_shape[i] = 1 or -1 or y_shape[i] = 1 or -1 or x_shape[i] = y_shape[i], but now x_shape and y_shape can not broadcast, got i: -2, x_shape: [25, 64, 19, 19], y_shape: [25, 64, 23, 23].

我们可以尝试对输入进行padding,使得x_shapey_shape能够broadcast。

class ResidualBlock(nn.Cell):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def construct(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        # padding identity to match the output shape
        _, _, h, w = out.shape
        _, _, ih, iw = identity.shape
        if h != ih or w != iw:
            padding_h = (h - ih) // 2
            padding_w = (w - iw) // 2
            identity = F.pad(identity, ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)))

        out += identity
        out = self.relu(out)

        return out

通过在Conv2d层中添加padding=1,我们可以确保输入和输出的形状一致,从而解决广播问题。同时,在'Add'操作之前,对identity进行padding,可以保证其形状与out的形状一致,从而避免出现广播错误。

需要注意的是,padding的具体参数需要根据实际情况进行调整,以确保最终的输出形状符合要求。

解决ResNet中'Add'操作的广播问题

原文地址: https://www.cveoy.top/t/topic/mH6u 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录