metadata
base_model: google/xtr-base-en
license: apache-2.0
tags:
- arxiv:2304.01982
XTR-ONNX
This model is google's XTR-base-en model exported to ONNX format.
original XTR model: https://huggingface.co/google/xtr-base-en
Given a max length input of 512, this model will output a 128 dimensional vector for each token.
XTR's demo notebook uses only one special token -- EOS.
Using this model
This model can be plugged into LintDB to index data into a database.
In LintDB
# create an XTR index
config = ldb.Configuration()
config.num_subquantizers = 64
config.dim = 128
config.nbits = 4
config.quantizer_type = ldb.IndexEncoding_XTR
index = ldb.IndexIVF(f"experiments/goog", config)
# build a collection on top of the index
opts = ldb.CollectionOptions()
opts.model_file = "assets/xtr/encoder.onnx"
opts.tokenizer_file = "assets/xtr/spiece.model"
collection = ldb.Collection(index, opts)
collection.train(chunks, 50, 10)
for i, snip in enumerate(chunks):
collection.add(0, i, snip, {'docid': f'{i}'})
Creating this model
In order to create this model, we had to combine XTR's T5 encoder model with a dense layer. Below is the code used to do this. Credit to yaman on Github for this solution.
from sentence_transformers import SentenceTransformer
from sentence_transformers import models
import torch
import torch.nn as nn
import onnx
import numpy as np
from transformers import T5EncoderModel
from pathlib import Path
from transformers import AutoTokenizer
# https://github.com/huggingface/optimum/issues/1519
class CombinedModel(nn.Module):
def __init__(self, transformer_model, dense_model):
super(CombinedModel, self).__init__()
self.transformer = transformer_model
self.dense = dense_model
def forward(self, input_ids, attention_mask):
outputs = self.transformer(input_ids, attention_mask=attention_mask)
token_embeddings = outputs['last_hidden_state']
return self.dense({'sentence_embedding': token_embeddings})
save_directory = "onnx/"
# Load a model from transformers and export it to ONNX
tokenizer = AutoTokenizer.from_pretrained(path)
# load the t5 base encoder model.
transformer_model = T5EncoderModel.from_pretrained(path)
dense_model = models.Dense(
in_features=768,
out_features=128,
bias=False,
activation_function= nn.Identity()
)
state_dict = torch.load(os.path.join(path, '2_Dense', dense_filename))
dense_model.load_state_dict(state_dict)
model = CombinedModel(transformer_model, dense_model)
model.eval()
input_text = "Who founded google"
inputs = tokenizer(input_text, padding='longest', truncation=True, max_length=128, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
torch.onnx.export(
model,
(input_ids, attention_mask),
"combined_model.onnx",
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names = ['input_ids', 'attention_mask'],
output_names = ['contextual'],
dynamic_axes={
'input_ids': {0 : 'batch_size', 1: 'seq_length'}, # variable length axes
'attention_mask': {0 : 'batch_size', 1: 'seq_length'},
'contextual' : {0 : 'batch_size', 1: 'seq_length'}
}
)
onnx.checker.check_model("combined_model.onnx")
combined_model = onnx.load("combined_model.onnx")
import onnxruntime as ort
ort_session = ort.InferenceSession("combined_model.onnx")
output = ort_session.run(None, {'input_ids': input_ids.numpy(), 'attention_mask': attention_mask.numpy()})
# Run the PyTorch model
pytorch_output = model(input_ids, attention_mask)
print(pytorch_output['sentence_embedding'])
print(output[0])
# Compare the outputs
# print("Are the outputs close?", np.allclose(pytorch_output.detach().numpy(), output[0], atol=1e-6))
# Calculate the differences between the outputs
differences = pytorch_output['sentence_embedding'].detach().numpy() - output[0]
# Print the standard deviation of the differences
print("Standard deviation of the differences:", np.std(differences))
print("pytorch_output size:", pytorch_output['sentence_embedding'].size())
print("onnx_output size:", output[0].shape)