说明
Vision Transformer是基于Transformer提出来的用于CV的深度学习模型,效果十分的显著,在训练之前最好先下载预训练权重,利用迁移学习可以让训练效果更好。如果直接进行训练,效果可能会很差。
代码
'''
python3.7
-*- coding: UTF-8 -*-
@Project -> File :pythonProject -> Vit
@IDE :PyCharm
@Author :
@USER:
@Date :2022/3/22 09:22:29
@LastEditor:
'''
"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict
import torch
from torch import nn
def drop_path(x, drop_prob:float = 0., training: bool = False):
if drop_prob == 0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.nidm - 1) # work with diff dim tensors, not just 20 ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(in_c, embed_dim,kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W =x.shape
assert H == self.img_size[0] and W == self.img_size[1],\
f"Input image size ({
H} * {
W}) doesn't match model ({
self.img_size[0]}*{
self.img_size[1]})."
#flatten: [B, C, H, W] -> [B, C, HW]
#transpose: [B, C, HW] -> [B, HW, C]
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm