PyTorch 数据集划分:训练集与验证集的分割方法
当划分训练集和验证集时,可以使用torch.utils.data.random_split函数将数据集按比例划分为训练集和验证集。以下是一个示例代码:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
# 定义自定义数据集类
class MyDataset(Dataset):
def __init__(self):
# 加载数据集,初始化数据
def __getitem__(self, index):
# 获取数据集中的一个样本
def __len__(self):
# 获取数据集的样本数量
# 创建自定义数据集实例
dataset = MyDataset()
# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 创建训练集和验证集的数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# 在训练过程中使用训练集和验证集的数据加载器进行训练和验证
for epoch in range(num_epochs):
for batch_data in train_loader:
# 在训练集上进行训练
for batch_data in val_loader:
# 在验证集上进行验证
上述代码中,首先定义了一个自定义的数据集类MyDataset,根据实际情况加载数据并实现__getitem__和__len__方法。
然后,使用random_split函数将数据集按比例划分为训练集和验证集,可以根据需要自行调整划分比例。
接下来,使用划分好的训练集和验证集创建数据加载器train_loader和val_loader,可以设置批量大小和是否打乱数据顺序。
最后,在训练过程中,使用训练集和验证集的数据加载器进行训练和验证。
希望这个示例能帮助到你,如果还有其他问题,请随时提问。
原文地址: https://www.cveoy.top/t/topic/j7I 著作权归作者所有。请勿转载和采集!