XTR-ONNX

This model is google's XTR-base-en model exported to ONNX format.

original XTR model: https://huggingface.co/google/xtr-base-en

Given a max length input of 512, this model will output a 128 dimensional vector for each token.

XTR's demo notebook uses only one special token -- EOS.

Using this model

This model can be plugged into LintDB to index data into a database.

In LintDB

# create an XTR index
config = ldb.Configuration()
config.num_subquantizers = 64
config.dim = 128
config.nbits = 4
config.quantizer_type = ldb.IndexEncoding_XTR
index = ldb.IndexIVF(f"experiments/goog", config)

# build a collection on top of the index
opts = ldb.CollectionOptions()
opts.model_file = "assets/xtr/encoder.onnx"
opts.tokenizer_file = "assets/xtr/spiece.model"

collection = ldb.Collection(index, opts)

collection.train(chunks, 50, 10)

for i, snip in enumerate(chunks):
    collection.add(0, i, snip, {'docid': f'{i}'})

Creating this model

In order to create this model, we had to combine XTR's T5 encoder model with a dense layer. Below is the code used to do this. Credit to yaman on Github for this solution.

from sentence_transformers import SentenceTransformer
from sentence_transformers import models
import torch
import torch.nn as nn
import onnx
import numpy as np
from transformers import T5EncoderModel
from pathlib import Path
from transformers import AutoTokenizer

# https://github.com/huggingface/optimum/issues/1519

class CombinedModel(nn.Module):
    def __init__(self, transformer_model, dense_model):
        super(CombinedModel, self).__init__()
        self.transformer = transformer_model
        self.dense = dense_model

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids, attention_mask=attention_mask)
        token_embeddings = outputs['last_hidden_state']
        return self.dense({'sentence_embedding': token_embeddings})


save_directory = "onnx/"

# Load a model from transformers and export it to ONNX
tokenizer = AutoTokenizer.from_pretrained(path)

# load the t5 base encoder model.
transformer_model = T5EncoderModel.from_pretrained(path)

dense_model = models.Dense(
    in_features=768,
    out_features=128,
    bias=False,
    activation_function= nn.Identity()
)

state_dict = torch.load(os.path.join(path, '2_Dense', dense_filename))
dense_model.load_state_dict(state_dict)

model = CombinedModel(transformer_model, dense_model)

model.eval()

input_text = "Who founded google"
inputs = tokenizer(input_text, padding='longest', truncation=True, max_length=128, return_tensors='pt')

input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

torch.onnx.export(
    model,
    (input_ids, attention_mask),
    "combined_model.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names = ['input_ids', 'attention_mask'],
    output_names = ['contextual'],
    dynamic_axes={
        'input_ids': {0 : 'batch_size', 1: 'seq_length'},    # variable length axes
        'attention_mask': {0 : 'batch_size', 1: 'seq_length'},
        'contextual' : {0 : 'batch_size', 1: 'seq_length'}
    }
)

onnx.checker.check_model("combined_model.onnx")

combined_model = onnx.load("combined_model.onnx")

import onnxruntime as ort
ort_session = ort.InferenceSession("combined_model.onnx")
output = ort_session.run(None, {'input_ids': input_ids.numpy(), 'attention_mask': attention_mask.numpy()})


# Run the PyTorch model
pytorch_output =  model(input_ids, attention_mask)
print(pytorch_output['sentence_embedding'])

print(output[0])
# Compare the outputs
# print("Are the outputs close?", np.allclose(pytorch_output.detach().numpy(), output[0], atol=1e-6))

# Calculate the differences between the outputs
differences = pytorch_output['sentence_embedding'].detach().numpy() - output[0]

# Print the standard deviation of the differences
print("Standard deviation of the differences:", np.std(differences))

print("pytorch_output size:", pytorch_output['sentence_embedding'].size())
print("onnx_output size:", output[0].shape)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for DeployQL/XTR-onnx

Base model

google/xtr-base-en
Quantized
(1)
this model