Train 400x faster Static Embedding Models with Sentence Transformers
Train 400x faster Static Embedding Models with Sentence Transformers
Back to Articles
Update on GitHub
Upvote 72 +66
tomaarsen
Tom Aarsen
TL;DR
This blog post introduces a method to train static embedding models that run
100x to 400x faster on CPU than state-of-the-art embedding models, while
retaining most of the quality. This unlocks a lot of exciting use cases, including on-
device and in-browser execution, edge computing, low power and embedded
applications.
We apply this recipe to train two extremely efficient embedding models: sentence-
transformers/static-retrieval-mrl-en-v1 for English Retrieval, and sentence-
transformers/static-similarity-mrl-multilingual-v1 for Multilingual Similarity
tasks. These models are 100x to 400x faster on CPU than common counterparts
like all-mpnet-base-v2 and multilingual-e5-small, while reaching at least 85% of
their performance on various benchmarks.
The two models (for English retrieval and for multilingual similarity)
mentioned above.
Two Weights and Biases reports with training and evaluation metrics
collected during training.
The detailed list of datasets we used: 30 for training and 13 for evaluation.
TL;DR
Table of Contents
Modern Embeddings
Static Embeddings
Our Method
Training Details
Training Requirements
Model Inspiration
English Retrieval
Multilingual Similarity
English Retrieval
Multilingual Similarity
Code
Code
Code
Code
Evaluator Selection
Code
Hardware Details
English Retrieval
Multilingual Similarity
Usage
English Retrieval
Multilingual Similarity
LangChain
LlamaIndex
Haystack
txtai
Performance
English Retrieval
NanoBEIR
GPU
CPU
Matryoshka Evaluation
Multilingual Similarity
Matryoshka Evaluation
Conclusion
Next Steps
Embeddings are one of the most versatile tools in natural language processing,
enabling practitioners to solve a large variety of tasks. In essence, an embedding is
a numerical representation of a more complex object, like text, images, audio, etc.
The embedding model will always produce embeddings of the same fixed size.
You can then compute the similarity of complex objects by computing the
similarity of the respective embeddings.
This has a large amount of use cases, and serves as the backbone for
recommendation systems, retrieval, outlier detection, one-shot or few-shot
learning, similarity search, clustering, paraphrase detection, classification, and
much more.
Modern Embeddings
Many of today's embedding models consist of a handful of conversion steps.
Following these steps is called "inference".
The Tokenizer and Pooler are responsible for pre- and post-processing for the
Encoder , respectively. The former chops texts up into tokens (a.k.a. words or
subwords) which can be understood by the Encoder , whereas the latter combines
the embeddings for all tokens into one embedding for the entire text.
Within this pipeline, the Encoder is often a language model with attention layers,
which allows each token to be computed within the context of the other tokens.
For example, bank might be a token, but the token embedding for that token will
likely be different if the text refers to a "river bank" or the financial institution.
Large encoder models with a lot of attention layers will be effective at using the
context to produce useful embeddings, but they do so at a high price of slow
inference. Notably, in the pipeline, the Encoder step is generally responsible for
almost all of the computational time.
Static Embeddings
Static Embeddings refers to a group of Encoder models that don't use large and
slow attention-based models, but instead rely on pre-computed token
embeddings. Static embeddings were used years before the transformer
architecture was developed. Common examples include GLoVe and word2vec.
Recently, Model2Vec has been used to convert pre-trained embedding models into
Static Embedding models.
For Static Embeddings, the Encoder step is as simple as a dictionary lookup: given
the token, return the pre-computed token embedding. Consequently, inference is
suddenly no longer bottlenecked by the Encoder phase, resulting in speedups of
several orders of magnitude. This blogpost shows that the hit on quality can be
quite small!
Our Method
We set out to revisit Static Embeddings models, using modern techniques to train
them. Most of our gains come from the use of a contrastive learning loss function,
as we'll explain shortly. Optionally, we can get additional speed improvements by
using Matryoshka Representation Learning, which makes it possible to use
truncated versions of the embedding vectors.
We'll be using the Sentence Transformers library for training. For a more general
overview on how this library can be used to train embedding models, consider
reading the Training and Finetuning Embedding Models with Sentence
Transformers v3 blogpost or the Sentence Transformers Training Overview
documentation.
Training Details
For future research, we leave various other modern training approaches for
improving data quality. See Next Steps for concrete ideas.
Training Requirements
1. Dataset
2. Loss Function
4. Evaluator (Optional)
5. Trainer
In the following sections, we'll go through our thought processes for each of these.
Model Inspiration
In our experience, embedding models are either used 1) exclusively for retrieval
or 2) for every task under the sun (classification, clustering, semantic textual
similarity, etc.). We set out to train one of each.
For the retrieval model, there is only a limited amount of multilingual retrieval
training data available, and hence we chose to opt for an English-only model. In
contrast, we decided to train a multilingual general similarity model because
multilingual data was much easier to acquire for this task.
For these models, we would like to use the StaticEmbedding module, which
implements an efficient tokenize method that avoids padding, and an efficient
forward method that takes care of computing and pooling embeddings. It's as
English Retrieval
tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)
model = SentenceTransformer(modules=[static_embedding])
The first entry in the modules list must implement tokenize , and the last one
must produce pooled embeddings. Both is the case here, so we're good to start
training this model.
Multilingual Similarity
initialization code:
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-multilin
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)
model = SentenceTransformer(modules=[static_embedding])
English Retrieval
For the English Retrieval datasets, we are primarily looking for any dataset with:
gooaq
squad
paq
trivia_qa
msmarco_10m
Multilingual Similarity
parallel sentences across languages, i.e. the same text in multiple languages,
or
positive pairs, i.e. pairs with high similarity, optionally with negatives (i.e.
low similarity).
wikititles
tatoeba
talks
europarl
global_voices
muse
wikimatrix
opensubtitles
wikianswers_duplicates
simple_wiki
altlex
flickr30k_captions
coco_captions
nli_for_simcse
negation
Code
print(gooaq_train_dataset)
"""
Dataset({
features: ['question', 'answer'],
num_rows: 3002496
})
"""
print(gooaq_eval_dataset)
"""
Dataset({
features: ['question', 'answer'],
num_rows: 10000
})
"""
The gooaq dataset doesn't already have a train-eval split, so we can make one with
train_test_split. Otherwise, we can just load a precomputed split with e.g.
split="eval" .
Note that train_test_split does mean that the dataset has to be loaded into
memory, whereas it is otherwise just kept on disk. This increased memory is not
ideal when training, so it's recommended to 1) load the data, 2) split it, and 3)
save it to disk with save_to_disk. Before training, you can then use
load_from_disk to load it again.
Within Sentence Transformers, your loss model must match your training data
format. The Loss Overview is designed as an overview of which losses are
compatible with which formats.
This loss is recommended over MNRL unless you can already fit a large
enough batch size in memory with just MNRL. In that case, you can use
MNRL to save the 20% training speed cost that CMNRL adds.
False negatives can hurt performance, but hard true negatives (texts that are
close to correct, but not quite) can help performance, so this filtering is a fine
line to walk.
Because these static embedding models are extremely small, it is possible to fit
our desired batch size of 2048 samples on our hardware: a single RTX 3090 with
24GB, so we don't need to use CMNRL.
Additionally, because we're training such fast models, the guide from the
GISTEmbedLoss would make the training much slower. Because of this, we've
Code
A very interesting one is the MatryoshkaLoss, which turns the trained model into
a Matryoshka Model. This allows users to truncate the output embeddings at a
minimal loss of performance, meaning that retrieval or clustering can be sped up
due to the smaller dimensionalities.
Code
num_train_epochs : 1
We have sufficient data, should we want to train for more, then we can
add more data instead of training with the same data multiple times.
per_device_train_batch_size / per_device_eval_batch_size : 2048
2048 dimensions fit comfortably on our RTX 3090. Various papers (Xiao
et al., Li et al.) show that even larger batch sizes still improve
performance. For future versions, we will apply
CachedMultipleNegativesRankingLoss with a larger batch size, e.g.
16384.
learning_rate : 2e-1
Note! This is much larger than with normal embedding model training,
which often uses a loss around 2e-5.
warmup_ratio : 0.1
bf16 : True
If your GPU(s) support(s) bf16 - it tends to make sense to train with it.
Otherwise you can use fp16=True if that's supported instead.
batch_sampler : BatchSamplers.NO_DUPLICATES
All losses with in-batch negatives (such as MNRL) benefit from this
batch sampler that avoids duplicates within the batch. Duplicates often
result in false negatives, weakening the trained model.
multi_dataset_batch_sampler :
MultiDatasetBatchSamplers.PROPORTIONAL
When you're training with multiple datasets, it's common that not all
datasets are the same size. When that happens, you can either:
Round Robin: sample the same amount of batches from each
dataset until one is exhausted. You'll have an equal distribution of
data, but not all data will be used.
Proportional: sample each dataset until all are exhausted. You'll use
up all data, but you won't have an equal distribution of data. We
chose this one as we're not too concerned with a data imbalance.
Beyond these core arguments, we also set a few training arguments for tracking
and debugging: eval_strategy , eval_steps , save_strategy , save_steps ,
save_total_limit , logging_steps , logging_first_step , and run_name .
Code
run_name = "static-retrieval-mrl-en-v1"
# or
# run_name = "static-similarity-mrl-multilingual-v1"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=2048,
per_device_eval_batch_size=2048,
learning_rate=2e-1,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRa
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONA
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=1000,
logging_first_step=True,
run_name=run_name, # Used if `wandb`, `tensorboard`, or `neptune
)
Evaluator Selection
Due to its simplicity, we will be using the NanoBEIREvaluator for the retrieval
model. This evaluator runs Information Retrieval benchmarks on the NanoBEIR
collection of datasets. This dataset is a subset of the much larger (and thus
slower) BEIR benchmark, which is commonly used as the Retrieval tab in the
MTEB Leaderboard.
Code
Because all datasets are already pre-defined, we can load the evaluator without
any arguments:
Hardware Details
CPU: i7-13700K
RAM: 32GB
Overall Training Scripts
This section contains the final training scripts for both models with all of the
previously described components (datasets, loss functions, training arguments,
evaluator, trainer) combined.
English Retrieval
Click to expand
See our Weights and Biases report for the training and evaluation metrics
collected during training.
Multilingual Similarity
Click to expand
See our Weights and Biases report for the training and evaluation losses collected
during training.
Usage
English Retrieval
The upcoming Performance > English Retrieval section will show that these
results are quite solid, within 15% of commonly used Transformer-based encoder
models like all-mpnet-base-v2.
SentenceTransformer API Reference.
Multilingual Similarity
This model only loses about 8% of performance compared to the popular but
much slower multilingual-e5-small, as shown in the upcoming Performance >
Multilingual Similarity section.
To reduce the dimensionality of your calculated embeddings, you can simply pass
the truncate_dim parameter. This works for all Sentence Transformer models.
This model also works out of the box in various third party libraries, for example
LangChain, LlamaIndex, Haystack, and txtai.
LangChain
model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
model_kwargs = {'device': 'cpu'} # you can use 'truncate_dim' here
model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
)
HuggingFaceEmbeddings documentation.
LlamaIndex
Haystack
model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
device = "cpu"
document_embedder = SentenceTransformersDocumentEmbedder(
model=model_name,
device=device,
# truncate_dim=256, # you can use 'truncate_dim' here
)
text_embedder = SentenceTransformersTextEmbedder(
model=model_name,
device=device,
# truncate_dim=256, # you can use 'truncate_dim' here
)
SentenceTransformersDocumentEmbedder documentation.
SentenceTransformersTextEmbedder documentation.
txtai
model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
embeddings = Embeddings(path=model_name)
Embeddings documentation
Performance
English Retrieval
NanoBEIR
We've evaluated sentence-transformers/static-retrieval-mrl-en-v1 on NanoBEIR
and plotted it against the inference speed computed on our hardware. For the
inference speed tests, we calculated the number of computed query embeddings
of the GooAQ dataset per second, either on CPU or GPU.
“NOTE: Many of the attention-based dense embedding models are finetuned on the
training splits of the (Nano)BEIR evaluation datasets. This gives the models an unfair
advantage in this benchmark and can result in lower downstream performance on real
retrieval tasks.
Click to see a table with all values from the next 2 Figures
GPU
CPU
We can draw some notable conclusions from these figures:
3. static-retrieval-mrl-en-v1 is
These findings show that reducing the dimensionality by e.g. 2x only has a 1.47%
reduction in performance (0.5031 NDCG@10 vs 0.4957 NDCG@10), while
realistically resulting in a 2x speedup in retrieval speed.
Multilingual Similarity
Matryoshka Evaluation
As you can see, you can easily reduce the dimensionality by 2x or 4x with minor
(0.15% or 0.56%) performance hits. If the speed of your downstream task or your
storage costs are a bottleneck, this should allow you to alleviate some of those
concerns.
Conclusion
This blogpost described all of the steps that we undertook from ideation to
finished models, in addition to details regarding usage and evaluation of the two
resulting models: static-retrieval-mrl-en-v1 and static-similarity-mrl-multilingual-
v1.
Should you need an efficient CPU-only dense embedding model for your retrieval
or similarity tasks, then static-retrieval-mrl-en-v1 and static-similarity-mrl-
multilingual-v1 will be extremely performant solutions at minimal costs that get
surprisingly close to the attention-based dense models.
Next Steps
Try it out! If you already use a Sentence Transformer model somewhere, feel free
to swap it out for static-retrieval-mrl-en-v1 or static-similarity-mrl-multilingual-
v1. Or, better yet: train your own models on data that is representative for the task
and language of your interest.
More experiments are required to determine what a good cutoff point is. For
now, we leave the maximum sequence length, chunking, etc. to the user.
Additionally, there are quite a few possible extensions that are likely to improve
the performance of this model, which we happily leave to other model authors.
We are also open to collaborations:
1. Hard Negatives Mining: Search for similar, but not quite relevant, texts to
improve training data difficulty.
Acknowledgements
I would like to thank Stéphan Tulkens and Thomas van Dongen of The Minish Lab
for bringing Static Embedding models to my attention via their Model2Vec work.
Additionally, I would like to thank Vaibhav Srivastav and Pedro Cuenca for their
assistance with this blogpost, and Antoine Chaffin for brainstorming the release
checkpoints.
Community
+ Reply
NickyNicky/StaticEmbedding-MatryoshkaLoss-gemma-2-2b-en-es
NickyNicky/StaticEmbedding-MatryoshkaLoss-gemma-2-2b-gooaq-en
I would like to know how to increase or decrease the
'max_length example 371'
4 replies · 🔥 3 +
Hello!
Nice work on those models! Am I correct in understanding that one of those models
reaches 0.5623 NDCG@10 on NanoBEIR across all datasets? That's a pretty huge jump
from the 0.5032 NDCG@10 for static-retrieval-mrl-en-v1.
That is simply some approximate statistics on the training data; taken from the first
1000 samples. Although it's not always recommended to use texts with (much) larger
sequence lengths than the training data, the actual maximum sequence length is
indeed infinity. It is defined here: https://github.com/UKPLab/sentence-
transformers/blob/cccab8303aaf6e18f069b0da578b3d162bf8442a/sentence_transfor
mers/models/StaticEmbedding.py#L106-L108
In short: the model will never truncate sequences, because the approach
1. has linear complexity (2x more data -> 2x slower) unlike Transformer models (2x
more data -> (much) slower than 2x).
So, Static Models don't have a maximum sequence length. They just require care by
the user to make sure that they're not feeding documents that are too large, as all
documents will eventually embed very similarly if they are long enough.
Tom Aarsen
Expand 3 replies
Reply in thread
This is really cool! I'm surprised you do better than model2vec - is the difference really
just the use of a (better) contrastive loss pretraining formula?
5 replies · 🧠 1 +
Yes! The architecture is identical. In fact, the StaticEmbedding module that is used
for the models described in this blogpost is actually the same that is used when
loading a Model2Vec model in Sentence Transformers:
# Pre-distilled embeddings:
static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_b
model = SentenceTransformer(modules=[static_embedding])
StaticEmbedding docs
Expand 4 replies
Reply in thread
❤️ 1 + Reply
Edit Preview
Start discussing this article
Comment
System theme
Company
TOS
Privacy
About
Jobs
Website
Models
Datasets
Spaces
Pricing
Docs