● Hugging Face
📅 15/01/2025 à 01:00
Train 400x faster Static Embedding Models with Sentence Transformers
Géopolitique
Back to Articles Train 400x faster Static Embedding Models with Sentence Transformers Published January 15, 2025 Update on GitHub Upvote 228 +222 Tom Aarsen tomaarsen Follow TL;DR This blog post introduces a method to train static embedding models that run 100x to 400x faster on CPU than state-of-the-art embedding models, while retaining most of the quality. This unlocks a lot of exciting use cases, including on-device and in-browser execution, edge computing, low power and embedded applications. We apply this recipe to train two extremely efficient embedding models: sentence-transformers/static-retrieval-mrl-en-v1 for English Retrieval, and sentence-transformers/static-similarity-mrl-multilingual-v1 for Multilingual Similarity tasks. These models are 100x to 400x faster on CPU than common counterparts like all-mpnet-base-v2 and multilingual-e5-small, while reaching at least 85% of their performance on various benchmarks. Today, we are releasing: The two models (for English retrieval and for multilingual similarity) mentioned above. The detailed training strategy we followed, from ideation to dataset selection to implementation and evaluation. Two training scripts, based on the open-source sentence transformers library. Two Weights and Biases reports with training and evaluation metrics collected during training. The detailed list of datasets we used: 30 for training and 13 for evaluation. We also discuss potential enhancements, and encourage the community to explore them and build on this work! Click to see Usage Snippets for the released models The usage of these models is very straightforward, identical to the normal Sentence Transformers flow: English Retrieval from sentence_transformers import SentenceTransformer # Download from the 🤗 Hub model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu") # Run inference sentences = [ 'Gadofosveset-enhanced MR angiography of carotid arteries: does steady-state imaging improve accuracy of first-pass imaging?', 'To evaluate the diagnostic accuracy of gadofosveset-enhanced magnetic resonance (MR) angiography in the assessment of carotid artery stenosis, with digital subtraction angiography (DSA) as the reference standard, and to determine the value of reading first-pass, steady-state, and "combined" (first-pass plus steady-state) MR angiograms.', 'In a longitudinal study we investigated in vivo alterations of CVO during neuroinflammation, applying Gadofluorine M- (Gf) enhanced magnetic resonance imaging (MRI) in experimental autoimmune encephalomyelitis, an animal model of multiple sclerosis. SJL/J mice were monitored by Gadopentate dimeglumine- (Gd-DTPA) and Gf-enhanced MRI after adoptive transfer of proteolipid-protein-specific T cells. Mean Gf intensity ratios were calculated individually for different CVO and correlated to the clinical disease course. Subsequently, the tissue distribution of fluorescence-labeled Gf as well as the extent of cellular inflammation was assessed in corresponding histological slices.', ] embeddings = model.encode(sentences) print(embeddings.shape) # [3, 1024] # Get the similarity scores for the embeddings similarities = model.similarity(embeddings[0], embeddings[1:]) print(similarities) # tensor([[0.7649, 0.3279]]) Multilingual Similarity from sentence_transformers import SentenceTransformer # Download from the 🤗 Hub model = SentenceTransformer("sentence-transformers/static-similarity-mrl-multilingual-v1", device="cpu") # Run inference sentences = [ 'It is known for its dry red chili powder.', 'It is popular for dried red chili powder.', 'These monsters will move in large groups.', ] embeddings = model.encode(sentences) print(embeddings.shape) # [3, 1024] # Get the similarity scores for the embeddings similarities = model.similarity(embeddings, embeddings) print(similarities) # tensor([[ 1.0000, 0.8388, -0.0012], # [ 0.8388, 1.0000, 0.0445], # [-0.0012, 0.0445, 1.0000]]) Table of Contents TL;DR Table of Contents What are Embeddings? Modern Embeddings Static Embeddings Our Method Training Details Training Requirements Model Inspiration English Retrieval Multilingual Similarity Training Dataset Selection English Retrieval Multilingual Similarity Code Loss Function Selection Code Matryoshka Representation Learning Code Training Arguments Selection Code Evaluator Selection Code Hardware Details Overall Training Scripts English Retrieval Multilingual Similarity Usage English Retrieval Multilingual Similarity Matryoshka Dimensionality Truncation Third Party libraries LangChain LlamaIndex Haystack txtai Performance English Retrieval NanoBEIR GPU CPU Matryoshka Evaluation Multilingual Similarity Matryoshka Evaluation Conclusion Next Steps What are Embeddings? Embeddings are one of the most versatile tools in natural language processing, enabling practitioners to solve a large variety of tasks. In essence, an embedding is a numerical representation of a more complex object, like text, images, audio, etc. The embedding model will always produce embeddings of the same fixed size. You can then compute the similarity of complex objects by computing the similarity of the respective embeddings. This has a large amount of use cases, and serves as the backbone for recommendation systems, retrieval, outlier detection, one-shot or few-shot learning, similarity search, clustering, paraphrase detection, classification, and much more. Modern Embeddings Many of today's embedding models consist of a handful of conversion steps. Following these steps is called "inference". The Tokenizer and Pooler are responsible for pre- and post-processing for the Encoder, respectively. The former chops texts up into tokens (a.k.a. words or subwords) which can be understood by the Encoder, whereas the latter combines the embeddings for all tokens into one embedding for the entire text. Within this pipeline, the Encoder is often a language model with attention layers, which allows each token to be computed within the context of the other tokens. For example, bank might be a token, but the token embedding for that token will likely be different if the text refers to a "river bank" or the financial institution. Large encoder models with a lot of attention layers will be effective at using the context to produce useful embeddings, but they do so at a high price of slow inference. Notably, in the pipeline, the Encoder step is generally responsible for almost all of the computational time. Static Embeddings Static Embeddings refers to a group of Encoder models that don't use large and slow attention-based models, but instead rely on pre-computed token embeddings. Static embeddings were used years before the transformer architecture was developed. Common examples include GLoVe and word2vec. Recently, Model2Vec has been used to convert pre-trained embedding models into Static Embedding models. For Static Embeddings, the Encoder step is as simple as a dictionary lookup: given the token, return the pre-computed token embedding. Consequently, inference is suddenly no longer bottlenecked by the Encoder phase, resulting in speedups of several orders of magnitude. This blogpost shows that the hit on quality can be quite small! Our Method We set out to revisit Static Embeddings models, using modern techniques to train them. Most of our gains come from the use of a contrastive learning loss function, as we'll explain shortly. Optionally, we can get additional speed improvements by using Matryoshka Representation Learning, which makes it possible to use truncated versions of the embedding vectors. We'll be using the Sentence Transformers library for training. For a more general overview on how this library can be used to train embedding models, consider reading the Training and Finetuning Embedding Models with Sentence Transformers v3 blogpost or the Sentence Transformers Training Overview documentation. Training Details The objective with these reimagined Static Embeddings is to experiment with modern embedding model finetuning techniques on these highly efficient embedding models. In particular, unlike GLoVe and word2vec, we will be using: Contrastive Learning: With most machine learning, you take input $X$ and expect output $Y$, and then train a model such that $X$ fed through the model produces something close to $Y$. For embedding models, we don't have $Y$: we don't know what a good embedding would be beforehand. Instead, with Contrastive Learning, we have multiple inputs $X_1$ and $X_2$, and a similarity. We feed both inputs through the model, after which we can contrast the two embeddings resulting in a predicted similarity. We can then push the embeddings further apart if the true similarity is low, or pull the embeddings closer together if the true similarity is high. Matryoshka Representation Learning (MRL): Matryoshka Embedding Models (blogpost) is a clever training approach that allows users to truncate embedding models to smaller dimensions at a minimal performance hit. It involves using the contrastive loss function not just with the normal-sized embedding, but also with truncated versions of them. Consequently, the model learns to store information primarily at the start of the embeddings. Truncated embeddings will be faster with downstream applications, such as retrieval, classification, and clustering. For future research, we leave various other modern training approaches for improving data quality. See Next Steps for concrete ideas. Training Requirements As shown in the Training Overview documentation in Sentence Transformers, training consists of 3 to 5 components: Dataset Loss Function Training Arguments (Optional) Evaluator (Optional) Trainer In the following sections, we'll go through our thought processes for each of these. Model Inspiration In our experience, embedding models are either used 1) exclusively for retrieval or 2) for every task under the sun (classification, clustering, semantic textual similarity, etc.). We set out to train one of each. For the retrieval model, there is only a limited amount of multilingual retrieval training data available, and hence we chose to opt for an English-only model. In contrast, we decided to train a multilingual general similarity model because multilingual data was much easier to acquire for this task. For these models, we would like to use the StaticEmbedding module, which implements an efficient tokenize method that avoids padding, and an efficient forward method that takes care of computing and pooling embeddings. It's as simple as using a torch EmbeddingBag, which is nothing more than an efficient Embedding (i.e. a lookup table for embeddings) with mean pooling. We can initialize it in a few ways: StaticEmbedding.from_model2vec to load a Model2Vec model, StaticEmbedding.from_distillation to perform Model2Vec-style distillation, or initializing it with a Tokenizer and an embedding dimension to get random weights. Based on our findings, the last option works best when fully training with a large amount of data. Matching common models like all-mpnet-base-v2 or bge-large-en-v1.5, we are choosing an embedding dimensionality of 1024, i.e. our embedding vectors consist of 1024 values each. English Retrieval For the English Retrieval model, we rely on the google-bert/bert-base-uncased tokenizer. As such, initializing the model looks like this: from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding from tokenizers import Tokenizer tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased") static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024) model = SentenceTransformer(modules=[static_embedding]) The first entry in the modules list must implement tokenize, and the last one must produce pooled embeddings. Both is the case here, so we're good to start training this model. Multilingual Similarity For the Multilingual Similarity model, we instead rely on the google-bert/bert-base-multilingual-uncased tokenizer, and that's the only thing we change in our initialization code: from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding from tokenizers import Tokenizer tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased") static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024) model = SentenceTransformer(modules=[static_embedding]) Training Dataset Selection Alongside dozens of Sentence Transformer models, the Sentence Transformers organization on Hugging Face also hosts 70+ datasets (at the time of writing): Embedding Model Datasets Beyond that, many datasets have been tagged with sentence-transformers to mark that they're useful for training embedding models: Datasets with the sentence-transformers tag English Retrieval For the English Retrieval datasets, we are primarily looking for any dataset with: question-answer pairs, optionally with negatives (i.e. wrong answers) as well, and no overlap with the BEIR benchmark, a.k.a. the Retrieval tab on MTEB. Our goal is to avoid training on these datasets so we can use MTEB as a 0-shot benchmark. We selected the following datasets: gooaq msmarco - the "triplet" subset squad s2orc - the "title-abstract-pair" subset allnli - the "triplet" subset paq trivia_qa msmarco_10m swim_ir - the "en" subset pubmedqa - the "triplet-20" subset miracl - the "en-triplet-all" subset mldr - the "en-triplet-all" subset mr_tydi - the "en-triplet-all" subset Multilingual Similarity For the Multilingual Similarity datasets, we aimed for datasets with: parallel sentences across languages, i.e. the same text in multiple languages, or positive pairs, i.e. pairs with high similarity, optionally with negatives (i.e. low similarity). We selected the following datasets as they contain parallel sentences: wikititles tatoeba talks europarl global_voices muse wikimatrix opensubtitles And these datasets as they contain positive pairs of some kind: stackexchange - the "post-post-pair" subset quora - the "triplet" subset wikianswers_duplicates all_nli - the "triplet" subset simple_wiki altlex flickr30k_captions coco_captions nli_for_simcse negation Code Loading these datasets is rather simple, e.g.: from datasets import load_dataset, Dataset gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train") gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12) gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"] gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"] print(gooaq_train_dataset) """ Dataset({ features: ['question', 'answer'], num_rows: 3002496 }) """ print(gooaq_eval_dataset) """ Dataset({ features: ['question', 'answer'], num_rows: 10000 }) """ The gooaq dataset doesn't already have a train-eval split, so we can make one with train_test_split. Otherwise, we can just load a precomputed split with e.g. split="eval". Note that train_test_split does mean that the dataset has to be loaded into memory, whereas it is otherwise just kept on disk. This increased memory is not ideal when training, so it's recommended to 1) load the data, 2) split it, and 3) save it to disk with save_to_disk. Before training, you can then use load_from_disk to load it again. Loss Function Selection Within Sentence Transformers, your loss model must match your training data format. The Loss Overview is designed as an overview of which losses are compatible with which formats. In particular, we currently have the following formats in our data: (anchor, positive) pair, no label (anchor, positive, negative) triplet, no label (anchor, positive, negative_1, ..., negative_n) tuples, no label For these formats, we have some excellent choices: MultipleNegativesRankingLoss (MNRL): Also known as in-batch negatives loss or InfoNCE loss, this loss has been used to train modern embedding models for a handful of years. In short, the loss optimizes the following: Given an anchor (e.g. a question), assign the highest similarity to the corresponding positive (i.e. answer) out of all positives and negatives (e.g. all answers) in the batch. If you provide the optional negatives, they will only be used as extra options (also known as in-batch negatives) from which the model must pick the correct positive. Within reason, the harder this "picking" is, the stronger the model will become. Because of this, higher batch sizes result in more in-batch negatives, which then increase performance (to a point). CachedMultipleNegativesRankingLoss (CMNRL): This is an extension of MNRL that implements GradCache, an approach that allows for arbitrarily increasing the batch size without increasing the memory. This loss is recommended over MNRL unless you can already fit a large enough batch size in memory with just MNRL. In that case, you can use MNRL to save the 20% training speed cost that CMNRL adds. GISTEmbedLoss (GIST): This is also an extension of MNRL, it uses a guide Sentence Transformer model to remove potential false negatives from the list of options that the model must "pick" the correct positive from. False negatives can hurt performance, but hard true negatives (texts that are close to correct, but not quite) can help performance, so this filtering is a fine line to walk. Because these static embedding models are extremely small, it is possible to fit our desired batch size of 2048 samples on our hardware: a single RTX 3090 with 24GB, so we don't need to use CMNRL. Additionally, because we're training such fast models, the guide from the GISTEmbedLoss would make the training much slower. Because of this, we've opted to use MultipleNegativesRankingLoss for our models. If we were to try these experiments again, we would pick a larger batch size, e.g. 16384 with CMNRL. If you try, please let us know how it goes! Code The usage is rather simple: from sentence_transformers import SentenceTransformer from sentence_transformers.losses import MultipleNegativesRankingLoss # Prepare a model to train tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased") static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024) model = SentenceTransformer(modules=[static_embedding]) # Initialize the MNRL loss given the model loss = MultipleNegativesRankingLoss(model) Matryoshka Representation Learning Beyond regular loss functions, Sentence Transformers also implements a handful of Loss modifiers. These work on top of standard loss functions, but apply them in different ways to try and instil useful properties into the trained embedding model. A very interesting one is the MatryoshkaLoss, which turns the trained model into a Matryoshka Model. This allows users to truncate the output embeddings at a minimal loss of performance, meaning that retrieval or clustering can be sped up due to the smaller dimensionalities. Code The MatryoshkaLoss is applied on top of a normal loss. It's recommended to also include the normal embedding dimensionality in the list of matryoshka_dims: from sentence_transformers import SentenceTransformer from sentence_transformers.losses import MultipleNegativesRankingLoss, MatryoshkaLoss # Prepare a model to train tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased") static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024) model = SentenceTransformer(modules=[static_embedding]) # Initialize the MNRL loss given the model base_loss = MultipleNegativesRankingLoss(model) loss = MatryoshkaLoss(model, base_loss, matryoshka_dims=[1024, 768, 512, 256, 128, 64, 32]) Training Arguments Selection Sentence Transformers supports a lot of training arguments, the most valuable of which have been listed in the Training Overview > Training Arguments documentation. We used the same core training parameters to train both models: num_train_epochs: 1 We have sufficient data, should we want to train for more, then we can add more data instead of training with the same data multiple times. per_device_tra
🔗 Lire l'article original
👁️ 2 lectures