torchwhere用法
torch.where(condition, x, y)函数的作用是根据条件(condition)返回一个新的张量,该张量的每个元素从x和y中的相应位置中获取。
参数:
- condition:一个布尔张量,用于确定返回哪个输入张量的每个元素。
- x:一个张量,与condition形状相同,在condition为True的位置上的元素将被选择。
- y:一个张量,与condition形状相同,在condition为False的位置上的元素将被选择。
返回值:
一个新的张量,与condition形状相同,其中condition为True的位置上的元素来自x,condition为False的位置上的元素来自y。
示例:
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
condition = torch.tensor([[True, False, True], [False, True, False]])
result = torch.where(condition, x, y)
print(result)
输出:
tensor([[ 1, 8, 3],
[10, 5, 12]])
解释:在condition中,第一行的第一列和第三列为True,所以返回的结果中第一行的第一列和第三列的元素来自x,第二行的第二列为True,所以返回的结果中第二行的第二列的元素来自y。
原文地址: https://www.cveoy.top/t/topic/Yis 著作权归作者所有。请勿转载和采集!