FastESM2_650 / README.md
lhallee's picture
Update README.md
1558ca7 verified
|
raw
history blame
780 Bytes
metadata
library_name: transformers
tags: []

FastESM

A faster half-precision version of ESM2-650 that leverages FlashAttenion2

Requires PyTorch 2.5+ for the most savings, see SDPA.

import torch
from transformers import AutoModel, AutoTokenizer

model_path = 'Synthyra/FastESM2_650'
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

sequence = 'MSEQWENCE'
tokenized = tokenizer(sequence, return_tensors='pt')
with torch.no_grad():
    embeddings = model(**tokenized).last_hidden_state

print(embeddings.shape) # (1, 11, 1280)