PyTorch 数据集划分:训练集和验证集

在深度学习中,将数据集划分为训练集和验证集是至关重要的。训练集用于训练模型,而验证集用于评估模型的性能并调整超参数。

以下代码演示了如何使用 PyTorch 将 input_tensors 中的数据划分为训练集和验证集:

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.dataset import random_split

# 指定pt文件的路径和训练数据
pt_file_path = 'path_to_your_pt_file.pt'
data = torch.load(pt_file_path)

# 将数据转换为张量
input_tensors = [torch.tensor(tensor) for tensor in data]

# 创建对应的标签张量
labels = [torch.tensor([1, 0, 0, 0]) if i == 0 else
          torch.tensor([0, 1, 0, 0]) if i == 1 else
          torch.tensor([0, 0, 1, 0]) if i == 2 else
          torch.tensor([1, 1, 1, 1]) for i in range(len(data))]

# 划分训练集和验证集
train_size = int(0.8 * len(input_tensors))
val_size = len(input_tensors) - train_size
train_dataset, val_dataset = random_split(list(zip(input_tensors, labels)), [train_size, val_size])

# 创建训练集和验证集的数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# 在训练过程中使用训练集和验证集的数据加载器进行训练和验证
for epoch in range(num_epochs):
    for batch_data, batch_labels in train_loader:
        # 在训练集上进行训练
        # ...

    for batch_data, batch_labels in val_loader:
        # 在验证集上进行验证
        # ...

代码解释:

  1. 加载数据: 首先,使用 torch.load() 加载 pt 文件中的数据。
  2. 创建张量: 将数据转换为 PyTorch 张量,以便在模型中使用。
  3. 创建标签: 根据您的需求创建对应的标签张量。
  4. 划分数据集: 使用 random_split() 函数将数据划分为训练集和验证集。
  5. 创建 DataLoader: 使用 DataLoader 创建训练集和验证集的数据加载器,以便在训练过程中方便地加载数据。

使用说明:

  • 'path_to_your_pt_file.pt' 替换为您的 pt 文件路径。
  • 根据实际情况调整 batch_size 和其他参数。

通过使用此代码,您可以轻松地将您的 PyTorch 数据集划分为训练集和验证集,为模型训练和评估做好准备。

PyTorch 数据集划分:训练集和验证集

原文地址: https://www.cveoy.top/t/topic/j9b 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录