PyTorch 中在三层卷积后添加 ResNet 模块的示例
以下是一个示例代码,展示了如何在三层卷积后加入 ResNet:
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.res1 = nn.Sequential(ResidualBlock(128, 128), ResidualBlock(128, 128))
self.res2 = nn.Sequential(ResidualBlock(128, 256, stride=2), ResidualBlock(256, 256))
self.res3 = nn.Sequential(ResidualBlock(256, 512, stride=2), ResidualBlock(512, 512))
self.fc = nn.Linear(512, 10)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.max_pool2d(out, 2)
out = F.relu(self.bn2(self.conv2(out)))
out = F.max_pool2d(out, 2)
out = F.relu(self.bn3(self.conv3(out)))
out = self.res1(out)
out = self.res2(out)
out = self.res3(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
这个示例代码定义了一个包含三个卷积层和三个 ResNet 块的神经网络。在每个 ResNet 块中,输入通过两个卷积层和一个跳跃连接运行,以形成输出。这个网络的最后一层是一个全连接层,用于将输出映射到分类标签。
原文地址: https://www.cveoy.top/t/topic/m1iC 著作权归作者所有。请勿转载和采集!