HardLabelVoteOneHot函数:基于硬标签投票生成One-Hot编码张量

该函数用于对多个客户端的硬标签进行投票,并将其转化为one-hot形式的张量返回。

函数定义:

def HardLabelVoteOneHot(all_client_hard_label, class_cat):
    client_cnt = len(all_client_hard_label)
    sample_cnt = len(all_client_hard_label[0])
    pred_labels = []
    all_vote_tensor = []
    label_votes_none_cnt = 0
    for i in range(sample_cnt):
        label_votes = [0] * class_cat
        for j in range(client_cnt):
            cur_client_cur_sample_hard_label = all_client_hard_label[j][i]
            pred_label = cur_client_cur_sample_hard_label
            if pred_label != class_cat:
                label_votes[pred_label] += 1
        if (len(label_votes) == 0):
            label_votes_none_cnt += 1
        max_vote_nums = max(label_votes)
        max_vote_idx = label_votes.index(max_vote_nums)
        pred_labels.append(max_vote_idx)
    all_one_hot_tensor = []
    for i in range(len(pred_labels)):
        cur_label = [0.0] * class_cat
        cur_label[pred_labels[i]] = 1.0
        all_one_hot_tensor.append(cur_label)
    all_vote_tensor = torch.tensor(all_one_hot_tensor)
    print('len of pred = {}'.format(len(pred_labels)))
    print()
    return all_vote_tensor

参数说明:

  • all_client_hard_label: 包含所有客户端硬标签的列表,每个元素是一个客户端的硬标签列表。
  • class_cat: 类别数量。

函数执行步骤:

  1. 获取客户端数量和样本数量。
  2. 初始化预测标签列表 pred_labels 和投票张量 all_vote_tensor
  3. 遍历每个样本,统计每个类别的投票数。
  4. 如果某个样本的所有类别都没有投票,则将 label_votes_none_cnt 加1。
  5. 找到投票数最多的类别,并将其作为预测标签添加到 pred_labels 中。
  6. 根据 pred_labels 生成所有样本的one-hot形式的张量 all_one_hot_tensor
  7. all_one_hot_tensor 转化为 torch.tensor 类型的 all_vote_tensor
  8. 打印预测标签的长度。
  9. 返回 all_vote_tensor

函数返回值:

一个one-hot形式的张量,表示所有样本的预测结果。

应用场景:

该函数可用于将多个客户端的硬标签整合为一个统一的预测结果,并以one-hot形式方便后续模型使用。例如,在联邦学习中,可以使用该函数将不同客户端训练得到的模型预测结果进行整合,从而获得更准确的整体预测结果。

HardLabelVoteOneHot函数:基于硬标签投票生成One-Hot编码张量

原文地址: http://www.cveoy.top/t/topic/eOaZ 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录