Spaces:
Running
Running
File size: 7,633 Bytes
6963cf4 9e29637 01ff8b6 58a0b29 6963cf4 a6b7cf0 4ed9ef0 01ff8b6 a6b7cf0 aae512c 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 a2460df 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6963cf4 a6b7cf0 6643342 a6b7cf0 6643342 a6b7cf0 6963cf4 6643342 6963cf4 8a9bdb5 50f1c9f 0c30782 6963cf4 01ff8b6 a6b7cf0 01ff8b6 a6b7cf0 01ff8b6 a6b7cf0 01ff8b6 a6b7cf0 01ff8b6 a6b7cf0 01ff8b6 a6b7cf0 01ff8b6 a6b7cf0 e4000f9 01ff8b6 e4000f9 01ff8b6 e4000f9 a6b7cf0 01ff8b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5EncoderModel, T5Tokenizer
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer, set_seed
from transformers import DataCollatorForTokenClassification
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
# for custom DataCollator
from transformers.data.data_collator import DataCollatorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from datasets import Dataset
from scipy.special import expit
import requests
from gradio_molecule3d import Molecule3D
#import peft
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
# Configuration
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
# Default representations for molecule rendering
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "spectrum",
"residue_range": "",
"around": 0,
"byres": False,
"visible": True
}
]
# Load model and move to device
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
def create_dataset(tokenizer, seqs, labels, checkpoint):
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True)
dataset = Dataset.from_dict(tokenized)
# Adjust labels based on checkpoint
if ("esm" in checkpoint) or ("ProstT5" in checkpoint):
labels = [l[:max_length-2] for l in labels]
else:
labels = [l[:max_length-1] for l in labels]
dataset = dataset.add_column("labels", labels)
return dataset
def convert_predictions(input_logits):
all_probs = []
for logits in input_logits:
logits = logits.reshape(-1, 2)
probabilities_class1 = expit(logits[:, 1] - logits[:, 0])
all_probs.append(probabilities_class1)
return np.concatenate(all_probs)
def normalize_scores(scores):
min_score = np.min(scores)
max_score = np.max(scores)
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
def predict_protein_sequence(test_one_letter_sequence):
# Sanitize input sequence
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \
.replace("B", "X").replace("U", "X") \
.replace("Z", "X").replace("J", "X")
# Prepare sequence for different model types
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint):
test_one_letter_sequence = " ".join(test_one_letter_sequence)
if "ProstT5" in checkpoint:
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence
# Create dummy labels
dummy_labels = [np.zeros(len(test_one_letter_sequence))]
# Create dataset
test_dataset = create_dataset(tokenizer,
[test_one_letter_sequence],
dummy_labels,
checkpoint)
# Select appropriate data collator
data_collator = (DataCollatorForTokenClassification(tokenizer)
if "esm" not in checkpoint and "ProstT5" not in checkpoint
else DataCollatorForTokenClassification(tokenizer))
# Create data loader
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)
# Predict
for batch in test_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits.detach().cpu().numpy()
# Process logits
logits = logits[:, :-1] # Remove last element for prot_t5
logits = convert_predictions(logits)
# Normalize and format results
normalized_scores = normalize_scores(logits)
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "")
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)])
return result_str
def fetch_pdb(pdb_id):
try:
# Create a directory to store PDB files if it doesn't exist
os.makedirs('pdb_files', exist_ok=True)
# Fetch the PDB structure from RCSB
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
pdb_path = f'pdb_files/{pdb_id}.pdb'
# Download the file
response = requests.get(pdb_url)
if response.status_code == 200:
with open(pdb_path, 'wb') as f:
f.write(response.content)
return pdb_path
else:
return None
except Exception as e:
print(f"Error fetching PDB: {e}")
return None
def process_input(sequence, pdb_id):
# Predict binding sites
binding_site_predictions = predict_protein_sequence(sequence)
# Fetch PDB file
pdb_path = fetch_pdb(pdb_id)
return binding_site_predictions, pdb_path
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Protein Binding Site Prediction")
with gr.Row():
with gr.Column():
# Sequence input
sequence_input = gr.Textbox(
lines=2,
placeholder="Enter protein sequence here...",
label="Protein Sequence"
)
# PDB ID input
pdb_input = gr.Textbox(
lines=1,
placeholder="Enter PDB ID here...",
label="PDB ID for 3D Visualization"
)
# Predict button
predict_btn = gr.Button("Predict Binding Sites")
with gr.Column():
# Binding site predictions output
predictions_output = gr.Textbox(
label="Binding Site Predictions"
)
# 3D Molecule visualization
molecule_output = Molecule3D(
label="Protein Structure",
reps=reps
)
# Prediction logic
predict_btn.click(
process_input,
inputs=[sequence_input, pdb_input],
outputs=[predictions_output, molecule_output]
)
# Add some example inputs
gr.Markdown("## Examples")
gr.Examples(
examples=[
["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"],
],
inputs=[sequence_input, pdb_input],
outputs=[predictions_output, molecule_output]
)
demo.launch() |