PyTorch 代码注释:生成训练和测试数据集
定义训练数据集大小,测试数据集大小,输入特征数和批次大小
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
定义真实的权重和偏差
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
生成训练数据集
train_data = d2l.synthetic_data(true_w, true_b, n_train)
加载训练数据集,每次迭代取出 batch_size 个样本
train_iter = d2l.load_array(train_data, batch_size)
生成测试数据集
test_data = d2l.synthetic_data(true_w, true_b, n_test)
加载测试数据集,每次迭代取出 batch_size 个样本,但不进行训练
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
原文地址: https://www.cveoy.top/t/topic/ntIR 著作权归作者所有。请勿转载和采集!