基于ResNet18的联邦学习CIFAR10分类模型训练代码 - 梯度裁剪优化
import torch import torch.nn as nn import torch.optim as optim import numpy as np from torchvision import datasets, transforms
定义ResNet18模型
class Net(nn.Module): def init(self): super(Net, self).init() self.resnet18 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.AdaptiveAvgPool2d((1, 1)), ) self.fc = nn.Linear(512, 10)
def forward(self, x):
x = self.resnet18(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
定义客户端类
class Client(object): def init(self, data, label, model, lr=0.01, epoch=5, batch_size=32): self.data = data self.label = label self.model = model self.lr = lr self.epoch = epoch self.batch_size = batch_size
def train(self):
optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
criterion = nn.CrossEntropyLoss()
for i in range(self.epoch):
running_loss = 0.0
for j in range(0, len(self.data), self.batch_size):
inputs = self.data[j:j+self.batch_size]
labels = self.label[j:j+self.batch_size]
optimizer.zero_grad()
outputs = self.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
print('Epoch: %d, Loss: %.3f' % (i+1, running_loss/len(self.data)))
def get_grad(self):
self.model.eval()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
optimizer.zero_grad()
inputs = self.data[0:self.batch_size]
labels = self.label[0:self.batch_size]
outputs = self.model(inputs)
loss = criterion(outputs, labels)
loss.backward()
grad = []
for param in self.model.parameters():
grad.append(param.grad.view(-1))
grad = torch.cat(grad)
return grad
定义服务器类
class Server(object): def init(self, clients, model): self.clients = clients self.model = model
def train(self):
for client in self.clients:
client.train()
def get_grad(self):
grads = []
for client in self.clients:
grad = client.get_grad()
grads.append(grad)
grads = torch.stack(grads)
l2_norms = torch.norm(grads, dim=1)
p_percentile = np.percentile(l2_norms, p=80)
thresholds = [p_percentile] * len(self.model.parameters())
for i, param in enumerate(self.model.parameters()):
param_grad = grads[:, i, ...]
param_norms = torch.norm(param_grad, dim=1)
mask = param_norms > thresholds[i]
param_grad[mask, ...] = 0
grad = torch.mean(grads, dim=0)
return grad
def update_model(self, grad):
for param, grad_param in zip(self.model.parameters(), grad):
param.data -= grad_param / len(self.clients)
定义训练函数
def train(model, train_loader, test_loader, lr=0.01, epoch=5, batch_size=32, num_clients=10): clients = [] for i in range(num_clients): data, label = next(iter(train_loader)) client = Client(data, label, model, lr, epoch, batch_size) clients.append(client) server = Server(clients, model) for i in range(epoch): server.train() grad = server.get_grad() server.update_model(grad) test(model, test_loader)
定义测试函数
def test(model, test_loader): model.eval() criterion = nn.CrossEntropyLoss() total_loss = 0.0 total_correct = 0 with torch.no_grad(): for data, label in test_loader: outputs = model(data) loss = criterion(outputs, label) total_loss += loss.item() * data.size(0) _, predicted = torch.max(outputs.data, 1) total_correct += (predicted == label).sum().item() print('Test Loss: %.3f, Test Acc: %.3f' % (total_loss/len(test_loader.dataset), total_correct/len(test_loader.dataset)))
加载数据集
train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2)
初始化模型
model = Net()
训练模型
train(model, train_loader, test_loader, lr=0.01, epoch=5, batch_size=32, num_clients=10)
原文地址: https://www.cveoy.top/t/topic/nx0i 著作权归作者所有。请勿转载和采集!