|
--- |
|
license: other |
|
pipeline_tag: sequence-classification |
|
tags: |
|
- biology |
|
- Protein |
|
- Pfam family |
|
- classification |
|
--- |
|
# `ProtAlBert-Pfam` |
|
|
|
## Model Description |
|
|
|
`ProtAlBert-Pfam` is a `ProtAlBert` language model fine-tuned to predict Pfam family from the sequence. |
|
|
|
It can predict the ten most likely Pfam families for a given protein sequence. |
|
|
|
This is just a proof of concept, and the model is not made to solve the Pfam family prediction task. |
|
|
|
|
|
**Key Features** |
|
* Pfam family prediction |
|
* Predict sequences up to 128 nucleotides |
|
|
|
## Usage |
|
|
|
Get started generating text with `ProtAlBert` by using the following code snippet: |
|
|
|
```python |
|
from transformers import AutoModel, AlbertTokenizer, AutoConfig |
|
import re |
|
|
|
def convert_sequence_to_input(sequence: str, model_name = "Rostlab/prot_albert"): |
|
seq = " ".join([aa for aa in sequence]) |
|
seq = re.sub(r"[UZOB]", "X", seq) |
|
tokenizer = AlbertTokenizer.from_pretrained(model_name, trust_remote_code=True, do_lower_case=False) |
|
params = dict(return_tensors="pt", padding="max_length", |
|
max_length=128, |
|
truncation=True,) |
|
x = tokenizer(seq, **params) |
|
return x |
|
|
|
def convert_pfam_idx_to_class(pfam_idx: int) -> str: |
|
""" |
|
Convert the prediction of the model to the corresponding class. |
|
:param pfam_idx: index of the pfam class |
|
:return: the Pfam family |
|
""" |
|
conversion = {"0": "Methyltransf_25", "1": "LRR_1", "2": "Acetyltransf_7", "3": "His_kinase", |
|
"4": "Bac_transf", "5": "Lum_binding", "6": "DNA_binding_1", "7": "Chromate_transp", |
|
"8": "Lipase_GDSL_2", "9": "DnaJ_CXXCXGXG"} |
|
return conversion.get(str(pfam_idx), "Unknown") |
|
|
|
|
|
model_name = "sayby/prot_albert_pfam" |
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
|
|
sequence = "ILDVGTGTGKLESLAEFKRDFIGLDVTKEMMALNRNKGKLLLASATQMPIKDGTFDAIVSSFVLRNLPSTKGYFSEGFRTLKEGG" |
|
x = convert_sequence_to_input(sequence) |
|
output = model(x) |
|
pfam_idx = output["logits"].argmax(dim=-1).item() |
|
pfam = convert_pfam_idx_to_class(pfam_idx) |
|
print(f"The Pfam family is: {pfam}") |
|
``` |
|
|