0% found this document useful (0 votes)
7 views

Hawthorne et al. - 2022 - General-purpose, long-context autoregressive modeling with Perceiver AR

Uploaded by

xlg47311
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
7 views

Hawthorne et al. - 2022 - General-purpose, long-context autoregressive modeling with Perceiver AR

Uploaded by

xlg47311
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 24

General-purpose, long-context autoregressive modeling with Perceiver AR

Curtis Hawthorne * 1 Andrew Jaegle * 2 Cătălina Cangea 2 Sebastian Borgeaud 2 Charlie Nash 2
Mateusz Malinowski 2 Sander Dieleman 2 Oriol Vinyals 2 Matthew Botvinick 2 Ian Simon 1 Hannah Sheahan 2
Neil Zeghidour 1 Jean-Baptiste Alayrac † 2 João Carreira † 2 Jesse Engel † 1

Abstract produced. This simple recipe can in principle be followed


to express any input-output relationship, and breakthrough
arXiv:2202.07765v2 [cs.LG] 14 Jun 2022

Real-world data is high-dimensional: a book,


image, or musical performance can easily con- results have been achieved using Transformers (Vaswani
tain hundreds of thousands of elements even af- et al., 2017) or related models to learn the input → output
ter compression. However, the most commonly mapping (Vinyals et al., 2019; Brown et al., 2020; Jumper
used autoregressive models, Transformers, are et al., 2021; Wu et al., 2021). For this to work, the model
prohibitively expensive to scale to the number must be able to capture patterns in the input that are useful
of inputs and layers needed to capture this long- for predicting the next output.
range structure. We develop Perceiver AR, an Patterns in real-world data often depend on the details of
autoregressive, modality-agnostic architecture many inputs (or “tokens,” each of which represents a row of
which uses cross-attention to map long-range in- an input array), some of which are far away in space, time,
puts to a small number of latents while also main- or setting from the current output. Many pieces of music,
taining end-to-end causal masking. Perceiver AR for example, begin by stating a theme with clear melodic
can directly attend to over a hundred thousand to- and rhythmic elements. Over the piece, these elements are
kens, enabling practical long-context density esti- gradually varied with increasingly elaborate restatements
mation without the need for hand-crafted sparsity designed to draw listeners in. How does the precise timing
patterns or memory mechanisms. When trained of a phrase relate to its antecedents? How does a tonal
on images or music, Perceiver AR generates out- motif persist and develop as it is restated? How does a
puts with clear long-term coherence and structure. new harmonization recontextualize a familiar melody? The
Our architecture also obtains state-of-the-art like- structure of the piece emerges when each component is
lihood on long-sequence benchmarks, including considered alongside many others.
64 × 64 ImageNet images and PG-19 books.
There is a tension between this kind of long-form, contextual
structure and the computational properties of Transformers.
1. Introduction Transformers repeatedly apply a self-attention operation to
their inputs: this leads to computational requirements that
A central goal of artificial intelligence research is the devel- simultaneously grow quadratically with input length and
opment of systems that can identify structure in the world linearly with model depth. As the input data grows longer,
and use it to effectively perform tasks of interest. In the more input tokens are needed to observe it, and as the pat-
past few years, autoregressive (AR) modeling with attention terns in the input data grow more subtle and complicated,
architectures (sometimes referred to metonymically as “lan- more depth is needed to model the patterns that result. Com-
guage modeling”) has emerged as a viable path to achieving putational constraints force users of Transformers to either
this goal. In AR modeling, a set of outputs are generated truncate the inputs to the model (preventing it from observ-
by (i) using a model to map a set of inputs to an output, (ii) ing many kinds of long-range patterns) or restrict the depth
appending that output to the set of inputs, and proceeding of the model (denuding it of the expressive power needed to
again from step (i) until the full set of outputs has been model complex patterns).
*
Equal contribution † Equal contribution 1 Google Re- The goal of this work is to design an architecture that re-
search, Brain Team 2 DeepMind. Correspondence to: tains the well-known benefits of Transformers for autore-
Curtis Hawthorne <[email protected]>, Andrew Jaegle gressive modeling while enabling long-range pattern recog-
<[email protected]>.
nition without adding extraneous complexity. To do this,
Proceedings of the 39 th International Conference on Machine we build on the Perceiver family of attention architectures
Learning, Baltimore, Maryland, USA, PMLR 162, 2022. Copy- (Jaegle et al., 2021; 2022), which have demonstrated excel-
right 2022 by the author(s).
Perceiver AR

A R <EOS> with a single output element, (ii) using causally masked


Targets cross-attention to allow each latent to attend only to input
(shifted inputs)
elements that precede it in sequence, and (iii) using causally
masked self-attention throughout the latent processing stack
to preserve this autoregressive dependency structure end-
self-attend V to-end. These changes allow Perceiver AR’s outputs to be

Attention
Mask
L layers decoded in autoregressive sequence while preserving the
Latent

essential computational and memory benefits of other Per-


K Q ceiver architectures. As each of the architecture’s outputs
is conditioned on all prior inputs, the architecture is well-
positioned to capture long-range dependencies.
Latents
We show that this architecture produces excellent results on
several real-world domains with long-range context: RGB-
level images (Section 5.2), tokenized language (Sections 5.3
P
to 5.5), and audio or symbolic music (Section 5.6). In-
e Cross-Attention Mask put lengths for these tasks are summarized in Table 1. We
V r
c demonstrate that Perceiver AR can learn to perfectly recog-
e nize long-context patterns over distances of at least 100k
i tokens on a synthetic copy task with known ground-truth
v
structure (Section 5.1.1). We highlight several intriguing
Cross-attend

e
K r
properties that result from decoupling input size from com-
A pute requirements: by keeping the long context, but chang-
R ing the number of latents, we can (i) increase or decrease
r A R
computational load at test time for improved (but slower) or
Q faster (but somewhat worse) results (Section 5.2.1) or (ii)
trade off model capacity against batch size at train time with
no effect on test-time performance (Section 5.1.2).
We make the following contributions:
P e r c e i v e r A R
Inputs • We introduce Perceiver AR, an efficient, domain-
agnostic architecture for autoregressive generation that
Figure 1. Perceiver AR maps inputs (X ∈ RM ×C ; M = 11 can directly attend to over a hundred thousand tokens.
shown) to a small latent (Z ∈ RN ×C ; N = 3 shown) by cross-
• We demonstrate the utility of using long contexts for
attention, querying with the N most recent inputs to produce one
autoregressive generation: Perceiver AR obtains state-
latent for each target. Latents subsequently interact via a deep
stack of L self-attention layers to produce estimates for each target. of-the-art results on ImageNet and Project Gutenberg
Causal masking is used in both cross- and self-attention to maintain density estimation and produces samples with a high
end-to-end autoregressive ordering. degree of coherence and fidelity on several challenging
generation tasks (images, symbolic music, and audio).
lent performance on a wide range of large-context domains.
Perceivers use cross-attention to map a full input array into • We explore the benefits of decoupling the compu-
a smaller latent array and perform all subsequent attention tational requirements of handling long inputs from
operations in the resulting latent space. those of model depth: improved efficiency compared
to the widely used decoder-only Transformer and
This decouples the computational requirements of process-
Transformer-XL architectures and the ability to vary
ing a large input array from those required to make a net-
the compute used at test time to match a target budget.
work very deep, allowing deep Perceivers to be used on
large numbers of inputs. However, because each model
Model code is available at https://github.com/
latent attends to all inputs regardless of position, Perceivers
google-research/perceiver-ar.
cannot be used directly for autoregressive generation, which
requires that each model output depend only on inputs that
precede it in sequence. 2. Autoregression and long-context modeling
Perceiver AR solves this problem with three fixes: (i) intro- Autoregressive models (e.g. Schmidhuber & Heil 1994;
ducing an ordering to the latents by identifying each latent Rosenfeld 2000; Bengio et al. 2003; Graves 2013; van den
Perceiver AR

Task Input Positions Transformer


A R <EOS> Targets
Copy (Section 5.1) 131,072
ImageNet (Section 5.2) 12,289 Context limited by
self-attention width Causally
PG-19 (Section 5.3) 4,096 masked
self-attention
Books (Section 5.4) 16,384
Wikitext-103 (Section 5.5) 8,192
Symbolic Music (Section 5.6) 65,536 P e r c e i v e r A R Inputs

Music Audio (Section 5.6) 65,536


Transformer-XL
Table 1. Maximum number of input (key-value) positions used for A R <EOS> Targets

tasks in this paper. Most models used 1024 query positions. Context expanded by
cached attention
Causally
No gradients masked
self-attention

Oord et al. 2016b; Uria et al. 2016) estimate the density of


an example X ∈ RM by decomposing it sequentially using P e r c e i v e r A R Inputs
the chain rule of probability:
Perceiver AR
M
Y −1   A R <EOS> Targets

p(X) = p Xm X<m . (1)


Causally
m=0 Context expanded by masked
cross-attention self-attention

Each Xm is typically a token (an audio sample, an RGB Causally


masked
channel, a character, etc.). The chain rule tells us that the cross-attention
P e r c e i v e r A R Inputs
density of an input X composed of arbitrarily many to-
kens can be estimated by sequentially estimating the con-
ditional density of each of the X’s tokens. This requires Figure 2. Perceiver AR compared to a standard decoder-only Trans-
that p(Xm | . . .) be conditioned on all prior tokens that are former (Liu et al., 2018) and Transformer-XL (train-time config-
useful for predicting the mth token. uration) (Dai et al., 2019) during training. Only the subset of
configurations that fit in device memory are shown. For a given
Conditioning on all relevant inputs is challenging due to the self-attention stack width (N = 3 shown here), the Transformer
difficulty of scaling standard models beyond small window can process only the same number of input tokens. Transformer-
lengths. In practice, models must be carefully designed XL incorporates more context while maintaining the width of
so that the tokens that can be included in the context are the processing stack by caching and computing only the forward
as relevant as possible. This usually means that spatially pass for longer-range inputs. In practice, Transformer-XL can
or temporally local tokens are included in the context, but incorporate only a moderate amount of additional context (see Fig-
longer-range context is ignored or subsampled. As the con- ure 3). Perceiver AR uses a single masked cross-attend to enable
text becomes smaller, this leads to worse approximations training on much longer input contexts without requiring a wider
self-attention stack.
for signals that contain long-term dependencies.
The latent array Z1 is typically small (N < M ), so it is
Perceiver AR is designed to incorporate longer-term context amenable to processing by more self-attention modules:
into powerful density estimators, thereby allowing models to
contextualize on more of a signal and giving better flexibility Zl+1 ← SelfAttend(Zl , Zl ). (3)
on how data is processed, moving us closer to the goal of
general-purpose density modeling. This operation does not depend on the number of input
points M , and N can be chosen so that this operation is
3. Perceiver AR affordable and repeated for many layers l ∈ [1, L]. This
strategy leads to an architecture (Figure 1) with complexity
Perceiver AR follows Perceiver and Perceiver IO in address- O(M N )+O(LN 2 ) due to the cross-attention and the latent
ing the problem of large inputs using a Transformer-style self-attention stack, respectively (Jaegle et al., 2022). For
cross-attention module to map inputs X ∈ RM ×C (C is deep networks, the self-attention stack is where the bulk of
the number of channels) to a smaller number of latents compute occurs.
Z1 ∈ RN ×C :
But reducing the number of points from M to N prevents us
from establishing the causal dependency between all input
Z1 ← CrossAttend(X, Z0 ) (2) and output points used by Transformers for autoregressive
Perceiver AR

modeling. Perceiver AR adapts the latents for autoregres- 4.2. Relationship to other scalable attention
sive modeling by introducing causal masking to both the architectures
input cross-attention and to the latent self-attention layers
Perceiver AR limits the computational requirements of pro-
(Figure 1) and assigning one latent to each of the final N
cessing long sequences by avoiding the use of one compu-
points of the input (i.e. those with the largest number of an-
tational node for each input element. This strategy is also
tecedents; N = 3 in Figure 1). The influence of inputs that
a feature of other architectures, like Set Transformer (Lee
come after a given latent are masked at the cross-attention
et al., 2019) and Luna (Ma et al., 2021), that use differ-
and all self-attention layers.
ently shaped query and key-value inputs to limit the cost
To see how this works, consider the second (middle) latent of attention and produce “downsampled” intermediate rep-
in Figure 1: this latent is constructed by querying the input resentations that are smaller than the input’s size. Unlike
with A’s embedding. Because R follows A in sequence, R’s the Perceiver family of architectures, both Set Transformer
embedding is masked out, both at the cross-attention and at and Luna alternate downsampling and upsampling atten-
the subsequent self-attention layers. tion operations throughout the architecture, which results
in an architecture that mitigates some of the cost of stan-
This procedure allows models to attend to long contexts
dard Transformers, but still has compute requirements mul-
while also preserving the autoregressive ordering of the tar-
tiplicative in input size and model depth. Of this family of
gets through the entire network. The same procedure can be
architectures, Luna is most similar, as it is compatible with
applied to any input that can be ordered, as long as masking
(causally masked) autoregressive modeling.
is applied. For example, an image’s RGB channels can be
ordered in raster scan order, by decoding the R, G, and B Several other architectures (Dai et al., 2020; Nawrot et al.,
color channels for each pixel in the sequence (Section 5.2) 2021; Clark et al., 2021) reduce the processing requirements
or even under different permutations (Section 5.2.2). of Transformers by sequentially compressing the input with
attention or convolution, but rely on the use of multiple
See Appendix C for in-depth mathematical description
large, quadratic-complexity attention layers or exploit local-
of Perceivers and the Perceiver AR architecture and Ap-
ity assumptions that limit their generality. This makes them
pendix E for additional technical details.
more efficient when applied to input chunks of similar size
as normal Transformers (i.e. a few thousand), but reduces
4. Related work their utility for longer contexts.
4.1. Relationship to Transformer and Transformer-XL
Perceiver AR decouples the length of the input from the
computational requirements of the model, which are primar- 6 Perceiver AR
ily controlled by the number of latents. In contrast, standard Transformer-XL
Transformer
decoder-only Transformer architectures (Figure 2) maintain 6 layers
5 12 layers
Training steps/second

a one-to-one correspondence between inputs and outputs


throughout the network, leading to O(LM 2 ) scaling. 18 layers
24 layers
4 30 layers
Perceiver AR also scales better to longer context in prac- 36 layers
tice than Transformer-XL, perhaps the most commonly 42 layers
used method for extending context length in practice. 3
Transformer-XL incorporates longer context by allowing
attention over long-context positions at every layer in the 2
forward pass and stopping gradient propagation to these
positions in the backward pass. Transformer-XL also typi-
1024 2048 4096 8192 16384 32768 65536
cally uses fewer positions at train than at test time, which Context length
further improves scaling at train time. But even with these
modifications, the input size and model depth are still cou- Figure 3. Training speed comparison on TPUv3 for Perceiver
pled, and as the context and depth increase, the forward pass AR compared to a standard decoder-only Transformer and
becomes a compute and memory bottleneck. In practical Transformer-XL. The Transformer is limited to a context length of
terms, Perceiver AR scales better when both the context 2,048 tokens, even with only 6 layers—larger models and larger
size and model depth increase. In Figure 3, we compare context length require too much memory. Using the same 6-layer
the wall-clock time per step of a decoder-only Transformer, configuration, we can scale the Transformer-XL memory to a total
Transformer-XL, and Perceiver AR in our codebase. context length of 8,192. Perceiver AR scales to 65k context length,
and can be scaled to over 100k context with further optimization.
Perceiver AR

4.3. Relationship to encoder-decoder architectures predict the mirrored tokens plus [EOS] of 12 unseen val-
idation sequences (a total of 786,432 tokens) with 100%
Perceiver AR bears a resemblance to encoder-decoder Trans-
accuracy. This experiment demonstrates that the model can
formers (Vaswani et al., 2017), seq2seq (Sutskever et al.,
successfully attend to individual tokens within very long
2014), and other encoder-decoder models. Encoder-decoder
sequences and that a relatively small training signal of 1024
models pass long-context inputs into an encoder stack (e.g.
targets per sequence can successfully propagate through the
Perceive, as in Figures 1 and 2) and use the outputs of
cross-attend bottleneck even when most of the inputs are
the encoder to contextualize the immediate antecedents of
irrelevant for a given target. As an historical aside, we note
each output (e.g. rAR), which are processed by a separate
that one antecedent of this experiment is the copy task used
decoder stack. In contrast, Perceiver AR passes both inputs
to validate the Neural Turing Machine (Graves et al., 2014)
and targets through a single, shared processing stack. This
and Differentiable Neural Computer (Graves et al., 2016).
allows each target to learn how to use long-context and re-
These models showed nearly perfect accuracy when copying
cent input as needed, with minimal architectural restrictions.
up to about length-50 sequences of random 26 - or 28 -bit
From this point of view, Perceiver AR can be viewed as
inputs, while Perceiver AR shows perfect accuracy when
an encoder-decoder architecture with 0 encoder layers. By
copying sequences of 216 random 28 -bit tokenized inputs.
handling all inputs with a single (causally-masked) cross-
attention, Perceiver AR sidesteps the need for separate en-
5.1.2. N UMBER OF TRAINING TARGETS
coder and decoder stacks.
In a standard decoder-only Transformer model, the input
See Appendix D for an extended discussion of Perceiver
length, number of latent processing nodes per layer, and
AR’s relationship to other methods for long-context
number of training targets are always the same. Perceiver
modeling by efficient attention (D.1), architecture de-
AR allows the flexibility of any ratio of input length to
sign (D.2), and input tokenization (D.3).
number of latents and training targets. Changing the width
of the self-attention stack affects the expressivity of the
5. Results network and also alters the number of training outputs for
which loss can be computed per sequence in a batch.
We evaluate Perceiver AR on a number of different do-
mains to demonstrate its flexibility and evaluate its ability To illustrate this effect, we trained models on the copy task
to capture long-range structure on a variety of data. In all with a context length of 8,192 tokens using different num-
experiments except where mentioned, we use pre-layernorm bers of latent nodes and batch sizes (Table 2). To reduce the
attention modules (Xiong et al., 2020) and squared-ReLU effect of network expressivity, we use only 1 self-attention
nonlinearities (So et al., 2021). layer for these experiments. We found that after training
for 25k steps with 1024 latents and a batch size of 128, the
For each domain, we tuned models against eval perplexity
model converged and predicted the second half of sequences
with ad hoc hyperparameter sweeps as our compute permit-
in the test set with 100% accuracy. If we reduced the batch
ted. We typically tuned channel size, head dimensions, and
size to 64, the model did not converge and achieved < 1%
model depth. See each domain’s section and appendices F
accuracy on the test set. However, if we kept the batch size
and G for more details.
at 64 and increased the number of latents (and therefore
training targets per sequence) to 2048, the model success-
5.1. Copy Task fully converged.
5.1.1. L ONG INPUT LENGTH
We first verify that the architecture can attend to very long
1024 latents 2048 latents
input lengths on a synthetic copy task. Using a model with
only 6 layers of 1024 latents, we provide an input with 64 batch 7 3
length 217 (131,072). The model is trained on sequences 128 batch 3 3
containing a [BOS] (Beginning of Sequence) token fol-
lowed by 65,535 random bytes (encoded as tokens taking Table 2. Convergence on the copy task with a context length of
one of 256 values). Those random bytes are then mirrored 8,192 tokens. Either increasing the number of latents (and therefore
for the second half of the sequence and followed by an number of targets) or batch size has a similar effect, as discussed
[EOS] (End of Sequence) token. This results in a maxi- in Section 5.1.2.
mum copy distance of 217 − 2 tokens. Train and eval loss
are calculated on only the second half of the sequence to
Depending on memory, compute, and expressivity require-
avoid training on noise.
ments, increasing the batch size or number of latents may
After training for 25k steps, the model was able to correctly be better: Perceiver AR enables this flexibility.
Perceiver AR

16 latents 1024 latents 1536 latents

Figure 4. Representative samples from the ImageNet model. The 4 images on the left were generated using only 16 latents, the middle 8
with 1024 latents (same as train time), and the right 4 with 1536 latents. The same 60-layer model is used for all configurations and all can
attend to the full sequence. The full batches from which these images were drawn are shown in Appendix A.

Model Type Bits/Dim # Latents Bits/Dim Generation (minutes)


PixelCNN AR 3.57 16 3.5664 1.99
Sparse Transformer AR 3.44 64 3.4946 2.02
Routing Transformer AR 3.43 512 3.4139 2.81
Combiner AR 3.42 1024 3.4025 3.68
VDM Diff 3.40 1536 3.3986 4.69
Perceiver AR (ours) AR 3.40 2048 3.4018 5.88
4096 3.5569 12.28
Table 3. Results on Downsampled ImageNet (64 × 64) density
estimation in bits/dim on the validation set, lower is better. We Table 4. Results on Downsampled ImageNet (64 × 64) density
compare our results against the autoregressive models PixelCNN estimation in bits/dim on the validation set (lower is better) when
(van den Oord et al., 2016b), Sparse Transformer (Child et al., changing the number of latents used at eval time (1024 used at train
2019), Routing Transformer (Roy et al., 2021), and Combiner time). Generation time is how long a single image takes to infer on
(Ren et al., 2021), and the diffusion model Variational Diffusion a TPUv3 core using activation caching described in Appendix E.3.
Model (Kingma et al., 2021). The same 60-layer model is used for all configurations and all can
attend to the full sequence.

5.2. ImageNet 64 × 64 have the intriguing option of evaluating with a different


number of latents than were used during training, while still
To test this architecture’s capabilities in the image modality,
maintaining the ability to attend to the full input sequence
we use the downsampled ImageNet dataset (van den Oord
(illustrated further in Figure 12).
et al., 2016b) at the 64 × 64 resolution. Similar to the
training procedure used in Sparse Transformer (Child et al.,
2019), we flatten the image into a sequence of color channel Input Context Standard Ordering R→G→B
bytes (256 possible values) for each pixel in raster scan 1024 3.55 4.63
order. After adding a [BOS] token to the beginning of the 12289 3.54 3.53
sequence, this results in a length of 12,289. Each input
has 3 randomly initialized position embeddings added to it Table 5. Impact of context length and sequence ordering for Ima-
for row (64), column (64), and color channel (3). No data geNet (64 × 64). Results are bits/dim on the validation set, lower
augmentation is used. is better. ImageNet examples contain 12,288 (64 × 64 × 3) tokens.
For short contexts, the sequence ordering has a large effect on
We train a model with 1024 latents in 60 self-attention layers performance. For long contexts, the effect is small.
after the initial cross-attend. After 750k steps, we achieve
We found that increasing the number of latents up to 2x the
3.40 bits/dim on the validation set, exceeding the perfor-
number used during training improves model performance,
mance of previous autoregressive models (Table 3). Gener-
and decreasing the number latents results in gracefully de-
ated images samples are in Figure 4 and Appendix A.
grading performance despite dramatic reductions in com-
pute requirements, as seen in Table 4. For example, when
5.2.1. VARYING COMPUTE AT TEST TIME
using only 16 latents to attend to the full input sequence
Because Perceiver AR decouples input length from the num- of 12,289 tokens, the model achieves the same bits/dim as
ber of latents in the self-attention stack and because no PixelCNN (van den Oord et al., 2016b). This kind of flex-
position-specific parameters are learned in this model, we ibility enables a single model to be used in scenarios with
Perceiver AR

Model Context length # layers Val ppl. Test ppl.


Transformer-XL (Rae et al., 2019) 512+1024 36 45.5 36.3
Compressive Transformer (Rae et al., 2019) 512+512+2x512 36 43.4 33.6
Routing Transformer (Roy et al., 2021) 8192 22 - 33.2
Perceiver AR (ours) 2048 60 45.9 28.9
Perceiver AR (ours) 4096 60 45.9 29.0

Table 6. Results on PG-19 language modeling. Results are shown in perplexity (ppl.), lower is better. Baseline results are reproduced
from the original papers. Routing Transformer does not report validation set performance. The context lengths shown include memory
(Transformer-XL and Compressive Transformer) and compressive memory (the latter only).

varying compute, latency, and quality requirements, without and scraped from the Project Gutenberg eBook depository.
requiring additional training or a distillation process. We use PG-19 as it is publicly available and contains a large
number of words (1.97B train, 3.01M validation, 6.97M
We suspect these models could be made even more flexible
test) drawn from a reasonably large number of books (28k
to the number of latents used during evaluation or inference
train, 50 validation, 10 test). We evaluate Perceiver AR
if variable latent access is incorporated into training, but we
on PG-19 using Subword-tokenized inputs as in prior work
leave a full exploration of these ideas to future work.
(Rae et al., 2019; Roy et al., 2021). We use the Subword
tokenization settings reported in section 4.2 of Rae et al.
5.2.2. I MPACT OF SEQUENCE ORDERING AND LENGTH
(2019), and we trained models until the loss converged on
We test Perceiver AR’s ability to utilize context beyond a subset of the validation set (after training on about 200k
the width of its self-attention stack by training on image steps at batch size 2048, or about 420B total tokens).
sequences with strong long-range dependencies. Typically
We compare Perceiver AR to state-of-the-art numbers from
autoregressive models of image data use a raster scan or-
the literature (Table 6). The best models reported by these
dering, where RGB subpixels in each image location are
papers use 36 (Compressive Transformer; Rae et al. 2019)
generated in sequence before moving on to the next loca-
and 22 layers (Routing Transformer; Roy et al. 2021). We
tion. We re-order ImageNet image data so that all the red
find that Perceiver AR outperforms state-of-the-art methods
subpixels are predicted in sequence, then the green, then the
on this dataset when using the same tokenization. How-
blue. This induces strong long range dependencies between
ever, consistent with previous work, we saw no evidence of
subpixels in the same spatial location.
improvement beyond 2k tokens on PG-19 (Sun et al., 2021).
We train 16-layer versions of the models in Section 5.2 for
Motivated by early experiments, both models use large input
30k steps (due to compute constraints), and compare short
embeddings (4096), large numbers of cross-attend heads
context models (1024 inputs) to full context models (12,289
(128), and high cross-attend dropout (0.96875 for the 4096-
inputs). As a baseline we also train both model variants on
context model and 0.875 for the 2048-context model; see
sequences with the standard ordering. The results are shown
Appendix E.2). The results presented in Table 6 suggest that
in Table 5. We find that for the standard ordering, both
simply scaling Perceiver AR can produce excellent results
small and long context models perform similarly. However,
even with very high levels of cross-attend dropout.
on the re-ordered image data, the short context model has
significantly worse performance. Whereas the long context Consistent with prior work, we notice that models exhibit a
model has comparable performance to the standard ordering consistent drop in performance between PG-19 validation
baselines. This indicates that our long context model is able and test sets. This is likely due to the relatively small number
to access and process information at distant timesteps. of books contained in the PG-19 validation (50 books) and
test (100 books) sets. Despite the large number of tokens,
For general domains, we often do not know which long-
the small number of books on the validation set will lead
range dependencies are important, and this experiment
many of these tokens to exhibit shared content and style,
shows that the performance impact of missing existing de-
limiting how representative the validation set can be of the
pendencies can be severe. Perceiver AR provides a solution
task as a whole. As noted by Sun et al. (2021), the PG-19
by extending Transformer context size in a scalable way.
validation set contains at least one book (out of 50) that is
arguably out of distribution for the train set.
5.3. Project Gutenberg (PG-19)
To better understand the effect of context in language model-
We next test the architecture for language modeling on the ing, we next evaluate on a large internal book dataset where
Project Gutenberg (PG-19) dataset (Rae et al., 2019), a the number of documents is not a concern (Section 5.4).
collection of English-language books published before 1919
Perceiver AR

Model Context Depth Steps/Sec Eval ppl. Model Context Depth Steps/Sec Eval ppl.
AR 1024 36 2.19 14.006 T-XL 1024 42-layer 1.17 13.253
T-XL 1024 23 2.17 14.822
AR 1024 62-layer 1.19 12.849
AR 4096 36 2.09 13.806 AR 4096 61-layer 1.21 12.680
T-XL 1024 24 2.06 14.719 AR 8192 60-layer 1.25 12.660
AR 16384 56-layer 1.26 12.816
AR 8192 36 1.95 13.791
T-XL 1024 25 1.97 14.593
Table 9. Books results on the test set, shown in perplexity (ppl.),
AR 16384 36 1.75 13.749 lower is better, from a 42-layer Transformer-XL model and Per-
T-XL 1024 28 1.76 14.276 ceiver AR models with varying contexts and depth matched for
compute (steps/sec), trained for 500k steps.
Table 8. Books results on the test set, shown in perplexity (ppl.), possible models for a specific context length. Even in this
lower is better, from 36-layer Perceiver AR models with vary- scenario, all Perceiver AR models performed better than the
ing contexts and Transformer-XL models with depth matched for
deepest Transformer-XL.
compute (steps/sec), trained for 500k steps.

5.5. Wikitext-103
5.4. Books
We further evaluate on the Wikitext-103 (Merity et al., 2017)
Here, we study the usefulness of longer input context by dataset, a commonly used word-level language modeling
training on an internal dataset containing 4 million books benchmark. The dataset consists of 28,475 Wikipedia arti-
published between 1500 and 2008. This dataset was pre- cles containing between 68 and 26,993 words, averaging at
viously used in Rae et al. 2021 as part of the MassiveTest 3.6k words. Wikitext-103 is a small dataset where strong
dataset. We train models with 1024 latents, 36 layers and regularization is required to prevent severe overfitting and
{1024, 4096, 8192, 16384} input context tokens. to obtain good performance. Nonetheless, we obtain com-
petitive results, suggesting Perceiver AR’s utility even for
These results (Table 7) show a clear trend: better results are relatively small datasets.
obtained with contexts longer than 1024.
Model Valid ppl. Test ppl.
Model Context Eval ppl. Train Steps/sec
Adaptive inputs 18.0 18.7
Perceiver AR 1024 14.88 2.19 Transformer-XL 18.3 18.2
Perceiver AR 4096 14.60 2.09 Shortformer 17.5 18.15
Perceiver AR 8192 14.57 1.95 Compressive Transformer 16.0 17.1
Perceiver AR 16384 14.56 1.75 Routing Transformer - 15.8

Table 7. Books results, shown in perplexity (ppl.), lower is better. Transformer-XL (ours) 17.58 18.42
Perceiver AR (1024) 17.86 18.52
We also run experiments where we compare our model Perceiver AR (2048) 17.60 18.35
against a Transformer-XL baseline, with compute (mea- Perceiver AR (4096) 17.66 18.25
sured by steps per second) matched as closely as possible. Perceiver AR (8192) 17.58 18.37
Models are trained with a batch size of 256 for 500k steps.
First, we evaluate 36-layer Perceiver AR models with Table 10. Results on the Wikitext-103 language modeling bench-
{1024, 4096, 8192, 16384} context lengths, respectively. mark. Baseline results are reproduced from the original papers:
Adaptive inputs (Baevski & Auli, 2018), Transformer-XL (Dai
Each of them is matched with a Transformer-XL trained
et al., 2019), Shortformer (Press et al., 2021), Compressive Trans-
on sequences of length 1024 and {23, 24, 25, 28} layers, former (Rae et al., 2019), and Routing Transformer (Roy et al.,
respectively. Table 8 shows that our model improves consis- 2021). Transformer-XL (ours) is a reimplementation in our code-
tently over the Transformer-XL in this controlled scenario. base. We train Perceiver AR models with varying context lengths.
In Table 9, we then compare the deepest Transformer-XL We show results in Table 10. Perceiver AR performs on par
model (42 layers) that fits in memory against 4 Perceiver with Transformer-XL Large (Dai et al., 2019) and Short-
AR models with the same context lengths and respective former (Press et al., 2021). We also include the state of the
numbers of self-attention layers {62, 61, 60, 56}. Remark- art results on Wikitext-103 (trained on Wikitext-103 only)
ably, we were not able to increase the number of layers for (Rae et al., 2019; Roy et al., 2021). Increasing the context
some AR models—and achieve a closer match in compute— length from 1,024 to 2,048 for Perceiver AR does improve
without running out of memory. Instead, we ran the deepest perplexity but further increasing to 4,096 or 8,192 is harm-
Perceiver AR

Model MAESTRO Test Validation SoundStream bitrate Context Test Validation


Music Transformer v1 - 1.84 12kbps 54.4s 2.49 2.34
Perceiver AR v1 1.82 1.82 18kbps 36.8s 2.60 2.57
Perceiver AR v3 1.91 1.90 22kbps 29.6s 2.65 2.62

Table 11. Results on MAESTRO symbolic music generation on Table 12. Perceiver AR negative log-likelihood results on Sound-
Test and Validation datasets. Results are shown in negative Stream audio generation, for a fixed context length of 65536.
log-likelihood, lower is better. Baseline result is reproduced
limit are unlikely to contain song content; at the other end,
from (Hawthorne et al., 2019).
there are only about 200 pieces longer than 32,768 tokens.
The same tokenization as for MAESTRO is used. We train
ful. This effect has been noted in previous work (Press et al., a model with 1024 latents and 24 self-attention layers on
2021; Sun et al., 2021) and further provides evidence that input contexts of 32,768 tokens, achieving a negative log-
current language model performance on small benchmarks likelihood of 1.24 on the test set.
is not bottlenecked by their limited range context. We generate samples from the models trained on this new
task, as well samples from MAESTRO v3 symbolic and
5.6. MAESTRO SoundStream models. The samples obtained from the large
We also evaluate Perceiver AR in the music domain, transcription dataset exhibit stylistic and structural coher-
with samples available in the online supplement (https: ence that spans several minutes, containing repeating musi-
//bit.ly/3uF5LJg). Here, modeling long-range de- cal themes, chord patterns, arpeggios and even ritardandos.
pendencies is essential for obtaining good representa- The audio domain samples exhibit the same tradeoff of au-
tions (Huang et al., 2019). We use the MAESTRO dio fidelity and long-term structure seen in the previous
dataset (Hawthorne et al., 2019), which contains approxi- section, but demonstrate strong correspondence to the orig-
mately 200 hours of music in both symbolic (MIDI) and inal domain. Rendered audio samples are available in the
audio modalities. For symbolic data, we experiment with online supplement: https://bit.ly/3uF5LJg.
both v1 and v3 versions of MAESTRO. The former allows
us to compare with the existing state-of-the-art; the latter 6. Conclusion
is an improved set with approximately 10% more perfor-
mances. We use the MIDI tokenization described in Section We introduce Perceiver AR, an architecture designed for
A.2 of Huang et al. (2019). long-context autoregressive modeling. Perceiver AR scales
to longer input sizes than the architectures typically used in
After training for 1M steps with an input context of 4096 to- practice (Transformers and Transformer-XL), while scaling
kens and 2048 latents in 12 self-attention layers, our model to the depths needed for density estimation on real-world
obtains a lower negative log-likelihood (Table 11) than Mu- data. Perceiver AR decouples the computational require-
sic Transformer (Huang et al., 2019; Hawthorne et al., 2019). ments of processing many inputs from those of building
That model had 6 layers and was trained on random crops of deep networks. This gives us more control over how much
2048 tokens, but benefited from data augmentation, which compute is used for a given model at test time and allows us
ours did not. We also report results on MAESTRO v3. to smoothly trade off speed against performance. Perceiver
For audio data, we generate vector-quantized embeddings AR produces good results on a number of domains. Our ex-
using the SoundStream neural audio codec (Zeghidour periments suggest that the larger context used by Perceiver
et al., 2021) at several bitrates. Lower-bitrate codecs model AR may open the door for more flexible autoregressive or-
coarser structure and enable training on a longer time win- dering strategies. Perceiver AR is a good candidate for a
dow for a fixed context length, but at the expense of lower general-purpose, long-context autoregressive model.
audio fidelity. We show results in Table 12.
7. Acknowledgements
5.7. Music samples
Thanks to Chitwan Saharia for help with SR3 upsampling,
To showcase the long-term coherence of these models, we Christos Kaplanis for Books guidance, and Jordan Hoff-
introduce another symbolic task involving a larger, private mann and Richard Tanburn for help with RoPE. Thanks to
dataset of piano performances (Simon et al., 2019). These Daniel Toyama, Douglas Eck, Adam Roberts, Nando de Fre-
were transcribed from 10,000+ hours of audio, using a vari- itas, Dani Yogatama, Josh Gardner, Catalin Ionescu, Skanda
ation of the Onsets and Frames model (Hawthorne et al., Koppula, Rabeeh Karimi Mahabadi, Ethan Manilow, Vior-
2018). From this dataset, we only used pieces resulting ica Patraucean, Oleh Rybkin, Nikolay Savinov, and others
in 1024–32,768 tokens. Samples shorter than the lower at DeepMind and Google Research for suggestions.
Perceiver AR

References Graves, A. Generating sequences with recurrent neural


networks. arXiv preprint arXiv:1308.0850, 2013.
Baevski, A. and Auli, M. Adaptive input representations
for neural language modeling. In Proceedings of the
Graves, A., Wayne, G., and Danihelka, I. Neural Turing
International Conference on Learning Representations
Machines. arXiv preprint arXiv:1410.5401, 2014.
(ICLR), 2018.

Bengio, Y., Ducharme, R., Vincent, P., and Jauvin, C. A Graves, A., Wayne, G., Reynolds, M., Harley, T., Dani-
neural probabilistic language model. Journal of Machine helka, I., Grabska-Barwińska, A., Colmenarejo, S. G.,
Learning Research (JMLR), 2003. Grefenestette, E., Ramalho, T., Agapiou, J., Badia, A. P.,
Hermann, K. M., Zwols, Y., Ostrovski, G., Cain, A., King,
Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, H., Summerfield, C., Blunsom, P., Kavukcuoglu, K., and
J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Hassabis, D. Hybrid computing using a neural network
Askell, A., et al. Language models are few-shot learners. with dynamic external memory. Nature, 538:471–476,
arXiv preprint arXiv:2005.14165, 2020. 2016.

Chen, M., Radford, A., Child, R., Wu, J., Jun, H., Luan, D., Gu, A., Goel, K., and Ré, C. Efficiently modeling long
and Sutskever, I. Generative pretraining from pixels. In sequences with structured state spaces. arXiv preprint
Proceedings of the International Conference on Machine arXiv:2111.00396, 2021.
Learning (ICML), 2020.
Harris, C. R., Millman, K. J., van der Walt, S. J., Gommers,
Child, R., Gray, S., Radford, A., and Sutskever, I. Gener-
R., Virtanen, P., Cournapeau, D., Wieser, E., Taylor, J.,
ating long sequences with sparse Transformers. arXiv
Berg, S., Smith, N. J., Kern, R., Picus, M., Hoyer, S., van
preprint arXiv:1904.10509, 2019.
Kerkwijk, M. H., Brett, M., Haldane, A., del Rı́o, J. F.,
Choromanski, K. M., Likhosherstov, V., Dohan, D., Song, Wiebe, M., Peterson, P., Gérard-Marchant, P., Sheppard,
X., Gane, A., Sarlos, T., Hawkins, P., Davis, J. Q., Mo- K., Reddy, T., Weckesser, W., Abbasi, H., Gohlke, C.,
hiuddin, A., Kaiser, L., Belanger, D. B., Colwell, L. J., and Oliphant, T. E. Array programming with NumPy.
and Weller, A. Rethinking attention with Performers. In Nature, 585(7825):357–362, 2020.
Proceedings of the International Conference on Learning
Representations (ICLR). Hawthorne, C., Elsen, E., Song, J., Roberts, A., Simon, I.,
Raffel, C., Engel, J., Oore, S., and Eck, D. Onsets and
Clark, J. H., Garrette, D., Turc, I., and Wieting, J. CANINE: frames: Dual-objective piano transcription. In Proceed-
pre-training an efficient tokenization-free encoder for lan- ings of the International Society for Music Information
guage representation. arXiv preprint arXiv:2103.06874, Retrieval Conference (ISMIR), 2018.
2021.
Hawthorne, C., Stasyuk, A., Roberts, A., Simon, I., Huang,
Dai, Z., Yang, Z., Yang, Y., Carbonell, J., Le, Q. V., and C.-Z. A., Dieleman, S., Elsen, E., Engel, J., and Eck, D.
Salakhutdinov, R. Transformer-XL: Attentive language Enabling factorized piano music modeling and genera-
models beyond a fixed-length context. In Proceedings tion with the MAESTRO dataset. In Proceedings of the
of the Annual Meetings of the Association for Computa- International Conference on Learning Representations
tional Linguistics (ACL), 2019. (ICLR), 2019.
Dai, Z., Lai, G., Yang, Y., and Le, Q. Funnel-Transformer:
Hendrycks, D. and Gimpel, K. Gaussian error linear units
Filtering out sequential redundancy for efficient language
(GELUs). arXiv preprint arXiv:1606.08415, 2016.
processing. In Proceedings of Neural Information Pro-
cessing Systems (NeurIPS), 2020.
Hessel, M., Budden, D., Viola, F., Rosca, M., Sezener,
Dhariwal, P., Jun, H., Payne, C., Kim, J. W., Radford, A., E., and Hennigan, T. Optax: composable gradient
and Sutskever, I. Jukebox: A generative model for music. transformation and optimisation, in JAX!, 2020. URL
arXiv preprint arXiv:2005.00341, 2020. http://github.com/deepmind/optax.

Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, Huang, C.-Z. A., Vaswani, A., Uszkoreit, J., Shazeer, N.,
D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Simon, I., Hawthorne, C., Dai, A. M., Hoffman, M. D.,
Heigold, G., Gelly, S., et al. An image is worth 16x16 Dinculescu, M., and Eck, D. Music Transformer: Gen-
words: Transformers for image recognition at scale. In erating music with long-term structure. In Proceedings
Proceedings of the International Conference on Learning of the International Conference on Learning Representa-
Representations (ICLR), 2021. tions (ICLR), 2019.
Perceiver AR

Jaegle, A., Gimeno, F., Brock, A., Zisserman, A., Vinyals, of the International Conference on Machine Learning
O., and Carreira, J. Perceiver: General perception with (ICML), 2019.
iterative attention. In Proceedings of the International
Conference on Machine Learning (ICML), 2021. Liu, P. J., Saleh, M., Pot, E., Goodrich, B., Sepassi, R.,
Kaiser, L., and Shazeer, N. Generating Wikipedia by
Jaegle, A., Borgeaud, S., Alayrac, J.-B., Doersch, C., summarizing long sequences. In Proceedings of the Inter-
Ionescu, C., Ding, D., Koppula, S., Zoran, D., Brock, national Conference on Learning Representations (ICLR),
A., Shelhamer, E., Henaff, O., Botvinick, M. M., Zisser- 2018.
man, A., Vinyals, O., and Carreira, J. Perceiver IO: A
general architecture for structured inputs & outputs. In Ma, X., Kong, X., Wang, S., Zhou, C., May, J., Ma, H., and
Proceedings of the International Conference on Learning Zettlemoyer, L. LUNA: Linear unified nested attention.
Representations (ICLR), 2022. In Proceedings of Neural Information Processing Systems
(NeurIPS), 2021.
Jumper, J., Evans, R., Pritzel, A., Green, T., Figurnov, M.,
Ronneberger, O., Tunyasuvunakool, K., Bates, R., Zı́dek, Mehri, S., Kumar, K., Gulrajani, I., Kumar, R., Jain, S.,
A., Potapenko, A., Bridgland, A., Meyer, C., Kohl, S. Sotelo, J., Courville, A., and Bengio, Y. SampleRNN: An
A. A., Ballard, A. J., Cowie, A., Romera-Paredes, B., unconditional end-to-end neural audio generation model.
Nikolov, S., Jain, R., Adler, J., Back, T., Petersen, S., In Proceedings of the International Conference on Learn-
Reiman, D., Clancy, E., Zielinski, M., Steinegger, M., ing Representations (ICLR), 2017.
Pacholska, M., Berghammer, T., Bodenstein, S., Silver,
D., Vinyals, O., Senior, A. W., Kavukcuoglu, K., Kohli, Merity, S., Xiong, C., Bradbury, J., and Socher, R. Pointer
P., and Hassabis, D. Highly accurate protein structure sentinel mixture models. In Proceedings of the Interna-
prediction with AlphaFold. Nature, 596:583–589, 2021. tional Conference on Learning Representations (ICLR),
2017.
Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F.
Transformers are RNNs: Fast autoregressive Transform- Nash, C., Menick, J., Dieleman, S., and Battaglia, P. W.
ers with linear attention. In Proceedings of the Interna- Generating images with sparse representations. In Pro-
tional Conference on Machine Learning (ICML), 2020. ceedings of the International Conference on Machine
Learning (ICML), 2021.
Kingma, D. P. and Ba, J. Adam: A method for stochastic
optimization. In Proceedings of the International Confer- Nawrot, P., Tworkowski, S., Tyrolski, M., Kaiser, L., Wu,
ence on Learning Representations (ICLR), 2015. Y., Szegedy, C., and Michalewski, H. Hierarchical
Transformers are more efficient language models. arXiv
Kingma, D. P., Salimans, T., Poole, B., and Ho, J. On preprint arXiv:2110.13711, 2021.
density estimation with diffusion models. In Proceedings
of Neural Information Processing Systems (NeurIPS), Peng, H., Pappas, N., Yogatama, D., Schwartz, R., Smith,
2021. N., and Kong, L. Random feature attention. In Pro-
ceedings of the International Conference on Learning
Kitaev, N., Kaiser, L., and Levskaya, A. Reformer: The Representations (ICLR), 2021.
efficient transformer. In Proceedings of the International
Conference on Learning Representations (ICLR), 2020. Polyak, A., Adi, Y., Copet, J., Kharitonov, E., Lakhotia,
K., Hsu, W.-N., Mohamed, A., and Dupoux, E. Speech
Kudo, T. and Richardson, J. SentencePiece: A simple and resynthesis from discrete disentangled self-supervised
language independent subword tokenizer and detokenizer representations. arXiv preprint arXiv:2104.00355, 2021.
for neural text processing. In Proceedings of the Annual
Meetings of the Association for Computational Linguis- Press, O., Smith, N. A., and Lewis, M. Shortformer: Better
tics (ACL), 2018. language modeling using shorter inputs. In Proceedings
of the Annual Meetings of the Association for Computa-
Lakhotia, K., Kharitonov, E., Hsu, W.-N., Adi, Y., Polyak, tional Linguistics (ACL), 2021.
A., Bolte, B., Nguyen, T.-A., Copet, J., Baevski, A.,
Mohamed, A., and Dupoux, E. Generative spoken Rabe, M. N. and Staats, C. Self-attention does not need
language modeling from raw audio. arXiv preprint O(n2 ) memory. arXiv preprint arXiv:2112.05682, 2021.
arXiv:2102.01192, 2021.
Rae, J. W., Potapenko, A., Jayakumar, S. M., Hillier, C., and
Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., and Teh, Lillicrap, T. P. Compressive Transformers for long-range
Y. W. Set Transformer: A framework for attention-based sequence modelling. In Proceedings of the International
permutation-invariant neural networks. In Proceedings Conference on Learning Representations (ICLR), 2019.
Perceiver AR

Rae, J. W., Borgeaud, S., Cai, T., Millican, K., Hoffmann, J., Schmidhuber, J. and Heil, S. Sequential neural text com-
Song, F., Aslanides, J., Henderson, S., Ring, R., Young, pression. IEEE Transactions on Neural Networks, 7(1):
S., Rutherford, E., Hennigan, T., Menick, J., Cassirer, 142–146, 1994.
A., Powell, R., van den Driessche, G., Hendricks, L. A.,
Rauh, M., Huang, P.-S., Glaese, A., Welbl, J., Dathathri, Shazeer, N. Fast Transformer decoding: one write-head is
S., Huang, S., Uesato, J., Mellor, J., Higgins, I., Creswell, all you need. arXiv preprint arXiv:1911.02150, 2019.
A., McAleese, N., Wu, A., Elsen, E., Jayakumar, S., Simon, I., Huang, C.-Z. A., Engel, J., Hawthorne, C.,
Buchatskaya, E., Budden, D., Sutherland, E., Simonyan, and Dinculescu, M. Generating piano music with
K., Paganini, M., Sifre, L., Martens, L., Li, X. L., Kun- transformer. 2019. URL https://magenta.
coro, A., Nematzadeh, A., Gribovskaya, E., Donato, D., tensorflow.org/piano-transformer.
Lazaridou, A., Mensch, A., Lespiau, J.-B., Tsimpoukelli,
M., Grigorev, N., Fritz, D., Sottiaux, T., Pajarskas, M., So, D. R., Mańke, W., Liu, H., Dai, Z., Shazeer, N., and
Pohlen, T., Gong, Z., Toyama, D., de Masson d’Autume, Le, Q. V. Primer: Searching for efficient Transformers
C., Li, Y., Terzi, T., Mikulik, V., Babuschkin, I., Clark, for language modeling. arXiv preprint arXiv:2109.08668,
A., de Las Casas, D., Guy, A., Jones, C., Bradbury, J., 2021.
Johnson, M., Hechtman, B., Weidinger, L., Gabriel, I.,
Isaac, W., Lockhart, E., Osindero, S., Rimell, L., Dyer, C., Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I.,
Vinyals, O., Ayoub, K., Stanway, J., Bennett, L., Hassabis, and Salakhutdinov, R. Dropout: A simple way to prevent
D., Kavukcuoglu, K., and Irving, G. Scaling language neural networks from overfitting. Journal of Machine
models: Methods, analysis & insights from training Go- Learning Research (JMLR), 2014.
pher. arXiv preprint arXiv:2112.11446, 2021.
Su, J., Lu, Y., Pan, S., Wen, B., and Liu, Y. RoFormer:
Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Enhanced Transformer with rotary position embedding.
Matena, M., Zhou, Y., Li, W., and Liu, P. J. Exploring arXiv preprint arxiv:2104.09864, 2021.
the limits of transfer learning with a unified text-to-text
Sun, S., Krishna, K., Mattarella-Micke, A., and Iyyer, M. Do
Transformer. Journal of Machine Learning Research
long-range language models actually use long-range con-
(JMLR), 2020.
text? In Proceedings of the Annual Conference on Empir-
Rajbhandari, S., Rasley, J., Ruwase, O., and He, Y. ZeRO: ical Methods in Natural Language Processing (EMNLP),
Memory optimizations toward training trillion parameter 2021.
models. In Proceedings of the International Conference Sutskever, I., Vinyals, O., and Le, Q. V. Sequence to se-
for High Performance Computing, Networking, Storage quence learning with neural networks. In Proceedings of
and Analysis (SC), 2020. Neural Information Processing Systems (NeurIPS), 2014.
Ramesh, A., Pavlov, M., Goh, G., Gray, S., Voss, C., Rad- Uria, B., Côté, M.-A., Gregor, K., Murray, I., and
ford, A., Chen, M., and Sutskever, I. Zero-shot text- Larochelle, H. Neural autoregressive distribution esti-
to-image generation. arXiv preprint arXiv:2102.12092, mation. Journal of Machine Learning Research (JMLR),
2021. 2016.
Ren, H., Dai, H., Dai, Z., Yang, M., Leskovec, J., Schuur- van den Oord, A., Dieleman, S., Zen, H., Simonyan, K.,
mans, D., and Dai, B. Combiner: Full attention Trans- Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A., and
former with sparse computation cost. In Proceedings of Kavukcuoglu, K. WaveNet: A generative model for raw
Neural Information Processing Systems (NeurIPS), 2021. audio. arXiv preprint arXiv:1609.03499, 2016a.
Rosenfeld, R. Two decades of statistical language modeling: van den Oord, A., Kalchbrenner, N., and Kavukcuoglu, K.
where do we go from here? Proceedings of the IEEE, 88 Pixel recurrent neural networks. In Proceedings of the
(8):1270–1278, 2000. International Conference on Machine Learning (ICML),
2016b.
Roy, A., Saffar, M., Vaswani, A., and Grangier, D. Efficient
content-based sparse attention with Routing Transform- van den Oord, A., Vinyals, O., and Kavukcuoglu, K. Neural
ers. Transactions of the Association for Computational discrete representation learning. In Proceedings of Neural
Linguistics (TACL), 9, 2021. Information Processing Systems (NeurIPS), 2017.

Saharia, C., Ho, J., Chan, W., Salimans, T., Fleet, D. J., Vaswani, A., Bengio, S., Brevdo, E., Chollet, F., Gomez,
and Norouzi, M. Image super-resolution via iterative A. N., Gouws, S., Jones, L., Kaiser, L., Kalchbrenner,
refinement. arXiv preprint arXiv:2104.07636, 2021. N., Parmar, N., Sepassi, R., Shazeer, N., and Uszkoreit,
Perceiver AR

J. Tensor2tensor for neural machine translation. arXiv Zeghidour, N., Luebs, A., Omran, A., Skoglund, J., and
preprint arXiv:1803.07416. Tagliasacchi, M. SoundStream: An end-to-end neural
audio codec. IEEE/ACM Transactions on Audio, Speech,
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones,
and Language Processing (TASLP), 2021.
L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Attention
is all you need. In Proceedings of Neural Information
Processing Systems (NeurIPS), 2017.
Vinyals, O., Babuschkin, I., Czarnecki, W. M., Mathieu, M.,
Dudzik, A., Chung, J., Choi, D. H., Powell, R., Ewalds, T.,
Georgiev, P., Oh, J., Horgan, D., Kroiss, M., Danihelka, I.,
Huang, A., Sifre, L., Cai, T., Agapiou, J. P., Jaderberg, M.,
Vezhnevets, A. S., Leblond, R., Pohlen, T., Dalibard, V.,
Budden, D., Sulsky, Y., Molloy, J., Paine, T. L., Gulcehre,
C., Wang, Z., Pfaff, T., Wu, Y., Ring, R., Yogatama,
Dani Wünsch, D., McKinney, K., Smith, O., Schaul, T.,
Lillicrap, T., Kavukcuoglu, K., Hassabis, D., Chris, A.,
and Silver, D. Grandmaster level in StarCraft II using
multi-agent reinforcement learning. Nature, 575(7782):
350–354, 2019.
Wang, B. Mesh transformer Jax, 2021. URL
https://github.com/kingoflolz/
mesh-transformer-jax.
Wang, S., Li, B. Z., Khabsa, M., Fang, H., and Ma, H.
Linformer: Self-attention with linear complexity. arXiv
preprint arXiv:2006.04768, 2020.
Weston, J., Chopra, S., and Bordes, A. Memory networks. In
Proceedings of the International Conference on Learning
Representations (ICLR), 2015.
Wu, C., Liang, J., Ji, L., Yang, F., Fang, Y., Jiang, D., and
Duan, N. NÜWA: Visual synthesis pre-training for neural
visual world creation. arXiv preprint arXiv:2111.12417,
2021.
Wu, F., Fan, A., Baevski, A., Dauphin, Y. N., and Auli, M.
Pay less attention with lightweight and dynamic convolu-
tions. arXiv preprint arXiv:1901.10430, 2019.
Wu, Y., Rabe, M. N., Hutchins, D., and Szegedy, C. Mem-
orizing transformers. In Proceedings of the Interna-
tional Conference on Learning Representations (ICLR),
2022. URL https://openreview.net/forum?
id=TrjbxzRcnf-.
Xiong, R., Yang, Y., He, D., Zheng, K., Zheng, S., Xing,
C., Zhang, H., Lan, Y., Wang, L., and Liu, T.-Y. On
layer normalization in the Transformer architecture. In
Proceedings of the International Conference on Machine
Learning (ICML), 2020.
Zaheer, M., Guruganesh, G., Dubey, K. A., Ainslie, J., Al-
berti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q.,
Yang, L., et al. Big Bird: Transformers for longer se-
quences. Proceedings of Neural Information Processing
Systems (NeurIPS), 33, 2020.
Perceiver AR

A. ImageNet Samples Second, regardless of the evaluation stride, each twofold


increase in context improves perplexity, but with dimin-
Full batches of the generated images used to populate Fig- ishing returns: the gap between 1024 and 4096 context is
ure 4 can be seen in Figures 5 to 8. All images were gen- consistently larger than that between 4096 and 8192 or 8192
erated with a temperature of 1.0. We also show upsampled and 16384 contexts. Perceiver AR models with the same
versions of the images that were generated with 1536 latents depth and different context sizes use the same number of
by using SR3 (Saharia et al., 2021) to achieve a resolution parameters (with the possible exception of differences in
of 256 × 256 in Figure 9. the parameters of the position encoding). We believe this
effect points to the need for larger capacity models to exploit
B. Books the increased information in longer contexts. Nonetheless,
the overall trend suggests that larger context leads to im-
As discussed in Section 5.4, we train all models on proved results, even when using essentially the same model
the Books dataset with 1024 latents, 36 layers and capacity.
{1024, 4096, 8192, 16384} input context tokens. In addi-
tion to the results shown in the main paper at stride 512,
we also evaluate performance on our test set of 100 books C. A More Detailed Look at Perceiver AR’s
as a function of stride (Appendix G), looking at 5 values: Internals
{16, 64, 128, 512, 1024} (Figure 10).
Like Perceiver and Perceiver IO, Perceiver AR is built on
We draw the reader’s attention to two effects here. First, Transformer-style attention blocks. Perceivers use two types
while the perplexity at a given context length is relatively of attention: cross- and self-attention. These two types of
stable for strides ≤ 128, perplexity consistently increases attention share an interface—both take in two arrays and
with stride. When using strided evaluation, the first tokens return a third—but differ in what they pass to that interface.
in a model’s context window see a relatively small number
Zooming in: QKV attention takes in a key-value input
of preceding tokens. With an evaluation stride of 1024,
XKV ∈ RM ×C and a query input XQ ∈ RN ×D (where C
the first token in the context window of a 1024-context
and D indicate number of channels). The output of QKV
model sees only one preceding token. This property is likely
attention is an array with the same index (first) dimension
responsible for the increasing gap between 1024-context
as the query input and a channel (second) dimension deter-
models and larger-context models as the stride increases:
mined by an output projection:
the perplexity gain moving from the perplexity gain of the
16384- over the 1024- model is 0.8 at stride 1024, but only
0.26 at stride 16.

Q = fQ (XQ ); K = fK (XKV ); V = fV (XKV ) (4)


pre
XQK = QK T (5)
1024 context √
16.00 4096 context XQK = softmax(XQKpre
/ F) (6)
8192 context
15.75 16384 context Attn(XQ , XKV ) = XQKV = fO (XQK V ), (7)

15.50
Perplexity

pre
15.25 where XQK and XQK are the array of pre- and post-softmax
attention maps ∈ RN ×M , and XQKV is an array ∈ RN ×D .
15.00 The functions f{Q,K,V } are linear layers mapping each input
to a shared feature dimension F and fO is a linear layer
14.75 projecting the output to a target channel dimension D, which
is often the same size as XQ ’s. All linear layers are applied
14.50 convolutionally over the index dimension. We have omitted
batch and head dimensions (in the case of multi-headed
16 64 128 512 1024
Evaluation stride attention) for readability.
In Perceiver AR attention blocks, QKV attention is followed
Figure 10. Perplexity results on the Books test set, from 4 different by a two-layer MLP with a squared ReLU nonlinearity
36-layer Perceiver AR models with 1024-, 4096-, 8192- and 16384- following the first layer. The full module has the following
contexts, respectively. The evaluation is done for 5 stride values. structure:
Perceiver AR

Figure 5. Full batch of generated samples from the model trained on ImageNet using 16 latents during inference.

If we denote causally masked self- and cross-attention by


CrossAttendcm and SelfAttendcm , respectively, Perceiver
XQKV = Attn(layerNorm(XQ ), layerNorm(XKV )) (8) AR (without learned latents2 ) is given by the following:
XQKV = XQKV + XQ (9)
XQKV = XQKV + MLP(layerNorm(XQKV )), (10)
Z0 ← CrossAttendcm (X, X[−N :, :]) (11)
slightly abusing notation for simplicity and to emphasize
the residual structure. “Attn” refers to QKV as described
above. Zl+1 ← SelfAttendcm (Zl , Zl ), (12)
Zooming out: When discussed in the main text, the opera-
tions CrossAttend : XKV × XQ → XQ and SelfAttend : where Equation (12) is applied once per self-attend layer.
XKV × XQ → XQ refer to the full system of equations The final output is obtained by layer-norming and projecting
given in Equations (8) to (10). ZL to the vocabulary size, followed by a softmax to produce
output logits.
These two operations differ only in that XQ 6= XKV for
cross-attention (with N < M for the “encoder” cross-
attention considered here) and XQ = XKV for self- D. Further Related Work
attention. Cross-attention is used to reduce the shape
In this section we describe additional background with the
of an input array and self-attention to keep it the same
goal of elucidating Perceiver AR’s problem setting and
shape. In Perceiver and Perceiver IO, the initial cross-
method.
attention’s query input XQ is typically learned (its elements
are “learned latents”), while in Perceiver AR, it is typically
constructed1 by taking the last N elements of the input array: D.1. Efficient attention
XQ = XKV [−N :, :], using NumPy-style indexing notation Many recent Transformer variants seek to avoid the O(N 2 )
(Harris et al., 2020). In either case, using fewer latents than memory requirements of self-attention. This is often done
inputs is essential to controlling compute and memory costs by introducing sparsity when computing the attention ma-
while keeping long-context inputs. trix – as in Sparse Transformer (Child et al., 2019), Big
Perciever AR also differs from Perceiver and Perceiver IO in Bird (Zaheer et al., 2020), and Combiner (Ren et al., 2021)
the use of causally masked cross- and self-attention. In self- – or by approximating this computation at lower cost – e.g.
attention masks, all elements of XQK pre
at indices (m0 , m) as in Linear Transformer (Katharopoulos et al., 2020), Lin-
(queries × keys), where m , m ∈ [0, M ) and m > m0 are
0 former (Wang et al., 2020), Reformer (Kitaev et al., 2020),
masked. In the cross-attention mask, to compensate for the Random-Feature Attention (Peng et al., 2021), and Per-
fact that latents are placed at the trailing index locations, all former (Choromanski et al.).
pre
elements of XQK at indices (n, m) (queries × keys), where The downside of methods that use sparsity is that this spar-
n ∈ [0, N ), m ∈ [0, M ) and m > n + M − N − 1 are sity must be hand-tuned or created with heuristics that are
masked. This prevents “earlier” queries from attending to often domain specific and can be hard to tune. In contrast,
“later” keys. Causal masking is implemented in attention our work does not force a hand-crafted sparsity pattern on
by multiplying all masked connections in the pre-softmax attention layers, but rather allows the network to learn which
pre
attention map XQK by −∞. long-context inputs to attend to and propagate through the
1 2
With the exception of experiments on Wikitext-103, where For latents, replace the second (query) input to the RHS of
learned latents are used. Equation (11) with Z0 (or Z−1 if you like), which is learned.
Perceiver AR

Figure 6. Full batch of generated samples from the model trained on ImageNet using 1024 latents during inference, the same as the
number of latents used during model training.

network. The initial cross-attend operation, which reduces sion (Wu et al., 2022) to process long-term structure. These
the number of positions in the sequence, can be viewed as a strategies typically impose bottlenecks with local structure,
form of learned sparsity. which limits the flexibility with which context can be ex-
ploited by a given target. Although Perceiver AR performs
Because Perceiver AR does not depend on hand-tuned spar-
processing using latents that are fewer in number than the
sity patterns or e.g. structured dilation patterns like those
inputs, each latent is given direct access to all inputs, rather
used in WaveNet (van den Oord et al., 2016a), it can model
than communicating with the past through a narrow or pre-
arbitrarily complex dependency patterns between any of its
computed mechanism.
inputs immediately after the cross-attend. Models with fixed
sparsity patterns, on the other hand, typically require sev- Alternatives to attention. Efficiency can also be obtained
eral layers (which depends logarithmically on the distance using domain-tuned architectures such as lightweight or
between points in the input array in the case of WaveNet) dynamic convolutions (Wu et al., 2019). The recently in-
or precise and fragile conjunctions of input receptive fields troduced S4 (Gu et al., 2021) is a very efficient model that
(in the case of hand-tuned sparsity patterns) to allow a given avoids attention altogether while producing interesting re-
set of points to interact. The net effect of this situation is sults, but it is not yet competitive on standard datasets like
that the effective processing used to process a given set of WikiText-103.
features is much smaller in these architectures than densely
Many of the insights presented in this prior work are com-
connected architectures like Perceiver AR.
plementary to Perceiver AR’s architectural design, and we
expect they can be hybridized to produce even more effi-
D.2. Other efficient architectures cient models suitable for long-scale, general purpose model
Memory-based models. Longer effective context size can design in future work.
also be achieved by reducing the compute requirement on
tokens that are far in the past using stop gradients (Dai et al., D.3. Input Tokenization
2019), recurrence (Mehri et al., 2017), memory (Weston
Another approach to reducing the memory and compute
et al., 2015; Rae et al., 2019), or other forms of compres-
requirements of self-attention is to directly reduce the length
Perceiver AR

Figure 7. Full batch of generated samples from the model trained on ImageNet using 1536 latents during inference. The same random
seed was used as when generating the images with 1024 latents, which explains how the image of the white dog appears in both batches.

Figure 8. Full batch of generated samples from the model trained on ImageNet using 2048 latents during inference.

of the input data by using tokenization to group multiple tion process, and inputs can not be recovered exactly. Care
inputs into single tokens. Such approaches have led to is required to ensure that the data can be reconstructed at a
excellent results on a variety of domains of interest and fidelity adequate for the required application. Neural com-
are often an implicit feature of the data arrays in standard pression schemes such as VQ-VAE require users to train and
datasets. maintain additional encoder and decoder networks, which
can hinder ready application to new datasets and domains.
In NLP, subword chunks are often grouped together (Kudo
And, perhaps most tellingly, effective tokenization is typi-
& Richardson, 2018) based on their frequency in a training
cally designed and used in a domain-specific fashion, which
text corpus. In vision, previous work has explored using K-
limits the ease of adaption and scaling to new domains. By
means to group RGB values into a single token (Chen et al.,
effectively modeling long sequences, PerceiverAR can in
2020) or to group individual pixels into patches (Dosovitskiy
some cases eliminate the need for tokenization and in others
et al., 2021), sometimes followed by quantization (Ramesh
can reduce the need for heavy lossy compression. But in the
et al., 2021). Others have also leveraged DCT coeffi-
near term, we anticipate that tokenization — of one form
cients (Nash et al., 2021) as used in JPEG to convert fixed-
or another — will remain a necessary tool for incorporating
size images to variable-length sequences. Arguably the
more context.
most widely applied tokenization method is neural vector
quantization, where an encoder network is trained to map
images to a spatially downsampled collection of discrete E. Additional Details of the Methods
codes (van den Oord et al., 2017; Ramesh et al., 2021).
E.1. Memory Usage
Similar techniques have been developed for audio, where
vector-quantized encodings of raw waveforms have effec- The single largest source of memory usage in a Per-
tively been used as vocabularies for end-to-end speech syn- ceiver AR model is typically the attention map in the ini-
thesis (Lakhotia et al., 2021; Polyak et al., 2021) and music tial cross-attend layer, which results in a matrix of size
generation (Dhariwal et al., 2020). The SoundStream codec [heads, input length, self attention length]. For experi-
we use in this paper builds upon these techniques by using ments in this paper where this matrix caused out of memory
residual vector quantization and adversarial reconstruction errors (Section 5.1.1), we found that processing attention
losses to achieve high audio fidelity with fewer discrete heads in subgroups rather than all at once was sufficient to
tokens (Zeghidour et al., 2021). reduce our memory usage. Using the chunked approach
in (Kitaev et al., 2020; Rabe & Staats, 2021; Jumper et al.,
Tokenization is a broadly useful strategy, but it has its down-
2021) for the cross-attend layer allows scaling input length
sides. For many domains, effective tokenization schemes
beyond even those limits without requiring any architecture
rely on lossy compression. Data is discarded in the tokeniza-
Perceiver AR

Figure 9. Images from the batch generated with 1536 latents (Figure 7) upsampled to 256 × 256 pixels using SR3 (Saharia et al., 2021).

changes or adding compute requirements for layers other fitting. This can also be used to save memory by setting a
than the cross-attend. These memory-saving tricks do result budget for a certain number of inputs and then selecting that
in reductions to training throughput (in steps per second), number of inputs randomly from the maximum input con-
so we avoid them where possible. text. Because no position-specific parameters are learned
for the cross-attend layer, a smaller number of inputs can be
E.2. Cross-attend Dropout used at train time than during evaluation or inference.
Dropout (Srivastava et al., 2014) is used by default in many Imposing cross-attend dropout can be interpreted as enforc-
Transformer implementations and is an essential tool for ing high sparsity at training time, but allowing less extreme
mitigating overfitting on small datasets. Dropout is typi- sparsity at evaluation time. Because the attention layer itself
cally imposed on linear layers after the attention softmax or is scale invariant, the uniform scaling normally imposed by
within the Transformer MLP block. Perceiver AR supports dropout at train time is unnecessary.
this kind of dropout, but the cross-attend layer also enables
interesting possibilities for dropout before the attention soft- E.3. Activation Caching for Inference
max.
Naively sampling from a Transformer for inference can
We find that for some tasks, masking out positions in the be very slow because activations for all positions must be
initial cross-attend is an effective way of preventing over- calculated at every step. Caching the key/value activations
Perceiver AR

t0 t1 t2 t3 t4 t5 t6 t7
e e r e r c e r c e e r c e i e r c e i v e r c e i v e e r c e i v e r

No
caching:

P P e P e r P e r c P e r c e P e r c e i P e r c e i v P e r c e i v e

e e r e r c e r c e e r c e i e r c e i v e r c e i v e e r c e i v e r

Naïve
caching:

P P e P e r P e r c P e r c e P e r c e i P e r c e i v P e r c e i v e

e e r e r c e r c e e r c e i e r c e i v e r c e i v e e r c e i v e r

Reset Reset
Caching+ memory memory
resetting:

P P e P e r P e r c P e r c e P e r c e i P e r c e i v P e r c e i v e

Figure 11. Figure best viewed on a screen. Caching at generation time allows previously computed states to be reused but introduces
long-term dependencies not seen at training time when the input size and latent size differ. We illustrate this effect here for a Perceiver AR
with N = 4 latents and an input context M = 8. Here, blue triangles indicate which attention operations are performed in the current
step, while gray triangles indicate reused (cached) computations. The top row (no caching) matches what happens at training time. If
caching is applied naively (middle row), the amount of computation per step can be greatly reduced. However, as the model is run out
for more steps than there are latents, caching introduces dependencies on latents that are no longer active but have already been used to
compute latents that are active. These latents are shown in red. In other words, caching previous latents allows longer-distance latents
(which are no longer active) to influence the current generation, introducing long-term dependencies that were not encountered at train
time. We find in practice that these long-term dependencies lead to degraded performance when caching is run out for too many steps. To
avoid this problem, we cache by periodically resetting memory, allowing some computation to be reused but avoiding the introduction of
long-term dependencies not encountered at training time. This strategy is described in detail in Appendix E.3.
Perceiver AR

for previously inferred positions is a common technique for Train Time


improving generation speed (Vaswani et al.; Shazeer, 2019). A R <EOS> Targets

Perceiver AR can use a similar approach, but the exact


Trained with Causal
technique cannot be applied directly because the number fixed # latents self-attention
(N=3)
of output positions is smaller than the number of input
positions, so preserving all previous activations would result Causal
cross-attention
in an effective self-attention stack as wide as the number of
P e r c e i v e r A R Inputs
inputs.
Transformer-XL (Dai et al., 2019) solves this problem by Test Time
Same compute <EOS>
keeping a buffer of activations for only the last N positions.
This works because the model is presented with activations
Test with
for that number of previous positions during training, which same # latents
(N=3)
is not the case for Perceiver AR. We also cannot simply
restrict the buffer size to be the same as the number of
targets because even when activations for a given position P e r c e i v e r A R
are expired, they have already influenced other positions
within the buffer (Figure 11). Less compute <EOS>

Instead, we apply a simple trick to ensure that no cached Test with


fewer latents Same
activations are influenced by positions beyond what was (N=2) Parameters

seen at training time. We use a fixed activation buffer the


same width as the self-attention stack at train time. When
the buffer is full, we do a full forward pass without any P e r c e i v e r A R

cached activations for the next position, but using half the More compute <EOS>
number of latents. The activations from that pass are saved
Test with
in the buffer, leaving it half full. Inference then proceeds more latents
(N=6)
until the buffer is full again. The activations from the cross-
attend layer do not require this trick unless inference will
extend beyond the length of the inputs used at train time.
P e r c e i v e r A R
We find that these occasional full forward passes add mini-
mal overhead, and the speed gains from using cached acti- Same Input Context

vations are still significant. We performed a test inferring


a single image with a sequence length of 12,289 using a Figure 12. Because Perceiver AR decouples input length from the
model with 1024 latents (see Section 5.2 for details) on a width of the self-attention stack, the number of latents can be
single TPUv3 core to compare speeds. Inference without different at train time and test time. This does not require additional
any caching took 7.93 minutes. With caching the same task training because no per-position parameters are learned in the self-
took 3.68 minutes, less than half the time. attention stack. Section 5.2.1 discusses how this possibility can be
used to scale up or down compute requirements and output quality.
E.4. Varying Compute at Test Time
mented in the Optax framework (Hessel et al., 2020) with
Figure 12 illustrates how the number of latents can be b1 = 0.1, b2 = 0.999, eps = 1e−8, a base learning rate of
changed at test time without changing either the trained 3e−4, and a 10k step linear warmup. To reduce memory
parameters of the model or the model’s input context length. usage, we used ZeRO Stage 1 optimizer state partitioning
In Section 5.2.1 we discuss how this possibility can be used (Rajbhandari et al., 2020).
to scale up or down compute requirements and output qual-
ity. For training stability, we used a global max norm of 1.0.
We also added an additional loss term z loss ∗ log(z)2 with
z loss = 1e−4, as used in training the T5 family of models
F. Training Details (Raffel et al., 2020).
F.1. Common Both the input embedding and latent vectors were size 1024,
Unless otherwise stated, models were trained with the fol- and 16 attention heads were used for the initial cross-attend
lowing configuration. and within the self-attention stack. Within the MLP lay-
ers of the cross-attend and self-attention stack, the input
We use the Adam optimizer (Kingma & Ba, 2015) as imple- dimensionality is projected to 4x its size and Squared ReLU
Perceiver AR

activations (So et al., 2021) were used. We used a cross- we can precisely tell both the current time (q/k position)
attend dropout probability of 0.1. and we can resolve it at different resolutions, depending on
what’s needed for a given task.
Training and evaluation were done on either TPUv2 or
TPUv3 clusters. We found in early experiments that rotating a fraction of
channel dimensions led to better results, which we discov-
F.2. Rotary Position Encodings ered has also been noticed (the first 50%) (Wang, 2021).
Fractional rotation of this kind reduces the fidelity with
We encode the position of tokens using rotary position en- which we encode position information (because it results
coding (Su et al., 2021). With this method, rotation ma- in fewer frequency bands). The result of fractional rotation
trices (built from sine and cosine functions, just as with is that only some channel dimensions are modulated by the
sinusoidal/Fourier position encodings) rotate pairs of di- relative position between a query and key. This may en-
mensions in each of the key and query heads to reflect the courage the network to exploit both position-dependent and
absolute position of the key or query in the sequence. When position-agnostic relationships between queries and keys
used to compare a given key and query pair, these rota- when computing attention weights. Prior work characteriz-
tions produce attention weights that reflect only the relative ing the effect of position on attention weights suggests that
distance between tokens. The result is a memory-efficient many common position encoding strategies bias the network
relative position mechanism. towards attending to recent tokens and fractional rotation
may mitigate this effect. Fractional rotation also but makes
A R <EOS>
the rotation computation significantly cheaper, as it requires
Targets
(shifted inputs) fewer matrix multiplies.

V F.3. Copy Task


self-attend

Attention
Mask
L layers

The copy task was trained with 1024 latents in a 6-layer


Latent

K Q
self-attention stack. For position encoding, we used fixed
sinusoidal embeddings to denote absolute position, as de-
Latents
scribed in the original Transformer paper (Vaswani et al.,
2017). There were 4096 sequences per batch.

P
To reduce the instantaneous memory requirements created
e by attending to the input sequence of length 131,072, the
Cross-Attention Mask

V r
c
16 cross-attention heads were split into 4 groups of 4 and
e computed separately (see Appendix E.1).
i
v
The first 1K steps were a linear learning rate warmup and the
Cross-attend

e
K r remaining followed a cosine learning rate decay schedule.
A
R
r A R
Q F.4. ImageNet 64 × 64
The model trained on ImageNet 64 × 64 has 770.1M param-
eters. Training proceeded for a total of 750k steps. After
an initial 10k step linear warmup, a constant learning rate
P e r c e i v e r A R Learned
Inputs Latents
of 3e−4 was used until the final 50k steps, which used a
cosine decay to 0.
Figure 13. Variant of Perceiver AR where we feed learned input
latents instead of the input embeddings directly. This is used F.5. Wikitext-103
for our Wikitext-103 experiments as we noticed it helped reduce
We train 18-layer Perceiver AR models with 1,024 latents,
overfitting.
adaptive inputs embeddings (Baevski & Auli, 2018), and
Intuitively, this mechanism represents position using a strat- with increasing context length from 1,024 to 8,192 tokens.
egy somewhat like that used by analogue clock hands. Like One key difference to other experiments is that we use
a clock hand, each pair of dimensions (sine/cosine) in the ro- learned input latents rather than the input embeddings as
tary position encoding is responsible for indicating position illustrated in Figure 13. We observe this variant to help
at some frequency (second, minutes, hours, etc.). When we reduce overfitting. We also apply cross-attend dropout (Ap-
have multiple clock hands (multiple dimensions within each pendix E.2) with p=0.15 to the context tokens for which we
token vector), each of which moves at a different frequency,
Perceiver AR

SoundStream bitrate Context (32k) Context (8k) Test (32k) Test (8k) Validation (32k) Validation (8k)
12kbps 27.2s 6.8s 2.31 2.25 2.28 2.27
18kbps 18.4s 4.6s 2.52 2.53 2.51 2.52
22kbps 14.8s 3.7s 2.55 2.60 2.55 2.60

Table 13. Perceiver AR negative log-likelihood results on SoundStream audio generation, from two models with context lengths of 8192
(8k) and 32768 (32k), respectively.

predict the next token.


The Transformer-XL baseline used a memory length of 384
at train time and 1600 at eval time, for a full effective context
length of 2624. We used a memory dropout rate of 0.25.
Perceiver AR models trained on Wikitext-103 have the fol-
lowing parameter counts: 356.5M (1024 context), 357.7M
(2048 context), 359.8M (4096 context), 364.0 (8192 con-
text). Note that the parameter count increases slightly with
increasing context length because of the use of absolute posi-
tion encodings on Wikitext-103. The baseline Transformer-
XL has 285.2M parameters.

F.6. PG-19
Both Perceiver AR models trained on PG-19 use 974.6M
parameters.

F.7. Books
All Books models reported in Table 7 have 498.9M parame-
ters.
Perceiver AR models reported in Table 8 all have 498.9M
parameters. Compute matched Transformer-XL models
have the following parameter counts, by number of layers:
346.6M (23), 360.3M (24), 373.9M (25), 414.8M (28).
Perceiver AR models reported in Table 9 have the following Figure 14. Validation loss on MAESTRO v3 as a function of
parameter counts, by context length and number of layers: the cross-attend dropout applied to the model. Each value
826.4M (1024, 62L), 813.8M (4096, 61L), 801.2M (8192, {0, 0.5, 0.75} is applied at 3 different model depths {12, 24, 48}.
60L), 750.8M (16384, 56L). The 42-layer Transformer-XL
with matched compute has 605.8M parameters.
We use the same memory settings for Transformer-XL base- model trained on the piano transcription dataset uses 0.1
lines as on Wikitext-103: all Transformer-XL models have for cross-attend dropout and a 0.25 dropout rate on the la-
a full effective context length of 1600 + 1024 = 2624. tents. Finally, this model pre-embeds the query inputs using
a 3-layer MLP with GELU (Hendrycks & Gimpel, 2016)
F.8. Music Generation Tasks activations.

All models use a value of 0.7 for the cross-attend dropout Audio from MAESTRO v3 was first processed using the
described in Appendix E.2. All but the MAESTRO SoundStream 12kbps codec. This bitrate of 12kbps cor-
SoundStream-task models apply rotation to 25% of the at- responds to a vocabulary of 1024 tokens (10 bits), 24 to-
tention dimensions. Additionally, the models trained on kens/frame, and 50 frames/second. This reconstructs audio
the piano dataset and MAESTRO SoundStream tasks use while reasonable quality, compressing each second of audio
a learning rate of 2e−4. The ones trained on the symbolic (16k wave points) into 1.2k sequential tokens. Our trained
MAESTRO data use a learning rate of 1e−4, 1 initial cross- model has an input context of length 8192 (∼6.8 seconds)
attend head and 4 heads within the self-attention stack. The and uses 1024 latents in 12 self-attention layers. We also
train and evaluate this configuration using higher-bitrate
Perceiver AR

codecs—18kbps and 22kbps, with 36 and 44 tokens/frame


respectively.

G. Evaluation Details
As discussed in Shortformer (Press et al., 2021), there is a
quality vs. speed tradeoff when evaluating long sequences
with regard to stride. A stride of 1 (maximum overlap) is
the slowest but highest quality, and a stride equal to the
input length (no overlap) is the fastest but lowest quality.
Because Perceiver AR decouples the number of targets from
the number of inputs, our stride options range from 1 to the
number of latents. We found that a stride of half the number
of latents gave a good balance between speed and quality,
and that is what we used for all the evaluations in this paper,
unless otherwise mentioned. For a related set of issues for
inference, see the discussion in Appendix E.3.

H. Dropout Ablations
We study the effects of applying different forms of dropout
to Perceiver AR, at several model depths when training on
the MAESTRO v3 symbolic dataset.
Figure 14 illustrates the behaviour of the model when vary-
ing the amount of cross-attend dropout applied to it. The
model has an input context of 4,096 and a post-attention
dropout value of 0.5—both hyperparameters are kept con-
stant across all runs. This type of dropout appears less useful
when increasing model depth, up to having no effect at all
during training for the 48-layer model.
For the second experiment, we vary the amount of (post-
attention) dropout in the same settings as previously de-
scribed, while keeping a constant value of 0.7 for the cross-
attend dropout. Figure 15 shows that higher dropout rates
become more useful at bigger depths—while at depth 12,
applying 0.75 dropout actually hurts model training, this
becomes beneficial at depth 48, where applying no dropout
leads to quick overfitting. Figure 15. Validation loss on MAESTRO v3 as a function of
the post-attention dropout applied to the model. Each value
I. MAESTRO SoundStream {0, 0.5, 0.75} is applied at 3 different model depths {12, 24, 48}.

In Table 13, we compare results on all MAESTRO Sound-


Stream tasks obtained from models that were trained on
contexts of 8192 and 32768 tokens, respectively. While
a smaller context length yields better results at the lowest
bitrate, the 32k-context model obtains a significantly lower
negative log-likelihood than the 8k- one on SoundStream
22kbps. Furthermore, the gap between corresponding NLLs
widens as the bitrate increases. This result suggests that a
shorter context is less and less useful as we attempt to pro-
duce higher-fidelity audio. The 32k context models reported
here have the following parameter counts: 366.3M (12kbps),
391.5M (18kbps), 408.3M (22kbps). The 8k context models
Perceiver AR

reported here have the following parameter counts: 215.2M


(12kbps), 240.4M (18kbps), 257.1M (22kbps).
65536-context SoundStream models reported in Table 12
have the following parameter counts: 668.6M (12kbps),
693.8M (18kbps), 710.6M (22kbps).

You might also like