PyTorch 数据集划分:训练集和验证集
PyTorch 数据集划分:训练集和验证集
在深度学习中,将数据集划分为训练集和验证集是至关重要的。训练集用于训练模型,而验证集用于评估模型的性能并调整超参数。
以下代码演示了如何使用 PyTorch 将 input_tensors 中的数据划分为训练集和验证集:
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataset import random_split
# 指定pt文件的路径和训练数据
pt_file_path = 'path_to_your_pt_file.pt'
data = torch.load(pt_file_path)
# 将数据转换为张量
input_tensors = [torch.tensor(tensor) for tensor in data]
# 创建对应的标签张量
labels = [torch.tensor([1, 0, 0, 0]) if i == 0 else
torch.tensor([0, 1, 0, 0]) if i == 1 else
torch.tensor([0, 0, 1, 0]) if i == 2 else
torch.tensor([1, 1, 1, 1]) for i in range(len(data))]
# 划分训练集和验证集
train_size = int(0.8 * len(input_tensors))
val_size = len(input_tensors) - train_size
train_dataset, val_dataset = random_split(list(zip(input_tensors, labels)), [train_size, val_size])
# 创建训练集和验证集的数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# 在训练过程中使用训练集和验证集的数据加载器进行训练和验证
for epoch in range(num_epochs):
for batch_data, batch_labels in train_loader:
# 在训练集上进行训练
# ...
for batch_data, batch_labels in val_loader:
# 在验证集上进行验证
# ...
代码解释:
- 加载数据: 首先,使用
torch.load()加载 pt 文件中的数据。 - 创建张量: 将数据转换为 PyTorch 张量,以便在模型中使用。
- 创建标签: 根据您的需求创建对应的标签张量。
- 划分数据集: 使用
random_split()函数将数据划分为训练集和验证集。 - 创建 DataLoader: 使用
DataLoader创建训练集和验证集的数据加载器,以便在训练过程中方便地加载数据。
使用说明:
- 将
'path_to_your_pt_file.pt'替换为您的 pt 文件路径。 - 根据实际情况调整
batch_size和其他参数。
通过使用此代码,您可以轻松地将您的 PyTorch 数据集划分为训练集和验证集,为模型训练和评估做好准备。
原文地址: https://www.cveoy.top/t/topic/j9b 著作权归作者所有。请勿转载和采集!