使用pytorch实现多层感知机分类器
以下是使用PyTorch实现多层感知机(MLP)分类器的示例代码:
import torch
import torch.nn as nn
import torch.optim as optim
# 定义 MLP 模型
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
# 定义训练函数
def train(model, train_loader, optimizer, criterion):
model.train()
for batch_idx, (data, targets) in enumerate(train_loader):
optimizer.zero_grad()
data = data.view(-1, 28*28) # 将输入数据展平
output = model(data)
loss = criterion(output, targets)
loss.backward()
optimizer.step()
# 定义测试函数
def test(model, test_loader):
model.eval()
correct = 0
with torch.no_grad():
for data, targets in test_loader:
data = data.view(-1, 28*28)
output = model(data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == targets).sum().item()
accuracy = 100.0 * correct / len(test_loader.dataset)
return accuracy
# 加载数据集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)
# 初始化模型和优化器
model = MLP(input_size=784, hidden_size=128, output_size=10)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 训练模型
for epoch in range(10):
train(model, train_loader, optimizer, criterion)
accuracy = test(model, test_loader)
print('Epoch:', epoch+1, 'Accuracy:', accuracy)
在上面的代码中,我们定义了一个包含一个隐藏层的 MLP 模型,该模型接受 784 维的输入(即 28x28 的图像展平后的向量),输出 10 维的预测结果。我们使用交叉熵作为损失函数,Adam 作为优化器,训练模型 10 个 epoch,并在每个 epoch 后计算模型在测试集上的准确率。
要运行此代码,您需要安装 PyTorch 和 torchvision 库。您可以使用以下命令安装它们:
pip install torch torchvision
原文地址: https://www.cveoy.top/t/topic/BrH 著作权归作者所有。请勿转载和采集!