metadata
library_name: transformers
tags: []
FastESM
A faster half-precision version of ESM2-650 that leverages FlashAttenion2
Requires PyTorch 2.5+ for the most savings, see SDPA.
import torch
from transformers import AutoModel, AutoTokenizer
model_path = 'Synthyra/FastESM2_650'
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
sequence = 'MSEQWENCE'
tokenized = tokenizer(sequence, return_tensors='pt')
with torch.no_grad():
embeddings = model(**tokenized).last_hidden_state
print(embeddings.shape) # (1, 11, 1280)