metadata
library_name: peft
datasets:
- xnli
license: cc-by-nc-4.0
pipeline_tag: sentence-similarity
These are LoRA adaption weights for mT5 encoder.
Multilingual Sentence T5
This model is a multilingual extension of Sentence T5 and was created using the mT5 encoder. It is proposed in this paper. It is an encoder for sentence embedding, and its performance has been verified in cross-lingual STS and sentence retrieval.
Traning Data
The model was trained on the XNLI dataset.
Framework versions
- PEFT 0.4.0.dev0
Hot to use
- If you have not installed peft, please do so.
pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git
- Load the model.
from transformers import MT5EncoderModel
from peft import PeftModel
model = MT5EncoderModel.from_pretrained("google/mt5-xxl")
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model: PeftModel = PeftModel.from_pretrained(model, "pkshatech/m-ST5")
- To obtain sentence embedding, use the mean pooling.
tokenizer = AutoTokenizer.from_pretrained("google/mt5-xxl", use_fast=False)
model.eval()
texts = ["I am a dog.","You are a cat."]
inputs = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
)
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
last_hidden_state[inputs.attention_mask == 0, :] = 0
sent_len = inputs.attention_mask.sum(dim=1, keepdim=True)
sent_emb = last_hidden_state.sum(dim=1) / sent_len
BenchMarks
- Tatoeba: Sentence retrieval tasks with pairs of English sentences and sentences in other languages.
- BUCC: Bitext mining task. It consists of English and one of the 4 languages (German, French, Russian and Chinese).
- XSTS: Cross-lingual semantic textual similarity task. Please check the paper for details and more.
Tatoeba-14 | Tatoeba-36 | BUCC | XSTS (ar-ar) |
XSTS (ar-en) |
XSTS (es-es) |
XSTS (es-en) |
XSTS (tr-en) |
|
---|---|---|---|---|---|---|---|---|
m-ST5 | 96.3 | 94.7 | 97.6 | 76.2 | 78.6 | 84.4 | 76.2 | 75.1 |
LaBSE | 95.3 | 95.0 | 93.5 | 69.1 | 74.5 | 80.8 | 65.5 | 72.0 |