Spaces:
Sleeping
Sleeping
File size: 6,637 Bytes
6963cf4 9e29637 58a0b29 6963cf4 a6b7cf0 4ed9ef0 a6b7cf0 aae512c 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 a2460df 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6643342 a6b7cf0 6643342 a6b7cf0 6963cf4 6643342 6963cf4 8a9bdb5 50f1c9f 0c30782 6963cf4 50f1c9f a6b7cf0 e4000f9 a6b7cf0 e4000f9 a6b7cf0 e4000f9 a6b7cf0 6963cf4 e4000f9 50f1c9f a6b7cf0 50f1c9f 6963cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
# for custom DataCollator
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from datasets import Dataset
from scipy.special import expit
import requests
import py3Dmol
#import peft
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
# Configuration
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
# Load model and move to device
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
def create_dataset(tokenizer, seqs, labels, checkpoint):
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
dataset = Dataset.from_dict(tokenized)
# Adjust labels based on checkpoint
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
labels = [l[:max_length-2] for l in labels]
else:
labels = [l[:max_length-1] for l in labels]
dataset = dataset.add_column("labels", labels)
return dataset
def convert_predictions(input_logits):
all_probs = []
for logits in input_logits:
logits = logits.reshape(-1, 2)
probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
all_probs.append(probabilities_class1)
return np.concatenate(all_probs)
def normalize_scores(scores):
min_score = np.min(scores)
max_score = np.max(scores)
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
def predict_protein_sequence(test_one_letter_sequence):
# Sanitize input sequence
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
.replace("B", "X").replace("U", "X") \
.replace("Z", "X").replace("J", "X")
# Prepare sequence for different model types
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
test_one_letter_sequence = " ".join(test_one_letter_sequence)
if "ProstT5" in checkpoint:
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
# Create dummy labels
dummy_labels = [np.zeros(len(test_one_letter_sequence))]
# Create dataset
test_dataset = create_dataset(tokenizer,
[test_one_letter_sequence],
dummy_labels,
checkpoint)
# Select appropriate data collator
data_collator = (DataCollatorForTokenClassification(tokenizer)
if "esm" not in checkpoint and "ProstT5" not in checkpoint
else DataCollatorForTokenClassification(tokenizer))
# Create data loader
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
# Predict
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits.detach().cpu().numpy()
# Process logits
logits = logits[:, :-1] # Remove last element for prot_t5
logits = convert_predictions(logits)
# Normalize and format results
normalized_scores = normalize_scores(logits)
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
return result_str
def fetch_and_display_pdb(pdb_id):
try:
# Fetch the PDB structure from RCSB
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
response = requests.get(pdb_url)
if response.status_code != 200:
return "Failed to load PDB structure. Please check the PDB ID."
pdb_structure = response.text
# Prepare the 3D molecular visualization
visualization = f"""
<div id="container" style="width: 100%; height: 400px; position: relative;"></div>
<script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
<script>
let viewer = $3Dmol.createViewer(document.getElementById("container"));
viewer.addModel(`{pdb_structure}`, "pdb");
viewer.setStyle({{}}, {{"cartoon": {{"color": "spectrum"}}}});
viewer.zoomTo();
viewer.render();
</script>
"""
return visualization
except Exception as e:
return f"Error visualizing PDB: {str(e)}"
def gradio_interface(sequence, pdb_id):
# Predict binding sites
binding_site_predictions = predict_protein_sequence(sequence)
# Fetch and visualize PDB structure
pdb_structure_html = fetch_and_display_pdb(pdb_id)
return binding_site_predictions, pdb_structure_html
# Create Gradio interface
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter protein sequence here...", label="Protein Sequence"),
gr.Textbox(lines=1, placeholder="Enter PDB ID here...", label="PDB ID for 3D Visualization")
],
outputs=[
gr.Textbox(label="Binding Site Predictions"),
gr.HTML(label="3D Molecular Viewer")
],
title="Protein Binding Site Prediction and 3D Structure Viewer",
description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.",
)
interface.launch() |