Soft-Label Dataset Distillation and Text Dataset Distillation
Soft-Label Dataset Distillation and Text Dataset Distillation
University of Waterloo
Waterloo, Ontario, Canada
Editor:
Abstract
Dataset distillation is a method for reducing dataset sizes by learning a small number of
synthetic samples containing all the information of a large dataset. This has several bene-
fits like speeding up model training, reducing energy consumption, and reducing required
storage space. Currently, each synthetic sample is assigned a single ‘hard’ label, and also,
dataset distillation can currently only be used with image data.
We propose to simultaneously distill both images and their labels, thus assigning each
synthetic sample a ‘soft’ label (a distribution of labels). Our algorithm increases accuracy
by 2-4% over the original algorithm for several image classification tasks. Using ‘soft’
labels also enables distilled datasets to consist of fewer samples than there are classes as
each sample can encode information for multiple classes. For example, training a LeNet
model with 10 distilled images (one per class) results in over 96% accuracy on MNIST, and
almost 92% accuracy when trained on just 5 distilled images.
We also extend the dataset distillation algorithm to distill sequential datasets including
texts. We demonstrate that text distillation outperforms other methods across multiple
datasets. For example, models attain almost their original accuracy on the IMDB sentiment
analysis task using just 20 distilled sentences.
Our code can be found at https://github.com/ilia10000/dataset-distillation
Keywords: Dataset Distillation, Knowledge Distillation, Neural Networks, Synthetic
Data, Gradient Descent
1. Introduction
The increase in computational requirements for modern deep learning presents a range
of issues. The training of deep learning models has an extremely high energy consump-
tion (Strubell et al., 2019), on top of the already problematic financial cost and time re-
quirement. One path for mitigating these issues is to reduce network sizes. Hinton et al.
(2015) proposed knowledge distillation as a method for imbuing smaller, more efficient net-
works with all the knowledge of their larger counterparts. Instead of decreasing network
size, a second path to efficiency may be to decrease dataset size. Dataset distillation (DD)
1
Sucholutsky and Schonlau
Figure 1: 10 MNIST images learned by SLDD can train networks with fixed initializations
from 11.13% distillation accuracy to 96.13% (r10 = 97.1). Each image is labeled with its
top 3 classes and their associated logits. The full labels for these 10 images are in Table 1.
Table 1: Learned distilled labels for the 10 distilled MNIST images in Figure 1. Distilled
labels are allowed to take on any real value. If a probability distribution is needed, a softmax
function can be applied to each row.
Digit
Distilled 0 1 2 3 4 5 6 7 8 9
Label
1 2.34 -0.33 0.23 0.04 -0.03 -0.23 -0.32 0.54 -0.39 0.49
2 -0.17 2.58 0.32 0.37 -0.68 -0.19 -0.75 0.53 0.27 -0.89
3 -0.26 -0.35 2.00 0.07 0.08 0.42 0.02 -0.08 -1.09 0.10
4 -0.28 0.04 0.59 2.08 -0.61 -1.11 0.52 0.19 -0.20 0.32
5 -0.11 -0.52 -0.08 0.90 2.63 -0.44 -0.72 -0.39 -0.29 0.87
6 0.25 -0.20 -0.19 0.51 -0.02 2.47 0.62 -0.42 -0.52 -0.63
7 0.42 0.55 -0.09 -1.07 0.83 -0.19 2.16 -0.30 0.26 -0.91
8 0.18 -0.33 -0.25 0.06 -0.91 0.55 -1.17 2.11 0.94 0.47
9 0.46 -0.48 0.24 0.09 -0.78 0.75 0.47 -0.40 2.45 -0.71
10 -0.53 0.52 -0.74 -1.32 1.03 0.23 0.05 0.55 0.31 2.45
2
Soft-Label Dataset Distillation and Text Dataset Distillation
Figure 2: Left: An example of a ‘hard’ label where the second class is selected. Center:
An example of a ‘soft’ label restricted to being a valid probability distribution. The second
class has the highest probability. Right: An example of an unrestricted ‘soft’ label. The
second class has the highest weight. ‘Hard’ labels can be derived from unrestricted ‘soft’
labels by applying the softmax function and then setting the highest probability element to
1, and the rest to 0.
0 0.01 0.8
1 0.69 5.1
0 0.02 1.5
0 0.02 1.5
0 Set largest to 1 0.03 Apply 2
←−−−−−−−−− ←−−−−−
0 and rest to 0 0.05 softmax 2.5
0 0.03 2
0 0.1 3.2
0 0.01 0.8
0 0.04 2
improve their already impressive results by learning ‘soft’ labels as a part of the distillation
process. The original dataset distillation algorithm uses fixed, or ‘hard’, labels for the
synthetic samples (e.g. the ten synthetic MNIST images each have a label corresponding to
a different digit). In other words, each label is a one-hot vector: a vector where all entries
are set to zero aside from a single entry, the one corresponding to the correct class, which
is set to one. We relax this one-hot restriction and make the synthetic labels learnable.
The resulting distilled labels are thus similar to those used for knowledge distillation as a
single image can now correspond to multiple classes. An example comparing a ‘hard’ label
to a ‘soft’ label is shown in Figure 2. A ‘hard’ label can be derived from a ‘soft’ label by
applying the softmax function and setting the element with the highest probability to one,
while the remaining elements are set to zero. Our soft-label dataset distillation (SLDD)
not only achieves over 96% accuracy on MNIST when using ten distilled images (as seen
in Figure 1), a 2% increase over the state-of-the-art (SOTA), but also achieves almost 92%
accuracy with just five distilled images, which is less than one image per class. In addition
to soft labels, we also extend dataset distillation to the natural language/sequence modeling
domain and enable it to be used with several additional neural network architectures. For
example, we show that Text Dataset Distillation (TDD) can train a custom convolutional
neural network (CNN) (LeCun et al., 1999) with known initialization up to 90% of its
original accuracy on the IMDB sentiment classification task (Maas et al., 2011) using just
two synthetic sentences.
The rest of this work is divided into four sections. In Section 2, we discuss related work in
the fields of knowledge distillation, dataset reduction, and example generation. In Section 3,
we propose improvements and extensions to dataset distillation and associated theory. In
Section 4, we empirically validate SLDD and TDD in a wide range of experiments. Finally,
in Section 5, we discuss the significance of SLDD and TDD, and our outlook for the future.
3
Sucholutsky and Schonlau
2. Related Work
2.1 Knowledge Distillation
Dataset distillation was originally inspired by network distillation (Hinton et al., 2015) which
is a form of knowledge distillation or model compression (Bucilu et al., 2006). Network
distillation has been studied in various contexts including when working with sequential
data (Kim and Rush, 2016). Network distillation aims to distill the knowledge of large,
or even multiple, networks into a smaller network. Similarly, dataset distillation aims to
distill the knowledge of large, or even multiple, datasets into a small number of synthetic
samples. ‘Soft’ labels were recently proposed as an effective way of distilling networks by
feeding the output probabilities of a larger network directly to a smaller network (Hinton
et al., 2015), and have previously been studied in the context of different machine learning
algorithms (El Gayar et al., 2006). Our soft-label dataset distillation (SLDD) algorithm
also uses ‘soft’ labels but these are persistent and learned over the training phase of a
network (rather than being produced during the inference phase as in the case of network
distillation).
4
Soft-Label Dataset Distillation and Text Dataset Distillation
cally just subsets of the original training dataset. Prototype generation methods typically
create samples that are not found in the training data; however, these methods are designed
specifically for use with nearest-neighbor classification algorithms.
All the dataset reduction methods discussed above also share another restriction. They
all use fixed labels. Soft-label dataset distillation removes this restriction and allows the
label distribution to be optimized simultaneously with the samples (or prototypes) them-
selves.
5
Sucholutsky and Schonlau
6
Soft-Label Dataset Distillation and Text Dataset Distillation
Figure 3: kNN models are fitted on 3 points obtained from the Iris flower dataset using four
methods: prototype selection, prototype generation, soft labels, and prototype generation
combined with soft labels. Each column contains 4 steps of the associated method used
to update the 3 points used to fit the associated kNN. The pie charts represent the label
distributions assigned to each of the 3 points. Selection method: A different random
point from each class is chosen to represent its class in each of the steps. Generation
method: The middle point associated with the ’green’ label is moved diagonally in each
step. Soft labels method: The label distribution of the middle point is changed each
step to contain a larger proportion of both other classes. Combined method: The middle
point is simultaneously moved and has its label distribution updated in each step.
7
Sucholutsky and Schonlau
Figure 4: kNN model fitted on 2 points obtained using a combination of prototype genera-
tion and soft labels. The pie charts represent the label distributions assigned to each of the
2 points. From left-to-right, in each plot, the locations of the 2 points are slightly shifted
and the values associated with their ‘green‘ label are increased. By modifying the location
and soft labels of the 2 points, the space can still be separated into 3 classes.
N
∗ 1 X
θ = arg min ` (xi , yi , θ) , arg min `(x, y, θ) . (1)
θ N θ
i=1
In general, training with stochastic gradient descent (SGD) involves repeatedly sampling
mini-batches of training data and updating network parameters by their error gradient
scaled by learning rate η.
θt+1 = θt − η∇θt ` (xt , yt , θt ) (2)
With dataset distillation, the goal is to perform just one such step while still achieving
the same accuracy. We do this by learning a very small number of synthetic samples x̃ that
minimize L, a one-step loss objective, for θ1 = θ0 − η̃∇θ0 ` (x̃, θ0 ).
x̃∗ , η̃ ∗ = arg min L (x̃, ỹ, η̃; θ0 ) = arg min ` (x, y, θ0 − η̃∇θ0 ` (x̃, ỹ, θ0 )) (4)
x̃,η̃ x̃,η̃
Note that, currently, we are minimizing over x̃ and η̃, but not ỹ, as the distilled labels are
fixed for the original dataset distillation algorithm. We minimize this objective, or in other
words ‘learn the distilled samples’, by using standard gradient descent.
8
Soft-Label Dataset Distillation and Text Dataset Distillation
x̃∗ , ỹ∗ , η̃ ∗ = arg min L (x̃, ỹ, η̃; θ0 ) = arg min ` (x, y, θ0 − η̃∇θ0 ` (x̃, ỹ, θ0 )) (5)
x̃,ỹ,η̃ x̃,η̃
Algorithm 1a details this soft-label dataset distillation (SLDD) algorithm. We note that
in our experiments, we generally initialize ỹ with the one-hot values that ‘hard’ labels
would have. We found that this tends to increase accuracy when compared to random
initialization, perhaps because it encourages more differentiation between classes early on
in the distillation process.
9
Sucholutsky and Schonlau
η̃ ← η̃ − α∇η̃ j L(j)
P
11: end for
Output: distilled data x̃; distilled labels ỹ; optimized learning rate η̃
10
Soft-Label Dataset Distillation and Text Dataset Distillation
a distilled embedding matrix, for every column vector in the matrix, the nearest embedding
vector from the original dictionary must be found. These embedding vectors must then be
converted back into their corresponding words, and those words joined into a sentence. The
resulting algorithm for text dataset distillation (TDD) is detailed in Algorithm 1b which is
a modification of the SLDD Algorithm 1a.
x̃∗ , ỹ∗ , η̃ ∗ = arg min Eθ0 ∼p(θ0 ) L (x̃, ỹ, η̃; θ0 ) (6)
x̃,ỹ,η̃
The resulting images, especially for MNIST, appear to have much clearer patterns and much
less random noise, and the results detailed in Section 4 suggest that this method generalizes
fairly well to other randomly sampled initializations from the same distribution.
Additionally, Wang et al. (2018) suggest that the above methods can work with multiple
gradient descent (GD) steps. If we want to perform multiple gradient descent steps, each
with a different mini-batch of distilled data, we simply backpropagate the gradient through
every one of these additional steps. Finally, it may also be beneficial to train the neural
networks on the distilled data for more than one epoch. The experimental results suggest
that multiple steps and multiple epochs improve distillation performance for both image
and text data, particularly when using random network initializations.
4. Experiments
4.1 Metrics
The simplest metric for gauging distillation performance is to train a model on distilled sam-
ples and then test it on real samples. We refer to the accuracy achieved on these real samples
as the ‘distillation accuracy’. However, several of the models we use in our experiments do
not achieve SOTA accuracy on the datasets they are paired with, so it is useful to construct
a relative metric that compares distillation accuracy to original accuracy. The first such
metric is the ‘distillation ratio’ which we define as the ratio of distillation accuracy to origi-
nal accuracy. The distillation ratio is heavily dependent on the number of distilled samples
so the notation we use is rM = 100%∗ [distillation accuracy]
[original accuracy] , M = [number of distilled samples].
We may refer to this metric as the ‘M -sample distillation ratio’ when clarification is needed.
It may also be of interest to find the minimum number of distilled images required to achieve
a certain distillation ratio. To this end we can define a second relative metric that we call the
‘A% distillation size’, and we write dA = M where M is the minimum number of distilled
samples required to achieve a distillation ratio of A%.
11
Sucholutsky and Schonlau
η̃ ← η̃ − α∇η̃ j L(j)
P
13: end for
14: for i = 1 to M do
15: Compute nearest embedding for every distilled word
x̃∗i = {NearestEmbed(x̃i,j )}sj=1
16: Decoden embedding into text os
z̃i = WordFromEmbed(x̃∗i,j )
j=1
17: end for
M
18: z̃ = {z̃i }i=1
Output: distilled data x̃; distilled labels ỹ; optimized learning rate η̃; nearest sentences z̃
12
Soft-Label Dataset Distillation and Text Dataset Distillation
• Random real images: We randomly sample the same number of real images per
class from the training data. These images are used for two baselines: training neural
networks and training K-Nearest Neighbors classifiers.
• Optimized real images: We sample several sets of random real images as above,
but now we choose the 20% of these sets that have the best performance on training
data. These images are used for one baseline: training neural networks.
• k-means: We use k-means to learn clusters for each class, and keep the resulting
centroids. These images are used for two baselines: training neural networks and
training K-Nearest Neighbors classifiers.
• Average real images: We compute the average image for each class and use it for
training. These images are used for one baseline: training neural networks.
Each of these baseline methods produces a small set of images that can be used to train
models. All four of the baseline methods are used to train and test LeNet and AlexCifarNet
on their respective datasets. Additionally, two of the baseline methods are used to also
train K-Nearest Neighbor classifiers to compare performance against neural networks. The
results for these six baselines, as determined by Wang et al. (2018), are shown in Table 2.
Fixed initialization. When the network initialization is fixed between the distillation and
training phases, synthetic images produced by dataset distillation result in high distillation
accuracies. The SLDD algorithm produces images that result in equal or higher accuracies
when compared to the original DD algorithm. For example, DD can produce 10 distilled
images that train a LeNet model up to 93.76% accuracy on MNIST (Wang et al., 2018).
Meanwhile, SLDD can produce 10 distilled images that train the same model up to 96.13%
accuracy (Figure 1). The full distilled labels for these 10 images are laid out in Table 1.
SLDD can even produce a tiny set of just 5 distilled images that train LeNet to 91.56%
accuracy. As can be seen in Figure 6, the 90% distillation size (i.e. the minimum number of
images needed to achieve 90% of the original accuracy) of MNIST with fixed initializations
is dA = 5, and while adding more distilled images typically increases distillation accuracy,
this begins to plateau after five images. Similarly, SLDD provides a 7.5% increase in 100-
sample distillation ratio (6% increase in distillation accuracy) on CIFAR10 over DD. Based
on these results, detailed further in Table 2, it appears that SLDD is even more effective
than DD at distilling image data into a small number of samples. This intuitively makes
sense as the learnable labels used by SLDD increase the capacity of the distilled dataset for
storing information.
13
Sucholutsky and Schonlau
(a) Step 0
(b) Step 5
(c) Step 9
Figure 5: SLDD can learn 100 distilled CIFAR10 images that train networks with fixed
initializations from 12.9% distillation accuracy to 60.0% (r100 = 75.0). Each image is
labeled with its top 3 classes and their associated logits. Only 3 of the 10 steps are shown.
14
Soft-Label Dataset Distillation and Text Dataset Distillation
Random initialization. It is also of interest to know whether distilled data are robust to
network initialization. Specifically, we aim to identify if distilled samples store information
only about the network initializations, or whether they can store information contained
within the training data. To this end, we perform experiments by sampling random network
initializations generated using the Xavier Initialization (Glorot and Bengio, 2010). The
distilled images produced in this way are more representative of the training data but
generally result in lower accuracies when models are trained on them. Once again, images
distilled using SLDD lead to higher distillation accuracies than DD when the number of
distilled images is held constant. For example, 100 MNIST images learned by DD result
in accuracies of 79.5 ± 8.1%, while 100 images learned by SLDD result in accuracies of
82.75 ± 2.75%. There is similarly a 3.8% increase in 100-sample distillation ratio (3%
increase in distillation accuracy) when using SLDD instead of DD on CIFAR10 using 100
distilled images each. These results are detailed in Table 2. It is also interesting to note
that the actual distilled images, as seen in Figures 7 and 8, appear to have much clearer
patterns emerging than in the fixed initialization case. These results suggest that DD, and
even more so SLDD, can be generalized to work with random initializations and distill
knowledge about the dataset itself when they are trained this way. All the mean and
standard deviation results for random initializations in Table 2 are derived by testing with
200 randomly initialized networks.
15
Sucholutsky and Schonlau
Figure 6: Distillation accuracy on MNIST with LeNet for different distilled dataset sizes.
16
Soft-Label Dataset Distillation and Text Dataset Distillation
(a) Step 0
(b) Step 5
(c) Step 9
Figure 7: SLDD can learn 100 distilled MNIST images that train networks with random
initializations from 10.09% ± 2.54% distillation accuracy to 82.75% ± 2.75% (r100 = 83.6).
Each image is labeled with its top 3 classes and their associated logits. Only 3 of the 10
steps are shown.
17
Sucholutsky and Schonlau
(a) Step 0
(b) Step 5
(c) Step 9
Figure 8: SLDD can learn 100 distilled CIFAR10 images that train networks with random
initializations from 10.17% ± 1.23% distillation accuracy to 39.82% ± 0.83% (r100 = 49.8).
Each image is labeled with its top 3 classes and their associated logits. Only 3 of the 10
steps are shown.
18
Table 2: Means and standard deviations of distillation and baseline accuracies on image data. All values are percentages. The
first four baselines are used to train the same neural network as in the distillation experiments. The last two baselines are used
to train a K-Nearest Neighbors classifier. Experiments with random initializations have their results listed in the form [mean ±
standard deviation] and are based on the resulting performance of 200 randomly initialized networks.
SLDD accuracy DD accuracy Used as training data in same # of GD steps Used in K-NN
Fixed Random Fixed Random Rand. real Optim. real k-means Avg. real Rand. real k-means
MNIST 98.6 82.7 ± 2.8 96.6 79.5 ± 8.1 68.6 ± 9.8 73.0 ± 7.6 76.4 ± 9.5 77.1 ± 2.7 71.5 ± 2.1 92.2 ± 0.1
CIFAR10 60.0 39.8 ± 0.8 54.0 36.8 ± 1.2 21.3 ± 1.5 23.4 ± 1.3 22.5 ± 3.1 22.3 ± 0.3 18.8 ± 1.3 29.4 ± 0.3
Table 3: Means and standard deviations of TDD and baseline accuracies on text data using TextConvNet. All values are
percentages. The first four baselines are used to train the same neural network as in the distillation experiments. The last two
19
baselines are used to train a K-Nearest Neighbors classifier. Each result uses 10 GD steps aside from IMDB with k-means and
TREC50 which had to be done with 2 GD steps due to GPU memory constraints and also insufficient training samples for some
classes in TREC50. The second TREC50 row uses TDD with 5 GD steps with 4 images per class. Experiments with random
initializations have their results listed in the form [mean ± standard deviation] and are based on the resulting performance of
200 randomly initialized networks.
TREC50 57.6 11.0 ± 0.0 8.2 ± 6.0 9.9 ± 6.6 14.7 ± 5.5 12.5 ± 6.4 15.4 ± 5.1 45.1 ± 6.6
TREC502 67.4 42.1 ± 2.1
Sucholutsky and Schonlau
Table 5: Distillation ratios for text datasets and their associated neural networks. Exper-
iments with random initializations have their results listed in the form [mean ± standard
deviation] and are based on the resulting performance of 200 randomly initialized networks.
20
Soft-Label Dataset Distillation and Text Dataset Distillation
Baselines. We consider the same six baselines as in the image case but modify them
slightly so that they work with text data.
• Random real sentences: We randomly sample the same number of real sentences
per class, pad/truncate them, and look up their embeddings. These sentences are
used for two baselines: training neural networks and training K-Nearest Neighbors
classifiers.
• k-means: First, we pre-process the sentences. Then, we use k-means to learn clusters
for each class, and use the resulting centroids to train. These sentences are used for
two baselines: training neural networks and training K-Nearest Neighbors classifiers.
Each of these baseline methods produces a small set of sentences, or sentence embeddings,
that can be used to train models. All four of the baseline methods are used to train and test
the TextConvNet on each of the text datasets. Additionally, two of the baseline methods
are used to also train K-Nearest Neighbor classifiers to compare performance against neural
networks. The baseline results are shown in Table 3.
Fixed initialization. When the network initialization is fixed between the distillation
and training phases, synthetic text produced by text dataset distillation also results in
high model accuracies. For example, TDD can produce 2 distilled sentences that train the
TextConvNet up to a distillation ratio of 89.88% on the IMDB dataset. Even for far more
difficult language tasks, TDD still has impressive results but with larger distilled datasets.
For example, for the 50-class TREC50 task, it can produce 1000 distilled sentences that train
the TextConvNet to a distillation ratio of 79.86%. Some examples of TDD performance are
detailed in Table 3 and Table 5. The distilled text embeddings from the six-sentence Trec6
experiments are visualized in Figure 9. However, since these distilled text embeddings are
still in the GloVe embedding space, it may be difficult to interpret them visually. We provide
a more natural method for analyzing distilled sentences by using nearest-word decoding to
reverse the GloVe embedding. We find the nearest word to each distilled vector based
on Euclidean distances. The result of this decoding is an approximation of the distilled
sentence in the original text space. We list the decoded distilled sentences corresponding
to the matrices from Figure 9 in Table 6, along with their respective label distributions.
These sentences can contain any tokens found in the TREC6 dataset, including punctuation,
numbers, abbreviations, etc. The sentences do not have much overlap. This is consistent
with the distilled labels which suggest that each sentence corresponds strongly to a different
class. It appears that the TDD algorithm encourages the separation of classes, at least when
there are enough distilled samples to have one or more per class. Additional results and
visualizations for TDD with fixed initialization can be found in the online appendix.
21
Sucholutsky and Schonlau
Figure 9: TDD can learn 6 distilled sentences of length 30 that train networks with fixed
initializations from 12.6% to 87.4% (r6 = 97.8). Each image corresponds to a distilled
embedding and is labeled with its top 3 classes and their associated logits.
Label Class
Distilled Sentence 0 1 2 3 4 5
allan milk banned yellow planted successfully in- 2.72 -0.48 -0.07 -0.62 -0.53 -0.27
troduced bombay 1936 grass mines iron delhi 1942
male heir throne oath clouds 7th occur millennium
smoking flows truth powder judiciary pact slim
profit
whom engineer grandfather joan officer entered vic- -0.05 3.21 -1.15 -0.79 -0.71 -0.64
toria 1940s taxi romania motorcycle italian busi-
nessman photographer powerful driving u brilliant
affect princess 1940s enemies conflicts southwestern
retired cola appearances super dow consumption
necessarily factors pronounced pronounced define -0.67 -0.66 3.28 -0.27 -0.77 0.47
bow destroying belonged balls 1923 storms build-
ings 1925 victorian sank dragged reputation sailed
nn occurs darkness blockade residence traveled
banner chef ruth rick lion psychology
accommodate accommodate peak 2.5 adults thin -0.98 -0.14 -1.12 5.57 -0.85 -1.85
teenagers hike aged nurse policeman admit aged
median philippines define baghdad libya ambas-
sador admit baseman burma inning bills trillion
donor fined visited stationed clean
suburb ports adjacent mountains nearest compare -0.29 -0.22 -0.36 -1.01 3.86 -0.44
hilton volcano igor nebraska correspondent 1926
suburb sailed hampshire hampshire gathering les-
son proposition metric copy carroll sacred moral
lottery whatever fix o completed ultimate
advertising racism excuse d nancy solved continu- 0.94 0.53 -0.85 0.08 -0.83 1.99
ing congo diameter oxygen accommodate provider
commercials spread pregnancy mideast ghana at-
traction volleyball zones kills partner serves serves
congressman advisory displays ranges profit evil
22
Soft-Label Dataset Distillation and Text Dataset Distillation
Figure 10: TDD can learn 2 distilled sentences of length 400 that train networks with
random initializations from 50.0% to 69.6% ± 5.5% (r2 = 79.96). Each image corresponds
to a distilled embedding and is labeled with its soft label value. The soft label is a scalar
as this is a binary classification task.
Random initialization. When using random initialization the performance decrease for
TDD is similar to that for SLDD. TDD can produce two distilled sentences that train the
TextConvNet with random initialization up to a distillation ratio of 79.96% on IMDB, or
20 distilled sentences that train it up to a distillation ratio of 85.22%. This is only slightly
lower performance than in the fixed initialization case. However, there is a larger difference
in performance between fixed and random initializations for the recurrent networks. For
TREC6, TDD can produce six distilled sentences that train a randomly initialized Bi-LSTM
to a distillation ratio of 69.33%, or 120 distilled sentences that train it to a distillation ratio
of 78.87%. All the mean and standard deviation results for random initializations in Table 3
and Table 5 are derived by testing with 200 randomly initialized networks. The distilled
text embeddings from the IMDB experiment with two distilled sentences are visualized in
Figure 10. We list the decoded distilled sentences corresponding to these matrices in Table 7.
These sentences can contain any tokens found in the IMDB dataset, including punctuation,
numbers, abbreviations, etc. Since this is a binary sentiment classification task, each label
is a scalar. If a probability is needed, a sigmoid function can be applied to the scalar soft
labels. In this case, the distillation algorithm appears to have produced one sentence with
a positive associated sentiment, and one with a negative sentiment. Curiously, the model
appears to have overcome the challenge of having to describe these long sentences with a
single scalar by using duplication. For example, in the second sentence, corresponding to
the negative label, negative words like ‘dump’, ‘stupid’, and ‘shoddy’ are all repeated several
times. In such a way, the algorithm is likely assigning lower sentiment scores to these words
than other ones, while using only a single label. Additional results for TDD with random
initialization can be found in the online appendix.
5. Conclusion
By introducing learnable distilled labels we have increased distillation accuracy across mul-
tiple datasets by up to 6%. By enabling text distillation, we have also greatly increased the
types of datasets and architectures with which distillation can be used.
23
Sucholutsky and Schonlau
Table 7: Nearest sentence decodings corresponding to the distilled embeddings in Figure 10.
Each sentence is accompanied by its associated soft label. Only the first 200 (out of 400)
words are shown for each sentence.
24
Soft-Label Dataset Distillation and Text Dataset Distillation
However, even with SLDD and TDD, there are still some limitations to dataset distil-
lation. The network initializations used for both SLDD and TDD all come from the same
distribution, and no testing has yet been done on whether a single distilled dataset can
be used to train networks with different architectures. Further investigations are needed
to determine more precisely how well dataset distillation can be generalized to work with
more variation in initializations, and even across networks with different architectures.
Interestingly, the initialization of distilled labels appears to affect the performance of
dataset distillation. Initializing the distilled labels with ‘hard’ label values leads to better
performance than with random initialization, possibly because it encourages class separation
earlier on in the distillation process. However, it is not immediately clear whether it is better
to separate similar classes (e.g. ‘3’ and ‘8’ in MNIST), thereby increasing the network’s
ability to discern between them, or to instead keep those classes together, thereby allowing
soft-label information to be shared between them. It may be interesting to explore the
dynamics of the distillation process when using a variety of label initialization methods.
We have shown dataset distillation works with CNNs, bi-directional RNNs, and LSTMs.
There is nothing in the dataset distillation algorithm that would limit it to these network
types. As long as a network has a twice-differentiable loss function and the gradient can be
back-propagated all the way to the inputs, then that network is compatible with dataset
distillation.
Another promising direction is to use distilled datasets for speeding up Neural Archi-
tecture Search and other very compute-intensive meta-algorithms. If distilled datasets are
a good proxy for performance evaluation, they can reduce search times by multiple orders
of magnitude. In general, dataset distillation is an exciting new branch of knowledge distil-
lation; improvements may help us not only better understand our datasets but also enable
several applications related to efficient machine learning.
Acknowledgments
We would like to thank Dr. Sebastian Fischmeister for providing us with the computational
resources that enabled us to perform many of the experiments found in this work.
25
Sucholutsky and Schonlau
References
Anelia Angelova, Yaser Abu-Mostafam, and Pietro Perona. Pruning training sets for learn-
ing of object categories. In 2005 IEEE Computer Society Conference on Computer Vision
and Pattern Recognition (CVPR’05), volume 1, pages 494–501. IEEE, 2005.
Olivier Bachem, Mario Lucic, and Andreas Krause. Practical coreset constructions for
machine learning. arXiv preprint arXiv:1703.06476, 2017.
Yunjey Choi, Minje Choi, Munyoung Kim, Jung-Woo Ha, Sunghun Kim, and Jaegul Choo.
Stargan: Unified generative adversarial networks for multi-domain image-to-image trans-
lation. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recog-
nition, pages 8789–8797. IEEE, 2018.
David A Cohn, Zoubin Ghahramani, and Michael I Jordan. Active learning with statistical
models. Journal of Artificial Intelligence Research, 4:129–145, 1996.
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-
training of deep bidirectional Transformers for language understanding. arXiv preprint
arXiv:1810.04805, 2018.
Neamat El Gayar, Friedhelm Schwenker, and Günther Palm. A study of the robustness of
knn classifiers trained using soft labels. In IAPR Workshop on Artificial Neural Networks
in Pattern Recognition, pages 67–80. Springer, 2006.
Salvador Garcia, Joaquin Derrac, Jose Cano, and Francisco Herrera. Prototype selection
for nearest neighbor classification: Taxonomy and empirical study. IEEE transactions on
pattern analysis and machine intelligence, 34(3):417–435, 2012.
Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward
neural networks. In Proceedings of the Thirteenth International Conference on Artificial
Intelligence and Statistics, pages 249–256. PMLR, 2010.
Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil
Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in
Neural Information Processing Systems, pages 2672–2680. Curran Associates, Inc., 2014.
Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network.
arXiv preprint arXiv:1503.02531, 2015.
Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural computation,
9(8):1735–1780, 1997.
26
Soft-Label Dataset Distillation and Text Dataset Distillation
Yoon Kim and Alexander M Rush. Sequence-level knowledge distillation. arXiv preprint
arXiv:1606.07947, 2016.
Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny
images. Technical report, Citeseer, 2009. URL https://www.cs.toronto.edu/~kriz/
learning-features-2009-TR.pdf.
Yann LeCun, Léon Bottou, Yoshua Bengio, Patrick Haffner, et al. Gradient-based learning
applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
Yann LeCun, Patrick Haffner, Léon Bottou, and Yoshua Bengio. Object Recognition with
Gradient-Based Learning, pages 319–345. Springer Berlin Heidelberg, Berlin, Heidelberg,
1999. ISBN 978-3-540-46805-9. doi: 10.1007/3-540-46805-6 19. URL https://doi.org/
10.1007/3-540-46805-6_19.
Christian Ledig, Lucas Theis, Ferenc Huszár, Jose Caballero, Andrew Cunningham, Ale-
jandro Acosta, Andrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, et al.
Photo-realistic single image super-resolution using a generative adversarial network. In
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages
4681–4690, 2017.
Chunyuan Li, Heerad Farkhoor, Rosanne Liu, and Jason Yosinski. Measuring the intrinsic
dimension of objective landscapes. arXiv preprint arXiv:1804.08838, 2018.
Xuezhe Ma and Eduard Hovy. End-to-end sequence labeling via bi-directional LSTM-CNNs-
CRF. arXiv preprint arXiv:1603.01354, 2016.
Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and
Christopher Potts. Learning word vectors for sentiment analysis. In Proceedings of the
49th Annual Meeting of the Association for Computational Linguistics: Human Language
Technologies, pages 142–150, Portland, Oregon, USA, June 2011. Association for Com-
putational Linguistics.
Jeffrey Pennington, Richard Socher, and Christopher Manning. Glove: Global vectors for
word representation. In Proceedings of the 2014 Conference on Empirical Methods in
Natural Language Processing (EMNLP), pages 1532–1543, 2014.
Matthew E Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Kenton
Lee, and Luke Zettlemoyer. Deep contextualized word representations. arXiv preprint
arXiv:1802.05365, 2018.
Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with
deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434,
2015.
27
Sucholutsky and Schonlau
Scott Reed, Zeynep Akata, Xinchen Yan, Lajanugen Logeswaran, Bernt Schiele, and
Honglak Lee. Generative adversarial text to image synthesis. arXiv preprint
arXiv:1605.05396, 2016.
Mike Schuster and Kuldip K Paliwal. Bidirectional recurrent neural networks. IEEE trans-
actions on Signal Processing, 45(11):2673–2681, 1997.
Ozan Sener and Silvio Savarese. Active learning for convolutional neural networks: A
core-set approach. arXiv preprint arXiv:1708.00489, 2017.
Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D Manning, Andrew
Ng, and Christopher Potts. Recursive deep models for semantic compositionality over
a sentiment treebank. In Proceedings of the 2013 Conference on Empirical Methods in
Natural Language Processing, pages 1631–1642, 2013.
Emma Strubell, Ananya Ganesh, and Andrew McCallum. Energy and policy considerations
for deep learning in NLP. arXiv preprint arXiv:1906.02243, 2019.
Simon Tong and Daphne Koller. Support vector machine active learning with applications
to text classification. Journal of Machine Learning Research, 2(Nov):45–66, 2001.
Isaac Triguero, Joaquı́n Derrac, Salvador Garcia, and Francisco Herrera. A taxonomy and
experimental study on prototype generation for nearest neighbor classification. IEEE
Transactions on Systems, Man, and Cybernetics, Part C (Applications and Reviews), 42
(1):86–100, 2011.
Ivor W Tsang, James T Kwok, and Pak-Ming Cheung. Core vector machines: Fast SVM
training on very large data sets. Journal of Machine Learning Research, 6(Apr):363–392,
2005.
Ellen M Voorhees et al. The TREC-8 question answering track report. In the Proceedings
of the Eighth Text Retrieval Conference (TREC-8), volume 99, pages 77–82, 1999.
Li Wan, Matthew Zeiler, Sixin Zhang, Yann Le Cun, and Rob Fergus. Regularization of
neural networks using dropconnect. In International Conference on Machine Learning,
pages 1058–1066, 2013.
Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba, and Alexei A Efros. Dataset distillation.
arXiv preprint arXiv:1811.10959, 2018.
L Yu, W Zhang, J Wang, and Y Yu. SeqGAN: sequence generative adversarial nets with
policy gradient. In AAAI-17: Thirty-First AAAI Conference on Artificial Intelligence,
volume 31, pages 2852–2858. Association for the Advancement of Artificial Intelligence
(AAAI), 2017.
28