Spaces:
Running
Running
File size: 3,853 Bytes
6963cf4 9e29637 8bd6bbb 11bcc1a a28eeb5 11bcc1a e0408f3 6963cf4 8bd6bbb 4499595 a6b7cf0 aae512c 6963cf4 4499595 1705a41 4499595 1f960e0 4499595 a28eeb5 4499595 a28eeb5 4499595 a28eeb5 4499595 6963cf4 1f960e0 4499595 1f960e0 4499595 a6b7cf0 4499595 a6b7cf0 4499595 8bd6bbb 4499595 01ff8b6 8bd6bbb 01ff8b6 1f960e0 4499595 01ff8b6 1f960e0 4499595 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
from datasets import Dataset
from scipy.special import expit
import requests
from gradio_molecule3d import Molecule3D
# Biopython imports
from Bio.PDB import PDBParser, Select, PDBIO
from Bio.PDB.DSSP import DSSP
from Bio.PDB import PDBList
from matplotlib import cm # For color mapping
from matplotlib.colors import Normalize
# Load model and move to device
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]
# Function to fetch a PDB file
def fetch_pdb(pdb_id):
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
pdb_path = f'pdb_files/{pdb_id}.pdb'
os.makedirs('pdb_files', exist_ok=True)
response = requests.get(pdb_url)
if response.status_code == 200:
with open(pdb_path, 'wb') as f:
f.write(response.content)
return pdb_path
return None
def normalize_scores(scores):
min_score = np.min(scores)
max_score = np.max(scores)
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
# Extract sequence and predict binding scores
def process_pdb(pdb_id, segment):
pdb_path = fetch_pdb(pdb_id)
if not pdb_path:
return "Failed to fetch PDB file", None, None
parser = PDBParser(QUIET=1)
structure = parser.get_structure('protein', pdb_path)
chain = structure[0][segment]
sequence = "".join(residue.get_resname().strip() for residue in chain)
input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
with torch.no_grad():
outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
scores = expit(outputs[:, 1] - outputs[:, 0])
normalized_scores = normalize_scores(scores)
result_str = "\n".join([
f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}"
for i, res in enumerate(chain)
])
with open(f"{pdb_id}_predictions.txt", "w") as f:
f.write(result_str)
return result_str, pdb_path, f"{pdb_id}_predictions.txt"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Protein Binding Site Prediction")
with gr.Row():
pdb_input = gr.Textbox(value="2IWI",
label="PDB ID",
placeholder="Enter PDB ID here...")
segment_input = gr.Textbox(value="A",
label="Chain ID (Segment)",
placeholder="Enter Chain ID here...")
visualize_btn = gr.Button("Visualize Sructure")
prediction_btn = gr.Button("Predict Ligand Binding Site")
molecule_output = Molecule3D(label="Protein Structure", reps=reps)
predictions_output = gr.Textbox(label="Binding Site Predictions")
download_output = gr.File(label="Download Predictions")
visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
prediction_btn.click(
process_pdb,
inputs=[pdb_input, segment_input],
outputs=[predictions_output, molecule_output, download_output]
)
gr.Markdown("## Examples")
gr.Examples(
examples=[
["2IWI"],
["7RPZ"],
["3TJN"]
],
inputs=[pdb_input, segment_input],
outputs=[predictions_output, molecule_output, download_output]
)
demo.launch(share=True) |