torch.argsort()函数组合的奇效

news/2024/10/18 12:30:12/

torch.argsort()函数组合的效果

前段时间在看何凯明大神MAE的代码的时候发现了下面一段代码:

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]# sort noise for each sampleids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is removeids_restore = torch.argsort(ids_shuffle, dim=1)# keep the first subsetids_keep = ids_shuffle[:, :len_keep]x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))# generate the binary mask: 0 is keep, 1 is removemask = torch.ones([N, L], device=x.device)mask[:, :len_keep] = 0# unshuffle to get the binary maskmask = torch.gather(mask, dim=1, index=ids_restore)

这个其实是在对输入序列进行随机采样的。方式是:
1.假设我们有一个需要采样的序列X[1,2,3,4,5],先创建一个随机值的noise tensor
2.根据noise的值从小到大排序,得到ids_shuffle,这个是noise的值从小到大排序后对应的下标序列。
例如noise:[5,1,6,9,8] ,则 ids_shuffle:[1,0,2,4,3]
3.以ids_shuffle的前n(以3为例)个值作为序列X的下标进行采样得到结果,得到X_ : [2,1,3]。

但是这段代码中有一行让我很好奇,那就是ids_restore = torch.argsort(ids_shuffle, dim=1)这段代码是在做什么呢?
从名字上看,取名为ids_restore ,意思是恢复ids。即恢复下标。但是我很好奇的是,将排序过的序列再排序怎么就能恢复下标了呢?

需要注意的一点是:X_可以认为是从X中随机抽取了N个数,抽取的方式可以是任意的顺序。比如我抽取 X[3] X[0] X[4]

经过思考最终得到的答案如下(还是以上面举的例子来说明):
1.在获得ids_shuffle后,ids_shuffle里面的值是我们要对X进行采样的一个index随机序列。通俗的讲就是,我现在要根据ids_shuffle中的每个index值去获取X。例如ids_shuffle:[1,0,2,4,3],则我们得到的新的序列X_就是 [X[1] , X[0] , X[2] ,…]。
2.这个时候如果对ids_shuffle的值再进行一个排序得到ids_restore,我们得到的ids_restore结果是什么呢? 因为ids_shuffle的值记录的是随机创建的X_采样子序列中每个位置的元素对应原X序列的位置,ids_restore获取的过程可以分为两步理解:(1),对ids_shuffle的值从小到大排序,即原始序列X从0-N的排序,就是X的原始序列位置。(2),获取每个原始位置在ids_shuffle中的index。也就是说如果我们是根据ids_shuffle来获取随机采样的子序列X_,那么ids_restore记录的就是我原始X中按照顺序X[0] X[1] X[2]… 在ids_shuffle中的位置。 例如这里我的X[0]在ids_shuffle中的位置为1, X[1]在ids_shuffle中的位置为0, X[2]的位置为2 ,X[3]的位置为4,X[4]的位置为3.ids_restore:[1,0,2,4,3]

3.那么如果以后的子序列都通过ids_shuffle构建的话,因为它是随机采样,没有位置顺序信息,当我们要将子序列恢复到输入图片patch原来序列的顺序的时候就可以使用ids_restore。按照ids_restore的每个值作为子序列的取值下标得到的序列就是按照原图patch的大小顺序得到的序列了/


http://www.ppmy.cn/news/714138.html

相关文章

CMake 学习笔记(检测系统环境)

CMake 学习笔记(检测系统环境) 写C/C程序时,尤其是我们的程序要能在不同的平台、操作系统下工作的时候。需要针对不同的环境写一些特定的代码。这时就需要CMake 根据不同的平台类型可以生成不同的项目文件。 CMake 在这方面做的很好。 检测…

flutter mac环境配置

在 macOS 上安装和配置 Flutter 开发环境 - Flutter 中文文档 - Flutter 中文开发者网站 - Flutter一、配置flutter环境变量在 macOS 上安装和配置 Flutter 开发环境 - Flutter 中文文档 - Flutter 中文开发者网站 - Flutter 解压文件放在我的文档里面 然后设置环境变量 1. 执…

KCC@成都首次非正式闭门会圆满成功

6月29日下午,KCC成都 召开了首次非正式闭门会,KCC 由开源社理事兼执行长庄表伟老师于2023年2月发起,全称 KAIYUANSHE City Community。本次会议由KCC成都站站长程诗杰主持召开,线下有唐云峰、水歌、肖伊、果汁、阿林等成都开源活跃…

C语言寻找第k小元素,小技巧——查找第k小的元素

今天分享一个小技巧,虽然是小技巧但是还是很有价值的,曾经是微软的面试题。题目是这样的,一个无序的数组让你找出第k小的元素,我当时看到这道题的时候也像很多人一样都是按普通的思维,先排序在去第K个,但是…

1的k次方一直加到n的k次方c语言,c语言函数求1到n的k次方和

#include #include /*----------------函数f2,求n的k次方-----------------*/ long f2(int n, int k) { long power n; /*power表示n的k次方*/ int i; for(i 1; i { power power*n; return power;/*将power作为f2的返回值*/ } } /*----------------函数f1&…

c语言编程最后j和k,C语言学习笔记:IJK运算,ijk,的,操作

表达式 i < j < k 在C语言中是合法的&#xff0c;但是它不是你所期望的意思。因为 &#xff1c; 运算符是左结合的&#xff0c; 所以这个表达式等价于 (i < j) < k . 换句话说&#xff0c; 表达式首先检测l.是否小千j, 然后用比较后产生的结果1或0来和K进行比较。 …

c语言学习-编写函数求组合数C= n! / (k! *( n-k)!)

编写函数求组合数C n! / (k! *( n-k)!) 程序流程图&#xff1a; 代码&#xff1a; #include<stdio.h> int mul(int x,int y); void main() { int n,k; double c; printf("please enter n:\tk:\t"); scanf("%d,%d",&n,&k); cmul(n,k); pr…

简洁解释k++,++k,k+1,k+=1的区别(附图)

以下为结合图进行说明 k和k两者都是递增1&#xff0c;但区别就在于k是先赋值给n再&#xff08;nk&#xff09;&#xff0c;而k是先后再赋值给n&#xff08;nk&#xff09;。 但两者不论是哪一种&#xff0c;区别也仅在于执行那一行&#xff0c;执行结束之后&#xff0c;对k来…