BEVFormer组件分析

news/2024/10/29 10:22:59/

BEVFormerEncoder中的get_reference_points


@staticmethoddef get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float):"""Get the reference points used in SCA and TSA.Args:H, W: spatial shape of bev.Z: hight of pillar.D: sample D points uniformly from each pillar.device (obj:`device`): The device wherereference_points should be.Returns:Tensor: reference points used in decoder, has \shape (bs, num_keys, num_levels, 2)."""# reference points in 3D space, used in spatial cross-attention (SCA)if dim == '3d':zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Zxs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / Wys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / Href_3d = torch.stack((xs, ys, zs), -1)ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)return ref_3d# reference points on 2D bev plane, used in temporal self-attention (TSA).elif dim == '2d':ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H - 0.5, H, dtype=dtype, device=device),torch.linspace(0.5, W - 0.5, W, dtype=dtype, device=device))ref_y = ref_y.reshape(-1)[None] / Href_x = ref_x.reshape(-1)[None] / Wref_2d = torch.stack((ref_x, ref_y), -1)ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)return ref_2d

根据上面的代码可以看出来,如果输入的是3d, 则是
按照:

  • X方向: 从0.5, 到W-0.5分成W份.
  • Y方向: 从0.5, 到H-0.5分成H份.
  • Z方向: 从0.5, 到Z-0.5, 分成 num_points_in_pillar份.
    其中num_points_in_pillar 默认给的是4.

配置文件里面给的其实也是4.
在这里插入图片描述

BEVFormerEncoder中的point_sampling

  # This function must use fp32!!!@force_fp32(apply_to=('reference_points', 'img_metas'))def point_sampling(self, reference_points, pc_range,  img_metas):lidar2img = []for img_meta in img_metas:lidar2img.append(img_meta['lidar2img'])lidar2img = np.asarray(lidar2img)lidar2img = reference_points.new_tensor(lidar2img)  # (B, N, 4, 4)reference_points = reference_points.clone()# 变换到点云的范围内. 这也是为何get_reference_points中会/H, /W, /Z, 先化到[0, 1]变成ratio.reference_points[..., 0:1] = reference_points[..., 0:1] * \(pc_range[3] - pc_range[0]) + pc_range[0]reference_points[..., 1:2] = reference_points[..., 1:2] * \(pc_range[4] - pc_range[1]) + pc_range[1]reference_points[..., 2:3] = reference_points[..., 2:3] * \(pc_range[5] - pc_range[2]) + pc_range[2]# 由(x, y, z) 变成(x, y, z, 1) 便于与4*4的参数矩阵相乘.reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)# 此时reference_points可以当成是点云的点了.reference_points = reference_points.permute(1, 0, 2, 3)# num_query等于H*W*Z. 等于grid_points的数量.D, B, num_query = reference_points.size()[:3]num_cam = lidar2img.size(1)# 要往每个相机上去投影. 因此先申请num_cam份.# reference_points的shape就变成了, (D, b, num_cam, num_query, 4, 1) 便于和4*4的矩阵做matmul.reference_points = reference_points.view(D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)# 相机参数由(b,num_cam, 4, 4) 变成(1, b, num_cam, 1, 4, 4) 再变成(D,b,num_cam,num_query,4,4)lidar2img = lidar2img.view(1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)reference_points_cam = torch.matmul(lidar2img.to(torch.float32),reference_points.to(torch.float32)).squeeze(-1)eps = 1e-5# 把每个相机后面的点mask掉. 因为相机后面的点投过来之后第三位是负的.bev_mask = (reference_points_cam[..., 2:3] > eps)# 再做齐次化. 得到像素坐标.reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)# 由像素坐标转成相对于图像的ratio..# NOTE 这里如果不同相机size不一样的话.要除以对应的相机的sizereference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]# 再把超出图像fov范围的点给去掉.bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)& (reference_points_cam[..., 1:2] < 1.0)& (reference_points_cam[..., 0:1] < 1.0)& (reference_points_cam[..., 0:1] > 0.0))if digit_version(TORCH_VERSION) >= digit_version('1.8'):bev_mask = torch.nan_to_num(bev_mask)else:bev_mask = bev_mask.new_tensor(np.nan_to_num(bev_mask.cpu().numpy()))# 由(D, b, num_cam, num_query, 2) 变成 (num_cam, b, num_query, D, 2)reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)# 至此. reference_points_cam代表的就是像素点相对于各个相机的ratio.# bev_mask就代表哪些点是有效的return reference_points_cam, bev_mask

SpatialCrossAttention

个人理解SpatialCrossAttention其实就是正常的Deformable Attention, 只不过原始Deformable Attention中的
refer points是由网络产生的,
而现在的refer points 是由 虚拟的grid points往图像上投影得到的. 在相机参数固定的情况下, 此时的refer points是固定的.

下面是 SpatialCrossAttention这个模块的forward函数的部分代码

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

问题: 给固定的这些refer points 的收益是多大? 文章好像并没有提. 这一块儿感觉不充分.

另外, 显然这样虚拟的grid points 是不合理的, 因为有些地方可能就没有点, 但是还是能够投影到图像上的. 这里用真值的点应该会更好,
比如用lidar的points. 但是BEVFormer paper里面没有对比加入lidar后的效果.

在这里插入图片描述


http://www.ppmy.cn/news/209557.html

相关文章

hdu 5687 Problem C 字典树

传送门&#xff1a;hdu 5687 Problem C 中文题目就不做过多的解释 解题思路 定义一个结构体&#xff0c;里面有26个字母&#xff0c;就像下面这样&#xff1a; struct Node{int next[26];int sum;void init(){sum 0;memset(next,-1,sizeof next);} }; 然后定义一个这个类…

字典树 HDU1251 HDU 5687

模板解决了五个问题&#xff0c;插入&#xff0c;删除&#xff0c;查询是否存在这个单词&#xff0c;查询以这个字符串为前缀的单词数量&#xff0c;查询这个单词是否有前缀 https://leetcode-cn.com/problems/implement-trie-prefix-tree/solution/trie-tree-de-shi-xian-gua…

leetcode 5687. 执行乘法运算的最大分数

给你两个长度分别 n 和 m 的整数数组 nums 和 multipliers &#xff0c;其中 n > m &#xff0c;数组下标 从 1 开始 计数。 初始时&#xff0c;你的分数为 0 。你需要执行恰好 m 步操作。在第 i 步操作&#xff08;从 1 开始 计数&#xff09;中&#xff0c;需要&#xff…

【题解】LuoGu5687:[CSP-SJX2019]网格图

原题传送门 肯定是让你一排一排的考虑 那就一排一排的考虑 暴力做法就是暴力 k r u s k a l kruskal kruskal&#xff0c;那么怎么把一整排看成一条边&#xff1f;这样就可以满分了 根据 k r u s k a l kruskal kruskal我们可以知道&#xff0c;只要每次选取最小的边且还未联通…

LeetCode 5687.执行乘法运算的最大分数

题目描述 给你两个长度分别 n 和 m 的整数数组 nums 和 multipliers &#xff0c;其中 n > m &#xff0c;数组下标 从 1 开始 计数。 初始时&#xff0c;你的分数为 0 。你需要执行恰好 m 步操作。在第 i 步操作&#xff08;从 1 开始 计数&#xff09;中&#xff0c;需要…

Problem C(HDU-5687)

Problem Description 度熊手上有一本神奇的字典&#xff0c;你可以在它里面做如下三个操作&#xff1a; 1、insert : 往神奇字典中插入一个单词 2、delete: 在神奇字典中删除所有前缀等于给定字符串的单词 3、search: 查询是否在神奇字典中有一个字符串的前缀等于给定的字符…

5687. 执行乘法运算的最大分数

难度中等 给你两个长度分别 n 和 m 的整数数组 nums 和 multipliers &#xff0c;其中 n > m &#xff0c;数组下标 从 1 开始 计数。 初始时&#xff0c;你的分数为 0 。你需要执行恰好 m 步操作。在第 i 步操作&#xff08;从 1 开始 计数&#xff09;中&#xff0c;需要…

HUD 5687(字典树)

反思&#xff1a;之前在删除单词的时候&#xff0c;只是删除掉单词前缀以后的字符&#xff0c;而没有把整个单词都删除掉&#xff0c;导致WA了很多次。 高手请略过... AC代码&#xff1a; #define _CRT_SECURE_NO_WARNINGS #include<iostream> #include<string.h&g…