使用 PyTorch 划分训练集和验证集

本篇博客将介绍如何使用 PyTorch 将数据集划分为训练集和验证集,并创建 DataLoader 用于模型训练。

以下是代码示例:

import torch
from torch.utils.data import DataLoader, random_split

# 文件路径
file_path = 'C:\Users\18105\Desktop\MVSA-Single\MVSA_Single\biaoqian.txt'
pt_file_path = 'C:\Users\18105\PycharmProjects\tuwenqingganfenxi\expanded.pt'

# 读取标签数据
tensor_list = []
with open(file_path, 'r') as file:
    lines = file.readlines()
    for line in lines:
        line = line.strip()
        numbers = line.split()
        tensor = torch.tensor([float(num) for num in numbers])
        tensor_list.append(tensor)

# 读取预处理后的数据
data = torch.load(pt_file_path)

# 将数据转换为张量
input_tensors = [torch.tensor(tensor) for tensor in data]

# 划分训练集和验证集
train_size = int(0.8 * len(input_tensors))
val_size = len(input_tensors) - train_size
train_tensors, val_tensors = random_split(input_tensors, [train_size, val_size])

# 创建训练集和验证集的数据加载器
batch_size = 64
train_loader = DataLoader(train_tensors, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensors, batch_size=batch_size, shuffle=False)

代码解释:

  1. 导入必要的库: 导入 torchDataLoader, random_split
  2. 定义文件路径: 设置标签数据文件路径 file_path 和预处理数据文件路径 pt_file_path
  3. 读取标签数据: 从 file_path 读取标签数据,并将其转换为张量存储在 tensor_list 中。
  4. 读取预处理数据: 从 pt_file_path 加载预处理后的数据。
  5. 将数据转换为张量: 确保数据类型为张量。
  6. 划分数据集: 使用 random_split 函数将数据集划分为训练集和验证集,比例为 8:2。
  7. 创建 DataLoader: 使用 DataLoader 创建训练集和验证集的数据加载器,用于迭代加载数据进行模型训练。

注意:

  • shuffle 参数设置为 True 可以打乱训练集数据,提高模型泛化能力。
  • 验证集数据不需要打乱。

希望这份代码能够帮助您更好地理解如何使用 PyTorch 划分数据集。如有任何疑问,请随时提出。

PyTorch 数据集划分:训练集与验证集

原文地址: https://www.cveoy.top/t/topic/klo 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录