from torchvision import datasets
from xml.etree.ElementInclude import default_loader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import os


class MNISTDataset(Dataset):
    def __init__(self, image_folder, label_folder, transform=None):
        self.image_folder = image_folder
        self.label_folder = label_folder
        self.transform = transform

        # 从文件夹中获取图像和标签文件列表
        self.image_filenames, self.label_filenames = self.get_filenames()

    def __len__(self):
        return len(self.label_filenames)

    def __getitem__(self, index):
        # 加载图像和标签
        image = Image.open(self.image_filenames[index])
        label = self.load_label(self.label_filenames[index])

        # 应用预处理
        if self.transform:
            image = self.transform(image)

        return image, label

    def get_filenames(self):
        image_filenames = []
        label_filenames = []

        # 遍历图像文件夹中的文件
        for filename in os.listdir(self.image_folder):
            if filename.endswith(".png"):
                image_filenames.append(
                    os.path.join(self.image_folder, filename))

        # 遍历标签文件夹中的文件
        for filename in os.listdir(self.label_folder):
            if filename.endswith(".txt"):
                label_filenames.append(
                    os.path.join(self.label_folder, filename))

        # 根据文件名进行排序以确保图像和标签的对应关系
        image_filenames.sort()
        label_filenames.sort()

        return image_filenames, label_filenames

    def load_label(self, filename):
        # 根据标签文件的格式进行相应的处理,这里假设标签文件是文本文件
        with open(filename, "r") as f:
            label = f.read().strip()  # 读取标签内容

        # 根据需要进行具体的标签处理

        return label


# 创建训练数据集对象
train_image_folder = './data/train'
train_label_folder = './data/train-labels'

dataset = MNISTDataset(
    train_image_folder, train_label_folder, transform=transforms)


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(7*7*32, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 1)  # 修改为1个输出类别

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x


def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(train_loader)


def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total


def main():
    train_folder = '.\data\train\'
    validation_folder = '.\data\validation\'
    batch_size = 64
    num_epochs = 10
    learning_rate = 0.001
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.DatasetFolder(
        train_folder, loader=default_loader, extensions='.png', transform=transform)
    validation_dataset = datasets.DatasetFolder(
        validation_folder, loader=default_loader, extensions='.png', transform=transform)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(
        validation_dataset, batch_size=batch_size, shuffle=False)

    model = CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        validation_acc = test(model, validation_loader, device)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Accuracy: {validation_acc:.4f}')

    torch.save(model.state_dict(), './train/model.h5')


if __name__ == '__main__':
    main()

请注意,这段代码仅适用于只有一个类别的训练集和测试集,且训练、测试集的文件名是按顺序编号的(例如image_0.png,image_1.png)。如果您的训练、测试集不符合这些条件,请提供更多关于数据集的详细信息,以便我能够帮助您进行修改。

MNIST 数据集分类模型训练和评估

原文地址: https://www.cveoy.top/t/topic/pY6H 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录