wordllama
Installation
Use the github repo or install via pip: https://github.com/dleemiller/WordLlama
pip install wordllama
Intended Use
This model is intended for use in natural language processing applications that require text embeddings, such as text classification, sentiment analysis, and document clustering. It's a token embedding model that is comparable to word embedding models, but substantionally smaller in size (16mb default 256-dim model).
from wordllama import load
wl = load()
similarity_score = wl.similarity("i went to the car", "i went to the pawn shop")
print(similarity_score) # Output: 0.06641249096796882
Model Architecture
Wordllama is based on token embedding codebooks extracted from large language models. It is trained like a general embedding, with MultipleNegativesRankingLoss using the sentence transformers library, using Matryoshka Representation Learning so that embeddings can be truncated to 64, 128, 256, 512 or 1024 dimensions.
To create WordLlama L2 "supercat", we extract and concatenate the token embedding codebooks from several large language models that use the llama2 tokenizer vocabulary (32k vocab size). This includes models like Llama2 70B and Phi-3 Medium. Then we add a trainable token weight parameter and initialize stopwords to a smaller value (0.1). Finally, we train a projection from the large, concatenated codebook down to a smaller dimension and average pool.
We use popular embeddings datasets from sentence transformers, and matryoshka representation learning (MRL) so that dimensions can be truncated. For "binary" models, we train using a straight through estimator, so that the embeddings can be binarized eg, (x>0).sign() and packed into integers for hamming distance computation.
After training, we save a new, small token embedding codebook, which is analogous to vectors of a word embedding.
MTEB Results (l2_supercat)
Metric | WL64 | WL128 | WL256 (X) | WL512 | WL1024 | GloVe 300d | Komninos | all-MiniLM-L6-v2 |
---|---|---|---|---|---|---|---|---|
Clustering | 30.27 | 32.20 | 33.25 | 33.40 | 33.62 | 27.73 | 26.57 | 42.35 |
Reranking | 50.38 | 51.52 | 52.03 | 52.32 | 52.39 | 43.29 | 44.75 | 58.04 |
Classification | 53.14 | 56.25 | 58.21 | 59.13 | 59.50 | 57.29 | 57.65 | 63.05 |
Pair Classification | 75.80 | 77.59 | 78.22 | 78.50 | 78.60 | 70.92 | 72.94 | 82.37 |
STS | 66.24 | 67.53 | 67.91 | 68.22 | 68.27 | 61.85 | 62.46 | 78.90 |
CQA DupStack | 18.76 | 22.54 | 24.12 | 24.59 | 24.83 | 15.47 | 16.79 | 41.32 |
SummEval | 30.79 | 29.99 | 30.99 | 29.56 | 29.39 | 28.87 | 30.49 | 30.81 |