Training:

For a report on the training please see here and here.

Metrics:

Train:
({'accuracy': 0.9406146072672105,
  'precision': 0.2947122459102886,
  'recall': 0.952624323712029,
  'f1': 0.4501592605994876,
  'auc': 0.9464622170085311,
  'mcc': 0.5118390407598565},
Test:
 {'accuracy': 0.9266827008067329,
  'precision': 0.22378953253253775,
  'recall': 0.7790246675002842,
  'f1': 0.3476966444342296,
  'auc': 0.8547531675185658,
  'mcc': 0.3930283737012391})

Using the Model

Using on your Protein Sequences

To use the model on one of your protein sequences try running the following:

from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))

Getting the Train/Test Metrics:

Head over to here to download the dataset first. Once you have the pickle files downloaded locally, run the following:

from datasets import Dataset
from transformers import AutoTokenizer
import pickle

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Function to truncate labels
def truncate_labels(labels, max_length):
    """Truncate labels to the specified max_length."""
    return [label[:max_length] for label in labels]

# Set the maximum sequence length
max_sequence_length = 1000

# Load the data from pickle files
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

# Tokenize the sequences
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)

# Truncate the labels to match the tokenized sequence lengths
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)

# Create train and test datasets
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)

Then run the following to get the train/test metrics:

from sklearn.metrics import(
    matthews_corrcoef, 
    accuracy_score, 
    precision_recall_fscore_support, 
    roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
from transformers import Trainer
from accelerate import Accelerator

# Instantiate the accelerator
accelerator = Accelerator()

# Define paths to the LoRA and base models
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1" # "path/to/your/lora/model" Replace with the correct path to your LoRA model

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the LoRA model
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)  # Prepare the model using the accelerator

# Define label mappings
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}

# Create a data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Define a function to compute the metrics
def compute_metrics(dataset):
    # Get the predictions using the trained model
    trainer = Trainer(model=model, data_collator=data_collator)
    predictions, labels, _ = trainer.predict(test_dataset=dataset)
    
    # Remove padding and special tokens
    mask = labels != -100
    true_labels = labels[mask].flatten()
    flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute the MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}  # Include the MCC in the returned dictionary

# Get the metrics for the training and test datasets
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)

train_metrics, test_metrics
Downloads last month
30
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.

Collection including AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1