目录
首先我们从PCT的整体架构来看,分析其所需要的基本网络包括作为Input Embedding的encoder,以及Attention模块,除此之外,还包括在sampling中使用到的knn以及FPS算法。下面将对其逐一实现。
1. Attention
1.1 Self Attention
class SelfAttention(nn.Module):
def __init__(self, in_f, dim_k=None, dim_v=None, transform='SS'):
"""
Self Attention Mechanism
:param in_f: dim of input feature
:param dim_k: dim of query,key vector(default in_f)
:param dim_v: dim of value vector,and also 3th dim of output(default in_f)
:param transform: SS(default) means Scale + SoftMax,SL means SoftMax+L1Norm
"""
super().__init__()
self.dim_k = dim_k if dim_k else in_f
self.dim_v = dim_v if dim_v else in_f
self.transform = transform
self.Q = nn.Linear(in_f, self.dim_k)
self.K = nn.Linear(in_f, self.dim_k)
self.V = nn.Linear(in_f, self.dim_v)
self.sm = nn.Softmax(dim=1)
def forward(self, x):
B, _, _ = x.shape
Q = self.Q(x)
K = self.K(x)
V = self.V(x)
if self.transform == 'SS':
att_score = self.sm(torch.divide(torch.matmul(Q, K.permute(0, 2, 1)), math.sqrt(self.dim_k)))
elif self.transform == 'SL':
QK = torch.matmul(Q, K.permute(0, 2, 1))
att_score = torch.divide(self.sm(QK), QK.sum(dim=2).view(B, -1, 1))
else:
att_score = None
Z = torch.matmul(att_score, V)
return Z
我们首先实现了一个self attention模块,除了可以指定Q,K,V的dimension,还可以指定PCT文章中提到的方法,即SS(Scale & SoftMax)和SL(SolfMax & L1Norm),其中SS是Self Attention原本的方法。原文给出的公式如下(SS,SL)
1.2 Offset Attention
从PCT原文提供的Offset Attention架构来看,其实是对Self Attention的一个改进,作者表示这启发于拉普拉斯矩阵L=D-E在GNN上代替邻接矩阵E的好处。不难实现,其实就是把SA中得到的Z与原本的input做差,经过LBR之后得到输出。
class OffsetAttention(nn.Module):
def __init__(self, num_points, in_f, dim_k=None, dim_v=None):
"""
Offset-Attention
:param num_points: num of points
:param in_f: dim of input feature
:param dim_k: dim of query,key vector(default in_f)
:param dim_v: dim of value vector,and also 3th dim of output(default in_f)
"""
super().__init__()
self.dim_k = dim_k if dim_k else in_f
self.dim_v = dim_v if dim_v else in_f
self.sa = SelfAttention(in_f, self.dim_k, self.dim_v, 'SL')
self.fc = nn.Linear(self.dim_v, self.dim_v)
self.bn = nn.BatchNorm1d(num_points)
self.relu = nn.ReLU()
def forward(self, x):
atte_score = self.sa(x)
x = self.relu(self.bn(self.fc(atte_score.sub(x)))).add(x)
return x
于是我们把上面实现的Self Attention作为这里的子网络进行调用。同时这里采用了前文提到的SL。
2. Sampling and Grouping
这里将会实现采样以及分组网络
2.1 KNN
def knn(x, k):
"""
knn 's application
:param x: (batch_size * num_point * in_f)
:param k: num of sampling point
:return: (batch_size * num_point * k * in_f)
"""
batch_size, num_point, in_f = x.size()
neigh = []
for b in range(batch_size):
neigh.append(NearestNeighbors(n_neighbors=k))
new_x = torch.zeros(batch_size, num_point, k, in_f)
for b in range(batch_size):