PyTorch 数据集划分:训练集与验证集
使用 PyTorch 划分训练集和验证集
本篇博客将介绍如何使用 PyTorch 将数据集划分为训练集和验证集,并创建 DataLoader 用于模型训练。
以下是代码示例:
import torch
from torch.utils.data import DataLoader, random_split
# 文件路径
file_path = 'C:\Users\18105\Desktop\MVSA-Single\MVSA_Single\biaoqian.txt'
pt_file_path = 'C:\Users\18105\PycharmProjects\tuwenqingganfenxi\expanded.pt'
# 读取标签数据
tensor_list = []
with open(file_path, 'r') as file:
lines = file.readlines()
for line in lines:
line = line.strip()
numbers = line.split()
tensor = torch.tensor([float(num) for num in numbers])
tensor_list.append(tensor)
# 读取预处理后的数据
data = torch.load(pt_file_path)
# 将数据转换为张量
input_tensors = [torch.tensor(tensor) for tensor in data]
# 划分训练集和验证集
train_size = int(0.8 * len(input_tensors))
val_size = len(input_tensors) - train_size
train_tensors, val_tensors = random_split(input_tensors, [train_size, val_size])
# 创建训练集和验证集的数据加载器
batch_size = 64
train_loader = DataLoader(train_tensors, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensors, batch_size=batch_size, shuffle=False)
代码解释:
- 导入必要的库: 导入
torch和DataLoader,random_split。 - 定义文件路径: 设置标签数据文件路径
file_path和预处理数据文件路径pt_file_path。 - 读取标签数据: 从
file_path读取标签数据,并将其转换为张量存储在tensor_list中。 - 读取预处理数据: 从
pt_file_path加载预处理后的数据。 - 将数据转换为张量: 确保数据类型为张量。
- 划分数据集: 使用
random_split函数将数据集划分为训练集和验证集,比例为 8:2。 - 创建 DataLoader: 使用
DataLoader创建训练集和验证集的数据加载器,用于迭代加载数据进行模型训练。
注意:
- 将
shuffle参数设置为True可以打乱训练集数据,提高模型泛化能力。 - 验证集数据不需要打乱。
希望这份代码能够帮助您更好地理解如何使用 PyTorch 划分数据集。如有任何疑问,请随时提出。
原文地址: https://www.cveoy.top/t/topic/klo 著作权归作者所有。请勿转载和采集!