license: apache-2.0
pipeline_tag: feature-extraction
tags:
- embedding
- text embedding
flan-ul2-text-encoder
THe encoder from flan-ul2. This model is 17.44 GB in bfloat16
precision.
basic usage
note: this is 'a way' of using the encoder, and not 'the only way'. suggestions and ideas welcome
This guide provides a set of functions to calculate the cosine similarity between the embeddings of different texts. The embeddings are calculated using a pre-trained model.
Functions
load_model_and_tokenizer
Details
This function loads the model and tokenizer based on the given model name. It returns a tuple containing the loaded model and tokenizer.
def load_model_and_tokenizer(model_name: str) -> Tuple[AutoModel, AutoTokenizer]:
"""
Load the model and tokenizer based on the given model name.
Args:
model_name (str): The name of the model to be loaded.
Returns:
Tuple[AutoModel, AutoTokenizer]: The loaded model and tokenizer.
"""
model = AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval() # Deactivate Dropout
return model, tokenizer
get_embeddings
This function gets the embeddings for the given texts using the provided model and tokenizer. It returns the calculated embeddings.
Details
def get_embeddings(model: AutoModel, tokenizer: AutoTokenizer, texts: List[str]) -> torch.Tensor:
"""
Get the embeddings for the given texts using the provided model and tokenizer.
Args:
model (AutoModel): The model to be used for getting embeddings.
tokenizer (AutoTokenizer): The tokenizer to be used for tokenizing the texts.
texts (List[str]): The texts for which embeddings are to be calculated.
Returns:
torch.Tensor: The calculated embeddings.
"""
# Tokenize input texts
batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# Get the embeddings
with torch.no_grad():
last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
# Get weights
weights = (
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float().to(last_hidden_state.device)
)
# Get attn mask
input_mask_expanded = (
batch_tokens["attention_mask"]
.unsqueeze(-1)
.expand(last_hidden_state.size())
.float()
)
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
embeddings = sum_embeddings / sum_mask
return embeddings
calculate_cosine_similarity
This function calculates and prints the cosine similarity between the first text and all other texts. It does not return anything.
click to expand
def calculate_cosine_similarity(embeddings: torch.Tensor, texts: List[str]) -> None:
"""
Calculate and print the cosine similarity between the first text and all other texts.
Args:
embeddings (torch.Tensor): The embeddings for the texts.
texts (List[str]): The texts for which cosine similarity is to be calculated.
"""
# Calculate cosine similarities
for i in range(1, len(embeddings)):
cosine_sim = 1 - cosine(embeddings[0], embeddings[i])
print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[i], cosine_sim))
Usage
To use these functions, you need to have the transformers
and scipy
libraries installed. You can install these with pip:
pip install transformers scipy
Then, you can use the functions in your Python code as needed. For example:
model_name = "pszemraj/flan-ul2-text-encoder"
model, tokenizer = load_model_and_tokenizer(model_name)
texts = [
"deep learning",
"artificial intelligence",
"deep diving",
"artificial snow",
]
embeddings = get_embeddings(model, tokenizer, texts)
calculate_cosine_similarity(embeddings, texts)
This will print the cosine similarity between the first text and all other texts in the texts
list.
Customization
You can customize the texts by modifying the texts
list. You can also use a different model by changing the model_name
variable.
References
This guide is based on the examples provided in the sGPT repository.