0
Follow
0
View

Is there a method like torch.topk() that takes the first few maximum indices and still propagates back?

ecit04 注册会员
2023-01-25 17:03

Hello Happy New Year! Answer useful please adopt, click on the right side of the answer can be adopted!
In PyTorch, you can use the torch.topk() function to get the values and subscripts of the first k maximum values in a given tensor. Back propagation can be achieved by propagating the gradient of the output tensor back to the input tensor.

For example:

import torch

x = torch.randn(3, 4)
values, indices = torch.topk(x, 2)
#values 就是前2大的数
#indices 就是前2大的数的下标

x.requiresGrad = True
loss = values.sum()
loss.backward()
#x 的梯度就是前2大的数的梯度


Also, if you want to keep only the gradients of the first k maximum values in the backpropagation, you can use torch.index_select() to set the gradients of the rest to 0.

x.grad[indices] = 0


The

code sets the gradient of x to 0 except for the first k maximum values.

dragon4cn 注册会员
2023-01-25 17:03

1.torch.argmax() : Fetch the index of the maximum value. This function takes an input tensor and an optional dimension argument. It returns the index of the maximum value.

2.torch.max() : Fetch the maximum value and its subscript. This function takes an input tensor and an optional dimension argument. It returns a tuple containing the maximum value and the corresponding index.

3.torch.sort() : sort the input tensor by its element values, returning the sorted elements and their subscripts.
Example :

x = torch.tensor([[2, 3, 4, 5], [1, 2, 3, 4]], requires_grad=True)

# use argmax
print(torch.argmax(x)) # tensor(3)

# use max
print(torch.max(x)) # (tensor(5), tensor(3))

# use sort
values, indices = torch.sort(x)
print(indices) 
# tensor([[2, 1, 0, 3],
#         [1, 0, 2, 3]])

csl875021 注册会员
2023-01-25 17:03

If it is helpful to your problem, hope to adopt!

About the Author

Question Info

Publish Time
2023-01-25 17:03
Update Time
2023-01-25 17:03

Related Question

如何改变按钮图像时,ListView是空通过一个触发器

当使用自动布局时,UIScrollView的底部元素不会显示/滚动

Arcgis(AE)二次开发:创建组件"AxGlobeControl"失败

ADB doesn't work, and there is no response at all

关于#arcgis#的问题,如何解决?

android studio与visual studio

MySQL插入数据显示unknown column ‘x' in 'field list'

我如何获得pin消息在discord.py

关于#spring整合mybatis#的问题,如何解决?

Python Flask SQLALCHEMY Issue[重复]