ResNet 和 BiGRU 并行特征提取器用于多类别分类 (PyTorch 代码)
使用 ResNet 和 BiGRU 并行特征提取器进行多类别分类 (PyTorch 代码)
本代码演示了如何使用 ResNet 和 BiGRU 并行特征提取器,并使用特征融合和全连接层进行多类别分类。代码使用 PyTorch 实现,包含训练、验证和测试过程。
数据集描述:
- 三个 CSV 数据集:训练集、验证集和测试集。
- 每条数据包含 43 个特征,前 42 个为特征值,最后 1 个为类别标签。
- 类别标签共 10 个类别,分别为 0,1,2,...,9。
模型架构:
- **ResNet 特征提取器:**使用预训练的 ResNet18 模型作为特征提取器。原始数据需先将 1 维特征转换为 2 维特征,然后输入 ResNet。
- **BiGRU 特征提取器:**使用 BiGRU 模型作为特征提取器,输入为原始数据的 42 个特征。
- **特征融合:**将 ResNet 和 BiGRU 提取的特征进行拼接。
- **全连接层:**使用全连接层和 softmax 函数进行 10 类分类。
代码:
import torch
import torch.nn as nn
import torchvision.models as models
class ResNetFeatureExtractor(nn.Module):
def __init__(self):
super(ResNetFeatureExtractor, self).__init__()
self.resnet = models.resnet18(pretrained=True)
self.resnet.fc = nn.Identity()
def forward(self, x):
features = self.resnet(x)
return features.view(features.size(0), -1)
class BiGRUFeatureExtractor(nn.Module):
def __init__(self, input_size, hidden_size):
super(BiGRUFeatureExtractor, self).__init__()
self.gru = nn.GRU(input_size, hidden_size, bidirectional=True)
def forward(self, x):
x = x.permute(1, 0, 2) # Reshape input to (sequence_length, batch_size, input_size)
_, hidden = self.gru(x)
features = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
return features
class FusionClassifier(nn.Module):
def __init__(self, input_size, num_classes):
super(FusionClassifier, self).__init__()
self.fc1 = nn.Linear(input_size, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.softmax(x, dim=1)
return x
# Define the model
resnet_extractor = ResNetFeatureExtractor()
bigru_extractor = BiGRUFeatureExtractor(input_size=42, hidden_size=64)
input_size = resnet_extractor.resnet.fc.in_features + bigru_extractor.gru.hidden_size * 2
fusion_classifier = FusionClassifier(input_size, num_classes=10)
# Load and preprocess the data
# ...
# Assuming you have loaded the data into train_dataset, val_dataset, test_dataset
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fusion_classifier.parameters(), lr=0.001)
# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fusion_classifier.to(device)
for epoch in range(num_epochs):
fusion_classifier.train()
running_loss = 0.0
for inputs, labels in train_dataset:
inputs = inputs.to(device)
labels = labels.to(device)
# Forward pass
resnet_features = resnet_extractor(inputs.unsqueeze(2))
bigru_features = bigru_extractor(inputs)
features = torch.cat((resnet_features, bigru_features), dim=1)
outputs = fusion_classifier(features)
# Compute loss
loss = criterion(outputs, labels)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# Validation
fusion_classifier.eval()
with torch.no_grad():
val_loss = 0.0
correct = 0
total = 0
for inputs, labels in val_dataset:
inputs = inputs.to(device)
labels = labels.to(device)
resnet_features = resnet_extractor(inputs.unsqueeze(2))
bigru_features = bigru_extractor(inputs)
features = torch.cat((resnet_features, bigru_features), dim=1)
outputs = fusion_classifier(features)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Epoch {epoch+1}/{num_epochs}: Training Loss: {running_loss:.3f}, Validation Loss: {val_loss:.3f}, Validation Accuracy: {accuracy:.2f}%")
# Test the model
fusion_classifier.eval()
with torch.no_grad():
test_loss = 0.0
correct = 0
total = 0
for inputs, labels in test_dataset:
inputs = inputs.to(device)
labels = labels.to(device)
resnet_features = resnet_extractor(inputs.unsqueeze(2))
bigru_features = bigru_extractor(inputs)
features = torch.cat((resnet_features, bigru_features), dim=1)
outputs = fusion_classifier(features)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Loss: {test_loss:.3f}, Test Accuracy: {accuracy:.2f}%")
注意:
- 上述代码中的数据加载、预处理和训练循环的部分需要根据实际数据集和需求进行修改。
- 确保已安装了 PyTorch 和 torchvision 库。
其他优化建议:
- 可以尝试使用不同的 ResNet 模型,例如 ResNet50 或 ResNet101,以获得更好的性能。
- 可以调整 BiGRU 的隐藏层大小和层数,以优化模型。
- 可以尝试使用不同的优化器,例如 SGD 或 AdamW。
- 可以使用数据增强技术,例如旋转、缩放和裁剪,来提高模型的鲁棒性。
- 可以使用学习率衰减策略,以加速训练过程。
- 可以使用早停法,以防止模型过拟合。
本代码提供了一个使用 ResNet 和 BiGRU 进行多类别分类的基本框架,您可以根据实际需求进行修改和优化。
原文地址: http://www.cveoy.top/t/topic/iU8L 著作权归作者所有。请勿转载和采集!