File size: 1,096 Bytes
5a7d048
1b5c960
 
 
5a7d048
1b5c960
5a7d048
1b5c960
5a7d048
1b5c960
 
5a7d048
1b5c960
 
 
 
 
 
5a7d048
1b5c960
 
5a7d048
1b5c960
 
 
 
 
5a7d048
1b5c960
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
---
datasets:
- tattabio/OMG
license: apache-2.0
---
# gLM2_650M_embed

gLM2_embed is a fine-tuned vesion of [`tattabio/gLM2_650M`](https://huggingface.co/tattabio/gLM2_650M) for embedding and retrieval.

- The first stage finetunes gLM2 over one epoch of UniRef50.
- The second stage trains an adapter layer to align mean-pooled representations with AlphaFold structural [clusters](https://www.nature.com/articles/s41586-023-06510-w).

## Getting Started
```python
import torch
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('tattabio/gLM2_650M_embed', torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained('tattabio/gLM2_650M_embed', trust_remote_code=True)

# NOTE: Prepend with `<+>` to match gLM2 pre-training.
sequence = "<+>MALTKVEKRNRIKRRVRGKISGTQASPRLSVYKSNK"

# Tokenize the sequence.
encodings = tokenizer([sequence], return_tensors='pt')
# Extract embeddings.
with torch.no_grad():
    embeddings = model(encodings.input_ids.cuda()).pooler_output

print(embeddings.shape)  # torch.Size([1, 512])
```