|
import gradio as gr |
|
import sys |
|
import pandas as pd |
|
from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
|
|
metalatte_path = '.' |
|
sys.path.insert(0, metalatte_path) |
|
|
|
|
|
from configuration import MetaLATTEConfig |
|
from modeling_metalatte import MultitaskProteinModel |
|
AutoConfig.register("metalatte", MetaLATTEConfig) |
|
AutoModel.register(MetaLATTEConfig, MultitaskProteinModel) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
config = AutoConfig.from_pretrained("ChatterjeeLab/MetaLATTE") |
|
model = AutoModel.from_pretrained("ChatterjeeLab/MetaLATTE", config=config) |
|
|
|
def predict(sequence): |
|
inputs = tokenizer(sequence, return_tensors="pt") |
|
raw_probs, predictions = model.predict(**inputs) |
|
|
|
id2label = config.id2label |
|
results = {} |
|
for i, pred in enumerate(predictions[0]): |
|
metal = id2label[i] |
|
probability = raw_probs[0][i].item() |
|
results[metal] = '✓' if pred == 1 else '' |
|
|
|
df = pd.DataFrame([results]) |
|
return df |
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence here..."), |
|
outputs=gr.Dataframe(headers=list(config.id2label.values())), |
|
title="MetaLATTE: Metal Binding Prediction", |
|
description="Enter a protein sequence to predict its metal binding properties." |
|
) |
|
|
|
iface.launch() |