Text Classification With Switch Transformer - 1716327819025
Text Classification With Switch Transformer - 1716327819025
Introduction
This example demonstrates the implementation of the Switch Transformer model for text
classi cation.
The Switch Transformer replaces the feedforward network (FFN) layer in the standard Transformer
with a Mixture of Expert (MoE) routing layer, where each expert operates independently on the
tokens in the sequence. This allows increasing the model size without increasing the computation
needed to process each example.
Note that, for training the Switch Transformer e ciently, data and model parallelism need to be
applied, so that expert modules can run simultaneously, each on its own accelerator. While the
implementation described in the paper uses the TensorFlow Mesh framework for distributed
training, this example presents a simple, non-distributed implementation of the Switch Transformer
model for demonstration purposes.
Setup
import keras
from keras import ops
from keras import layers
It consists of two separate embedding layers, one for tokens, one for token index (positions). ◆ Conclusion
class TokenAndPositionEmbedding(layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super().__init__()
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
if training:
# Add noise for exploration across experts.
router_logits += keras.random.uniform(
shape=router_logits.shape, minval=0.9, maxval=1.1
)
# Probabilities for each token of what expert it should be sent to.
router_probs = keras.activations.softmax(router_logits, axis=-1)
# Get the top−1 expert for each token. expert_gate is the top−1 probability
# from the router for each token. expert_index is what expert each token
# is going to be routed to.
expert_gate, expert_index = ops.top_k(router_probs, k=1)
# expert_mask shape: [tokens_per_batch, num_experts]
expert_mask = ops.one_hot(expert_index, self.num_experts)
# Compute load balancing loss.
aux_loss = load_balanced_loss(router_probs, expert_mask)
self.add_loss(aux_loss)
# Experts have a fixed capacity, ensure we do not exceed it. Construct
# the batch indices, to each expert, with position in expert make sure that
# not more that expert capacity examples can be routed to each expert.
position_in_expert = ops.cast(
ops.cumsum(expert_mask, axis=0) * expert_mask, "int32"
)
# Keep only tokens that fit within expert capacity.
expert_mask *= ops.cast(
ops.less(ops.cast(position_in_expert, "int32"), self.expert_capacity),
"float32",
)
expert_mask_flat = ops.sum(expert_mask, axis=-1)
# Mask out the experts that have overflowed the expert capacity.
expert_gate *= expert_mask_flat
# Combine expert outputs and scaling with router probability.
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
combined_tensor = ops.expand_dims(
expert_gate
* expert_mask_flat
* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
-1,
) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)
# Create binary dispatch_tensor [tokens_per_batch, num_experts,
expert_capacity]
# that is 1 if the token gets routed to the corresponding expert.
dispatch_tensor = ops.cast(combined_tensor, "float32")
classifier = create_classifier()
run_experiment(classifier)
Epoch 1/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 251s 485ms/step - accuracy: 0.7121 - loss: 1.5394 -
val_accuracy: 0.8748 - val_loss: 1.2891
Epoch 2/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 240s 480ms/step - accuracy: 0.9243 - loss: 1.2063 -
val_accuracy: 0.8752 - val_loss: 1.3090
Epoch 3/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 242s 485ms/step - accuracy: 0.9572 - loss: 1.1222 -
val_accuracy: 0.8614 - val_loss: 1.3744
<keras.src.callbacks.history.History at 0x7efb79d82a90>
Conclusion
Compared to the standard Transformer architecture, the Switch Transformer can have a much
larger number of parameters, leading to increased model capacity, while maintaining a reasonable
computational cost.
Terms | Privacy