Spaces:
Running
Running
import gradio as gr | |
from model_loader import load_model | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import re | |
import numpy as np | |
import os | |
import pandas as pd | |
import copy | |
import transformers, datasets | |
from transformers import AutoTokenizer | |
from transformers import DataCollatorForTokenClassification | |
from datasets import Dataset | |
from scipy.special import expit | |
import requests | |
from gradio_molecule3d import Molecule3D | |
# Biopython imports | |
from Bio.PDB import PDBParser, Select, PDBIO | |
from Bio.PDB.DSSP import DSSP | |
from Bio.PDB import PDBList | |
from matplotlib import cm # For color mapping | |
from matplotlib.colors import Normalize | |
# Load model and move to device | |
checkpoint = 'ThorbenF/prot_t5_xl_uniref50' | |
max_length = 1500 | |
model, tokenizer = load_model(checkpoint, max_length) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}] | |
# Function to fetch a PDB file | |
def fetch_pdb(pdb_id): | |
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' | |
pdb_path = f'pdb_files/{pdb_id}.pdb' | |
os.makedirs('pdb_files', exist_ok=True) | |
response = requests.get(pdb_url) | |
if response.status_code == 200: | |
with open(pdb_path, 'wb') as f: | |
f.write(response.content) | |
return pdb_path | |
return None | |
# Extract sequence and predict binding scores | |
def process_pdb(pdb_id, segment): | |
pdb_path = fetch_pdb(pdb_id) | |
if not pdb_path: | |
return "Failed to fetch PDB file", None, None | |
parser = PDBParser(QUIET=1) | |
structure = parser.get_structure('protein', pdb_path) | |
chain = structure[0][segment] | |
sequence = "".join(residue.get_resname().strip() for residue in chain) | |
input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids).logits.detach().cpu().numpy().squeeze() | |
scores = outputs[:, 1] - outputs[:, 0] | |
result_str = "\n".join([ | |
f"{res.get_resname()} {res.id[1]} {sequence[i]} {scores[i]:.2f}" | |
for i, res in enumerate(chain) | |
]) | |
with open(f"{pdb_id}_predictions.txt", "w") as f: | |
f.write(result_str) | |
return result_str, pdb_path, f"{pdb_id}_predictions.txt" | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# Protein Binding Site Prediction") | |
with gr.Row(): | |
pdb_input = gr.Textbox(label="PDB ID") | |
segment_input = gr.Textbox(label="Segment (Chain ID)") | |
visualize_btn = gr.Button("Visualize") | |
prediction_btn = gr.Button("Predict") | |
molecule_output = Molecule3D(label="Protein Structure", reps=reps) | |
predictions_output = gr.Textbox(label="Binding Site Predictions") | |
download_output = gr.File(label="Download Predictions") | |
visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output) | |
prediction_btn.click( | |
process_pdb, | |
inputs=[pdb_input, segment_input], | |
outputs=[predictions_output, molecule_output, download_output] | |
) | |
demo.launch(share=True) |