import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, Conv2D, LayerNormalization, GlobalAveragePooling1D
CFGS = {
'swin_tiny_224': dict(input_size=(224, 224), window_size=7, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24]),
'swin_small_224': dict(input_size=(224, 224), window_size=7, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24]),
'swin_base_224': dict(input_size=(224, 224), window_size=7, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
'swin_base_384': dict(input_size=(384, 384), window_size=12, embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32]),
'swin_large_224': dict(input_size=(224, 224), window_size=7, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]),
'swin_large_384': dict(input_size=(384, 384), window_size=12, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48])
}
class Mlp(tf.keras.layers.Layer):
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., prefix=''):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = Dense(hidden_features, name=f'{prefix}/mlp/fc1')
self.fc2 = Dense(out_features, name=f'{prefix}/mlp/fc2')
self.drop = Dropout(drop)
def call(self, x):
x = self.fc1(x)
x = tf.keras.activations.gelu(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.get_shape().as_list()
x = tf.reshape(x, shape=[-1, H // window_size,
window_size, W // window_size, window_size, C])
x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
windows = tf.reshape(x, shape=[-1, window_size, window_size, C])
return windows
def window_reverse(windows, window_size, H, W, C):
x = tf.reshape(windows, shape=[-1, H // window_size,
W // window_size, window_size, window_size, C])
x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
x = tf.reshape(x, shape=[-1, H, W, C])
return x
class WindowAttention(tf.keras.layers.Layer):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., prefix=''):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.prefix = prefix
self.qkv = Dense(dim * 3, use_bias=qkv_bias,
name=f'{self.prefix}/attn/qkv')
self.attn_drop = Dropout(attn_drop)
self.proj = Dense(dim, name=f'{self.prefix}/attn/proj')
self.proj_drop = Dropout(proj_drop)
def build(self, input_shape):
self.relative_position_bias_table = self.add_weight(f'{self.prefix}/attn/relative_position_bias_table',
shape=(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads),
initializer=tf.initializers.Zeros(), trainable=True)
coords_h = np.arange(self.window_size[0])
coords_w = np.arange(self.window_size[1])
coords = np.stack(np.meshgrid(coords_h, coords_w, indexing='ij'))
coords_flatten = coords.reshape(2, -1)
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:, None, :]
relative_coords = relative_coords.transpose([1, 2, 0])
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1).astype(np.int64)
self.relative_position_index = tf.Variable(initial_value=tf.convert_to_tensor(
relative_position_index), trainable=False, name=f'{self.prefix}/attn/relative_position_index')
self.built = True
def call(self, x, mask=None):
B_, N, C = x.get_shape().as_list()
qkv = tf.transpose(tf.reshape(self.qkv(
x), shape=[-1, N, 3, self.num_heads, C // self.num_heads]), perm=[2, 0, 3, 1, 4])
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ tf.transpose(k, perm=[0, 1, 3, 2]))
relative_position_bias = tf.gather(self.relative_position_bias_table, tf.reshape(
self.relative_position_index, shape=[-1]))
relative_position_bias = tf.reshape(relative_position_bias, shape=[
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1])
relative_position_bias = tf.transpose(
relative_position_bias, perm=[2, 0, 1])
attn = attn + tf.expand_dims(relative_position_bias, axis=0)
if mask is not None:
nW = mask.get_shape()[0] # tf.shape(mask)[0]
attn = tf.reshape(attn, shape=[-1, nW, self.num_heads, N, N]) + tf.cast(
tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), attn.dtype)
attn = tf.reshape(attn, shape=[-1, self.num_heads, N, N])
attn = tf.nn.softmax(attn, axis=-1)
else:
attn = tf.nn.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = tf.transpose((attn @ v), perm=[0, 2, 1, 3])
x = tf.reshape(x, shape=[-1, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
def drop_path(inputs, drop_prob, is_training):
if (not is_training) or (drop_prob == 0.):
return inputs
# Compute keep_prob
keep_prob = 1.0 - drop_prob
# Compute drop_connect tensor
random_tensor = keep_prob
shape = (tf.shape(inputs)[0],) + (1,) * \
(len(tf.shape(inputs)) - 1)
random_tensor += tf.random.uniform(shape, dtype=inputs.dtype)
binary_tensor = tf.floor(random_tensor)
output = tf.math.divide(inputs, keep_prob) * binary_tensor
return output
class DropPath(tf.keras.layers.Layer):
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def call(self, x, training=None):
return drop_path(x, self.drop_prob, training)
class SwinTransformerBlock(tf.keras.layers.Layer):
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4.,
qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path_prob=0., norm_layer=LayerNormalization, prefix=''):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.prefix = prefix
self.norm1 = norm_layer(epsilon=1e-5, name=f'{self.prefix}/norm1')
self.attn = WindowAttention(dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, prefix=self.prefix)
self.drop_path = DropPath(
drop_path_prob if drop_path_prob > 0. else 0.)
self.norm2 = norm_layer(epsilon=1e-5, name=f'{self.prefix}/norm2')
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
drop=drop, prefix=self.prefix)
def build(self, input_shape):
if self.shift_size > 0:
H, W =