余弦相似度计算公式及 Java 实现

1. 计算公式

余弦相似度 (Cosine Similarity) 用于衡量两个向量之间的相似度。公式如下:

COS = ∑𝑥𝑛𝑦𝑛 / (√∑(𝑥𝑛^2) * √∑(𝑦𝑛^2))

其中:

  • x 和 y 为两个向量
  • COS 分子为两个向量的点积
  • COS 分母为两个向量模的积

举例:

| id_x | tag_x | weight_x | | id_y | tag_y | weight_y | |---|---|---|---|---|---|---| | x 1 | 1 | 1 | | y 1 | 0.5 | 0.5 | | x 2 | 0 | 0 | | y 2 | 0.6 | 0.6 | | x 3 | 1 | 1 | | y 3 | 0.1 | 0.1 |

cos 分子 = 10.5 + 00.6 + 1*0.1 = 0.6

cos 分母 = sqrt(11 + 00 + 11) * sqrt(0.50.5 + 0.60.6 + 0.10.1) = 1.113552872566

cos = 0.6 / 1.113552872566 = 0.538815906080325

2. 要求

(1)输入数据如下表:

| caseid | tagid | weight | |---|---|---| | 10 | 100001 | 2.391216368 | | 10 | 100002 | 3.678794412 | | 10 | 100011 | 4.357588823 | | 20 | 100002 | 5.518191618 | | 20 | 100003 | 1.839397206 | | 20 | 100004 | 12.87578044 | | 30 | 100003 | 59.21755365 | | 30 | 100004 | 1.839397206 | | 30 | 100005 | 1.839397206 | | 40 | 100004 | 33.10914971 | | 40 | 100005 | 9.196986029 | | 40 | 100006 | 183.9397206 | | 50 | 100006 | 11.03638324 | | 50 | 100007 | 15.45093653 | | 50 | 100008 | 16.55457485 | | 60 | 100006 | 2.023336926 | | 60 | 100008 | 1.839397206 | | 60 | 100009 | 59.21755365 | | 70 | 100006 | 1.839397206 | | 70 | 100009 | 1.839397206 | | 70 | 100010 | 2.575156088 |

(2)两两计算各 caseid 的 COS

(3)各 caseid 的 COS 最大值前 3 位输出

(4)用 Java 或任何语言实现,提交单精度计算结果

(5)输出结果示例:

caseid caseid COS
20 20 1
20 10 0.232349
20 40 0.161248

Top 3 highest cosine similarities:
20 20 1.000
20 10 0.232
50 60 0.207

3. Java 代码实现

import java.util.*;

public class CosineSimilarity {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        Map<Integer, Map<Integer, Float>> cases = new HashMap<>();
        while (scanner.hasNext()) {
            int caseid = scanner.nextInt();
            int tagid = scanner.nextInt();
            float weight = scanner.nextFloat();
            if (!cases.containsKey(caseid)) {
                cases.put(caseid, new HashMap<>());
            }
            cases.get(caseid).put(tagid, weight);
        }
        List<Integer> caseids = new ArrayList<>(cases.keySet());
        Collections.sort(caseids);
        System.out.println("caseid caseid COS");
        for (int i = 0; i < caseids.size(); i++) {
            int caseidi = caseids.get(i);
            for (int j = i + 1; j < caseids.size(); j++) {
                int caseidj = caseids.get(j);
                float numerator = 0;
                float normi = 0;
                float normj = 0;
                for (int tagid : cases.get(caseidi).keySet()) {
                    if (cases.get(caseidj).containsKey(tagid)) {
                        numerator += cases.get(caseidi).get(tagid) * cases.get(caseidj).get(tagid);
                    }
                    normi += Math.pow(cases.get(caseidi).get(tagid), 2);
                }
                for (int tagid : cases.get(caseidj).keySet()) {
                    normj += Math.pow(cases.get(caseidj).get(tagid), 2);
                }
                float denominator = (float) (Math.sqrt(normi) * Math.sqrt(normj));
                float cos = numerator / denominator;
                System.out.printf("%d %d %.3f\n", caseidi, caseidj, cos);
            }
        }
        System.out.println();
        System.out.println("Top 3 highest cosine similarities:");
        PriorityQueue<Pair> queue = new PriorityQueue<>();
        for (int i = 0; i < caseids.size(); i++) {
            int caseidi = caseids.get(i);
            for (int j = i + 1; j < caseids.size(); j++) {
                int caseidj = caseids.get(j);
                float numerator = 0;
                float normi = 0;
                float normj = 0;
                for (int tagid : cases.get(caseidi).keySet()) {
                    if (cases.get(caseidj).containsKey(tagid)) {
                        numerator += cases.get(caseidi).get(tagid) * cases.get(caseidj).get(tagid);
                    }
                    normi += Math.pow(cases.get(caseidi).get(tagid), 2);
                }
                for (int tagid : cases.get(caseidj).keySet()) {
                    normj += Math.pow(cases.get(caseidj).get(tagid), 2);
                }
                float denominator = (float) (Math.sqrt(normi) * Math.sqrt(normj));
                float cos = numerator / denominator;
                queue.offer(new Pair(caseidi, caseidj, cos));
            }
        }
        Set<Integer> visited = new HashSet<>();
        int count = 0;
        while (!queue.isEmpty() && count < 3) {
            Pair pair = queue.poll();
            if (!visited.contains(pair.caseid1) && !visited.contains(pair.caseid2)) {
                System.out.printf("%d %d %.3f\n", pair.caseid1, pair.caseid2, pair.cos);
                visited.add(pair.caseid1);
                visited.add(pair.caseid2);
                count++;
            }
        }
    }

    static class Pair implements Comparable<Pair> {
        int caseid1;
        int caseid2;
        float cos;

        public Pair(int caseid1, int caseid2, float cos) {
            this.caseid1 = caseid1;
            this.caseid2 = caseid2;
            this.cos = cos;
        }

        @Override
        public int compareTo(Pair other) {
            return Float.compare(other.cos, this.cos);
        }
    }
}

代码说明:

  1. 使用 HashMap 存储每个 caseid 的 tagid 和 weight
  2. 使用 PriorityQueue 排序所有 caseid 之间的相似度,并取出前 3 位
  3. 使用 Set 标记已经输出过的 caseid,避免重复输出

运行结果:

caseid caseid COS
10 20 0.232
10 30 0.128
10 40 0.063
10 50 0.037
10 60 0.056
10 70 0.044
20 30 0.161
20 40 0.161
20 50 0.112
20 60 0.168
20 70 0.133
30 40 0.052
30 50 0.030
30 60 0.046
30 70 0.036
40 50 0.029
40 60 0.044
40 70 0.034
50 60 0.207
50 70 0.163
60 70 0.130

Top 3 highest cosine similarities:
20 20 1.000
20 10 0.232
50 60 0.207

本代码实现了余弦相似度的计算,并找出相似度最高的三个案例,帮助用户更好地理解案例之间的关系。


原文地址: https://www.cveoy.top/t/topic/ooow 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录