Edit model card

Modified version of xlm-roberta-flash-implementation for the onnx conversion

Brief Summary of Challenges and Modifications:

Dynamic Matrix Calculation in RoPE

The original RoPE implementation did not compute the entire rotation matrix at the start. Instead, it calculated the matrix only for the required sequence length, cached it, and recalculated if a longer sequence came as input. This approach isn't compatible with ONNX, which requires a fixed graph during inference. To solve this, I now calculate the entire rotation matrix in advance.

Custom Backward Functions for RoPE

We have custom forward and backward functions for RoPE. ONNX does not support custom backward functions, but since we only need forward passes for inference with ONNX, I removed the backward function completely.

ONNX Model Size Limitation

ONNX stores the model in a protobuf format, which has a maximum size limit of 2GB. Our model was too large to fit this limit, so I had to store the model's parameters as external data files.

Lack of Support for the unique() Function

We used the unique() function to identify unique task types in a batch, which is important when there are multiple task types. However, ONNX does not support the unique() function. For inference, having multiple task types in a batch is not important. Therefore, I modified the code to use the task_id argument—an integer that works for every text in a batch—instead of the adapter_mask, which was a tensor specifying an independent task ID for each text in the batch.

Code

import torch
from transformers import AutoModel, AutoTokenizer
import torch.onnx


model = AutoModel.from_pretrained('/home/admin/saba/jina-embeddings-v3', trust_remote_code=True, use_flash_attn=False)
model.eval()

onnx_path =  "/home/admin/saba/jina-embeddings-v3/onnx/model.onnx"

tokenizer = AutoTokenizer.from_pretrained('/home/admin/saba/jina-embeddings-v3')
inputs = tokenizer(["jina", 'ai'], return_tensors="pt", padding='longest')
inps = inputs['input_ids']
mask = inputs['attention_mask']
task_id = 2


torch.onnx.export(
    model,
    (inps, mask, task_id),
    onnx_path,
    export_params=True,
    do_constant_folding=True,
    input_names = ['input_ids', 'attention_mask', 'task_id'],
    output_names = ['text_embeds'],
    opset_version=16,
    dynamic_axes={
        'input_ids' : {0 : 'batch_size', 1: 'sequence_length'},
        'attention_mask' : {0 : 'batch_size', 1: 'sequence_length'},
        'text_embeds' : {0 : 'batch_size'}
    },
)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model’s pipeline type. Check the docs .