metadata
library_name: transformers
tags: []
FastESM
A faster half-precision version of ESM2-650 that leverages FlashAttenion2
Requires PyTorch 2.5+ for the most savings, see SDPA.
Outputting attentions and predicting contacts are not possible from SDPA. Various other optimizations also make the base implementation slightly different than the HF one.
import torch
from transformers import AutoModel, AutoTokenizer
model_path = 'Synthyra/FastESM2_650'
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
sequence = 'MSEQWENCE'
tokenized = tokenizer(sequence, return_tensors='pt')
with torch.no_grad():
embeddings = model(**tokenized).last_hidden_state
print(embeddings.shape) # (1, 11, 1280)