Spaces:
Running
Running
File size: 7,044 Bytes
6963cf4 9e29637 58a0b29 6963cf4 4ed9ef0 a2460df 4ed9ef0 a2460df aae512c 6963cf4 a2460df 6963cf4 6643342 6963cf4 6643342 6963cf4 6643342 6963cf4 8a9bdb5 6963cf4 50f1c9f 0c30782 6963cf4 50f1c9f 0c18e19 50f1c9f 0c18e19 50f1c9f 0c18e19 50f1c9f e4000f9 6963cf4 e4000f9 50f1c9f 6963cf4 50f1c9f 6963cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
# for custom DataCollator
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from datasets import Dataset
from scipy.special import expit
import requests
import py3Dmol
#import peft
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
checkpoint='ThorbenF/prot_t5_xl_uniref50'
max_length=1500
model, tokenizer = load_model(checkpoint,max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
def create_dataset(tokenizer,seqs,labels,checkpoint):
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
dataset = Dataset.from_dict(tokenized)
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
labels = [l[:max_length-2] for l in labels]
else:
labels = [l[:max_length-1] for l in labels]
dataset = dataset.add_column("labels", labels)
return dataset
def convert_predictions(input_logits):
all_probs = []
for logits in input_logits:
logits = logits.reshape(-1, 2)
# Mask out irrelevant regions
# Compute probabilities for class 1
probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
all_probs.append(probabilities_class1)
return np.concatenate(all_probs)
def normalize_scores(scores):
min_score = np.min(scores)
max_score = np.max(scores)
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
def predict_protein_sequence(test_one_letter_sequence):
dummy_labels=[np.zeros(len(test_one_letter_sequence))]
# Replace uncommon amino acids with "X"
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X").replace("B", "X").replace("U", "X").replace("Z", "X").replace("J", "X")
# Add spaces between each amino acid for ProtT5 and ProstT5 models
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
test_one_letter_sequence = " ".join(test_one_letter_sequence)
# Add <AA2fold> for ProstT5 model input format
if "ProstT5" in checkpoint:
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
test_dataset=create_dataset(tokenizer,[test_one_letter_sequence],dummy_labels,checkpoint)
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
data_collator = DataCollatorForTokenClassificationESM(tokenizer)
else:
data_collator = DataCollatorForTokenClassification(tokenizer)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'] # Ensure to get labels from batch
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits.detach().cpu().numpy()
logits = logits[:, :-1] #remove for prot_t5 the last element, because it is a special token
logits=convert_predictions(logits)
normalized_scores = normalize_scores(logits)
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
return result_str
#interface = gr.Interface(
# fn=predict_protein_sequence,
# inputs=gr.Textbox(lines=2, placeholder="Enter protein sequence here..."),
# outputs=gr.Textbox(), #gr.JSON(), # Use gr.JSON() for list or array-like outputs
# title="Protein sequence - Binding site prediction",
# description="Enter a protein sequence to predict its possible binding sites.",
#)
# Launch the app
#interface.launch()
def fetch_and_display_pdb(pdb_id):
# Construct the PDB URL
pdb_url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
# Try fetching the PDB file
response = requests.get(pdb_url)
if response.status_code != 200:
return "Failed to fetch PDB file"
# Get the structure content as text
structure_text = response.text
# Create the HTML content with embedded 3Dmol.js
html_content = f"""
<html>
<head>
<script src="https://3Dmol.js.org/build/3Dmol-min.js"></script>
<style>
#viewer {{
width: 800px;
height: 600px;
}}
</style>
</head>
<body>
<div id="viewer"></div>
<script>
const viewer = $3Dmol.createViewer("viewer", {{ backgroundColor: "white" }});
viewer.addModel(`{structure_text}`, "pdb");
viewer.setStyle({}, {{ cartoon: {{ color: "spectrum" }} }});
viewer.zoomTo();
viewer.render();
</script>
</body>
</html>
"""
return html_content
# Define the Gradio interface
def gradio_interface(sequence, pdb_id):
# Call the prediction function
binding_site_predictions = predict_protein_sequence(sequence)
# Call the PDB structure visualization function
pdb_structure_html = fetch_and_display_pdb(pdb_id)
return binding_site_predictions, pdb_structure_html
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter protein sequence here...", label="Protein Sequence"),
gr.Textbox(lines=1, placeholder="Enter PDB ID here...", label="PDB ID for 3D Visualization")
],
outputs=[
gr.Textbox(label="Binding Site Predictions"),
gr.HTML(label="3Dmol Viewer") # HTML output to render the 3Dmol viewer
],
title="Protein Binding Site Prediction and 3D Structure Viewer",
description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
)
# Launch the Gradio app
interface.launch() |