FastESM2_650 / README.md
lhallee's picture
Update README.md
1b6c3e6 verified
|
raw
history blame
953 Bytes
metadata
library_name: transformers
tags: []

FastESM

A faster half-precision version of ESM2-650 that leverages FlashAttenion2

Requires PyTorch 2.5+ for the most savings, see SDPA.

Outputting attentions and predicting contacts are not possible from SDPA. Various other optimizations also make the base implementation slightly different than the HF one.

import torch
from transformers import AutoModel, AutoTokenizer

model_path = 'Synthyra/FastESM2_650'
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

sequence = 'MSEQWENCE'
tokenized = tokenizer(sequence, return_tensors='pt')
with torch.no_grad():
    embeddings = model(**tokenized).last_hidden_state

print(embeddings.shape) # (1, 11, 1280)