SSFL_IDS: A Distributed Intrusion Detection System with Feature Selection and Label Distillation
def SSFL_IDS(conf, dev, clients, server, test_dataset, open_dataset):
comm_cnt = conf['comm_cnt']
open_idx_set_cnt = conf['open_idx_set_cnt']
batchsize = conf['batchsize']
train_rounds = conf['train_rounds']
dis_rounds = conf['discri_rounds']
dist_rounds = conf['dist_rounds']
theta = conf['theta']
labels = conf['labels']
first_train_rounds = conf['first_train_rounds']
class_cat = conf['classify_model_out_len'] if conf['classify_model_out_len'] > 1 else 2
dis_train_cnt = 10000
start_idx = 0
end_idx = start_idx + open_idx_set_cnt
open_len = len(open_dataset)
for e in range(comm_cnt):
sure_unknown_none = set()
all_client_hard_label = []
open_feature, open_label = GetFeatureFromOpenDataset(open_dataset, start_idx, end_idx)
if open_idx_set_cnt > open_len:
global_logits = torch.zeros(open_len, len(labels))
else:
global_logits = torch.zeros(open_idx_set_cnt, len(labels))
client_cnt = len(clients)
participate = 0
print('Round {} Stage I'.format(e+1))
for c_idx in range(client_cnt):
print('Client {} Training...'.format(c_idx+1))
cur_client = clients[c_idx]
cur_train_rounds = train_rounds if e != 0 else first_train_rounds
if len(cur_client.classify_dataset) == 0:
continue
for train_r in range(cur_train_rounds):
TrainWithDataset(dev, cur_client.classify_dataset, batchsize, cur_client.classify_model,
cur_client.classify_opt, cur_client.hard_label_loss_func)
if sum(i > 0 for i in cur_client.each_class_cnt) == 1:
continue
else:
participate += 1
dis_train_feature, _ = GetFeatureFromOpenDataset(open_dataset, 0, dis_train_cnt)
succ = DisUnknown(dev, cur_client, dis_rounds, batchsize, dis_train_feature, theta)
if succ == False:
sure_unknown_none.add(c_idx)
cur_client_open_feature = open_feature.detach().clone()
if c_idx not in sure_unknown_none:
local_logit = PredictWithDisUnknown(dev, cur_client_open_feature,
cur_client.classify_model, cur_client.classify_model_out_len,
cur_client.discri_model, cur_client.discri_model_out_len,
len(labels))
copy_local_logit = local_logit.detach().clone()
hard_label = HardLabel(copy_local_logit)
all_client_hard_label.append(hard_label)
print()
global_logits = HardLabelVoteHard(all_client_hard_label, class_cat)
print('Round {} Stage II'.format(e+1))
for c_idx in range(len(clients)):
cur_client = clients[c_idx]
print('Client {} Distillation Training...'.format(c_idx+1))
for r in range(dist_rounds):
cur_global_logits = global_logits.detach().clone()
cur_client_open_feature = open_feature.detach().clone()
if cur_client.classify_model_out_len != 1:
TrainWithFeatureLabel(dev, cur_client_open_feature, cur_global_logits, batchsize,
cur_client.classify_model, cur_client.classify_opt,
cur_client.hard_label_loss_func)
else:
cur_global_logits = OneHot2Label(cur_global_logits)
TrainWithFeatureLabel(dev, cur_client_open_feature, cur_global_logits, batchsize,
cur_client.classify_model, cur_client.classify_opt,
cur_client.hard_label_loss_func)
print()
print('Server Training...')
for dist_i in range(dist_rounds):
cur_global_logits = global_logits.detach().clone()
server_open_feature = open_feature.detach().clone()
if server.model_out_len != 1:
TrainWithFeatureLabel(dev, server_open_feature, cur_global_logits, batchsize,
server.model, server.dist_opt, server.hard_label_loss_func)
else:
cur_global_logits = OneHot2Label(cur_global_logits)
TrainWithFeatureLabel(dev, server_open_feature, cur_global_logits, batchsize,
server.model, server.dist_opt, server.hard_label_loss_func)
test_feature, test_label = test_dataset[:]
pred_label = Predict(dev, test_feature, server.model, server.model_out_len)
correct_num, test_acc = Metrics(test_label, pred_label)
print('Round {} Test Acc = {} '.format(e+1, test_acc))
print()
这段代码实现了一个名为SSFL_IDS的函数。该函数接受一些配置参数(conf)、设备(dev)、客户端列表(clients)、服务器(server)、测试数据集(test_dataset)和开放数据集(open_dataset)作为输入。
函数中的变量定义和初始化如下:
- comm_cnt:通信轮数,从配置参数conf中获取。
- open_idx_set_cnt:开放数据集索引集合的数量,从配置参数conf中获取。
- batchsize:批大小,从配置参数conf中获取。
- train_rounds:训练轮数,从配置参数conf中获取。
- dis_rounds:判别器训练轮数,从配置参数conf中获取。
- dist_rounds:蒸馏训练轮数,从配置参数conf中获取。
- theta:判别器阈值,从配置参数conf中获取。
- labels:类别标签列表,从配置参数conf中获取。
- first_train_rounds:首次训练轮数,从配置参数conf中获取。
- class_cat:分类模型输出长度,如果大于1则为该长度,否则为2。
- dis_train_cnt:判别器训练数量,设为10000。
- start_idx:开放数据集起始索引,设为0。
- end_idx:开放数据集结束索引,为start_idx加上open_idx_set_cnt。
- open_len:开放数据集的长度,即open_dataset的长度。
接下来进入通信轮数的循环。在每一轮中,进行两个阶段的操作。
第一阶段(Stage I):
- 初始化一些变量,包括sure_unknown_none(存储确定为未知类的客户端索引的集合)和all_client_hard_label(存储所有客户端的硬标签)。
- 从开放数据集中获取特征和标签,存储在open_feature和open_label中。
- 根据开放数据集的长度和open_idx_set_cnt,初始化一个全零的全局logits矩阵global_logits。
- 获取客户端数量,并初始化参与训练的客户端数量为0。
- 对于每个客户端,进行以下操作:
- 进行分类模型的训练,训练轮数为train_rounds,如果当前客户端的分类数据集为空,则跳过。
- 如果当前客户端的每个类别的数量中只有一个类别的数量大于0,则跳过。
- 否则,将参与训练的客户端数量加1,并从开放数据集中获取前dis_train_cnt个样本的特征,存储在dis_train_feature中。
- 调用DisUnknown函数,进行未知类判别器的训练,返回成功与否的标志succ。
- 如果succ为False,则将当前客户端的索引加入sure_unknown_none集合。
- 否则,进行以下操作:
- 克隆open_feature和local_logit,并计算硬标签hard_label。
- 将硬标签存储在all_client_hard_label中。
第二阶段(Stage II):
-
对于每个客户端,进行以下操作:
- 进行蒸馏模型的训练,训练轮数为dist_rounds。
- 克隆global_logits和open_feature,并根据当前客户端的分类模型输出长度进行以下操作:
- 如果不等于1,则调用TrainWithFeatureLabel函数进行特征和标签的训练。
- 否则,将global_logits转换为标签形式,并调用TrainWithFeatureLabel函数进行特征和标签的训练。
-
进行服务器的训练,训练轮数为dist_rounds。
- 克隆global_logits和open_feature,并根据服务器的模型输出长度进行以下操作:
- 如果不等于1,则调用TrainWithFeatureLabel函数进行特征和标签的训练。
- 否则,将global_logits转换为标签形式,并调用TrainWithFeatureLabel函数进行特征和标签的训练。
- 克隆global_logits和open_feature,并根据服务器的模型输出长度进行以下操作:
-
获取测试数据集的特征和标签,并使用训练好的服务器模型进行预测。
-
计算预测准确率并打印输出。
总结:该函数实现了一个分布式的入侵检测系统(SSFL_IDS),包括两个阶段的操作。第一阶段通过训练分类模型和判别器模型来对客户端进行分类和未知类判别。第二阶段通过蒸馏模型和服务器模型的训练来改进分类模型的性能。最后,根据测试数据集对服务器模型进行评估并输出预测准确率。
原文地址: https://www.cveoy.top/t/topic/pe4u 著作权归作者所有。请勿转载和采集!