f = f.view(n,-1,h*w) #改变f维度,[batch_size,C,H,W]变为[batch_size,C,H*W]
f = f/(torch.norm(f,dim=1,keepdim=True)+1e-5)
aff = F.relu(torch.matmul(f.transpose(1,2), f),inplace=True) #大小为HW*HW,去掉负值
aff = aff/(torch.sum(aff,dim=1,keepdim=True)+1e-5) #变为0~1之间
cam = cam.view(n,-1,h*w)
cam_rv = torch.matmul(cam, aff).view(n,-1,h,w)
另外一种方法:
sim_map = torch.bmm(query.transpose(1, 2), key) #求相似度
sim_map = sim_map / 16
sim_map = sim_map / 0.1
sim_map = self.softmax(sim_map)