import gradio as gr from model_loader import load_model import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import re import numpy as np import os import pandas as pd import copy import transformers, datasets from transformers import AutoTokenizer from transformers import DataCollatorForTokenClassification from datasets import Dataset from scipy.special import expit import requests from gradio_molecule3d import Molecule3D # Biopython imports from Bio.PDB import PDBParser, Select, PDBIO from Bio.PDB.DSSP import DSSP from Bio.PDB import PDBList from matplotlib import cm # For color mapping from matplotlib.colors import Normalize # Load model and move to device checkpoint = 'ThorbenF/prot_t5_xl_uniref50' max_length = 1500 model, tokenizer = load_model(checkpoint, max_length) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}] # Function to fetch a PDB file def fetch_pdb(pdb_id): pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' pdb_path = f'pdb_files/{pdb_id}.pdb' os.makedirs('pdb_files', exist_ok=True) response = requests.get(pdb_url) if response.status_code == 200: with open(pdb_path, 'wb') as f: f.write(response.content) return pdb_path return None # Extract sequence and predict binding scores def process_pdb(pdb_id, segment): pdb_path = fetch_pdb(pdb_id) if not pdb_path: return "Failed to fetch PDB file", None, None parser = PDBParser(QUIET=1) structure = parser.get_structure('protein', pdb_path) chain = structure[0][segment] sequence = "".join(residue.get_resname().strip() for residue in chain) input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device) with torch.no_grad(): outputs = model(input_ids).logits.detach().cpu().numpy().squeeze() scores = outputs[:, 1] - outputs[:, 0] result_str = "\n".join([ f"{res.get_resname()} {res.id[1]} {sequence[i]} {scores[i]:.2f}" for i, res in enumerate(chain) ]) with open(f"{pdb_id}_predictions.txt", "w") as f: f.write(result_str) return result_str, pdb_path, f"{pdb_id}_predictions.txt" # Gradio UI with gr.Blocks() as demo: gr.Markdown("# Protein Binding Site Prediction") with gr.Row(): pdb_input = gr.Textbox(label="PDB ID") segment_input = gr.Textbox(label="Segment (Chain ID)") visualize_btn = gr.Button("Visualize") prediction_btn = gr.Button("Predict") molecule_output = Molecule3D(label="Protein Structure", reps=reps) predictions_output = gr.Textbox(label="Binding Site Predictions") download_output = gr.File(label="Download Predictions") visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output) prediction_btn.click( process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output] ) demo.launch(share=True)