Two-Step Knowledge Distillation For Tiny Speech Enhancement
Two-Step Knowledge Distillation For Tiny Speech Enhancement
ABSTRACT offers additional useful context compared to the ground truth data
Tiny, causal models are crucial for embedded audio machine by itself. Unlike pruning, this process does not involve modifying
learning applications. Model compression can be achieved via dis- the student network from its original dense form, which reduces the
arXiv:2309.08144v1 [cs.SD] 15 Sep 2023
tilling knowledge from a large teacher into a smaller student model. complexity of the deployment process. In this work, we focus on
In this work, we propose a novel two-step approach for tiny speech KD due to its above-outlined benefits over pruning.
enhancement model distillation. In contrast to the standard approach KD methods have been applied to various classification tasks
of a weighted mixture of distillation and supervised losses, we firstly in the audio domain [12, 13]. However, KD has not been exten-
pre-train the student using only the knowledge distillation (KD) ob- sively explored for causal low-latency SE, which often requires tiny
jective, after which we switch to a fully supervised training regime. networks (sub-100k parameters) optimized for low-resource wear-
We also propose a novel fine-grained similarity-preserving KD loss, able devices, such as hearing aids [5, 6]. So-called response-based
which aims to match the student’s intra-activation Gram matrices to KD approaches use the pre-trained teacher model’s outputs to train a
that of the teacher. Our method demonstrates broad improvements, student network [14, 15]. However, distillation can be further facili-
but particularly shines in adverse conditions including high compres- tated by taking advantage of intermediate representations of the two
sion and low signal to noise ratios (SNR), yielding signal to distor- models, not just their outputs [10]. One common obstacle in such
tion ratio gains of 0.9 dB and 1.1 dB, respectively, at -5 dB input feature-based KD is the dimensionality mismatch between teacher
SNR and 63× compression compared to baseline. and student activations due to the model size difference. To alleviate
this issue, [16] proposed aligning intermediate features, while [17]
Index Terms— speech enhancement, knowledge distillation,
used attention maps to do so. The latter was applied in the context
tinyML, model compression
of SE in [18] using considerably large, non-causal student models
intended for offline applications. In [19], the authors addressed the
1. INTRODUCTION dimensionality mismatch problem for the causal SE models by us-
ing frame-level Similarity Preserving KD [20] (SPKD). SPKD cap-
In recent years, deep neural network (DNN) models have become a tures the similarity patterns between network activations for different
common approach to the speech enhancement (SE) problem, due to training examples and aims to match those patterns between the stu-
their performance [1, 2, 3]. However, large, powerful models are of- dent and the frozen pre-trained teacher models. The authors of [19]
ten unsuitable for resource-constrained platforms, like hearing aids also introduced fusion blocks, analogous to [21], to distill relation-
or wearables, because of their memory footprint, computational la- ships between consecutive layers.
tency, and power consumption [2, 4, 5, 6]. To meet these constraints,
audio TinyML research tends to focus on designing model architec- Here, we show that the efficacy of conventional KD methods is
tures with small numbers of parameters, using model compression limited for tiny, causal SE models. To improve distillation efficacy,
techniques to reduce the size of large models, or both [4, 5, 6, 7]. we first extend the method from [19] by computing SPKD for each
Pruning is a popular method for reducing the size of DNN mod- bin of the latent representations, corresponding to the time frame (as
els for SE [4, 5, 6, 8]. The goal of pruning is to remove weights in [19]) but also the frequency bin of the input, thus providing more
least contributing to model performance. In its simplest form, this resolution for KD loss optimization. The proposed extension out-
can be performed post-training by removing weights with the low- performs other similarity-based KD methods which we also explore.
est magnitudes. Online pruning, where the model is trained and Second, we hypothesize that matching a large teacher model might
pruned concurrently, builds on post-training pruning by exposing be challenging for small student models and thus may lead to sub-
the model to pruning errors while training, allowing it to adapt to optimal performance. Inspired by [22], we propose a novel two-step
this form of compression noise [4]. Unstructured pruning of indi- framework for distilling tiny SE models. In the first step, the student
vidual weights can yield impressive model size reduction with little is pre-trained using only the KD criterion to match the activation
performance sacrifice, but corresponding savings in computational patterns of the teacher, with no additional ground truth supervision.
throughput are not possible without hardware support for sparse in- The goal of this unsupervised KD pre-training is to make the student
ference, which is unusual in embedded hardware. Structured pruning similar to the teacher prior to the main training. Then, the pre-trained
of blocks of weights and/or neurons is often designed with broader student model is further optimized in a supervised fashion and/or us-
hardware compatibility in mind, but the performance drop tends to ing KD routines. We find that pre-training using the proposed SPKD
be larger than for the unstructured case [6]. method at the level of the individual bin of the latent representation,
In contrast to pruning, knowledge distillation (KD) adopts a dif- followed by fully supervised training yields superior performance
ferent framework. The goal of KD is to utilize a strong pre-trained compared to other distillation approaches utilizing weighted mix-
teacher model to guide the training of the smaller student [9, 10, 11]. tures of KD and supervised losses. We report the performance of
The underlying assumption is that the pre-trained teacher network our method across various student model sizes, input mixture signal-
to-noise ratios (SNRs), and finally, assess the similarity between the
∗ These authors contributed equally to this work. activation patterns of the teacher and distilled student.
(a) Distillation process (b) Self-Similarity Gram matrices (c) Flow matrices
Teacher (frozen)
(1)
...
(2)
...
Student
N
Fig. 1: (a) Distillation process overview (b) Self-Similarity Gram matrices computation. (c) Flow matrices computation ( denotes matrix
multiplication). Note that, for clarity, transpositions and matrix multiplications are applied only to the last two dimensions of each tensor.
2.1. Model architecture Inspired by previous work [19, 20], we address the issue of dimen-
sionality mismatch between teacher and student models by comput-
Our backbone architecture for the exploration of tiny SE KD is the ing similarity-based distillation losses. The method captures and
Convolutional Recurrent U-Net for SE (CRUSE) topology [7]. How- compares the relationship between batch items at each layer out-
ever, note that the methodology developed here can, in principle, put, between teacher and student (Fig. 1a, Llocal
KD ). We refer to this
be applied to any other architecture. The CRUSE model operates relationship as the self-similarity Gram matrix Gx .
in the time-frequency domain and takes power-law compressed log- Self-similarity matrices (Fig. 1b) can be computed for an ex-
mel spectrograms (LMS) as input. The LMS is obtained by taking ample network latent activation X of shape [b, c, t, f ], where b -
the magnitude of the complex short-time Fourier transform (STFT, batch size, c - channel, t - activation width (corresponding to the
512/256 samples frame/hop size), processing it through a Mel-space input time dimension), f - activation height (corresponding to the
filterbank (80 bins, covering 50-8k Hz range) and finally compress- input frequency dimension), as shown in Fig. 1b. The original im-
ing the result by raising it to the power of 0.3. The model output is plementation from [20] involves reshaping X to [b, ctf ] and matrix
a real-valued time-frequency mask bounded within the range (0, 1) multiplying it by its transpose XT to obtain the [b, b] symmetric self-
through the sigmoid activation of the final block. The mask is ap- similarity matrix G. Analogously, this operation can be performed
plied to the noisy model input and reconstituted into the time domain for each t or f dimension independently with resulting Gt/f matri-
using the inverse STFT and the noisy input phase. ces of size [t/f , b, b]. Such an increase in granularity improved the
The model comprises four encoder/decoder blocks and a bot- KD performance in [19]. Here, we obtain even more detailed intra-
tleneck with grouped GRU units (4 groups), reducing the computa- activation Gram matrices by considering each (t, f ) bin separately,
tional complexity compared to a conventional GRU layer with the resulting in the Gtf self-similarity matrix with shape [t, f , b, b].
same number of units [23]. The encoder/decoder blocks are com- Finally, the local KD loss is computed using self-similarity ma-
posed of 2D convolution/transpose convolution layers with (2, 3) trices Gx of any kind x obtained from teacher T and student S as:
kernels (time, frequency) and (1, 2) strides, followed by cumulative
layer normalization [24] and leaky ReLU activation (α = 0.2). To 1 X 2
Llocal
KD = GTxi − GS
x
i
, (1)
further reduce the model complexity, skip connections between the b2 i F
encoder and decoder used in classic U-Net are replaced with 1x1
convolutions, whose outputs are summed into the decoder inputs. where i is the layer index and ∥∥2F is the Frobenius l2 norm.
We enforce the model’s frame-level causality by using causal con-
volutions and causal cumulative layer norms. The total algorithmic 2.3. Information flow knowledge distillation
latency of the model is 32 ms (single STFT frame) [2].
In our experiments, both teacher and student are CRUSE mod- The above-outlined local similarity losses can be extended to capture
els and their sizes are adjusted by changing the number of units in relationships between activations of subsequent layers of the teacher
each block. In particular, the teacher uses {32, 64, 128, 192} en- and student models (Fig. 1a, LfKD
low
). The method is inspired by the
coder/decoder channels and 960 bottleneck GRU units, resulting in Flow of Solution Procedure (FSP) matrices introduced in [22] and
1.9M parameters, and 13.34 MOps/frame (i.e. the number of opera- aims to not only match local similarity between the teacher and stu-
tions required to process a single STFT frame). Our default student dent in the corresponding layers but also global inter-layer relations.
uses {8, 16, 32, 32} encoder/decoder channels and 160 bottleneck We propose two versions of flow matrices between layers i and
GRU units resulting in 62k parameters (3.3% of the teacher), and j in our model (Fig. 1c). The first one, Gi→jt , leverages Gt self-
0.84 MOps/frame (6.3% of the teacher). similarity matrices. Thereby each self-similarity block shares the t-
dimension and thus the interaction between the layers’ self-similarity Table 1: One-step KD for tiny SE. Output: LKD comparing teacher
can be captured by performing matrix multiplication of Git and trans- and student outputs (similar to [15]). Gx : Feature-based LKD using
posed Gjt (both sized [t, b, b]) for each time frame t. self-similarity matrix of type x (Fig. 1b). All models are initialized
The second version leverages Gtf self-similarity matrices. with the same random weights and use γ = 0.5 (Eq. 3).
However, the f dimension in our model changes for each block
due to the strided convolutions. To quantify the relationship be- ∆SDR ∆PESQ ∆eSTOI ∆DNS-MOS
Model
i/j (dB) (MOS) (%) BAK OVRL SIG
tween layers i and j of different dimensions we reshaped Gtf to the
size of [t, b, fi/j , b]. Then for each time-batch-item pair (t,b), we Teacher 8.65 1.25 10.07 1.44 0.69 0.06
obtain a [fi/j , b] sub-matrix, which can be matrix multiplied with Student 6.34 0.75 5.82 1.27 0.55 -0.02
its transpose to obtain the flow matrix Gi→j
tf of size [t, b, fi , fj ]. Distillation
We define the loss similarly to Eq. 1 by comparing the teacher
T S Output [15] 6.35 0.75 5.59 1.33 0.56 -0.03
Gxi→j flow matrix with the student Gx i→j flow matrix, of the same G [20] 6.32 0.75 5.70 1.29 0.56 -0.02
kind x, for every 2-layer-combination (i, j): Gt [19] 6.50 0.77 5.95 1.33 0.55 -0.04
1 XX T S 2 Gf 6.47 0.74 6.03 1.29 0.56 -0.02
LfKD
low
= Gxi→j − Gx i→j (2) Gtf (ours) 6.68 0.77 5.99 1.36 0.57 -0.04
b2 i j>i F
Similarity
0.6 SNR (dB) Model
(dB) (MOS) (%) BAK OVRL SIG
4
0.4 Teacher 14.05 0.62 19.12 2.16 1.02 0.64
2
0.2 -5 Student 10.82 0.30 10.07 1.86 0.79 0.51
0 Proposed 11.73 0.35 11.61 1.98 0.81 0.47
0 2 4 6 8 0 2 4 6 8 0 2 4 6 8 0.0
Teacher block Teacher block Teacher block Teacher 12.30 0.92 17.83 1.99 0.98 0.40
0 Student 9.65 0.49 10.56 1.75 0.75 0.26
Fig. 2: Block-wise CKA similarity between students and teacher Proposed 10.23 0.56 11.51 1.84 0.79 0.25
networks, averaged over the MS-DNS test set. Mean(diag) and Teacher 10.27 1.21 13.98 1.65 0.78 0.02
Mean(all) denote the average similarity for the corresponding blocks 5 Student 7.97 0.69 8.58 1.44 0.59 -0.10
(diagonal) or all the block combinations, respectively. Proposed 8.43 0.76 9.32 1.51 0.62 -0.09