masked_fill_(mask, value)
掩码操作
Fills elements of self tensor with value where mask is one. The shape of mask must be broadcastable with the shape of the underlying tensor.
参数
mask (ByteTensor) – the binary mask
value (float) – the value to fill in with
masked_fill(mask, value) → Tensor
Out-of-place version of torch.Tensor.masked_fill_()
代码示例
a = torch.randn(5,6)
x = [5,4,3,2,1]
mask = torch.zeros(5,6,dtype=torch.float)
for e_id, src_len in enumerate(x):
mask[e_id, src_len:] = 1
mask = mask.to(device = 'cpu')
print(mask)
a.data.masked_fill_(mask.byte(),-float('inf'))
print(a)
----------------------------输出
tensor([[0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 1.],
[0., 0., 0., 1., 1., 1.],
[0., 0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1., 1.]])
tensor([[-0.1053, -0.0352, 1.4759, 0.8849, -0.7233, -inf],
[-0.0529, 0.6663, -0.1082, -0.7243, -inf, -inf],
[-0.0364, -1.0657, 0.8359, -inf, -inf, -inf],
[ 1.4160, 1.1594, -inf, -inf, -inf, -inf],
[ 0.4163, -inf, -inf, -inf, -inf, -inf]])