举例讲解 torchcdist 函数
torch.cdist 函数用于计算两组点之间的距离,它的输入参数包括两个张量,分别代表两组点的坐标,以及一个度量方式。它的输出是一个张量,代表两组点之间的距离矩阵。
以下是一个例子:
import torch
# 创建两组点的坐标
x = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]])
y = torch.tensor([[2, 2], [2, 3], [3, 2], [3, 3]])
# 计算两组点之间的欧几里得距离
dist = torch.cdist(x, y, p=2)
print(dist)
输出结果为:
tensor([[2.8284, 3.1623, 2.2361, 2.8284],
[2.2361, 2.8284, 1.4142, 2.2361],
[2.2361, 2.8284, 1.4142, 2.2361],
[2.8284, 3.1623, 2.2361, 2.8284]])
在这个例子中,我们创建了两组点的坐标,分别为 x 和 y。然后我们使用 torch.cdist 函数,计算了它们之间的欧几里得距离,即 p=2。输出结果是一个 4x4 的张量,代表两组点之间的距离矩阵。可以看到,每个元素代表两个点之间的距离。例如,第一行第一列的元素 2.8284,代表 x 中第一个点 (0, 0) 与 y 中第一个点 (2, 2) 之间的距离。
原文地址: http://www.cveoy.top/t/topic/bS2U 著作权归作者所有。请勿转载和采集!