File size: 780 Bytes
822e13f
 
 
 
1558ca7
822e13f
1558ca7
822e13f
1558ca7
822e13f
1558ca7
 
 
822e13f
1558ca7
 
 
822e13f
1558ca7
 
 
 
822e13f
1558ca7
 
822e13f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
---
library_name: transformers
tags: []
---
# FastESM

## A faster half-precision version of ESM2-650 that leverages FlashAttenion2

Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).

```python
import torch
from transformers import AutoModel, AutoTokenizer

model_path = 'Synthyra/FastESM2_650'
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

sequence = 'MSEQWENCE'
tokenized = tokenizer(sequence, return_tensors='pt')
with torch.no_grad():
    embeddings = model(**tokenized).last_hidden_state

print(embeddings.shape) # (1, 11, 1280)
```