import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torchvision.datasets import CIFAR100 from torch.utils.data import DataLoader from torchvision.models import resnet101

设置参数

num_classes = 100 # CIFAR-100数据集的类别数 batch_size = 32

数据预处理

data_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ])

加载CIFAR-100数据集

train_dataset = CIFAR100(root='path/to/train_data', train=True, download=True, transform=data_transforms) val_dataset = CIFAR100(root='path/to/val_data', train=False, download=True, transform=data_transforms)

创建数据加载器

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size)

加载预训练的ResNet-101模型

model = resnet101(pretrained=True)

将最后一层全连接层替换为新的分类器

model.fc = nn.Linear(model.fc.in_features, num_classes)

定义损失函数

criterion = nn.CrossEntropyLoss()

定义优化器

optimizer = optim.Adam(model.parameters(), lr=0.001)

定义学习率调度器

lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device)

num_epochs = 10 for epoch in range(num_epochs): model.train() running_loss = 0.0 for batch_idx, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device)

    optimizer.zero_grad()

    outputs = model(images)
    loss = criterion(outputs, labels)

    loss.backward()
    optimizer.step()

    running_loss += loss.item()

epoch_loss = running_loss / len(train_loader)

print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(val_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

val_accuracy = 100 * correct / total
print(f'Validation Accuracy: {val_accuracy:.2f}%')

lr_scheduler.step()
使用PyTorch进行CNN图像处理:ResNet-101模型训练与CIFAR-100数据集应用

原文地址: http://www.cveoy.top/t/topic/byuC 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录