Understanding and Coding The Self-Attention Mechanism of Large Language Models From Scratch
Understanding and Coding The Self-Attention Mechanism of Large Language Models From Scratch
Contact Resources
RSS
In this article, we are going to understand how self-attention works from scratch.
This means we will code it ourselves one step at a time.
Since its introduction via the original transformer paper (Attention Is All You Need),
self-attention has become a cornerstone of many state-of-the-art deep learning
models, particularly in the field of Natural Language Processing (NLP). Since self-
attention is now everywhere, it’s important to understand how it works.
:
Self-Attention
The concept of “attention” in deep learning has its roots in the effort to improve
Recurrent Neural Networks (RNNs) for handling longer sequences or sentences.
For instance, consider translating a sentence from one language to another.
Translating a sentence word-by-word does not work effectively.
To overcome this issue, attention mechanisms were introduced to give access to all
sequence elements at each time step. The key is to be selective and determine
which words are most important in a specific context. In 2017, the transformer
architecture introduced a standalone self-attention mechanism, eliminating the need
for RNNs altogether.
:
(For brevity, and to keep the article focused on the technical self-attention details,
and I am skipping parts of the motivation, but my Machine Learning with PyTorch
and Scikit-Learn book has some additional details in Chapter 16 if you are
interested.)
Note that there are many variants of self-attention. A particular focus has been on
:
making self-attention more efficient. However, most papers still implement the
original scaled-dot product attention mechanism discussed in this paper since it
usually results in superior accuracy and because self-attention is rarely a
computational bottleneck for most companies training large-scale transformers.
For simplicity, here our dictionary dc is restricted to the words that occur in the
input sentence. In a real-world application, we would consider all words in the
training dataset (typical vocabulary sizes range between 30k to 50k).
In:
Out:
In:
:
import torch
Out:
tensor([0, 4, 5, 2, 1, 3])
Now, using the integer-vector representation of the input sentence, we can use an
embedding layer to encode the inputs into a real-vector embedding. Here, we will
use a 16-dimensional embedding such that each input word is represented by a 16-
dimensional vector. Since the sentence consists of 6 words, this will result in a
6 × 16 -dimensional embedding:
In:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)
Out:
The respective query, key and value sequences are obtained via matrix
multiplication between the weight matrices W and the embedded inputs x :
The index i refers to the token index position in the input sequence, which has
length T .
(i)
Here, both q (i) and k are vectors of dimension dk . The projection matrices Wq
:
and Wk have a shape of dk × d, while Wv has the shape dv × d .
(It’s important to note that d represents the size of each word vector, x .)
Since we are computing the dot-product between the query and key vectors, these
two vectors have to contain the same number of elements (dq = dk ). However, the
number of elements in the value vector v (i) , which determines the size of the
resulting context vector, is arbitrary.
So, for the following code walkthrough, we will set dq = dk = 24 and use
dv = 28, initializing the projection matrices as follows:
In:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)
In:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)
print(query_2.shape)
print(key_2.shape)
print(value_2.shape)
torch.Size([24])
torch.Size([24])
torch.Size([28])
We can then generalize this to compute th remaining key, and value elements for all
inputs as well, since we will need them in the next step when we compute the
unnormalized attention weights ω :
In:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T
print("keys.shape:", keys.shape)
:
print("values.shape:", values.shape)
Out:
Now that we have all the required keys and values, we can proceed to the next step
and compute the unnormalized attention weights ω , which are illustrated in the
figure below:
As illustrated in the figure above, we compute ωi,j as the dot product between the
⊤
query and key sequences, ωij = q (i) k (j) .
For example, we can compute the unnormalized attention weight for the query and
5th input element (corresponding to index position 4) as follows:
In:
omega_24 = query_2.dot(keys[4])
print(omega_24)
Out:
tensor(11.1466)
:
Since we will need those to compute the attention scores later, let’s compute the ω
values for all input tokens as illustrated in the previous figure:
In:
omega_2 = query_2.matmul(keys.T)
print(omega_2)
Out:
The scaling by dk ensures that the Euclidean length of the weight vectors will be
approximately in the same magnitude. This helps prevent the attention weights from
becoming too small or too large, which could lead to numerical instability or affect
the model’s ability to converge during training.
:
In code, we can implement the computation of the attention weights as follows:
In:
import torch.nn.functional as F
Out:
Finally, the last step is to compute the context vector z(2) , which is an attention-
weighted version of our original query input x (2) , including all the other input
elements as its context via the attention weights:
In:
context_vector_2 = attention_weights_2.matmul(values)
print(context_vector_2.shape)
print(context_vector_2)
Out:
:
torch.Size([28])
tensor(torch.Size([28])
tensor([-1.5993, 0.0156, 1.2670, 0.0032, -0.6460, -1.1407, -0.4908, -1.4632
0.4747, 1.1926, 0.4506, -0.7110, 0.0602, 0.7125, -0.1628, -2.0184
0.3838, -2.1188, -0.8136, -1.5694, 0.7934, -0.2911, -1.3640, -0.2366
-0.9564, -0.5265, 0.0624, 1.7084])
Note that this output vector has more dimensions (dv = 28) than the original input
vector (d = 16 ) since we specified dv > d earlier; however, the embedding size
choice is arbitrary.
Multi-Head Attention
In the very first figure, at the top of this article, we saw that transformers use a
module called multi-head attention. How does that relate to the self-attention
mechanism (scaled-dot product attention) we walked through above?
In the scaled dot-product attention, the input sequence was transformed using three
matrices representing the query, key, and value. These three matrices can be
considered as a single attention head in the context of multi-head attention. The
figure below summarizes this single attention head we covered previously:
As its name implies, multi-head attention involves multiple such heads, each
consisting of query, key, and value matrices. This concept is similar to the use of
multiple kernels in convolutional neural networks.
:
To illustrate this in code, suppose we have 3 attention heads, so we now extend the
d ′ × d dimensional weight matrices so 3 × d ′ × d :
In:
h = 3
multihead_W_query = torch.rand(h, d_q, d)
multihead_W_key = torch.rand(h, d_k, d)
multihead_W_value = torch.rand(h, d_v, d)
In:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)
Out:
torch.Size([3, 24])
In:
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
:
Now, these key and value elements are specific to the query element. But, similar to
earlier, we will also need the value and keys for the other sequence elements in
order to compute the attention scores for the query. We can do this is by expanding
the input sequence embeddings to size 3, i.e., the number of attention heads:
In:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)
Out:
Now, we can compute compute all the keys and values using via torch.bmm() (
batch matrix multiplication):
In:
Out:
We now have tensors that represent the eight attention heads in their first
dimension. The third and second dimensions refer to the number of words and the
embedding size, respectively. To make the values and keys more intuitive to
interpret, we will swap the second and third dimensions, resulting in tensors with
the same dimensional structure as the original input sequence,
embedded_sentence :
In:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
:
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)
Out:
Then, we follow the same steps as previously to compute the unscaled attention
weights ω and attention weights α , followed by the scaled-softmax computation to
obtain an h × dv (here: 3 × dv ) dimensional context vector z for the input element
x(2) .
Cross-Attention
In the code walkthrough above, we set dq = dk = 24 and dv = 28. Or in other
words, we used the same dimensions for query and key sequences. While the
value matrix Wv is often chosen to have the same dimension as the query and key
matrices (such as in PyTorch’s MultiHeadAttention class), we can select an arbitrary
number size for the value dimensions.
Since the dimensions are sometimes a bit tricky to keep track of, let’s summarize
everything we have covered so far in the figure below, which depicts the various
tensor sizes for a single attention head.
:
Now, the illustration above corresponds to the self-attention mechanism used in
transformers. One particular flavor of this attention mechanism we have yet to
discuss is cross-attention.
Note that in cross-attention, the two input sequences x 1 and x 2 can have different
:
numbers of elements. However, their embedding dimensions must match.
How does that work in code? Previously, when we implemented the self-attention
mechanism at the beginning of this article, we used the following code to compute
the query of the second input element along with all the keys and values as follows:
In:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
print("embedded_sentence.shape:", embedded_sentence.shape:)
W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
print("query.shape", query_2.shape)
:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
Out:
The only part that changes in cross attention is that we now have a second input
sequence, for example, a second sentence with 8 instead of 6 input elements.
Here, suppose this is a sentence with 8 tokens.
In:
keys = W_key.matmul(embedded_sentence_2.T).T
values = W_value.matmul(embedded_sentence_2.T).T
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
Out:
Notice that compared to self-attention, the keys and values now have 8 instead of 6
rows. Everything else stays the same.
Conclusion
In this article, we saw how self-attention works using a step-by-step coding
approach. We then extended this concept to multi-head attention, the widely used
component of large-language transformers. After discussing self-attention and
multi-head attention, we introduced yet another concept: cross-attention, which is a
flavor of self-attention that we can apply between two different sequences. This is
already a lot of information to take in. Let’s leave the training of a neural network
using this multi-head attention block to a future article.
If you liked this article, you can also find me on Twitter and LinkedIn where I share
more content related to machine learning and AI.
If you are looking for a way to support me and my work, consider purchasing one of
my books or subscribing to the paid version of my free machine learning newsletter.
:
If you find it valuable, please spread the word and recommend it to others.
Q
© 2013-2023 Sebastian Raschka
: