从本地文件夹中读取Fashion-MNIST数据以CSV格式读取训练集和测试集分别读取并构建Dataset。由第1步的Dataset生成Dataloader。遍历两次测试用的Dataloader每次以绘制出前10个样本并标出对应的类别名称。二值化:选10张图像进行二值化选择合适的阀值将每个像素用0或1表示。边缘检测:使用2维卷积运算对二值化后的10张图像进行处理实现边缘检测。降分辨率:任选一张原图
以下是代码实现:
导入相关库
import pandas as pd import numpy as np import matplotlib.pyplot as plt import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image
读取数据集
train_df = pd.read_csv('fashion-mnist_train.csv') test_df = pd.read_csv('fashion-mnist_test.csv')
构建Dataset
class FashionMNISTDataset(Dataset): def init(self, df, transform=None): self.df = df self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
label = self.df.iloc[idx, 0]
img = self.df.iloc[idx, 1:].values.astype(np.uint8).reshape((28, 28, 1))
if self.transform:
img = self.transform(img)
return img, label
定义变换
transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomRotation(90), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor() ])
构建Dataset和Dataloader
train_dataset = FashionMNISTDataset(train_df, transform=transform) test_dataset = FashionMNISTDataset(test_df) train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=10)
遍历测试用的Dataloader,绘制前10个样本
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] for images, labels in test_dataloader: fig, axes = plt.subplots(1, 10, figsize=(20, 2)) for i in range(10): axes[i].imshow(images[i][0], cmap='gray') axes[i].set_title(classes[labels[i]]) axes[i].axis('off') plt.show()
二值化
threshold = 128 for images, labels in test_dataloader: for i in range(10): img = images[i][0].numpy() img[img < threshold] = 0 img[img >= threshold] = 1 plt.imshow(img, cmap='gray') plt.title(classes[labels[i]]) plt.axis('off') plt.show()
边缘检测
kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) for images, labels in test_dataloader: for i in range(10): img = images[i][0].numpy() img = np.expand_dims(img, axis=0) img = torch.from_numpy(img).float() conv = torch.nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) conv.weight.data = torch.from_numpy(kernel).float().unsqueeze(0).unsqueeze(0) out = conv(img).detach().numpy() out = np.squeeze(out, axis=0) plt.imshow(out, cmap='gray') plt.title(classes[labels[i]]) plt.axis('off') plt.show()
降分辨率
for images, labels in test_dataloader: img = images[0][0].numpy() img = np.expand_dims(img, axis=0) img = torch.from_numpy(img).float() pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) out = pool(img).detach().numpy() out = np.squeeze(out, axis=0) plt.imshow(out, cmap='gray') plt.title(classes[labels[0]]) plt.axis('off') plt.show()
升分辨率
for images, labels in test_dataloader: img = images[0][0].numpy() img = np.expand_dims(img, axis=0) img = torch.from_numpy(img).float() upsample = torch.nn.Upsample(scale_factor=2, mode='nearest') out = upsample(img).detach().numpy() out = np.squeeze(out, axis=0) plt.imshow(out, cmap='gray') plt.title(classes[labels[0]]) plt.axis('off') plt.show()
旋转和翻转
img = test_dataset[0][0].numpy() plt.imshow(img.squeeze(), cmap='gray') plt.title(classes[test_dataset[0][1]]) plt.axis('off') plt.show()
rot_90 = transforms.RandomRotation(90) img_rot_90 = rot_90(Image.fromarray(np.uint8(img.squeeze()*255))) plt.imshow(img_rot_90, cmap='gray') plt.title('Rotated 90 degrees') plt.axis('off') plt.show()
flip_ud = transforms.RandomVerticalFlip() img_flip_ud = flip_ud(Image.fromarray(np.uint8(img.squeeze()*255))) plt.imshow(img_flip_ud, cmap='gray') plt.title('Flipped up-down') plt.axis('off') plt.show()
flip_lr = transforms.RandomHorizontalFlip() img_flip_lr = flip_lr(Image.fromarray(np.uint8(img.squeeze()*255))) plt.imshow(img_flip_lr, cmap='gray') plt.title('Flipped left-right') plt.axis('off') plt.show(
原文地址: https://www.cveoy.top/t/topic/dxhi 著作权归作者所有。请勿转载和采集!