Exploring Low Rank Training of Deep Neural Networks
Exploring Low Rank Training of Deep Neural Networks
Siddhartha Rao Kamalakara * 1 2 Acyr Locatelli * 2 Bharat Venkitesh * 1 Jimmy Ba 3 Yarin Gal 4
Aidan N. Gomez 1 2 4
imposes sparse low rank structure. (Jaderberg et al., 2014) vides speedups across all our experiments while ensuring
also considered a trained network upon which a low rank minimal activation memory overhead.
structure is imposed through filter and data reconstruction
Consider the difference between the vanilla gradient descent
objectives. (Tai et al., 2016) focused on low rank training of
update (unfactorised) Wt+1 = Wt − α∇W and the update
CNNs from scratch; they proposed a horizontal and vertical
performed in the factorised setting:
filter decomposition of a convolutional kernel and reproject
into orthogonal vectors at every step. One of the reasons >
Wt+1 = Ut+1 Vt+1
why prior work has focused on post-training low rank ap-
proximations is that training dynamics of neural networks
are poorly understood. Moreover, it has been found that
naively training in the low rank space from scratch suffers a Wt+1 = (Ut − α∇U )(Vt − α∇V )>
gap in performance – section 4. To resolve this to an extent, Wt+1 = Wt − α (∇Wt Vt Vt> + Ut Ut> ∇Wt )
many recent attempts have been made to understand the im- | {z }
∇t (1)
plicit bias of gradient descent (GD) in matrix factorisation in
both linear and non-linear networks. (Arora et al., 2019) in- +α 2
∇Wt Wt ∇Wt>
vestigated the behaviour of GD in deep linear networks and
found that as the depth of factorisation increases, GD tends (Khodak et al., 2021) extend the update equation above
to find low rank solutions. They also present evidence for to normalised layers. Most modern architectures rely on
the hypothesis that the language of norms such as nuclear normalisation layers to train networks that generalise well.
norm, Frobenius norm, etc, may not be enough to describe This includes batch normalisation (Ioffe & Szegedy, 2015)
the behaviour of GD. (Martin & Mahoney, 2018) presented in ResNets and layer normalisation (Ba et al., 2016) in Trans-
an empirical analysis of commonly used architectures and formers. We refer the reader to (Khodak et al., 2021) for a
characterised the dynamics of GD in deep non-linear net- more detailed discussion on the type and role of normali-
works in terms of Empirical Spectral Distributions (ESD) sation in factorised layers and use their formulation of the
and phases of training. They define a set of rank measures, normalised update equation, which is given by
which we use in our work to analyse low rank training jux- α
taposed with analysis on unfactored training. (Wang et al., ŵt+1 = ŵt − 2 (Imn − ŵt ŵt > )vec(∇
ˆ t)
kW kF (2)
2021) used low rank training with unfactorised pretraining
in the context of efficient communication in a distributed + O(α2 )
setting. (Khodak et al., 2021) proposed a low rank training
ˆ t is ∇t with gradients taken with respect to the
where ∇
procedure by investigating initialisation and regularisation
in factorised layers. They analysed SVD based initialisation normalised weight matrix Ŵ = kWWk and ŵ = vec(Ŵ ).
F
(Spectral Initialisation) and properties of L2 regularisation We see that gradient descent in the factorised setting does
which we study independently in our work. They conjecture not perfectly align with the vanilla gradient descent update.
that there is an interplay between normalisation and weight In the subsequent sections, we empirically explore and work
decay and formalise this behaviour through factorised up- to overcome the implicit biases of this factorised update so
date equations. that we can make low rank training an effective and efficient
training method.
3. Low Rank Training
In this section, we present the formulation we choose for 3.1.1. F ULLY CONNECTED LAYER
factorising layers. We discuss and critique the assumptions Let W ∈ Rm×n be the weight matrix of a fully-connected
and conjectures associated with the low rank formulation in layer. We factorise W as W = U V T with U ∈ Rm×r and
the context of SVD initialisation and L2 regularisation. V T ∈ Rr×n , where 0 < r ≤ min(m, n). At inference,
when r < m×nm+n , factorising the fully connected weight ma-
3.1. Factorisation trix leads to a reduced memory footprint as well as floating
In all our experiments and analyses, we factorise a weight point operations (flops) from O(mn) to O(mr + rn). For
matrix W at each layer into two components U and V such training, the memory requirements change from O(mn + n)
that W = U V > . to O(mr + rn + n + r) as we need to store the intermediate
activations for backpropagation.
We focus on a factorisation depth of 2, taking into consid-
eration memory-speedup tradeoffs: As the depth of factori- 3.1.2. C ONVOLUTIONAL LAYER
sation at each layer increases, more activations need to be
stored in-memory for backpropagation. A depth of two pro- We factorise convolution kernels in a way that supports
rewriting the single convolution as two convolutions.
Exploring Low Rank Training of Deep Neural Networks
beliefs about why these techniques work. We hope to put Frankle, J., Dziugaite, G. K., Roy, D. M., and Carbin, M.
forth the theoretical reasons behind the effectiveness of these Linear mode connectivity and the lottery ticket hypothesis.
techniques in a future work. Additionally, we demonstrated CoRR, abs/1912.05671, 2019b. URL http://arxiv.
pretraining as an effective strategy to improve low-rank per- org/abs/1912.05671.
formance and presented insights on the nature of solutions
found by networks with pretraining. He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learn-
ing for image recognition. CoRR, abs/1512.03385, 2015.
URL http://arxiv.org/abs/1512.03385.
References
Ioffe, S. and Szegedy, C. Batch normalization: Accelerating
Achille, A., Rovere, M., and Soatto, S. Critical
deep network training by reducing internal covariate shift,
learning periods in deep neural networks. CoRR,
2015.
abs/1711.08856, 2017. URL http://arxiv.org/
abs/1711.08856. Jaderberg, M., Vedaldi, A., and Zisserman, A. Speeding up
convolutional neural networks with low rank expansions,
Arora, S., Cohen, N., Hu, W., and Luo, Y. Implicit regular- 2014.
ization in deep matrix factorization, 2019.
Khodak, M., Tenenholtz, N. A., Mackey, L., and Fusi, N. Ini-
Ba, J. L., Kiros, J. R., and Hinton, G. E. Layer normalization, tialization and regularization of factorized neural layers.
2016. In International Conference on Learning Representations,
2021. URL https://openreview.net/forum?
Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, id=KTlJT1nof6d.
J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G.,
Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Lee, N., Ajanthan, T., Gould, S., and Torr, P. H. S. A signal
Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, propagation perspective for pruning neural networks at
J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., initialization. CoRR, abs/1906.06307, 2019. URL http:
Gray, S., Chess, B., Clark, J., Berner, C., McCandlish, //arxiv.org/abs/1906.06307.
S., Radford, A., Sutskever, I., and Amodei, D. Language
models are few-shot learners, 2020. Martin, C. H. and Mahoney, M. W. Implicit self-
regularization in deep neural networks: Evidence from
Chelba, C., Mikolov, T., Schuster, M., Ge, Q., Brants, T., random matrix theory and implications for learning, 2018.
and Koehn, P. One billion word benchmark for measur-
Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., and
ing progress in statistical language modeling. CoRR,
Sutskever, I. Language models are unsupervised multitask
abs/1312.3005, 2013. URL http://arxiv.org/
learners. 2019.
abs/1312.3005.
Srebro, N. and Shraibman, A. Rank, trace-norm and max-
Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, norm. In Auer, P. and Meir, R. (eds.), Learning Theory,
D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, pp. 545–560, Berlin, Heidelberg, 2005. Springer Berlin
M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. Heidelberg. ISBN 978-3-540-31892-7.
An image is worth 16x16 words: Transformers for image
recognition at scale. CoRR, abs/2010.11929, 2020. URL Tai, C., Xiao, T., Zhang, Y., Wang, X., and E, W. Con-
https://arxiv.org/abs/2010.11929. volutional neural networks with low-rank regularization,
2016.
Evci, U., Pedregosa, F., Gomez, A. N., and Elsen, E. The
difficulty of training sparse neural networks. CoRR, Wang, H., Agarwal, S., and Papailiopoulos, D. Pufferfish:
abs/1906.10732, 2019. URL http://arxiv.org/ Communication-efficient models at no extra cost, 2021.
abs/1906.10732.
Yu, X., Liu, T., Wang, X., and Tao, D. On compressing
Fedus, W., Zoph, B., and Shazeer, N. Switch transform- deep models by low rank and sparse decomposition. pp.
ers: Scaling to trillion parameter models with simple and 67–76, 2017. doi: 10.1109/CVPR.2017.15.
efficient sparsity. CoRR, abs/2101.03961, 2021. URL
Zagoruyko, S. and Komodakis, N. Wide residual networks.
https://arxiv.org/abs/2101.03961.
CoRR, abs/1605.07146, 2016. URL http://arxiv.
Frankle, J., Dziugaite, G. K., Roy, D. M., and Carbin, org/abs/1605.07146.
M. The lottery ticket hypothesis at scale. CoRR,
abs/1903.01611, 2019a. URL http://arxiv.org/
abs/1903.01611.
Exploring Low Rank Training of Deep Neural Networks
Perplexity
0.5 75.04 92.36 35
L2
0.5
1.0 74.83 92.25 34
0.5 75.97 92.85
Frobenius Decay
1.0 76.13 93.09 33
32
Table 6. Comparison between Frobenius Decay and L2 regularisa-
tion on Imagenet 31
100 200 300 400 500
Total Parameters (Millions)
Rank Regularisation lr scaling Perplexity
0.5 38.87 Figure 3. Total parameters vs Performance of GPT-2 on LM1B as
L2
1.0 39.01 the model is scaled up. Each point on the line corresponds to a
0.62
0.5 38.78 different model size starting from 1024 hidden dimensions (on the
Frobenius Decay
1.0 39.2 top left) to 2560 (in the bottom right) with increments of 256.
80
A.5. Pre-training Results 70
60
Accuracy on Test Set
6
5 76.07 92.88
10 75.96 93.04
15 76.12 92.96 7
20 76.08 92.94
0.5 25 76.15 93.00 8
30 76.05 92.9 Low Rank
Pretrain: 40K
35 76.24 93.06 Pretrain: 120K
0 2 4 6 8 10
40 76.21 93.09 Interpolation Step
45 76.29 93.12
Figure 5. Comparison of interpolation of low rank and pretrained
Table 9. Pre-training results for ResNet50 on ImageNet networks for transformer LM.