XTR-onnx / README.md
mtbarta's picture
update base model
0dab98b
metadata
base_model: google/xtr-base-en
license: apache-2.0
tags:
  - arxiv:2304.01982

XTR-ONNX

This model is google's XTR-base-en model exported to ONNX format.

original XTR model: https://huggingface.co/google/xtr-base-en

Given a max length input of 512, this model will output a 128 dimensional vector for each token.

XTR's demo notebook uses only one special token -- EOS.

Using this model

This model can be plugged into LintDB to index data into a database.

In LintDB

# create an XTR index
config = ldb.Configuration()
config.num_subquantizers = 64
config.dim = 128
config.nbits = 4
config.quantizer_type = ldb.IndexEncoding_XTR
index = ldb.IndexIVF(f"experiments/goog", config)

# build a collection on top of the index
opts = ldb.CollectionOptions()
opts.model_file = "assets/xtr/encoder.onnx"
opts.tokenizer_file = "assets/xtr/spiece.model"

collection = ldb.Collection(index, opts)

collection.train(chunks, 50, 10)

for i, snip in enumerate(chunks):
    collection.add(0, i, snip, {'docid': f'{i}'})

Creating this model

In order to create this model, we had to combine XTR's T5 encoder model with a dense layer. Below is the code used to do this. Credit to yaman on Github for this solution.

from sentence_transformers import SentenceTransformer
from sentence_transformers import models
import torch
import torch.nn as nn
import onnx
import numpy as np
from transformers import T5EncoderModel
from pathlib import Path
from transformers import AutoTokenizer

# https://github.com/huggingface/optimum/issues/1519

class CombinedModel(nn.Module):
    def __init__(self, transformer_model, dense_model):
        super(CombinedModel, self).__init__()
        self.transformer = transformer_model
        self.dense = dense_model

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        token_embeddings = outputs['last_hidden_state']
        return self.dense({'sentence_embedding': token_embeddings})


save_directory = "onnx/"

# Load a model from transformers and export it to ONNX
tokenizer = AutoTokenizer.from_pretrained(path)

# load the t5 base encoder model.
transformer_model = T5EncoderModel.from_pretrained(path)

dense_model = models.Dense(
    in_features=768,
    out_features=128,
    bias=False,
    activation_function= nn.Identity()
)

state_dict = torch.load(os.path.join(path, '2_Dense', dense_filename))
dense_model.load_state_dict(state_dict)

model = CombinedModel(transformer_model, dense_model)

model.eval()

input_text = "Who founded google"
inputs = tokenizer(input_text, padding='longest', truncation=True, max_length=128, return_tensors='pt')

input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

torch.onnx.export(
    model,
    (input_ids, attention_mask),
    "combined_model.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names = ['input_ids', 'attention_mask'],
    output_names = ['contextual'],
    dynamic_axes={
        'input_ids': {0 : 'batch_size', 1: 'seq_length'},    # variable length axes
        'attention_mask': {0 : 'batch_size', 1: 'seq_length'},
        'contextual' : {0 : 'batch_size', 1: 'seq_length'}
    }
)

onnx.checker.check_model("combined_model.onnx")

combined_model = onnx.load("combined_model.onnx")

import onnxruntime as ort
ort_session = ort.InferenceSession("combined_model.onnx")
output = ort_session.run(None, {'input_ids': input_ids.numpy(), 'attention_mask': attention_mask.numpy()})


# Run the PyTorch model
pytorch_output =  model(input_ids, attention_mask)
print(pytorch_output['sentence_embedding'])

print(output[0])
# Compare the outputs
# print("Are the outputs close?", np.allclose(pytorch_output.detach().numpy(), output[0], atol=1e-6))

# Calculate the differences between the outputs
differences = pytorch_output['sentence_embedding'].detach().numpy() - output[0]

# Print the standard deviation of the differences
print("Standard deviation of the differences:", np.std(differences))

print("pytorch_output size:", pytorch_output['sentence_embedding'].size())
print("onnx_output size:", output[0].shape)