self-supervised-learning-for-large-scale-item-recommendations
self-supervised-learning-for-large-scale-item-recommendations
Tiansheng Yao, Xinyang Yi, Derek Zhiyuan Cheng, Felix Yu, Ting Chen, Aditya Menon
Lichan Hong, Ed H. Chi, Steve Tjoa, Jieqi (Jay) Kang, Evan Ettinger
Google Inc., United States
embedding space through neural networks for both queries and architecture, a neural network encodes a set of item features into an
items from user feedback data. However, with millions to billions embedding thus making it applicable even for indexing cold-start
of items in the corpus, users tend to provide feedback for a very items. Moreover, the two-tower DNN architecture enables efficient
small set of them, causing a power-law distribution. This makes serving for a large corpus of items in real time, by converting the
the feedback data for long-tail items extremely sparse. top-k nearest neighbor search problem to Maximum-Inner-Product-
Inspired by the recent success in self-supervised representation Search (MIPS) [9] that is solvable in sublinear complexity.
learning research in both computer vision and natural language un- Embedding-based deep models typically have large amount of pa-
derstanding, we propose a multi-task self-supervised learning (SSL) rameters because they are built with high-dimensional embeddings
framework for large-scale item recommendations. The framework that represent high cardinality sparse features such as topics or
is designed to tackle the label sparsity problem by learning better item IDs. In many existing literature, the loss functions for training
latent relationship of item features. Specifically, SSL improves item these models are formulated as a supervised learning problem. The
representation learning as well as serving as additional regulariza- supervision comes from the collected labels (e.g., clicks). Modern
tion to improve generalization. Furthermore, we propose a novel recommendation systems collect billions to trillions of footprints
data augmentation method that utilizes feature correlations within from users, providing huge amount of training data for building
the proposed framework. deep models. However, when it comes to modeling a huge catalogue
We evaluate our framework using two real-world datasets with of items in the order of millions (e.g., songs and apps [28]) to even
500M and 1B training examples respectively. Our results demon- billions (e.g., videos on YouTube [10]), data could still be highly
strate the effectiveness of SSL regularization and show its superior sparse for certain slices due to:
performance over the state-of-the-art regularization techniques. We • Highly-skewed data distribution: The interaction between
also have already launched the proposed techniques to a web-scale queries and items are often highly skewed in a power-law dis-
commercial app-to-app recommendation system, with significant tribution [30]. So a small set of the popular items gets most of
improvements top-tier business metrics demonstrated in A/B exper- the interactions. This will always leave the training data for
iments on live traffic. Our online results also verify our hypothesis long-tail items very sparse.
that our framework indeed improves model performance even more • Lack of explicit user feedback: Users often provide lots of
on slices that lack supervision. positive feedback implicitly like clicks and thumb-ups. However,
they are much less likely to provide explicit feedback like item
ratings, feedback for user happiness, and relevance scores.
1 INTRODUCTION
Recently, neural-net models have emerged to the main stage of Self-supervised learning (SSL) offers a different angle to improve
modern recommendation systems throughout the industry (see, e.g., deep representation learning via unlabeled data. The basic idea
[18, 31, 39, 42]), and academia ([8, 32]). Compared to conventional is to enhance training data with various data augmentations, and
approaches like matrix factorization [1, 21, 22], gradient boosted supervised tasks to predict or reconstitute the original examples as
decision trees [4, 29], and logistic regression based recommenders auxiliary tasks. Self-supervised learning has been widely used in the
[19], these deep models handle categorical features more effectively. areas of Compute Vision (CV) [15, 25, 33] and Natural Language Un-
They also enable more complex data representations, and introduce derstanding (NLU) [12, 24]. An example work [25] in CV proposed
more non-linearity to better fit the complex data for recommenders. to rotate images at random, and train a model to predict how each
A particular recommendation task we focus on in this paper is augmented input image was rotated. In NLU, masked language task
to identify the most relevant items given a query from a huge item was introduced in the BERT model, to help improve pre-training of
catalog. This general problem of large-scale item recommendations language models. Similarly, other pre-training tasks like predicting
has been widely adopted in various applications. Depending on the surrounding sentences and linked sentences in Wikipedia articles
type of the query, a recommendation task could be: (i) personal- have also been used in improving dual-encoder type models in NLU
ized recommendation: when the query is a user; (ii) item to item [3]. Compared to conventional supervised learning, self-supervised
recommendation: when the query is also an item; and (iii) search: learning provides complementary objectives eliminating the pre-
when the query is a piece of free text. To model the interactions requisite of collecting labels manually. In addition, SSL enables
between a query and an item, a well-known approach leverages
supervised loss
autonomous discovery of good semantic representations by exploit-
ing the internal relationship of input features.
Despite the wide adoption in computer vision and natural lan- embedding embedding
• Dropout. For categorical features with multiple values, we drop Heterogeneous Sample Distributions. The marginal item distribu-
out each value with a certain probability. It further reduces input tion from D𝑡𝑟𝑎𝑖𝑛 typically follows a power-law. Therefore, using
information and increase the hardness of SSL task. the training item distribution for L𝑠𝑒𝑙 𝑓 would cause the learned
The masking step can be interpreted as a special case of dropout feature relationship to be biased towards head items. Instead, we
with a 100% dropout rate. One strategy is the complementary mask- sample items uniformly from the corpus for L𝑠𝑒𝑙 𝑓 . In other words,
ing pattern, that we split the feature set into two exclusive features D𝑖𝑡𝑒𝑚 is the uniform item distribution. In practice, we find using
sets into the two augmented examples. Specifically, we could ran- the heterogeneous distributions for main and ssl tasks is critical for
domly split the feature set into two disjoint subsets. We call this SSL to achieve superior performance.
method Random Feature Masking (RFM), and will use it as one
Loss for Main Task. There could be many choices for the main
of our baselines. We now introduce Correlated Feature Masking
loss depending on the objectives. In this paper, we consider the
(CFM) where we further explore the feature correlations when
batch softmax loss used in both recommenders [39] and NLP [16] for
creating masking patterns.
optimizing top-k accuracy. In detail, let q𝑖 , x𝑖 be the embeddings of
Mutual Information of Categorical Features. If the set of masked query and item examples (𝑞𝑖 , 𝑥𝑖 ) after being encoded by two neural
features are chosen at random, (ℎ, 𝑔) are essentially sampled from networks, then for a batch of pairs {(𝑞𝑖 , 𝑥𝑖 )}𝑖=1
𝑁 and temperature 𝜏,
2𝑘 different masking patterns over the whole feature set with 𝑘 the batch softmax cross entropy loss is
features. Different masking patterns would naturally lead to dif-
1 ∑︁ exp (𝑠 (q𝑖 , x𝑖 )/𝜏)
ferent effects for the SSL task. For instance, the SSL contrastive L𝑚𝑎𝑖𝑛 = − log Í . (6)
learning task may exploit the shortcut of highly correlated features 𝑁
𝑖 ∈ [𝑁 ] 𝑗 ∈ [𝑁 ] exp (𝑠 (q𝑖 , x 𝑗 )/𝜏)
4
supervised loss self-supervised loss
query features item features augmented item features augmented item features
Figure 3: Model architecture: Two-tower model with SSL. In the SSL task, we apply feature masking and dropout on the item
features to learn item embeddings. The whole item tower (in red) is shared with the supervised task.
Other Baselines. As mentioned in Section 2, we use two-tower training and evaluation using a (90%, 10%) split, following the same
DNNs as the baseline model for main task. Two-tower model has treatment in [23] and [39].
the unique property of encoding item features compared to classic App-to-App Install (AAI): The AAI dataset was collected on
matrix factorization (MF) and classification models. While the latter the app landing pages from a commercial mobile app store. On a
two methods are also applicable to large-scale item retrieval, they particular app’s (seed app) landing page, the app installs (candidate
only learn item embeddings based on IDs, and thus do not fit in our apps) from the section of recommended apps were collected. Each
proposal of using SSL for exploiting item feature relations. training example represents a pair of seed-candidate pairs denoted
as (𝑥𝑠𝑒𝑒𝑑 , 𝑥𝑐𝑎𝑛𝑑𝑖𝑑𝑎𝑡𝑒 ) and their metadata features. The goal is to
4 OFFLINE EXPERIMENTS recommend highly similar apps given a seed app. This is also formu-
We provide empirical results to demonstrate the effectiveness of lated as an item-to-item recommendation problem via a multi-class
our proposed self-supervised framework both in academic public classification loss. Note that we only collect positive examples, i.e.,
dataset and in actual large-scale recommendation products. The ex- 𝑥𝑐𝑎𝑛𝑑𝑖𝑑𝑎𝑡𝑒 is an installed app from the landing page of 𝑥𝑠𝑒𝑒𝑑 . All the
periments are designed to answer the following research questions. impressed recommended apps with no installs are all ignored since
we consider them more like weak positives instead of negatives
• RQ1: Does the proposed SSL Framework improve deep models
for building retrieval models. Each item (app) is represented by a
for recommendations?
feature vector x with the following features:
• RQ2: SSL is designed to improve primary supervised task through
• id: Application id as a one-hot categorical feature.
introduced SSL task on unlabeled examples. What is the impact
of the amount of training data on the improvement from SSL? • developer_name: Name of the app developer as a one-hot cate-
gorical feature.
• RQ3: How do the SSL parameters, i.e., loss multiplier 𝛼 and
dropout rate in data augmentation, affect model quality? • categories: Semantic categories of the app as a multi-hot cate-
gorical feature.
• RQ4: How does RFM perform compared to CFM? What is the
benefit of leveraging feature correlations in data augmentation? • title_unigram: Uni-grams of the app title as a multi-hot categor-
ical feature.
The above questions are addressed in order from Section 4.3 - 4.5.
4.2 Experiment Setup
4.1 Datasets
Backbone Network. For the main task that predicts relevant items
We conduct experiments on two large-scale datasets that both come given the query, we use the two-tower DNN to encode query and
with a rich set of item metadata features. We formulate their primary items features (see Figure 1) as the backbone network. The item-
supervised task as an item-to-item recommendation problem to to-item recommendation problem is formalized as a multi-class
study the effects of SSL on training recommender (in this case, classification problem, using the batch softmax loss presented in
retrieval) models. See Appendix .1 for details about the statistics of Equation (6) as the loss function. For discussions of the choice
these two datasets. of backbone network, we refer the readers to related sections in
Wikipedia [14]: The first dataset focuses on the problem of Section 2 and Section 3.3.
link prediction between Wikipedia pages. It consists of pairs of
pages (𝑥, 𝑦) ∈ 𝜒 × 𝜒, where 𝑥 indicates a source page, and 𝑦 is a Hyper-parameters. For the backbone two-tower DNN, we search
destination page linked from 𝑥. The goal is to predict the set of pages the set of hyper-parameters such as the learning rate, softmax tem-
that are likely to be linked to a given source page from the whole perature (𝜏) and model architecture that gives the highest Recall@50
corpus of web pages. Each page is represented by a feature vector on the validation set. Note that the training batch size in batch soft-
𝑥 = (𝑥𝑖𝑑 , 𝑥𝑛𝑔𝑟𝑎𝑚𝑠 , 𝑥𝑐𝑎𝑡𝑠 ), where all the features are categorical. max is critical for model quality as it determines the number of
Here, 𝑥𝑖𝑑 denotes the one-hot encoding of the page URL, 𝑥𝑛𝑔𝑟𝑎𝑚𝑠 negatives used for each positive item. Throughout this section, we
denotes a bag-of-words representation of the set of n-grams of the use batch sizes 1024 and 4096 for Wikipedia and AAI respectively.
page’s title, and 𝑥𝑐𝑎𝑡𝑠 denotes a bag-of-words representation of the We also tuned the number of hidden layers, hidden layer sizes
categories that the page belongs to. We partitioned the dataset into and softmax temperature 𝜏 for the baseline models. For Wikipedia
5
dataset, we use softmax temperature 𝜏 = 0.07, and ℎ𝑖𝑑𝑑𝑒𝑛_𝑙𝑎𝑦𝑒𝑟𝑠 Wikipedia
with sizes [1024, 128]. For AAI, we use 𝜏 = 0.06 and ℎ𝑖𝑑𝑑𝑒𝑛_𝑙𝑎𝑦𝑒𝑟𝑠 Method MAP@10 MAP@50 Recall@10 Recall@50
[1024, 256]. Note that the dimension of last hidden layer is also
Baseline 0.0171 0.0229 0.0537 0.1930
the dimension of final query and item embeddings. All models are
FD [35] 0.0172 0.0229 0.0535 0.1912
trained with Adagrad [13] optimizer with learning rate 0.01.
SO [41] 0.0176 0.0235 0.0549 0.1956
We consider two SSL parameters: 1) the SSL loss multiplier 𝛼
Our method 0.0183 0.0243 0.057 0.2009
in equation (5), and 2) the feature dropout rate, denoted as 𝑑𝑟 ,
in the second phase of data augmentation (see Section 3.2). For AAI
each augmentation method (e.g., CFM, RFM), we conduct grid Method MAP@10 MAP@50 Recall@10 Recall@50
search of the two parameters by ranges 𝛼 = [0.1, 0.3, 1.0, 3.0], 𝑑𝑟 = Baseline 0.1257 0.1363 0.2793 0.4983
[0.1, 0.2, ..., 0.9], and report the best result. FD [35] 0.1278 0.1384 0.2840 0.5058
Evaluation. To evaluate the recommendation performance given SO [41] 0.1300 0.1406 0.2870 0.5076
a seed item, we compute and find the top 𝐾 items with the highest Our method 0.1413 0.1522 0.3078 0.5355
cosine similarity from the whole corpus and evaluate the quality Table 1: Results on the full Wikipedia and AAI dataset.
based on the 𝐾 retrieved items. Note this is a relatively challenging
task, given the sparsity of the dataset and large number of items
in the corpus. We adopt popular standard metrics 𝑅𝑒𝑐𝑎𝑙𝑙@𝐾 and proposed SSL method is reduced to SO. By comparing CFM and FD,
mean average precision (𝑀𝐴𝑃@𝐾) to evaluate recommendation we find the feature augmentation is more effective when applied to
performance [18]. For each configuration of experiment results, we the SSL task than to the supervised task as a standard regularization
ran the experiment 5 times and report the average. technique. Note that FD, as a well known approach for improving
generalization in some cases, applies feature augmentation together
4.3 Effectiveness of SSL with Correlated with supervised training.
Feature Masking Head-tail Analysis. To understand the gain from SSL, we further
To answer RQ1, we first evaluate the impact of SSL on model break down the overall performance by looking at different item
quality. We focus on using CFM followed by dropout as the data slices by item popularity. The splitting of the head and tail test
augmentation technique. We will show the superior performance set is described in the appendix .2. Our hypothesis is that SSL
of CFM over other variants in Section 4.5. generally helps improve the performance for slices without much
We consider three baseline methods: supervision (e.g., tail items). The results evaluated on the tail and
• Baseline: Vanilla backbone network with the two-tower DNN head test sets are reported the results in Table 3. We observe that
architecture. the proposed SSL methods improve the performance for both head
and tail item recommendations, with larger gains from the tail
• Feature Dropout (FD) [35]: Backbone model with random feature items. For instance, in AAI, the CFM improves over 51.5% of the
dropout on the item tower in the supervised learning task. The Recall@10 on tail items, while the improvement is 8.57% on head.
feature dropout on item features could be treated as data aug-
mentation. FD does not have the additional SSL regularization Effects of SSL Parameters (RQ3). Figure 5 summarizes the Re-
compared to our approach. call@50 evaluated on the Wikipedia and AAI dataset w.r.t. the
• Spread-out Regularization (SO) [41]: Backbone model with spread- regularization strength 𝛼. It also shows the results of SO which
out regularization on the item tower as a regularization. The SO shares the same regularization parameter. We observe that with
regularization shares similar contrastive loss as that in our SSL increasing 𝛼, the model performance is worse than the baseline
framework. However, it applies contrastive learning on original model (shown in dash line) after certain threshold. This is expected,
examples without any data augmentation, and is thus different since large SSL weight 𝛼 leads to the multitask loss L dominated
from our approach. by 𝛼 · L𝑠𝑒𝑙 𝑓 in equation (5). By further comparing our approach
with SO, we show that the SSL based regularization outperforms
The latter two methods are chosen since they are (1) model-agnostic
SO in a wide range of 𝛼. Figure 6 shows the model performance
and scalable for industrial-size recommendation systems; (2) com-
across different dropout rates 𝑑𝑟 . It also shows 𝐷𝑂 which shares
patible with categorical sparse features for improving generaliza-
the same parameter. As 𝑑𝑟 increases, the model performance of 𝐷𝑂
tion. In addition, FD can be viewed as an ablation study to isolate
continues to deteriorate. For most choices of 𝛼 (except 𝛼 = 0.1), 𝐷𝑂
the potential improvement from contrastive learning. Similarly, SO
is worse than the baseline. For the SSL task with feature dropout,
is included to isolate the improvement from feature augmentation.
the model performance peaks when 𝑑𝑟 = 0.3 and then deteriorates
We observe that with full datasets (see Table 1), CFM consis-
when we further improve dropout rate. The model starts to under-
tently performs the best compared with non-SSL regularization
perform the baseline when 𝑑𝑟 is too large. This observation aligns
techniques. On AAI, CFM out-performs the next best method by
with our expectation in the sense that when the rate is too large,
8.69% relatively and on AAI by 3.98%. This helps answer RQ1 that
the input information becomes too little for to learn meaningful
the proposed SSL framework and tasks indeed improves model per-
representations through SSL.
formance for recommenders. By comparing CFM with SO, it shows
that the data augmentation is critical for the SSL regularization Visualization of Item Representations. We visualize the learned
to have better performance. Without any data augmentation, the app embeddings from the AAI dataset in t-SNE plot. We postpone
6
10% Wikipedia Dataset
Method MAP@10 MAP@50 Recall@10 Recall@50
Baseline 0.0077 0.0105 0.0237 0.0924
FD [35] 0.0089 0.0120 0.0272 0.1046
SO [41] 0.0083 0.0112 0.0254 0.0978
Our method 0.0093 0.0126 0.0286 0.1093
10% AAI Dataset
Method MAP@10 MAP@50 Recall@10 Recall@50
Baseline 0.1112 0.1194 0.2406 0.4068
(a) Baseline Model FD [35] 0.1217 0.1302 0.2611 0.4324
SO [41] 0.1220 0.1308 0.2632 0.4400
Our method 0.1409 0.1507 0.3024 0.5014
Table 2: Experiment results trained on the sparse (10% down-
sampled) Wikipedia and AAI datasets.
Wikipedia
Method Tail Head
Recall@10 Recall@50 Recall@10 Recall@50
Baseline 0.0472 0.1621 0.0610 0.2273
(b) Best SSL Model
FD 0.0474 0.1638 0.0593 0.2212
SO 0.0481 0.1644 0.0606 0.2268
Figure 4: Comparison of t-SNE plots for app embeddings for
Our method 0.0524 0.1749 0.0619 0.2283
baseline, and our best SSL model.
AAI
Baseline 0.0475 0.2333 0.2846 0.4993
the detailed setup to Appendix .3. As shown in in Figure 4, we clearly FD 0.0727 0.2743 0.2849 0.5069
see that apps embeddings learned with our SSL regularization are SO 0.0661 0.2602 0.2879 0.5086
better clustered according to their own categories, compared to the
Our method 0.0720 0.2906 0.309 0.537
counter parts of our baseline, which demonstrates that representa-
tions learned through SSL have stronger semantic structures. This Table 3: Results of Wikipedia and AAI on tail and head item
partially explains the gain from SSL. slices.
Dataset # queries # items # examples .3 Visualization of Learned Embeddings. Besides better model per-
Wikipedia 5.3M 5.3M 490M formance, we expect the representations learned with SSL to have
AAI 2.4M 2.4M 1B better quality than the counterparts without SSL. To verify our
Table 5: Corpus sizes of the Wikipedia and the AAI datasets. hypothesis, we take the app embeddings learned in the models
trained on AAI dataset, and plot them using t-SNE plot in Figure 4.
Apps from different categories are plotted in different colors, as
illustrated in the legends in Figure 4. Compared to the apps in (Fig-
A APPENDIX ure 4a), the apps in the best SSL model (Figure 4b) tend to group
.1 Dataset Statistics. Table 5 shows some basic stats for the Wikipedia much better with similar apps in the same category, and the sepa-
and AAI datasets. Figure 9 shows the CDF of most frequent items ration of different category looks much more clear. For example,
for the two datasets, indicating a highly skewed data distribution. we could see that the “Sports & Recreation” apps (in red) are mixed
For example, the top 50 items in the AAI dataset collectively ap- together with “Law & Government” and “Travel” apps in Figure 4a.
peared roughly 10% in the training data. If we consider a naive While in Figure 4b, we clearly see the 4 categories of apps grouped
baseline (i.e., TopPopular recommender [11]) that recommends the together among themselves. This indicates that the representations
most frequent top-K items for every query, the 𝐶𝐷𝐹 of the 𝐾-th learned with SSL carry more semantic information, and is also why
frequent item essentially represents the 𝑅𝑒𝑐𝑎𝑙𝑙@𝐾 metric of such SSL leads to better model performance in our experiments.
10