使用 ResNet 和 BiGRU 并行特征提取器进行多类别分类 (PyTorch 代码)

本代码演示了如何使用 ResNet 和 BiGRU 并行特征提取器,并使用特征融合和全连接层进行多类别分类。代码使用 PyTorch 实现,包含训练、验证和测试过程。

数据集描述:

  • 三个 CSV 数据集:训练集、验证集和测试集。
  • 每条数据包含 43 个特征,前 42 个为特征值,最后 1 个为类别标签。
  • 类别标签共 10 个类别,分别为 0,1,2,...,9。

模型架构:

  1. **ResNet 特征提取器:**使用预训练的 ResNet18 模型作为特征提取器。原始数据需先将 1 维特征转换为 2 维特征,然后输入 ResNet。
  2. **BiGRU 特征提取器:**使用 BiGRU 模型作为特征提取器,输入为原始数据的 42 个特征。
  3. **特征融合:**将 ResNet 和 BiGRU 提取的特征进行拼接。
  4. **全连接层:**使用全连接层和 softmax 函数进行 10 类分类。

代码:

import torch
import torch.nn as nn
import torchvision.models as models

class ResNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(ResNetFeatureExtractor, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Identity()
        
    def forward(self, x):
        features = self.resnet(x)
        return features.view(features.size(0), -1)

class BiGRUFeatureExtractor(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BiGRUFeatureExtractor, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, bidirectional=True)
        
    def forward(self, x):
        x = x.permute(1, 0, 2) # Reshape input to (sequence_length, batch_size, input_size)
        _, hidden = self.gru(x)
        features = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        return features

class FusionClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(FusionClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.softmax(x, dim=1)
        return x

# Define the model
resnet_extractor = ResNetFeatureExtractor()
bigru_extractor = BiGRUFeatureExtractor(input_size=42, hidden_size=64)
input_size = resnet_extractor.resnet.fc.in_features + bigru_extractor.gru.hidden_size * 2
fusion_classifier = FusionClassifier(input_size, num_classes=10)

# Load and preprocess the data
# ...
# Assuming you have loaded the data into train_dataset, val_dataset, test_dataset

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fusion_classifier.parameters(), lr=0.001)

# Training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fusion_classifier.to(device)
for epoch in range(num_epochs):
    fusion_classifier.train()
    running_loss = 0.0
    for inputs, labels in train_dataset:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Forward pass
        resnet_features = resnet_extractor(inputs.unsqueeze(2))
        bigru_features = bigru_extractor(inputs)
        features = torch.cat((resnet_features, bigru_features), dim=1)
        outputs = fusion_classifier(features)
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Validation
    fusion_classifier.eval()
    with torch.no_grad():
        val_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in val_dataset:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            resnet_features = resnet_extractor(inputs.unsqueeze(2))
            bigru_features = bigru_extractor(inputs)
            features = torch.cat((resnet_features, bigru_features), dim=1)
            outputs = fusion_classifier(features)
            
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs}: Training Loss: {running_loss:.3f}, Validation Loss: {val_loss:.3f}, Validation Accuracy: {accuracy:.2f}%")

# Test the model
fusion_classifier.eval()
with torch.no_grad():
    test_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in test_dataset:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        resnet_features = resnet_extractor(inputs.unsqueeze(2))
        bigru_features = bigru_extractor(inputs)
        features = torch.cat((resnet_features, bigru_features), dim=1)
        outputs = fusion_classifier(features)
        
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Test Loss: {test_loss:.3f}, Test Accuracy: {accuracy:.2f}%")

注意:

  • 上述代码中的数据加载、预处理和训练循环的部分需要根据实际数据集和需求进行修改。
  • 确保已安装了 PyTorch 和 torchvision 库。

其他优化建议:

  • 可以尝试使用不同的 ResNet 模型,例如 ResNet50 或 ResNet101,以获得更好的性能。
  • 可以调整 BiGRU 的隐藏层大小和层数,以优化模型。
  • 可以尝试使用不同的优化器,例如 SGD 或 AdamW。
  • 可以使用数据增强技术,例如旋转、缩放和裁剪,来提高模型的鲁棒性。
  • 可以使用学习率衰减策略,以加速训练过程。
  • 可以使用早停法,以防止模型过拟合。

本代码提供了一个使用 ResNet 和 BiGRU 进行多类别分类的基本框架,您可以根据实际需求进行修改和优化。

ResNet 和 BiGRU 并行特征提取器用于多类别分类 (PyTorch 代码)

原文地址: http://www.cveoy.top/t/topic/iU8L 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录