对以下代码进行注释:n_train n_test num_inputs batch_size = 20 100 200 5true_w true_b = torchonesnum_inputs 1 001 005train_data = d2lsynthetic_datatrue_w true_b n_traintrain_iter = d2lload_arraytrain_data batch
定义训练数据集大小,测试数据集大小,输入特征数和批次大小
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
定义真实的权重和偏差
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
生成训练数据集
train_data = d2l.synthetic_data(true_w, true_b, n_train)
加载训练数据集,每次迭代取出 batch_size 个样本
train_iter = d2l.load_array(train_data, batch_size)
生成测试数据集
test_data = d2l.synthetic_data(true_w, true_b, n_test)
加载测试数据集,每次迭代取出 batch_size 个样本,但不进行训练
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
原文地址: https://www.cveoy.top/t/topic/ckVM 著作权归作者所有。请勿转载和采集!