Is there a method like torch.topk() that takes the first few maximum indices and still propagates back?

ecit04 注册会员
2023-01-25 17:03

Hello Happy New Year! Answer useful please adopt, click on the right side of the answer can be adopted!
In PyTorch, you can use the torch.topk() function to get the values and subscripts of the first k maximum values in a given tensor. Back propagation can be achieved by propagating the gradient of the output tensor back to the input tensor.

For example:

import torch

x = torch.randn(3, 4)
values, indices = torch.topk(x, 2)
#values 就是前2大的数
#indices 就是前2大的数的下标

x.requiresGrad = True
loss = values.sum()
#x 的梯度就是前2大的数的梯度

Also, if you want to keep only the gradients of the first k maximum values in the backpropagation, you can use torch.index_select() to set the gradients of the rest to 0.

x.grad[indices] = 0


code sets the gradient of x to 0 except for the first k maximum values.

dragon4cn 注册会员
2023-01-25 17:03

1.torch.argmax() : Fetch the index of the maximum value. This function takes an input tensor and an optional dimension argument. It returns the index of the maximum value.

2.torch.max() : Fetch the maximum value and its subscript. This function takes an input tensor and an optional dimension argument. It returns a tuple containing the maximum value and the corresponding index.

3.torch.sort() : sort the input tensor by its element values, returning the sorted elements and their subscripts.
Example :

x = torch.tensor([[2, 3, 4, 5], [1, 2, 3, 4]], requires_grad=True)

# use argmax
print(torch.argmax(x)) # tensor(3)

# use max
print(torch.max(x)) # (tensor(5), tensor(3))

# use sort
values, indices = torch.sort(x)
# tensor([[2, 1, 0, 3],
#         [1, 0, 2, 3]])

csl875021 注册会员
2023-01-25 17:03

If it is helpful to your problem, hope to adopt!

About the Author

Question Info

Publish Time
2023-01-25 17:03
Update Time
2023-01-25 17:03

Related Question




ADB doesn't work, and there is no response at all


android studio与visual studio

MySQL插入数据显示unknown column ‘x' in 'field list'



Python Flask SQLALCHEMY Issue[重复]