PyTorch 代码添加进度条:使用 tqdm 库实现训练过程可视化
使用 tqdm 库可以方便地在 PyTorch 代码中添加进度条,实现训练过程的可视化,方便观察训练进度和损失函数变化。
from tqdm import tqdm
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader)
# 包装 dataloader 为 tqdm 对象
dataloader = tqdm(dataloader)
for step, (img, _) in enumerate(dataloader):
img = img.to(device)
size = img.size(0)
random_noise = torch.randn(size, 100, device=device)
d_optim.zero_grad()
# 判别器输入真实图片
real_output = Dis(img)
# 判别器在真实图像上的损失
d_real_loss = loss_function(real_output, torch.ones_like(real_output))
d_real_loss.backward()
gen_img = Gen(random_noise)
fake_output = Dis(gen_img.detach()) # 判别器输入生成图片,fake_output对生成图片的预测
# gen_img是由生成器得来的,但我们现在只对判别器更新,所以要截断对Gen的更新
# detach()得到了没有梯度的tensor,求导到这里就停止了,backward的时候就不会求导到Gen了
d_fake_loss = loss_function(fake_output, torch.zeros_like(fake_output))
d_fake_loss.backward()
d_loss = d_real_loss + d_fake_loss
d_optim.step()
g_optim.zero_grad()
fake_output = Dis(gen_img)
g_loss = loss_function(fake_output, torch.ones_like(real_output))
g_loss.backward()
g_optim.step()
with torch.no_grad():
d_epoch_loss += d_loss
g_epoch_loss += g_loss
# 更新进度条状态
dataloader.set_description(f'Epoch {epoch + 1}')
dataloader.set_postfix(d_loss=d_loss.item(), g_loss=g_loss.item())
dataloader.update(1)
# 关闭进度条
dataloader.close()
with torch.no_grad(): # 之后的内容不进行梯度的计算(图的构建)
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('Epoch:', epoch + 1)
gen_img_plot(Gen, test_input, epoch == 0, (epoch == Epoch))
步骤:
-
导入 tqdm 库:
from tqdm import tqdm -
包装 dataloader 为 tqdm 对象:
dataloader = tqdm(dataloader) -
在循环中更新进度条:
- 使用
dataloader.set_description()设置进度条描述,例如当前 epoch 编号。 - 使用
dataloader.set_postfix()设置进度条后缀,显示当前迭代的损失值等信息。 - 使用
dataloader.update(1)更新进度条,每迭代一次更新一次。
- 使用
-
关闭进度条:
dataloader.close()
代码解释:
tqdm(dataloader):将dataloader包装成一个tqdm对象,该对象会自动在迭代过程中显示进度条。dataloader.set_description(f'Epoch {epoch + 1}'):设置进度条描述,显示当前 epoch 编号。dataloader.set_postfix(d_loss=d_loss.item(), g_loss=g_loss.item()):设置进度条后缀,显示当前迭代的判别器损失d_loss和生成器损失g_loss。dataloader.update(1):更新进度条,每迭代一次更新一次。dataloader.close():关闭进度条。
通过以上步骤,您就可以在 PyTorch 代码中添加进度条,方便地观察训练过程和损失函数变化。tqdm 库还可以用于其他可迭代对象,例如列表、字典等,方便地实现进度条显示。
原文地址: https://www.cveoy.top/t/topic/nF5P 著作权归作者所有。请勿转载和采集!