PyTorch
cadurosar's picture
Update README.md
dc74ffc
metadata
license: cc-by-nc-sa-4.0

Our best attempt at reproducing RankT5 Enc-Softmax, with a few important differences:

  1. We use a SPLADE first stage for the negatives vs GTR on the paper
  2. We train using Pytorch vs Flaxx on the paper
  3. We use the original t5-3b vs Flan T5-3b on the paper -> Actually the paper also uses t5-3b
  4. The head is not exactly the same, here we add Linear->LayerNorm->Linear and actually make a mistake by not including a nonlinearity. The original paper uses just a dense layer. Fixing this should improve our performance because we have more layers without actually using them correctly

This leads to what seems to be a slightly worse performance (42.8 vs 43.? on the paper) and seems slightly worse on BEIR as well.

To use this model, first clone the huggingface repo

git clone https://huggingface.co/naver/trecdl22-crossencoder-rankT53b-repro

And then we suggest loading it like follows:

import torch
from transformers import T5EncoderModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput

class T5EncoderRerank(torch.nn.Module):
    def __init__(self, model_type_or_dir):
        super().__init__()
        self.model = T5EncoderModel.from_pretrained(model_type_or_dir)
        self.config = self.model.config
        self.first_transform = torch.nn.Linear(self.config.d_model, self.config.d_model)
        self.layer_norm = torch.nn.LayerNorm(self.config.d_model, eps=1e-12)
        self.linear = torch.nn.Linear(self.config.d_model,1)

    def forward(self, **kwargs):
        result = self.model(**kwargs).last_hidden_state[:,0,:]
        first_transformed = self.first_transform(result)
        layer_normed = self.layer_norm(first_transformed)
        logits = self.linear(layer_normed)
        return SequenceClassifierOutput(
            logits=logits
        )


original_model="t5-3b"
path_checkpoint="trecdl22-crossencoder-rankT53b-repro/pytorch_model.bin"

print("Loading")
model = T5EncoderRerank(original_model)
model.load_state_dict(torch.load(path_checkpoint,map_location=torch.device("cpu")))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(original_model)
print("loaded")