如何利用pytorch保存训练出来的模型
- 使用torch.save()函数保存模型
可以使用torch.save()函数将训练出来的模型保存到文件中。该函数需要两个参数:要保存的模型和文件名。
示例代码:
import torch
# 定义模型
model = torch.nn.Linear(10, 1)
# 训练模型...
# 保存模型
torch.save(model.state_dict(), 'model.pth')
- 使用torch.load()函数加载模型
使用torch.load()函数可以加载保存在文件中的模型。该函数需要一个参数:模型文件的路径。
示例代码:
import torch
# 定义模型
model = torch.nn.Linear(10, 1)
# 加载模型
model.load_state_dict(torch.load('model.pth'))
注意:加载模型时,需要先定义模型的结构,然后再加载模型参数。如果模型的结构和保存模型时不同,加载模型时会出现错误。
原文地址: https://www.cveoy.top/t/topic/BMk 著作权归作者所有。请勿转载和采集!