ESM-2 (esm2_t6_8M_UR50D
) for Token Classification
This is a fine-tuned version of esm2_t6_8M_UR50D trained on the token classification task
to classify amino acids in protein sequences into one of three categories 0: other
, 1: alpha helix
, 2: beta strand
. It was trained with
this notebook and achieves
78.13824286786025 % accuracy.
Using the Model
To use, try running:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import numpy as np
# 1. Prepare the Model and Tokenizer
# Replace with the path where your trained model is saved if you're training a new model
model_dir = "AmelieSchreiber/esm2_t6_8M_UR50D-finetuned-secondary-structure"
model = AutoModelForTokenClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Define a mapping from label IDs to their string representations
label_map = {0: "Other", 1: "Helix", 2: "Strand"}
# 2. Tokenize the New Protein Sequence
new_protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your protein sequence
tokens = tokenizer.tokenize(new_protein_sequence)
inputs = tokenizer.encode(new_protein_sequence, return_tensors="pt")
# 3. Predict with the Model
with torch.no_grad():
outputs = model(inputs).logits
predictions = np.argmax(outputs[0].numpy(), axis=1)
# 4. Decode the Predictions
predicted_labels = [label_map[label_id] for label_id in predictions]
# Print the tokens along with their predicted labels
for token, label in zip(tokens, predicted_labels):
print(f"{token}: {label}")
- Downloads last month
- 109
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.