|
--- |
|
widget: |
|
- text: "MEPLDDLDLLLLEEDSGAEAVPRMEILQKKADAFFAETVLSRGVDNRYLVLAVETKLNERGAEEKHLLITVSQEGEQEVLCILRNGWSSVPVEPGDIIHIEGDCTSEPWIVDDDFGYFILSPDMLISGTSVASSIRCLRRAVLSETFRVSDTATRQMLIGTILHEVFQKAISESFAPEKLQELALQTLREVRHLKEMYRLNLSQDEVRCEVEEYLPSFSKWADEFMHKGTKAEFPQMHLSLPSDSSDRSSPCNIEVVKSLDIEESIWSPRFGLKGKIDVTVGVKIHRDCKTKYKIMPLELKTGKESNSIEHRGQVILYTLLSQERREDPEAGWLLYLKTGQMYPVPANHLDKRELLKLRNQLAFSLLHRVSRAAAGEEARLLALPQIIEEEKTCKYCSQMGNCALYSRAVEQVHDTSIPEGMRSKIQEGTQHLTRAHLKYFSLWCLMLTLESQSKDTKKSHQSIWLTPASKLEESGNCIGSLVRTEPVKRVCDGHYLHNFQRKNGPMPATNLMAGDRIILSGEERKLFALSKGYVKRIDTAAVTCLLDRNLSTLPETTLFRLDREEKHGDINTPLGNLSKLMENTDSSKRLRELIIDFKEPQFIAYLSSVLPHDAKDTVANILKGLNKPQRQAMKKVLLSKDYTLIVGMPGTGKTTTICALVRILSACGFSVLLTSYTHSAVDNILLKLAKFKIGFLRLGQSHKVHPDIQKFTEEEMCRLRSIASLAHLEELYNSHPVVATTCMGISHPMFSRKTFDFCIVDEASQISQPICLGPLFFSRRFVLVGDHKQLPPLVLNREARALGMSESLFKRLERNESAVVQLTIQYRMNRKIMSLSNKLTYEGKLECGSDRVANAVITLPNLKDVRLEFYADYSDNPWLAGVFEPDNPVCFLNTDKVPAPEQIENGGVSNVTEARLIVFLTSTFIKAGCSPSDIGIIAPYRQQLRTITDLLARSSVGMVEVNTVDKYQGRDKSLILVSFVRSNEDGTLGELLKDWRRLNVAITRAKHKLILLGSVSSLKRF" |
|
example_title: "Protein Sequence 1" |
|
- text: "MNSVTVSHAPYYIVYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPPKFFIQLKQMLRNKRVCVCGILPYPIDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINELLELDNKVPINWAQGFIY" |
|
example_title: "Protein Sequence 2" |
|
- text: "MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY" |
|
example_title: "Protein Sequence 3" |
|
license: mit |
|
datasets: |
|
- AmelieSchreiber/general_binding_sites |
|
language: |
|
- en |
|
metrics: |
|
- precision |
|
- recall |
|
- f1 |
|
library_name: transformers |
|
tags: |
|
- biology |
|
- esm |
|
- esm2 |
|
- ESM-2 |
|
- protein language model |
|
--- |
|
|
|
# ESM-2 for General Protein Binding Site Prediction |
|
|
|
This model is trained to predict general binding sites of proteins using only the sequence. This is a finetuned version of |
|
`esm2_t6_8M_UR50D` ([see here](https://huggingface.co/facebook/esm2_t6_8M_UR50D) and [also here](https://huggingface.co/docs/transformers/model_doc/esm) |
|
for more info on the base model), trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is |
|
not filtered by family, and thus the model may be overfit to some degree. In the Hugging Face Inference API widget to the right |
|
there are three protein sequence examples. The first is a DNA binding protein truncated to the first 1022 amino acid residues |
|
([see UniProt entry here](https://www.uniprot.org/uniprotkb/D3ZG52/entry)). |
|
|
|
The second and third were obtained using [EvoProtGrad](https://github.com/Amelie-Schreiber/sampling_protein_language_models/blob/main/EvoProtGrad_copy.ipynb) |
|
a Markov Chain Monte Carlo method of (*in silico*) directed evolution of proteins based on a form of Gibbs sampling. The mutatant-type |
|
protein sequences in theory should have similar binding sites to the wild-type protein sequence, but perhaps with higher binding affinity. |
|
Testing this out on the model, we see the two proteins indeed have the same binding sites, which validates to some degree that the model |
|
has learned to predict binding sites well (and that EvoProtGrad works as intended). |
|
|
|
## Training |
|
|
|
This model was trained on approximately 70,000 proteins with binding site and active site annotations in UniProt. |
|
The training split was a random 85/15 split for this version, and does not consider anything in the way of family or sequence |
|
similarity. New iterations of the model have been trained on larger datasets (over 200,000 proteins), with the split such that |
|
there are no overlapping families, however they seem to overfit much earlier and have significantly worse performance in terms |
|
of the training metrics (precision, recall, and F1). To address this we plan to implement LoRA (and hopefully QLoRA). |
|
|
|
Training Metrics for the Model in the form of the `trainer_state.json` can be |
|
[found here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2/blob/main/trainer_state.json). |
|
|
|
``` |
|
epoch 3: |
|
Training Loss Validation Loss Precision Recall F1 Auc |
|
0.031100 0.074720 0.684798 0.966856 0.801743 0.980853 |
|
``` |
|
The hyperparameters are: |
|
|
|
``` |
|
wandb: lr: 0.0004977045729600779 |
|
wandb: lr_scheduler_type: cosine |
|
wandb: max_grad_norm: 0.5 |
|
wandb: num_train_epochs: 3 |
|
wandb: per_device_train_batch_size: 8 |
|
wandb: weight_decay: 0.025 |
|
``` |
|
|
|
## Using the Model |
|
|
|
To use the model, try running: |
|
```python |
|
import torch |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
|
|
def predict_binding_sites(model_path, protein_sequences): |
|
""" |
|
Predict binding sites for a collection of protein sequences. |
|
|
|
Parameters: |
|
- model_path (str): Path to the saved model. |
|
- protein_sequences (List[str]): List of protein sequences. |
|
|
|
Returns: |
|
- List[List[str]]: Predicted labels for each sequence. |
|
""" |
|
|
|
# Load tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForTokenClassification.from_pretrained(model_path) |
|
|
|
# Ensure model is in evaluation mode |
|
model.eval() |
|
|
|
# Tokenize sequences |
|
inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True) |
|
|
|
# Move to the same device as model and obtain logits |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
# Obtain predicted labels |
|
predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy() |
|
|
|
# Convert label IDs to human-readable labels |
|
id2label = model.config.id2label |
|
human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels] |
|
|
|
return human_readable_labels |
|
|
|
# Usage: |
|
model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2" # Replace with your model's path |
|
unseen_proteins = [ |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD", |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD" |
|
] # Replace with your protein sequences |
|
predictions = predict_binding_sites(model_path, unseen_proteins) |
|
predictions |
|
``` |