Spaces:
Sleeping
Sleeping
File size: 4,681 Bytes
6963cf4 4ed9ef0 a2460df 4ed9ef0 a2460df 6963cf4 a2460df 6963cf4 8a9bdb5 6963cf4 0c30782 6963cf4 3ae17ec 6963cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
# for custom DataCollator
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from datasets import Dataset
from scipy.special import expit
#import peft
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
checkpoint='ThorbenF/prot_t5_xl_uniref50'
max_length=1500
model, tokenizer = load_model(checkpoint,max_length)
def create_dataset(tokenizer,seqs,labels,checkpoint):
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
dataset = Dataset.from_dict(tokenized)
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
labels = [l[:max_length-2] for l in labels]
else:
labels = [l[:max_length-1] for l in labels]
dataset = dataset.add_column("labels", labels)
return dataset
def convert_predictions(input_logits):
all_probs = []
for logits in input_logits:
logits = logits.reshape(-1, 2)
# Mask out irrelevant regions
# Compute probabilities for class 1
probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
all_probs.append(probabilities_class1)
return np.concatenate(all_probs)
def normalize_scores(scores):
min_score = np.min(scores)
max_score = np.max(scores)
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
def predict_protein_sequence(test_one_letter_sequence):
dummy_labels=[np.zeros(len(test_one_letter_sequence))]
# Replace uncommon amino acids with "X"
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X").replace("J", "X")
# Add spaces between each amino acid for ProtT5 and ProstT5 models
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
test_one_letter_sequence = " ".join(test_one_letter_sequence)
# Add <AA2fold> for ProstT5 model input format
if "ProstT5" in checkpoint:
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint)
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
data_collator = DataCollatorForTokenClassificationESM(tokenizer)
else:
data_collator = DataCollatorForTokenClassification(tokenizer)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'] # Ensure to get labels from batch
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits.detach().cpu().numpy()
logits=convert_predictions(logits)
logits.shape
normalized_scores = normalize_scores(logits)
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
return result_str
interface = gr.Interface(
fn=predict_protein_sequence,
inputs=gr.Textbox(lines=2, placeholder="Enter protein sequence here..."),
outputs=gr.Textbox(), #gr.JSON(), # Use gr.JSON() for list or array-like outputs
title="Protein sequence - Binding site prediction",
description="Enter a protein sequence to predict its possible binding sites.",
)
# Launch the app
interface.launch() |