import torch import torch.nn as nn import torch.optim as optim import numpy as np from torchvision import datasets, transforms

定义ResNet18模型

class Net(nn.Module): def init(self): super(Net, self).init() self.resnet18 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.AdaptiveAvgPool2d((1, 1)), ) self.fc = nn.Linear(512, 10)

def forward(self, x):
    x = self.resnet18(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

定义客户端类

class Client(object): def init(self, data, label, model, lr=0.01, epoch=5, batch_size=32): self.data = data self.label = label self.model = model self.lr = lr self.epoch = epoch self.batch_size = batch_size

def train(self):
    optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
    criterion = nn.CrossEntropyLoss()
    for i in range(self.epoch):
        running_loss = 0.0
        for j in range(0, len(self.data), self.batch_size):
            inputs = self.data[j:j+self.batch_size]
            labels = self.label[j:j+self.batch_size]
            optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        print('Epoch: %d, Loss: %.3f' % (i+1, running_loss/len(self.data)))

def get_grad(self):
    self.model.eval()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
    optimizer.zero_grad()
    inputs = self.data[0:self.batch_size]
    labels = self.label[0:self.batch_size]
    outputs = self.model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    grad = []
    for param in self.model.parameters():
        grad.append(param.grad.view(-1))
    grad = torch.cat(grad)
    return grad

定义服务器类

class Server(object): def init(self, clients, model): self.clients = clients self.model = model

def train(self):
    for client in self.clients:
        client.train()

def get_grad(self):
    grads = []
    for client in self.clients:
        grad = client.get_grad()
        grads.append(grad)
    grads = torch.stack(grads)
    l2_norms = torch.norm(grads, dim=1)
    p_percentile = np.percentile(l2_norms, p=80)
    thresholds = [p_percentile] * len(self.model.parameters())
    for i, param in enumerate(self.model.parameters()):
        param_grad = grads[:, i, ...]
        param_norms = torch.norm(param_grad, dim=1)
        mask = param_norms > thresholds[i]
        param_grad[mask, ...] = 0
    grad = torch.mean(grads, dim=0)
    return grad

def update_model(self, grad):
    for param, grad_param in zip(self.model.parameters(), grad):
        param.data -= grad_param / len(self.clients)

定义训练函数

def train(model, train_loader, test_loader, lr=0.01, epoch=5, batch_size=32, num_clients=10): clients = [] for i in range(num_clients): data, label = next(iter(train_loader)) client = Client(data, label, model, lr, epoch, batch_size) clients.append(client) server = Server(clients, model) for i in range(epoch): server.train() grad = server.get_grad() server.update_model(grad) test(model, test_loader)

定义测试函数

def test(model, test_loader): model.eval() criterion = nn.CrossEntropyLoss() total_loss = 0.0 total_correct = 0 with torch.no_grad(): for data, label in test_loader: outputs = model(data) loss = criterion(outputs, label) total_loss += loss.item() * data.size(0) _, predicted = torch.max(outputs.data, 1) total_correct += (predicted == label).sum().item() print('Test Loss: %.3f, Test Acc: %.3f' % (total_loss/len(test_loader.dataset), total_correct/len(test_loader.dataset)))

加载数据集

train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2)

初始化模型

model = Net()

训练模型

train(model, train_loader, test_loader, lr=0.01, epoch=5, batch_size=32, num_clients=10)


原文地址: https://www.cveoy.top/t/topic/nx0i 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录