Spaces:
Sleeping
Sleeping
import gradio as gr | |
from model_loader import load_model | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from torch.utils.data import DataLoader | |
import re | |
import numpy as np | |
import os | |
import pandas as pd | |
import copy | |
import transformers, datasets | |
from transformers.modeling_outputs import TokenClassifierOutput | |
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack | |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
from transformers import T5EncoderModel, T5Tokenizer | |
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel | |
from transformers import AutoTokenizer | |
from transformers import TrainingArguments, Trainer, set_seed | |
from transformers import DataCollatorForTokenClassification | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Tuple, Union | |
# for custom DataCollator | |
from transformers.data.data_collator import DataCollatorMixin | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from transformers.utils import PaddingStrategy | |
from datasets import Dataset | |
from scipy.special import expit | |
#import peft | |
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig | |
checkpoint='ThorbenF/prot_t5_xl_uniref50' | |
max_length=1500 | |
model, tokenizer = load_model(checkpoint,max_length) | |
def create_dataset(tokenizer,seqs,labels,checkpoint): | |
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True) | |
dataset = Dataset.from_dict(tokenized) | |
if ("esm" in checkpoint) or ("ProstT5" in checkpoint): | |
labels = [l[:max_length-2] for l in labels] | |
else: | |
labels = [l[:max_length-1] for l in labels] | |
dataset = dataset.add_column("labels", labels) | |
return dataset | |
def convert_predictions(input_logits): | |
all_probs = [] | |
for logits in input_logits: | |
logits = logits.reshape(-1, 2) | |
# Mask out irrelevant regions | |
# Compute probabilities for class 1 | |
probabilities_class1 = expit(logits[:, 1] - logits[:, 0]) | |
all_probs.append(probabilities_class1) | |
return np.concatenate(all_probs) | |
def normalize_scores(scores): | |
min_score = np.min(scores) | |
max_score = np.max(scores) | |
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores | |
def predict_protein_sequence(test_one_letter_sequence): | |
dummy_labels=[np.zeros(len(test_one_letter_sequence))] | |
# Replace uncommon amino acids with "X" | |
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X").replace("J", "X") | |
# Add spaces between each amino acid for ProtT5 and ProstT5 models | |
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint): | |
test_one_letter_sequence = " ".join(test_one_letter_sequence) | |
# Add <AA2fold> for ProstT5 model input format | |
if "ProstT5" in checkpoint: | |
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence | |
test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint) | |
if ("esm" in checkpoint) or ("ProstT5" in checkpoint): | |
data_collator = DataCollatorForTokenClassificationESM(tokenizer) | |
else: | |
data_collator = DataCollatorForTokenClassification(tokenizer) | |
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
for batch in test_loader: | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
labels = batch['labels'] # Ensure to get labels from batch | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits.detach().cpu().numpy() | |
logits=convert_predictions(logits) | |
logits.shape | |
normalized_scores = normalize_scores(logits) | |
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)]) | |
return result_str | |
interface = gr.Interface( | |
fn=predict_protein_sequence, | |
inputs=gr.Textbox(lines=2, placeholder="Enter protein sequence here..."), | |
outputs=gr.Textbox(), #gr.JSON(), # Use gr.JSON() for list or array-like outputs | |
title="Protein sequence - Binding site prediction", | |
description="Enter a protein sequence to predict its possible binding sites.", | |
) | |
# Launch the app | |
interface.launch() |