以下是源代码到查看的时候发生了错误数据集的路径是kaggleinputshisuncatcat_12_testcat_12_test和kaggleinputshisuncatcat_12_traincat_12_train怎么修改以下的错误?源代码:# 数据科学包import random # 随机切分数据集import numpy as np
需要修改的地方是在读取数据集时的路径,需要改为/kaggle/input/shisuncat/cat_12_test/cat_12_test和/kaggle/input/shisuncat/cat_12_train/cat_12_train。具体修改如下:
按比例随机切割数据集
train_ratio = 0.9 # 训练集占0.9,验证集占0.1
train_paths, train_labels = [], [] valid_paths, valid_labels = [], [] with open('/kaggle/input/shisuncat/cat_12_train/cat_12_train_list.txt', 'r') as f: lines = f.readlines() for line in lines: if random.uniform(0, 1) < train_ratio: train_paths.append('/kaggle/input/shisuncat/cat_12_train/cat_12_train/'+line.split(' ')[0]) label = line.split(' ')[1] train_labels.append(int(line.split(' ')[1])) else: valid_paths.append('/kaggle/input/shisuncat/cat_12_train/cat_12_train/'+line.split(' ')[0]) valid_labels.append(int(line.split(' ')[1]))
定义训练数据集
class TrainData(Dataset): def init(self): super().init() self.color_jitter = T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05) self.normalize = T.Normalize(mean=0, std=1) self.random_crop = T.RandomCrop(224, pad_if_needed=True)
def __getitem__(self, index):
# 读取图片
image_path = train_paths[index]
image = np.array(Image.open(image_path)) # H, W, C
try:
image = image.transpose([2, 0, 1])[:3] # C, H, W
except:
image = np.array([image, image, image]) # C, H, W
# 图像增广
features = self.color_jitter(image.transpose([1, 2, 0]))
features = self.random_crop(features)
features = self.normalize(features.transpose([2, 0, 1])).astype(np.float32)
# 读取标签
labels = train_labels[index]
return features, labels
def __len__(self):
return len(train_paths)
定义验证数据集
class ValidData(Dataset): def init(self): super().init() self.normalize = T.Normalize(mean=0, std=1)
def __getitem__(self, index):
# 读取图片
image_path = valid_paths[index]
image = np.array(Image.open(image_path)) # H, W, C
try:
image = image.transpose([2, 0, 1])[:3] # C, H, W
except:
image = np.array([image, image, image]) # C, H, W
# 图像变换
features = cv2.resize(image.transpose([1, 2, 0]), (256, 256)).transpose([2, 0, 1]).astype(np.float32)
features = self.normalize(features)
# 读取标签
labels = valid_labels[index]
return features, labels
def __len__(self):
return len(valid_paths)
train_data = TrainData() img, labels = train_data.getitem(98) plt.figure(dpi=40,figsize=(16,16)) plt.imshow(img.astype(np.uint8).transpose([1, 2, 0])) plt.show(
原文地址: http://www.cveoy.top/t/topic/hehF 著作权归作者所有。请勿转载和采集!