test_webpage / app.py
ThorbenF's picture
Update
4499595
raw
history blame
3.14 kB
import gradio as gr
from model_loader import load_model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import re
import numpy as np
import os
import pandas as pd
import copy
import transformers, datasets
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
from datasets import Dataset
from scipy.special import expit
import requests
from gradio_molecule3d import Molecule3D
# Biopython imports
from Bio.PDB import PDBParser, Select, PDBIO
from Bio.PDB.DSSP import DSSP
from Bio.PDB import PDBList
from matplotlib import cm # For color mapping
from matplotlib.colors import Normalize
# Load model and move to device
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]
# Function to fetch a PDB file
def fetch_pdb(pdb_id):
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
pdb_path = f'pdb_files/{pdb_id}.pdb'
os.makedirs('pdb_files', exist_ok=True)
response = requests.get(pdb_url)
if response.status_code == 200:
with open(pdb_path, 'wb') as f:
f.write(response.content)
return pdb_path
return None
# Extract sequence and predict binding scores
def process_pdb(pdb_id, segment):
pdb_path = fetch_pdb(pdb_id)
if not pdb_path:
return "Failed to fetch PDB file", None, None
parser = PDBParser(QUIET=1)
structure = parser.get_structure('protein', pdb_path)
chain = structure[0][segment]
sequence = "".join(residue.get_resname().strip() for residue in chain)
input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
with torch.no_grad():
outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
scores = outputs[:, 1] - outputs[:, 0]
result_str = "\n".join([
f"{res.get_resname()} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
for i, res in enumerate(chain)
])
with open(f"{pdb_id}_predictions.txt", "w") as f:
f.write(result_str)
return result_str, pdb_path, f"{pdb_id}_predictions.txt"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Protein Binding Site Prediction")
with gr.Row():
pdb_input = gr.Textbox(label="PDB ID")
segment_input = gr.Textbox(label="Segment (Chain ID)")
visualize_btn = gr.Button("Visualize")
prediction_btn = gr.Button("Predict")
molecule_output = Molecule3D(label="Protein Structure", reps=reps)
predictions_output = gr.Textbox(label="Binding Site Predictions")
download_output = gr.File(label="Download Predictions")
visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
prediction_btn.click(
process_pdb,
inputs=[pdb_input, segment_input],
outputs=[predictions_output, molecule_output, download_output]
)
demo.launch(share=True)