dlc-slides-13-2-attention-mechanisms
dlc-slides-13-2-attention-mechanisms
François Fleuret
https://fleuret.org/dlc/
The most classical version of attention is a context-attention with a dot-product
for attention function, as used by Vaswani et al. (2017) for their transformer
models. We will come back to them.
KQi
∀i, Ai = softmax √
D
Yi = V ⊤ Ai ,
or
QK ⊤
A = softmaxrow √
D
Y = AV .
The queries and keys have the same dimension D, and there are as many keys
T ′ as there are values. The result Y has as many rows T as there are queries,
and they are of same dimension D ′ as the values.
Y
V
K
Q
KQi
Ai = softmax √
D
Y
V
K
Q
Yi = V ⊤ Ai
Y
V
K
Q
KQi
Ai = softmax √
D
Y
V
K
Q
Yi = V ⊤ Ai
Y
V
K
Q
KQi
Ai = softmax √ Yi = V ⊤ Ai
D
Y
V
K
Q
K ·⊤ softmax A · Y
QK ⊤
A = softmaxrow √
D
Y = AV .
Standard attention
(0)
(0)
queries
queries
queries
(0)
⊤
Q = XWQ
⊤
K = X ′W K
⊤
V = X ′W V
QK ⊤
A = softmaxrow √
D
Y = AV
Y
⊤
Q = XWQ
⊤
K = X ′W K
A
⊤
V = X ′W V
QK ⊤
Q
A = softmaxrow √ K V
D
Y = AV
X
Y Y
⊤
Q = XWQ
⊤
K = X ′W K
A A
⊤
V = X ′W V
QK ⊤
Q Q
A = softmaxrow √ K V K V
D
Y = AV
X X X′
Y Y
⊤
Q = XWQ
⊤
K = X ′W K
A A
⊤
V = X ′W V
QK ⊤
Q Q
A = softmaxrow √ K V K V
D
Y = AV
X X X′
Input Target
Input Target
Input Target
Input Target
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input Input
Target Target
25 25
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Sequential(
(0): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,))
(1): ReLU()
(2): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
(3): ReLU()
(4): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
(5): ReLU()
(6): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
(7): ReLU()
(8): Conv1d(64, 1, kernel_size=(5,), stride=(1,), padding=(2,))
)
nb_parameters 62337
batch_size = 100
for e in range(args.nb_epochs):
optimizer.zero_grad()
loss.backward()
optimizer.step()
1400
1200
1000
MSE
800
600
400
200
0
100 101 102
Nb. of epochs
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input Input
Output Output
25 25
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
However it is more natural to equip the model with the ability to combine
information from parts of the signal that it actively identifies as relevant.
The computation of the attention matrix A and the layer’s output Y could also
be expressed somehow more clearly with Einstein summations (see lecture 1.5.
“High dimension tensors”) as
A = torch.einsum('nct,ncs->nts', Q, K).softmax(2)
y = torch.einsum('nts,ncs->nct', A, V)
nb_parameters 54081
1400
1200
1000
MSE
800
600
400
200
0
100 101 102
Nb. of epochs
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input Input
Output Output
25 25
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input
Output
25
80
20
60
15
10 40
5
20
0
0 20 40 60 80 100
0
0 20 40 60 80 100
Input
Output
25
80
20
60
15
10 40
5
20
0
0 20 40 60 80 100
0
0 20 40 60 80 100
Input
Output
25
80
20
60
15
10 40
5
20
0
0 20 40 60 80 100
0
0 20 40 60 80 100
Our toy problem does not require to take into account the positioning in the
tensor. We can modify it with a target where the pairs to average are the two
rightmost and leftmost shapes.
Our toy problem does not require to take into account the positioning in the
tensor. We can modify it with a target where the pairs to average are the two
rightmost and leftmost shapes.
Input Target
Our toy problem does not require to take into account the positioning in the
tensor. We can modify it with a target where the pairs to average are the two
rightmost and leftmost shapes.
Input Target
Our toy problem does not require to take into account the positioning in the
tensor. We can modify it with a target where the pairs to average are the two
rightmost and leftmost shapes.
Input Target
Our toy problem does not require to take into account the positioning in the
tensor. We can modify it with a target where the pairs to average are the two
rightmost and leftmost shapes.
Input Target
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input Input
Target Target
25 25
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
2000
1500
MSE
1000
500
0
100 101 102
Nb. of epochs
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input Input
Output Output
25 25
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
>>> len = 20
>>> c = math.ceil(math.log(len) / math.log(2.0))
>>> o = 2**torch.arange(c).unsqueeze(1)
>>> pe = (torch.arange(len).unsqueeze(0).div(o, rounding_mode = 'floor')) % 2
>>> pe
tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],
[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
>>> pe = pe[None].float()
>>> input = torch.cat((input, pe.expand(input.size(0), -1, -1)), 1)
2000
1500
MSE
1000
500
0
100 101 102
Nb. of epochs
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input Input
Output Output
25 25
20 20
15 15
10 10
5 5
0 0
0 20 40 60 80 100 0 20 40 60 80 100
Input
Output
25
80
20
60
15
10 40
5
20
0
0 20 40 60 80 100
0
0 20 40 60 80 100
Input
Output
25
80
20
60
15
10 40
5
20
0
0 20 40 60 80 100
0
0 20 40 60 80 100
Input
Output
25
80
20
60
15
10 40
5
20
0
0 20 40 60 80 100
0
0 20 40 60 80 100