import gradio as gr from model_loader import load_model import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.utils.data import DataLoader import re import numpy as np import os import pandas as pd import copy import transformers, datasets from transformers.modeling_outputs import TokenClassifierOutput from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack from transformers.utils.model_parallel_utils import assert_device_map, get_device_map from transformers import T5EncoderModel, T5Tokenizer from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel from transformers import AutoTokenizer from transformers import TrainingArguments, Trainer, set_seed from transformers import DataCollatorForTokenClassification from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union # for custom DataCollator from transformers.data.data_collator import DataCollatorMixin from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy from datasets import Dataset from scipy.special import expit import requests from gradio_molecule3d import Molecule3D #import peft #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig # Configuration checkpoint = 'ThorbenF/prot_t5_xl_uniref50' max_length = 1500 # Default representations for molecule rendering reps = [ { "model": 0, "chain": "", "resname": "", "style": "cartoon", "color": "spectrum", "residue_range": "", "around": 0, "byres": False, "visible": True } ] # Load model and move to device model, tokenizer = load_model(checkpoint, max_length) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() def create_dataset(tokenizer, seqs, labels, checkpoint): tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True) dataset = Dataset.from_dict(tokenized) # Adjust labels based on checkpoint if ("esm" in checkpoint) or ("ProstT5" in checkpoint): labels = [l[:max_length-2] for l in labels] else: labels = [l[:max_length-1] for l in labels] dataset = dataset.add_column("labels", labels) return dataset def convert_predictions(input_logits): all_probs = [] for logits in input_logits: logits = logits.reshape(-1, 2) probabilities_class1 = expit(logits[:, 1] - logits[:, 0]) all_probs.append(probabilities_class1) return np.concatenate(all_probs) def normalize_scores(scores): min_score = np.min(scores) max_score = np.max(scores) return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores def predict_protein_sequence(test_one_letter_sequence): # Sanitize input sequence test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \ .replace("B", "X").replace("U", "X") \ .replace("Z", "X").replace("J", "X") # Prepare sequence for different model types if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint): test_one_letter_sequence = " ".join(test_one_letter_sequence) if "ProstT5" in checkpoint: test_one_letter_sequence = " " + test_one_letter_sequence # Create dummy labels dummy_labels = [np.zeros(len(test_one_letter_sequence))] # Create dataset test_dataset = create_dataset(tokenizer, [test_one_letter_sequence], dummy_labels, checkpoint) # Select appropriate data collator data_collator = (DataCollatorForTokenClassification(tokenizer) if "esm" not in checkpoint and "ProstT5" not in checkpoint else DataCollatorForTokenClassification(tokenizer)) # Create data loader test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator) # Predict for batch in test_loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) logits = outputs.logits.detach().cpu().numpy() # Process logits logits = logits[:, :-1] # Remove last element for prot_t5 logits = convert_predictions(logits) # Normalize and format results normalized_scores = normalize_scores(logits) test_one_letter_sequence = test_one_letter_sequence.replace(" ", "") result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)]) return result_str def fetch_pdb(pdb_id): try: # Create a directory to store PDB files if it doesn't exist os.makedirs('pdb_files', exist_ok=True) # Fetch the PDB structure from RCSB pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' pdb_path = f'pdb_files/{pdb_id}.pdb' # Download the file response = requests.get(pdb_url) if response.status_code == 200: with open(pdb_path, 'wb') as f: f.write(response.content) return pdb_path else: return None except Exception as e: print(f"Error fetching PDB: {e}") return None def process_input(sequence, pdb_id): # Predict binding sites binding_site_predictions = predict_protein_sequence(sequence) # Fetch PDB file pdb_path = fetch_pdb(pdb_id) return binding_site_predictions, pdb_path # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Protein Binding Site Prediction") with gr.Row(): with gr.Column(): # Sequence input sequence_input = gr.Textbox( lines=2, placeholder="Enter protein sequence here...", label="Protein Sequence" ) # PDB ID input pdb_input = gr.Textbox( lines=1, placeholder="Enter PDB ID here...", label="PDB ID for 3D Visualization" ) # Predict button predict_btn = gr.Button("Predict Binding Sites") with gr.Column(): # Binding site predictions output predictions_output = gr.Textbox( label="Binding Site Predictions" ) # 3D Molecule visualization molecule_output = Molecule3D( label="Protein Structure", reps=reps ) # Prediction logic predict_btn.click( process_input, inputs=[sequence_input, pdb_input], outputs=[predictions_output, molecule_output] ) # Add some example inputs gr.Markdown("## Examples") gr.Examples( examples=[ ["MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALMDETMKELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADMEDVCGRLVQYRGEVQAMLGQSTEELRVRLASHLRKLRKRLLRDADDLQKRLAVYQAGAREGAERGLSAIRERLGPLVEQGRVRAATVGSLAGQPLQERAQAWGERLRARMEEMGSRTRDRLDEVKEQVAEVRAKLEEQAQQRL", "1ABC"], ], inputs=[sequence_input, pdb_input], outputs=[predictions_output, molecule_output] ) demo.launch()