HardLabelVoteHard 函数:基于硬标签投票的多客户端标签预测
这段代码定义了一个名为 HardLabelVoteHard 的函数,该函数通过对多个客户端的硬标签进行投票,来实现标签预测。
函数接收两个参数:
all_client_hard_label:包含所有客户端的硬标签列表,每个子列表对应一个客户端,子列表中的每个元素对应一个样本的硬标签。class_cat:类别数量。
函数内部首先获取客户端的数量和样本数量,然后通过循环遍历每个样本,对每个客户端的硬标签进行投票。如果某个客户端的硬标签不等于类别数量(即不是无效标签),则将其计入对应类别的票数中。如果某个样本的所有投票都是无效标签,则将其计入无效标签数量中。
接下来,函数找到每个样本中票数最多的类别,并将其作为预测的标签。最后,将预测的标签转换为 PyTorch 张量类型,并返回。
代码解释:
def HardLabelVoteHard(all_client_hard_label, class_cat):
client_cnt = len(all_client_hard_label) # 获取客户端数量
sample_cnt = len(all_client_hard_label[0]) # 获取样本数量
pred_labels = [] # 初始化预测标签列表
label_votes_none_cnt = 0 # 初始化无效标签数量
for i in range(sample_cnt): # 循环遍历每个样本
label_votes = [0] * class_cat # 初始化每个类别票数
for j in range(client_cnt): # 循环遍历每个客户端
cur_client_cur_sample_hard_label = all_client_hard_label[j][i] # 获取当前客户端当前样本的硬标签
pred_label = cur_client_cur_sample_hard_label # 预测标签为当前硬标签
if pred_label != class_cat: # 如果硬标签不是无效标签
label_votes[pred_label] += 1 # 对应类别票数加 1
if (len(label_votes) == 0): # 如果所有投票都是无效标签
label_votes_none_cnt += 1 # 无效标签数量加 1
max_vote_nums = max(label_votes) # 找到票数最多的类别
max_vote_idx = label_votes.index(max_vote_nums) # 获取票数最多的类别的索引
pred_labels.append(max_vote_idx) # 将预测标签添加到列表中
pred_labels = torch.tensor(pred_labels) # 将预测标签转换为张量类型
return pred_labels # 返回预测标签列表
总结:
该函数通过对多个客户端的硬标签进行投票,统计每个样本中票数最多的类别来预测标签。这是一个简单有效的标签预测方法,尤其适用于多客户端协作学习场景。
原文地址: https://www.cveoy.top/t/topic/eOfx 著作权归作者所有。请勿转载和采集!