0
Follow
0
View

# Is there a method like torch.topk() that takes the first few maximum indices and still propagates back?

ecit04 注册会员
2023-01-25 17:03

In PyTorch, you can use the torch.topk() function to get the values and subscripts of the first k maximum values in a given tensor. Back propagation can be achieved by propagating the gradient of the output tensor back to the input tensor.

For example:

``````import torch

x = torch.randn(3, 4)
values, indices = torch.topk(x, 2)
#values 就是前2大的数
#indices 就是前2大的数的下标

loss = values.sum()
loss.backward()
#x 的梯度就是前2大的数的梯度

``````

Also, if you want to keep only the gradients of the first k maximum values in the backpropagation, you can use torch.index_select() to set the gradients of the rest to 0.

``````x.grad[indices] = 0

``````
The

code sets the gradient of x to 0 except for the first k maximum values.

dragon4cn 注册会员
2023-01-25 17:03

1.torch.argmax() : Fetch the index of the maximum value. This function takes an input tensor and an optional dimension argument. It returns the index of the maximum value.

2.torch.max() : Fetch the maximum value and its subscript. This function takes an input tensor and an optional dimension argument. It returns a tuple containing the maximum value and the corresponding index.

3.torch.sort() : sort the input tensor by its element values, returning the sorted elements and their subscripts.
Example :

``````x = torch.tensor([[2, 3, 4, 5], [1, 2, 3, 4]], requires_grad=True)

# use argmax
print(torch.argmax(x)) # tensor(3)

# use max
print(torch.max(x)) # (tensor(5), tensor(3))

# use sort
values, indices = torch.sort(x)
print(indices)
# tensor([[2, 1, 0, 3],
#         [1, 0, 2, 3]])

``````
csl875021 注册会员
2023-01-25 17:03

Publish Time
2023-01-25 17:03
Update Time
2023-01-25 17:03