找回密码
 立即注册
首页 业界区 安全 Geotransform代码解读

Geotransform代码解读

懵诬哇 2 小时前
网络整体流程

1.png

网络流程较为简单,主要分为四部分
(1)对原始点云和目标点云进行基于Kpconv-FPN的骨干网络进行特征提取,然后我们将下采样进行到最底层的这层的点称为超点
(2)对超点部分进行Geotransformer变换,然后利用双归一化挑选出前top-k个作为处理好的精细点对
(3)对刚刚处理过的点和上采样后的点进行精细点的处理,实际上是进行置信度计算,挑选出较为可靠的点对
(4)使用局部到全局的方法,具体是加权SVD算法来对每个局部的点云进行处理,得到不同的R和t,最终选内点最多的作为最终的R和t
代码解读

Feature Extraction

首先它是使用的基于Kpconv的FPN骨干网络,所以我们这里首先看特征提取部分
这里主要有三部分,
第一部分是基础块,它主要是对函数进行Kpconv提取特征,然后使用归一化和激活函数处理,接下来提取完特征进行残差块,准备开始下采样操作。
  1. class ConvBlock(nn.Module):
  2.     def __init__(
  3.         self,
  4.         in_channels,
  5.         out_channels,
  6.         kernel_size,#卷积核的大小
  7.         radius,#Kpconv中设定的半径
  8.         sigma,#Kpconv中设定的权重
  9.         group_norm,#分组归一化的组数
  10.         negative_slope=0.1,#leaky relu的负斜率
  11.         bias=True,
  12.         layer_norm=False,
  13.     ):
  14.         super(ConvBlock, self).__init__()
  15.         self.in_channels = in_channels#输入特征维度
  16.         self.out_channels = out_channels#输出特征维度
  17.         self.KPConv = KPConv(in_channels, out_channels, kernel_size, radius, sigma, bias=bias)#KPConv卷积层
  18.         if layer_norm:#根据layer_norm参数选择归一化方式
  19.             self.norm = nn.LayerNorm(out_channels)
  20.         else:#使用分组归一化
  21.             self.norm = GroupNorm(group_norm, out_channels)
  22.         self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope)#定义LeakyReLU激活函数
  23.     def forward(self, s_feats, q_points, s_points, neighbor_indices):#前向传播函数  
  24.         x = self.KPConv(s_feats, q_points, s_points, neighbor_indices)#KPConv卷积操作
  25.         x = self.norm(x)#归一化
  26.         x = self.leaky_relu(x)#激活函数
  27.         return x
复制代码
第二部分是残差块模块,它主要是对代码进行下采样,具体操作是使用线性变换维度,然后使用Kpconv卷积提取特征
  1. class ResidualBlock(nn.Module):
  2.     def __init__(
  3.         self,
  4.         in_channels,
  5.         out_channels,
  6.         kernel_size,
  7.         radius,#Kpconv中设定的半径
  8.         sigma,#Kpconv中设定的权重
  9.         group_norm,#分组归一化的组数
  10.         strided=False,#是否进行下采样
  11.         bias=True,
  12.         layer_norm=False,#是否使用层归一化替代分组归一化
  13.     ):
  14.         
  15.         super(ResidualBlock, self).__init__()
  16.         self.in_channels = in_channels#输入特征维度
  17.         self.out_channels = out_channel#s#输出特征维度
  18.         self.strided = strided#是否进行下采样
  19.         mid_channels = out_channels // 4 #中间特征维度,通常是输出维度的1/4
  20.         if in_channels != mid_channels: #输入维度和中间维度不相等时,使用一层线性变换将输入特征映射到中间维度
  21.             self.unary1 = UnaryBlock(in_channels, mid_channels, group_norm, bias=bias, layer_norm=layer_norm)
  22.         else:#相等时,使用恒等映射
  23.             self.unary1 = nn.Identity()
  24.         self.KPConv = KPConv(mid_channels, mid_channels, kernel_size, radius, sigma, bias=bias)#KPConv卷积层,将中间维度的特征进行卷积操作
  25.         if layer_norm:#根据layer_norm参数选择归一化方式
  26.             self.norm_conv = nn.LayerNorm(mid_channels)
  27.         else:
  28.             self.norm_conv = GroupNorm(group_norm, mid_channels)
  29.         self.unary2 = UnaryBlock(
  30.             mid_channels, out_channels, group_norm, has_relu=False, bias=bias, layer_norm=layer_norm
  31.         )#第二个一层线性变换,将卷积后的特征映射到输出维度,且不使用ReLU激活函数
  32.         if in_channels != out_channels:#输入维度和输出维度不相等时,使用一层线性变换将输入特征映射到输出维度,以便进行残差连接
  33.             self.unary_shortcut = UnaryBlock(
  34.                 in_channels, out_channels, group_norm, has_relu=False, bias=bias, layer_norm=layer_norm
  35.             )
  36.         else:#相等时,使用恒等映射
  37.             self.unary_shortcut = nn.Identity()
  38.         self.leaky_relu = nn.LeakyReLU(0.1)#定义LeakyReLU激活函数
  39.     def forward(self, s_feats, q_points, s_points, neighbor_indices):
  40.         x = self.unary1(s_feats)#第一层线性变换
  41.         x = self.KPConv(x, q_points, s_points, neighbor_indices)#KPConv卷积操作
  42.         x = self.norm_conv(x)#归一化
  43.         x = self.leaky_relu(x)#激活函数
  44.         x = self.unary2(x)#第二层线性变换
  45.         if self.strided:#如果进行下采样,使用最大池化操作对输入特征进行下采样
  46.             shortcut = maxpool(s_feats, neighbor_indices)
  47.         else:#否则直接使用输入特征
  48.             shortcut = s_feats
  49.         shortcut = self.unary_shortcut(shortcut)#将shortcut映射到输出维度
  50.         x = x + shortcut#残差连接
  51.         x = self.leaky_relu(x)#激活函数
  52.         return x
复制代码
第三部分则是解码器的上采样部分,这里主要是将原点云数恢复到目标部分:
首先我们使用torch.zeros_like(x[:1, :])取出第一行,以0填充数值,然后cat函数中的最后的0,表示以行的形式融合到x中,也就是给x多添加一行全为0的。
接下来是使用upsample_indices[:, 0]函数,表示取出所有行的第0列,也就是找到上采样的列进行排序,然后我们根据它对应的列的数值作为索引,找到每一个对应的行,比如upsample_indices[:, 0]的值是[1,-1],那么它就会取出x中第一行和倒数第一行,作为最终的x
  1. def nearest_upsample(x, upsample_indices):
  2.     # Add a last row with minimum features for shadow pools
  3.     x = torch.cat((x, torch.zeros_like(x[:1, :])), 0)
  4.     # Get features for each pooling location [n2, d]
  5.     x = index_select(x, upsample_indices[:, 0], dim=0)
  6.     return x
复制代码
这个时候再来看Kpconv的整体部分,就好理解多了,首先进行编码器的编码,不断变换通道数进行深化特征,然后进行下采样,再进行下采样之后,通过我们的上采样进行恢复尺度,然后进行使用UnaryBlock和LastUnaryBlock进行尺度的块融合,这两个函数实际上都是简单的MLP函数,然后前者多加了归一化和激活函数处理。
  1. class KPConvFPN(nn.Module):
  2.     def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm):
  3.         super(KPConvFPN, self).__init__()
  4.         self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm)#第一层卷积块,将输入特征映射到初始维度   
  5.         self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm)#第二层残差块,将初始维度的特征映射到2倍初始维度
  6.         self.encoder2_1 = ResidualBlock(
  7.             init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True
  8.         )#第三层残差块,进行下采样操作,同时保持特征维度不变   
  9.         self.encoder2_2 = ResidualBlock(
  10.             init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
  11.         )#第四层残差块,将特征维度映射到4倍初始维度
  12.         self.encoder2_3 = ResidualBlock(
  13.             init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
  14.         )#第五层残差块,保持特征维度不变
  15.         self.encoder3_1 = ResidualBlock(
  16.             init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True
  17.         )#
  18.         self.encoder3_2 = ResidualBlock(
  19.             init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
  20.         )
  21.         self.encoder3_3 = ResidualBlock(
  22.             init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
  23.         )
  24.         self.decoder2 = UnaryBlock(init_dim * 12, init_dim * 4, group_norm)#尺度融合块,将上采样后的特征和对应编码器特征拼接后映射到4倍初始维度
  25.         self.decoder1 = LastUnaryBlock(init_dim * 6, output_dim)#将上采样后的特征和对应编码器特征拼接后映射到输出维度
  26.     def forward(self, feats, data_dict):
  27.         feats_list = []
  28.         points_list = data_dict['points']
  29.         neighbors_list = data_dict['neighbors']
  30.         subsampling_list = data_dict['subsampling']
  31.         upsampling_list = data_dict['upsampling']
  32.         feats_s1 = feats
  33.         feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0])
  34.         feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0])
  35.         feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0])
  36.         feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1])
  37.         feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1])
  38.         feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1])
  39.         feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2])
  40.         feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2])
  41.         latent_s3 = feats_s3
  42.         feats_list.append(feats_s3)
  43.         latent_s2 = nearest_upsample(latent_s3, upsampling_list[1])
  44.         latent_s2 = torch.cat([latent_s2, feats_s2], dim=1)
  45.         latent_s2 = self.decoder2(latent_s2)
  46.         feats_list.append(latent_s2)
  47.         latent_s1 = nearest_upsample(latent_s2, upsampling_list[0])
  48.         latent_s1 = torch.cat([latent_s1, feats_s1], dim=1)
  49.         latent_s1 = self.decoder1(latent_s1)
  50.         feats_list.append(latent_s1)
  51.         feats_list.reverse()
  52.         return feats_list
复制代码
接下来完成了初步特征提取,就到了超点的GeoTransfrom处理阶段
Superpoint Matching

GeoTrasnformer

这里我们来看它是如何实现的
在实现具体的GeoTransfromer前,他首先定义了编码函数,我们知道这里的注意力多加的R是有两部分的,距离和角度编码
2.png

他们的具体计算方式都是需要进行编码的,如下所示
3.png

4.png

所以这里我们首先定义GeometricStructureEmbedding函数,完成这些编码的处理
  1. class GeometricStructureEmbedding(nn.Module):
  2.     def __init__(self, hidden_dim, sigma_d, sigma_a, angle_k, reduction_a='max'):
  3.         super(GeometricStructureEmbedding, self).__init__()
  4.         self.sigma_d = sigma_d#给定距离的温度参数
  5.         self.sigma_a = sigma_a#给定角度的温度参数
  6.         self.factor_a = 180.0 / (self.sigma_a * np.pi)#定义角度缩放因子
  7.         self.angle_k = angle_k #用于角度嵌入的最近邻数量
  8.         self.embedding = SinusoidalPositionalEmbedding(hidden_dim)#位置编码嵌入,这里其实也就是PE编码嵌入的部分
  9.         self.proj_d = nn.Linear(hidden_dim, hidden_dim)
  10.         self.proj_a = nn.Linear(hidden_dim, hidden_dim)
  11.         self.reduction_a = reduction_a
  12.         if self.reduction_a not in ['max', 'mean']:#检查角度嵌入的归约方式是否有效
  13.             raise ValueError(f'Unsupported reduction mode: {self.reduction_a}.')
  14.     @torch.no_grad()#禁用梯度计算,因为这里是纯粹的索引计算
  15.     def get_embedding_indices(self, points):#计算距离和角度嵌入的索引
  16.         batch_size, num_point, _ = points.shape#获取批量大小和点的数量
  17.         dist_map = torch.sqrt(pairwise_distance(points, points))  # (B, N, N),计算点云之间的成对距离
  18.         d_indices = dist_map / self.sigma_d # (B, N, N),计算距离嵌入的索引,对应公式中的正弦位置编码部分,现在再用PE编码就已经实现了
  19.         k = self.angle_k #获取用于角度嵌入的最近邻数量
  20.         knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:]  # (B, N, k) ,找到每个点的k个最近邻点的索引,排除自身
  21.         knn_indices = knn_indices.unsqueeze(3).expand(batch_size, num_point, k, 3)  # (B, N, k, 3),在第三个维度后加一个新维度,并扩展以匹配点的坐标维度
  22.         expanded_points = points.unsqueeze(1).expand(batch_size, num_point, num_point, 3)  # (B, N, N, 3),扩展点云以便与每个点进行比较
  23.         knn_points = torch.gather(expanded_points, dim=2, index=knn_indices)  # (B, N, k, 3),获取每个点的k个最近邻点的坐标
  24.         ref_vectors = knn_points - points.unsqueeze(2)  # (B, N, k, 3),计算参考向量,即从当前点指向其k个最近邻点的向量
  25.         anc_vectors = points.unsqueeze(1) - points.unsqueeze(2)  # (B, N, N, 3),计算锚点向量,即从每个点指向所有其他点的向量
  26.         ref_vectors = ref_vectors.unsqueeze(2).expand(batch_size, num_point, num_point, k, 3)  # (B, N, N, k, 3),扩展参考向量以匹配锚点向量的维度  
  27.         anc_vectors = anc_vectors.unsqueeze(3).expand(batch_size, num_point, num_point, k, 3)  # (B, N, N, k, 3),扩展锚点向量以匹配参考向量的维度
  28.         sin_values = torch.linalg.norm(torch.cross(ref_vectors, anc_vectors, dim=-1), dim=-1)  # (B, N, N, k),计算参考向量和锚点向量的叉积的模长,得到正弦值
  29.         cos_values = torch.sum(ref_vectors * anc_vectors, dim=-1)  # (B, N, N, k),计算参考向量和锚点向量的点积,得到余弦值
  30.         angles = torch.atan2(sin_values, cos_values)  # (B, N, N, k),计算角度值,使用atan2函数结合正弦值和余弦值
  31.         a_indices = angles * self.factor_a# (B, N, N, k),计算角度嵌入的索引,对应公式中的正弦位置编码部分
  32.         return d_indices, a_indices#返回距离和角度嵌入的索引
  33.     def forward(self, points):
  34.         d_indices, a_indices = self.get_embedding_indices(points)#获取距离和角度嵌入的索引
  35.         d_embeddings = self.embedding(d_indices)#距离嵌入,通过正弦位置编码获取嵌入特征
  36.         d_embeddings = self.proj_d(d_embeddings)#线性变换映射到隐藏维度
  37.         a_embeddings = self.embedding(a_indices)#角度嵌入,通过正弦位置编码获取嵌入特征
  38.         a_embeddings = self.proj_a(a_embeddings)#线性变换映射到隐藏维度
  39.         if self.reduction_a == 'max':#对角度嵌入进行归约操作
  40.             a_embeddings = a_embeddings.max(dim=3)[0]#取最大值
  41.         else:
  42.             a_embeddings = a_embeddings.mean(dim=3)#取均值
  43.         embeddings = d_embeddings + a_embeddings#将距离嵌入和角度嵌入相加,得到最终的几何结构嵌入特征
  44.         return embeddings
复制代码
看着很复杂,实则就是计算每个点与其他所有点之间的相对距离和相对角度,来编码点云内部的几何结构信息
接下来就是GeoTransfrom的实现了,这里实际上就是利用刚刚所定义的函数,对源点云和目标点云分别进行编码,然后进行Transformer变换,具体代码如下
  1. class GeometricTransformer(nn.Module):
  2.     def __init__(
  3.         self,
  4.         input_dim,
  5.         output_dim,
  6.         hidden_dim,
  7.         num_heads,
  8.         blocks,
  9.         sigma_d,
  10.         sigma_a,
  11.         angle_k,
  12.         dropout=None,
  13.         activation_fn='ReLU',
  14.         reduction_a='max',
  15.     ):
  16.         super(GeometricTransformer, self).__init__()
  17.         self.embedding = GeometricStructureEmbedding(hidden_dim, sigma_d, sigma_a, angle_k, reduction_a=reduction_a)#几何结构嵌入模块
  18.         self.in_proj = nn.Linear(input_dim, hidden_dim)#输入特征映射到隐藏维度
  19.         self.transformer = RPEConditionalTransformer(
  20.             blocks, hidden_dim, num_heads, dropout=dropout, activation_fn=activation_fn
  21.         )#堆叠的自注意力和交叉注意力模块组成的变换器
  22.         self.out_proj = nn.Linear(hidden_dim, output_dim)
  23.     def forward(
  24.         self,
  25.         ref_points,
  26.         src_points,
  27.         ref_feats,
  28.         src_feats,
  29.         ref_masks=None,
  30.         src_masks=None,
  31.     ):
  32.         ref_embeddings = self.embedding(ref_points)#对源点云位置进行几何结构嵌入   
  33.         src_embeddings = self.embedding(src_points)#对目标点云位置进行几何结构嵌入
  34.         ref_feats = self.in_proj(ref_feats)#对源点云特征进行线性映射
  35.         src_feats = self.in_proj(src_feats)#对目标点云特征进行线性映射
  36.         ref_feats, src_feats = self.transformer(
  37.             ref_feats,
  38.             src_feats,
  39.             ref_embeddings,
  40.             src_embeddings,
  41.             masks0=ref_masks,
  42.             masks1=src_masks,
  43.         )#通过变换器模块进行特征变换
  44.         ref_feats = self.out_proj(ref_feats)#对源点云特征进行线性映射
  45.         src_feats = self.out_proj(src_feats)#对目标点云特征进行线性映射
  46.         return ref_feats, src_feats
复制代码
这里我们发现了这个RPEConditionalTransformer函数,这里其实就可以发现它是跟我们之前所学理论知识完全对应,当是自注意力机制时,QKV均来自自身,而如果是交叉注意力,那么它的就是Q来自本身,而KV来自于另一点云,代码如下:
  1. class RPEConditionalTransformer(nn.Module):
  2.     def __init__(
  3.         self,
  4.         blocks,
  5.         d_model,
  6.         num_heads,
  7.         dropout=None,
  8.         activation_fn='ReLU',
  9.         return_attention_scores=False,
  10.         parallel=False,
  11.     ):
  12.         super(RPEConditionalTransformer, self).__init__()
  13.         self.blocks = blocks
  14.         layers = []
  15.         for block in self.blocks:#遍历每个块的类型
  16.             _check_block_type(block)
  17.             if block == 'self':#如果是自注意力块
  18.                 layers.append(RPETransformerLayer(d_model, num_heads, dropout=dropout, activation_fn=activation_fn))
  19.             else:#如果是交叉注意力块   
  20.                 layers.append(TransformerLayer(d_model, num_heads, dropout=dropout, activation_fn=activation_fn))
  21.         self.layers = nn.ModuleList(layers)
  22.         self.return_attention_scores = return_attention_scores
  23.         self.parallel = parallel
  24.     def forward(self, feats0, feats1, embeddings0, embeddings1, masks0=None, masks1=None):
  25.         attention_scores = []
  26.         for i, block in enumerate(self.blocks):
  27.             if block == 'self':#如果是自注意力块
  28.                 feats0, scores0 = self.layers[i](feats0, feats0, embeddings0, memory_masks=masks0)
  29.                 feats1, scores1 = self.layers[i](feats1, feats1, embeddings1, memory_masks=masks1)
  30.             else:#如果是交叉注意力块
  31.                 if self.parallel:#如果是并行计算
  32.                     new_feats0, scores0 = self.layers[i](feats0, feats1, memory_masks=masks1)#计算源点云特征的新表示和注意力分数
  33.                     new_feats1, scores1 = self.layers[i](feats1, feats0, memory_masks=masks0)#计算目标点云特征的新表示和注意力分数
  34.                     feats0 = new_feats0#更新源点云特征
  35.                     feats1 = new_feats1#更新目标点云特征
  36.                 else:
  37.                     feats0, scores0 = self.layers[i](feats0, feats1, memory_masks=masks1)
  38.                     feats1, scores1 = self.layers[i](feats1, feats0, memory_masks=masks0)
  39.             if self.return_attention_scores:
  40.                 attention_scores.append([scores0, scores1])
  41.         if self.return_attention_scores:
  42.             return feats0, feats1, attention_scores
  43.         else:
  44.             return feats0, feats1
复制代码
这里可以发现的是他们实现自注意力机制和交叉注意力使用的函数分别是RPETransformerLayer和TransformerLayer,接下来我们来分别看一下对应代码
  1. class RPETransformerLayer(nn.Module):
  2.     def __init__(self, d_model, num_heads, dropout=None, activation_fn='ReLU'):
  3.         super(RPETransformerLayer, self).__init__()
  4.         self.attention = RPEAttentionLayer(d_model, num_heads, dropout=dropout)#注意力层
  5.         self.output = AttentionOutput(d_model, dropout=dropout, activation_fn=activation_fn)#输出层
  6.     def forward(
  7.         self,
  8.         input_states,
  9.         memory_states,
  10.         position_states,
  11.         memory_weights=None,
  12.         memory_masks=None,
  13.         attention_factors=None,
  14.     ):
  15.         hidden_states, attention_scores = self.attention(
  16.             input_states,
  17.             memory_states,
  18.             position_states,
  19.             memory_weights=memory_weights,
  20.             memory_masks=memory_masks,
  21.             attention_factors=attention_factors,
  22.         )#通过注意力层计算隐藏状态和注意力分数
  23.         output_states = self.output(hidden_states)#通过输出层计算最终的输出状态
  24.         return output_states, attention_scores#返回输出状态和注意力分数
复制代码
这里首先使用RPEAttentionLayer进行注意力层的计算,然后一般注意力处理后的还需要进行线性层和dropout挑选后才能输出,所以还有一个output函数进行处理,这里我们继续跟进看注意力层的实现
  1. class RPEAttentionLayer(nn.Module):
  2.     def __init__(self, d_model, num_heads, dropout=None):
  3.         super(RPEAttentionLayer, self).__init__()
  4.         self.attention = RPEMultiHeadAttention(d_model, num_heads, dropout=dropout)
  5.         self.linear = nn.Linear(d_model, d_model)
  6.         self.dropout = build_dropout_layer(dropout)
  7.         self.norm = nn.LayerNorm(d_model)
  8.     def forward(
  9.         self,
  10.         input_states,
  11.         memory_states,
  12.         position_states,
  13.         memory_weights=None,
  14.         memory_masks=None,
  15.         attention_factors=None,
  16.     ):
  17.         hidden_states, attention_scores = self.attention(
  18.             input_states,
  19.             memory_states,
  20.             memory_states,
  21.             position_states,
  22.             key_weights=memory_weights,
  23.             key_masks=memory_masks,
  24.             attention_factors=attention_factors,
  25.         )
  26.         hidden_states = self.linear(hidden_states)
  27.         hidden_states = self.dropout(hidden_states)
  28.         output_states = self.norm(hidden_states + input_states)
  29.         return output_states, attention_scores
复制代码
这里可以发现它是使用RPEMultiHeadAttention多头注意力函数实现的,然后加上了线性层,dropout和标准化这些,最终得到了输出状态和注意力分数,跟进这个多头注意力函数观察其实现
  1. class RPEMultiHeadAttention(nn.Module):
  2.     def __init__(self, d_model, num_heads, dropout=None):
  3.         super(RPEMultiHeadAttention, self).__init__()
  4.         if d_model % num_heads != 0:
  5.             raise ValueError('`d_model` ({}) must be a multiple of `num_heads` ({}).'.format(d_model, num_heads))
  6.         self.d_model = d_model#总的特征维度
  7.         self.num_heads = num_heads#头的数量
  8.         self.d_model_per_head = d_model // num_heads#每个头的特征维度
  9.         self.proj_q = nn.Linear(self.d_model, self.d_model)#即q=Wq*x+b_q的实现,查询的线性映射
  10.         self.proj_k = nn.Linear(self.d_model, self.d_model)#即k=Wk*x+b_k的实现,键的线性映射
  11.         self.proj_v = nn.Linear(self.d_model, self.d_model)#值的线性映射
  12.         self.proj_p = nn.Linear(self.d_model, self.d_model)#相对位置嵌入的线性映射
  13.         self.dropout = build_dropout_layer(dropout)#dropout层
  14.     def forward(self, input_q, input_k, input_v, embed_qk, key_weights=None, key_masks=None, attention_factors=None):
  15.         q = rearrange(self.proj_q(input_q), 'b n (h c) -> b h n c', h=self.num_heads)#查询的多头表示,从b n (h c)变换到b h n c
  16.         k = rearrange(self.proj_k(input_k), 'b m (h c) -> b h m c', h=self.num_heads)
  17.         v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads)
  18.         p = rearrange(self.proj_p(embed_qk), 'b n m (h c) -> b h n m c', h=self.num_heads)
  19.         attention_scores_p = torch.einsum('bhnc,bhnmc->bhnm', q, p)#计算查询和相对位置嵌入之间的注意力分数
  20.         attention_scores_e = torch.einsum('bhnc,bhmc->bhnm', q, k)#计算查询和键之间的注意力分数
  21.         attention_scores = (attention_scores_e + attention_scores_p) / self.d_model_per_head ** 0.5#缩放注意力分数
  22.         if attention_factors is not None:#如果提供了注意力因子
  23.             attention_scores = attention_factors.unsqueeze(1) * attention_scores#调整注意力分数
  24.         if key_weights is not None:#如果提供了键的权重
  25.             attention_scores = attention_scores * key_weights.unsqueeze(1).unsqueeze(1)
  26.         if key_masks is not None:#如果提供了键的掩码
  27.             attention_scores = attention_scores.masked_fill(key_masks.unsqueeze(1).unsqueeze(1), float('-inf'))
  28.         attention_scores = F.softmax(attention_scores, dim=-1)#计算注意力分数的softmax
  29.         attention_scores = self.dropout(attention_scores)#应用dropout
  30.         hidden_states = torch.matmul(attention_scores, v)#计算加权值
  31.         hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)')#重新排列隐藏状态的形状
  32.         return hidden_states, attention_scores#返回隐藏状态和注意力分数
复制代码
这就是它的实现了,然后另一个交叉注意力机制代码与此类似,不过这里的不同的是自注意力是RPE,即带有位置编码的transformer,而交叉注意力则是普通的,只有QKV。在进行过transform变换后,接下来就是超点的筛选了
SuperPointTargetGenerator&&Matching

这里有两个函数,一个是SuperPointTargetGenerator从真值数据中筛选高质量的点对应关系,为训练提供监督目标,另一个是SuperPointMatching,它是基于学习到的特征相似度,预测两个点云之间的对应匹配关系。
具体实现如下,这里主要是筛选出重叠度高于阈值对应关系,如果数量过多就会随机挑选数量为目标数量的点对。
  1. class SuperPointTargetGenerator(nn.Module):
  2.     def __init__(self, num_targets, overlap_threshold):
  3.         super(SuperPointTargetGenerator, self).__init__()
  4.         self.num_targets = num_targets#目标数量
  5.         self.overlap_threshold = overlap_threshold#重叠阈值
  6.     @torch.no_grad()#禁用梯度计算
  7.     def forward(self, gt_corr_indices, gt_corr_overlaps):
  8.         gt_corr_masks = torch.gt(gt_corr_overlaps, self.overlap_threshold)#筛选出重叠度高于阈值的对应关系
  9.         gt_corr_overlaps = gt_corr_overlaps[gt_corr_masks]#筛选对应的重叠度
  10.         gt_corr_indices = gt_corr_indices[gt_corr_masks]#筛选对应的索引
  11.         if gt_corr_indices.shape[0] > self.num_targets:#如果筛选后的对应关系数量超过目标数量
  12.             indices = np.arange(gt_corr_indices.shape[0])#生成对应关系的索引数组
  13.             sel_indices = np.random.choice(indices, self.num_targets, replace=False)#随机选择目标数量的索引
  14.             sel_indices = torch.from_numpy(sel_indices).cuda()#转换为GPU张量
  15.             gt_corr_indices = gt_corr_indices[sel_indices]#选择对应的索引
  16.             gt_corr_overlaps = gt_corr_overlaps[sel_indices]#选择对应的重叠度
  17.         gt_ref_corr_indices = gt_corr_indices[:, 0]#选择参考点云中的超点索引
  18.         gt_src_corr_indices = gt_corr_indices[:, 1]#选择源点云中的超点索引
  19.         return gt_ref_corr_indices, gt_src_corr_indices, gt_corr_overlaps#返回选择的超点索引和对应的重叠度
复制代码
这个则是主要对点进行双归一化,然后得到匹配得分矩阵,选出前top-k个作为匹配好的点对。
  1. class SuperPointMatching(nn.Module):
  2.     def __init__(self, num_correspondences, dual_normalization=True):
  3.         super(SuperPointMatching, self).__init__()
  4.         self.num_correspondences = num_correspondences#最大对应点对数量
  5.         self.dual_normalization = dual_normalization#是否使用双重归一化
  6.     def forward(self, ref_feats, src_feats, ref_masks=None, src_masks=None):
  7.         if ref_masks is None:#如果参考点云的掩码为空
  8.             ref_masks = torch.ones(size=(ref_feats.shape[0],), dtype=torch.bool).cuda()
  9.         if src_masks is None:#如果源点云的掩码为空
  10.             src_masks = torch.ones(size=(src_feats.shape[0],), dtype=torch.bool).cuda()
  11.         # remove empty patch
  12.         ref_indices = torch.nonzero(ref_masks, as_tuple=True)[0]#获取参考点云中非空超点的索引
  13.         src_indices = torch.nonzero(src_masks, as_tuple=True)[0]#获取源点云中非空超点的索引
  14.         ref_feats = ref_feats[ref_indices]#选择参考点云非空超点的特征
  15.         src_feats = src_feats[src_indices]#选择源点云非空超点的特征
  16.         # select top-k proposals
  17.         matching_scores = torch.exp(-pairwise_distance(ref_feats, src_feats, normalized=True))#计算参考点云和源点云超点特征之间的匹配分数
  18.         if self.dual_normalization:#如果使用双重归一化
  19.             ref_matching_scores = matching_scores / matching_scores.sum(dim=1, keepdim=True)#归一化参考点云的匹配分数
  20.             src_matching_scores = matching_scores / matching_scores.sum(dim=0, keepdim=True)#归一化源点云的匹配分数
  21.             matching_scores = ref_matching_scores * src_matching_scores#结合两种归一化的匹配分数
  22.         num_correspondences = min(self.num_correspondences, matching_scores.numel())#确定实际的对应点对数量
  23.         corr_scores, corr_indices = matching_scores.view(-1).topk(k=num_correspondences, largest=True)#选择得分最高的对应点对
  24.         ref_sel_indices = corr_indices // matching_scores.shape[1]#计算参考点云中选择的超点索引
  25.         src_sel_indices = corr_indices % matching_scores.shape[1]#计算源点云中选择的超点索引
  26.         # recover original indices
  27.         ref_corr_indices = ref_indices[ref_sel_indices]#选择参考点云中对应的超点索引
  28.         src_corr_indices = src_indices[src_sel_indices]#选择源点云中对应的超点索引
  29.         return ref_corr_indices, src_corr_indices, corr_scores
复制代码
这里得到top-k个点对,就来到了第三部分,将这部分和上采样部分一起输入到Point Matching Module
Point Matching Module

这里主要是进行精细匹配,它会挑选出对应的点对,然后分别对源点云和目标点云的点进行置信度矩阵计算,得到两个矩阵,如果是双向一致,就进行与操作,否则两矩阵就进行或操作,然后移除无效匹配,返回最终的索引和坐标
  1. class PointMatching(nn.Module):
  2.     def __init__(
  3.         self,
  4.         k: int,
  5.         mutual: bool = True,
  6.         confidence_threshold: float = 0.05,
  7.         use_dustbin: bool = False,
  8.         use_global_score: bool = False,
  9.         remove_duplicate: bool = False,
  10.     ):
  11.         r"""Point Matching with Local-to-Global Registration.
  12.         Args:
  13.             k (int): top-k selection for matching.
  14.             mutual (bool=True): mutual or non-mutual matching.
  15.             confidence_threshold (float=0.05): ignore matches whose scores are below this threshold.
  16.             use_dustbin (bool=False): whether dustbin row/column is used in the score matrix.
  17.             use_global_score (bool=False): whether use patch correspondence scores.
  18.         """
  19.         super(PointMatching, self).__init__()
  20.         self.k = k#给定的top-k值
  21.         self.mutual = mutual#是否进行互相匹配
  22.         self.confidence_threshold = confidence_threshold#置信度阈值
  23.         self.use_dustbin = use_dustbin#是否使用尘箱
  24.         self.use_global_score = use_global_score#是否使用全局分数
  25.         self.remove_duplicate = remove_duplicate#是否移除重复匹配
  26.     def compute_correspondence_matrix(self, score_mat, ref_knn_masks, src_knn_masks):#定义计算对应矩阵的方法
  27.         r"""Compute matching matrix and score matrix for each patch correspondence."""
  28.         mask_mat = torch.logical_and(ref_knn_masks.unsqueeze(2), src_knn_masks.unsqueeze(1))#计算掩码矩阵
  29.         batch_size, ref_length, src_length = score_mat.shape#获取批量大小、参考点云长度和源点云长度
  30.         batch_indices = torch.arange(batch_size).cuda()#生成批量索引
  31.         # correspondences from reference side,即参考点云侧的对应关系
  32.         ref_topk_scores, ref_topk_indices = score_mat.topk(k=self.k, dim=2)  # (B, N, K),给定top-k值,获取每个参考点的k个最高分数及其索引
  33.         ref_batch_indices = batch_indices.view(batch_size, 1, 1).expand(-1, ref_length, self.k)  # (B, N, K),扩展批量索引以匹配参考点和top-k维度
  34.         ref_indices = torch.arange(ref_length).cuda().view(1, ref_length, 1).expand(batch_size, -1, self.k)  # (B, N, K),扩展参考点索引以匹配批量和top-k维度
  35.         ref_score_mat = torch.zeros_like(score_mat)#初始化参考点云的分数矩阵
  36.         ref_score_mat[ref_batch_indices, ref_indices, ref_topk_indices] = ref_topk_scores#将top-k分数填充到参考点云的分数矩阵中
  37.         ref_corr_mat = torch.gt(ref_score_mat, self.confidence_threshold)#生成参考点云的对应矩阵,基于置信度阈值   
  38.         # correspondences from source side,即源点云侧的对应关系
  39.         src_topk_scores, src_topk_indices = score_mat.topk(k=self.k, dim=1)  # (B, K, N)
  40.         src_batch_indices = batch_indices.view(batch_size, 1, 1).expand(-1, self.k, src_length)  # (B, K, N)
  41.         src_indices = torch.arange(src_length).cuda().view(1, 1, src_length).expand(batch_size, self.k, -1)  # (B, K, N)
  42.         src_score_mat = torch.zeros_like(score_mat)
  43.         src_score_mat[src_batch_indices, src_topk_indices, src_indices] = src_topk_scores
  44.         src_corr_mat = torch.gt(src_score_mat, self.confidence_threshold)
  45.         # merge results from two sides
  46.         if self.mutual:#如果是互相匹配  
  47.             corr_mat = torch.logical_and(ref_corr_mat, src_corr_mat)#参考点云和源点云的对应矩阵进行逻辑与操作   
  48.         else:#如果不是互相匹配
  49.             corr_mat = torch.logical_or(ref_corr_mat, src_corr_mat)#参考点云和源点云的对应矩阵进行逻辑或操作
  50.         if self.use_dustbin:#如果使用尘箱
  51.             corr_mat = corr_mat[:, -1:, -1]#保留尘箱对应的行和列
  52.         corr_mat = torch.logical_and(corr_mat, mask_mat)#应用掩码矩阵,移除无效匹配
  53.         return corr_mat
  54.     def forward(
  55.         self,
  56.         ref_knn_points,
  57.         src_knn_points,
  58.         ref_knn_masks,
  59.         src_knn_masks,
  60.         ref_knn_indices,
  61.         src_knn_indices,
  62.         score_mat,
  63.         global_scores,
  64.     ):
  65.         score_mat = torch.exp(score_mat)#将对数似然转换为概率   
  66.         corr_mat = self.compute_correspondence_matrix(score_mat, ref_knn_masks, src_knn_masks)  # (B, K, K),计算对应矩阵
  67.         if self.use_dustbin:#如果使用尘箱
  68.             score_mat = score_mat[:, :-1, :-1]
  69.         if self.use_global_score:#如果使用全局分数
  70.             score_mat = score_mat * global_scores.view(-1, 1, 1)#结合全局分数调整匹配分数
  71.         score_mat = score_mat * corr_mat.float()#应用对应矩阵调整匹配分数
  72.         batch_indices, ref_indices, src_indices = torch.nonzero(corr_mat, as_tuple=True)#获取非零元素的批量索引、参考点索引和源点索引
  73.         ref_corr_indices = ref_knn_indices[batch_indices, ref_indices]#获取参考点云中对应的超点索引
  74.         src_corr_indices = src_knn_indices[batch_indices, src_indices]#获取源点云中对应的超点索引
  75.         ref_corr_points = ref_knn_points[batch_indices, ref_indices]#获取参考点云中对应的超点坐标
  76.         src_corr_points = src_knn_points[batch_indices, src_indices]#获取源点云中对应的超点坐标
  77.         corr_scores = score_mat[batch_indices, ref_indices, src_indices]#获取对应点对的匹配分数
  78.         return ref_corr_points, src_corr_points, ref_corr_indices, src_corr_indices, corr_scores#返回对应点对的坐标、索引和匹配分数
复制代码
Local to Global Registraion

这里看主要函数的实现,它其实就是对点对进行分块,然后从惊喜的匹配中不断计算R和t,找出内点数量最多的作为最终的R和t。
  1.    def local_to_global_registration(self, ref_knn_points, src_knn_points, score_mat, corr_mat):#本地到全局注册
  2.         # extract dense correspondences
  3.         batch_indices, ref_indices, src_indices = torch.nonzero(corr_mat, as_tuple=True)#获取非零元素的批量索引、参考点索引和源点索引
  4.         global_ref_corr_points = ref_knn_points[batch_indices, ref_indices]#获取参考点云的对应点
  5.         global_src_corr_points = src_knn_points[batch_indices, src_indices]#获取源点云的对应点
  6.         global_corr_scores = score_mat[batch_indices, ref_indices, src_indices]#获取对应点的分数
  7.         # build verification set,即建立验证集
  8.         if self.correspondence_limit is not None and global_corr_scores.shape[0] > self.correspondence_limit:#限制对应点的数量
  9.             corr_scores, sel_indices = global_corr_scores.topk(k=self.correspondence_limit, largest=True)#选择得分最高的对应点
  10.             ref_corr_points = global_ref_corr_points[sel_indices]#选择参考点云中对应的点
  11.             src_corr_points = global_src_corr_points[sel_indices]#选择源点云中对应的点
  12.         else:#不限制对应点的数量
  13.             ref_corr_points = global_ref_corr_points#选择参考点云中对应的点
  14.             src_corr_points = global_src_corr_points#选择源点云中对应的点
  15.             corr_scores = global_corr_scores#选择对应点的分数
  16.         # compute starting and ending index of each patch correspondence.
  17.         # torch.nonzero is row-major, so the correspondences from the same patch correspondence are consecutive.
  18.         # find the first occurrence of each batch index, then the chunk of this batch can be obtained.
  19.         unique_masks = torch.ne(batch_indices[1:], batch_indices[:-1])#找到每个批次索引的第一次出现
  20.         unique_indices = torch.nonzero(unique_masks, as_tuple=True)[0] + 1#调整索引以匹配原始张量
  21.         unique_indices = unique_indices.detach().cpu().numpy().tolist()#将张量转换为列表
  22.         unique_indices = [0] + unique_indices + [batch_indices.shape[0]]#添加起始和结束索引
  23.         chunks = [
  24.             (x, y) for x, y in zip(unique_indices[:-1], unique_indices[1:]) if y - x >= self.correspondence_threshold
  25.         ]#为每个批次创建块,确保每个块至少有最小数量的对应点
  26.         batch_size = len(chunks)#计算批次大小
  27.         if batch_size > 0:#如果批次大小大于0
  28.             # local registration
  29.             batch_ref_corr_points, batch_src_corr_points, batch_corr_scores = self.convert_to_batch(
  30.                 global_ref_corr_points, global_src_corr_points, global_corr_scores, chunks
  31.             )#转换为批量点
  32.             batch_transforms = self.procrustes(batch_src_corr_points, batch_ref_corr_points, batch_corr_scores)#计算变换矩阵
  33.             batch_aligned_src_corr_points = apply_transform(src_corr_points.unsqueeze(0), batch_transforms)#应用变换矩阵
  34.             batch_corr_residuals = torch.linalg.norm(
  35.                 ref_corr_points.unsqueeze(0) - batch_aligned_src_corr_points, dim=2
  36.             )#计算残差
  37.             batch_inlier_masks = torch.lt(batch_corr_residuals, self.acceptance_radius)  # (P, N),即进行lier掩码
  38.             best_index = batch_inlier_masks.sum(dim=1).argmax()#选择具有最多内点的批次作为最佳索引
  39.             cur_corr_scores = corr_scores * batch_inlier_masks[best_index].float()#更新当前对应分数
  40.         else:#如果批次大小为0
  41.             # degenerate: initialize transformation with all correspondences
  42.             estimated_transform = self.procrustes(src_corr_points, ref_corr_points, corr_scores)#计算初始变换矩阵
  43.             cur_corr_scores = self.recompute_correspondence_scores(
  44.                 ref_corr_points, src_corr_points, corr_scores, estimated_transform
  45.             )#更新当前对应分数
  46.         # global refinement
  47.         estimated_transform = self.procrustes(src_corr_points, ref_corr_points, cur_corr_scores)#计算变换矩阵
  48.         for _ in range(self.num_refinement_steps - 1):
  49.             cur_corr_scores = self.recompute_correspondence_scores(
  50.                 ref_corr_points, src_corr_points, corr_scores, estimated_transform
  51.             )# 根据当前变换重新计算内点
  52.             estimated_transform = self.procrustes(src_corr_points, ref_corr_points, cur_corr_scores)#计算变换矩阵
  53.         return global_ref_corr_points, global_src_corr_points, global_corr_scores, estimated_transform#返回参考点云对应点、源点云对应点、对应分数和估计变换矩阵
复制代码
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册