Spaces:
Sleeping
Sleeping
import gradio as gr | |
from model_loader import load_model | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from torch.utils.data import DataLoader | |
import re | |
import numpy as np | |
import os | |
import pandas as pd | |
import copy | |
import transformers, datasets | |
from transformers.modeling_outputs import TokenClassifierOutput | |
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack | |
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map | |
from transformers import T5EncoderModel, T5Tokenizer | |
from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel | |
from transformers import AutoTokenizer | |
from transformers import TrainingArguments, Trainer, set_seed | |
from transformers import DataCollatorForTokenClassification | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Tuple, Union | |
# for custom DataCollator | |
from transformers.data.data_collator import DataCollatorMixin | |
from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
from transformers.utils import PaddingStrategy | |
from datasets import Dataset | |
from scipy.special import expit | |
import requests | |
import py3Dmol | |
#import peft | |
#from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig | |
# Configuration | |
checkpoint = 'ThorbenF/prot_t5_xl_uniref50' | |
max_length = 1500 | |
# Load model and move to device | |
model, tokenizer = load_model(checkpoint, max_length) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
def create_dataset(tokenizer, seqs, labels, checkpoint): | |
tokenized = tokenizer(seqs, max_length=max_length, padding=False, truncation=True) | |
dataset = Dataset.from_dict(tokenized) | |
# Adjust labels based on checkpoint | |
if ("esm" in checkpoint) or ("ProstT5" in checkpoint): | |
labels = [l[:max_length-2] for l in labels] | |
else: | |
labels = [l[:max_length-1] for l in labels] | |
dataset = dataset.add_column("labels", labels) | |
return dataset | |
def convert_predictions(input_logits): | |
all_probs = [] | |
for logits in input_logits: | |
logits = logits.reshape(-1, 2) | |
probabilities_class1 = expit(logits[:, 1] - logits[:, 0]) | |
all_probs.append(probabilities_class1) | |
return np.concatenate(all_probs) | |
def normalize_scores(scores): | |
min_score = np.min(scores) | |
max_score = np.max(scores) | |
return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores | |
def predict_protein_sequence(test_one_letter_sequence): | |
# Sanitize input sequence | |
test_one_letter_sequence = test_one_letter_sequence.replace("O", "X") \ | |
.replace("B", "X").replace("U", "X") \ | |
.replace("Z", "X").replace("J", "X") | |
# Prepare sequence for different model types | |
if ("prot_t5" in checkpoint) or ("ProstT5" in checkpoint): | |
test_one_letter_sequence = " ".join(test_one_letter_sequence) | |
if "ProstT5" in checkpoint: | |
test_one_letter_sequence = "<AA2fold> " + test_one_letter_sequence | |
# Create dummy labels | |
dummy_labels = [np.zeros(len(test_one_letter_sequence))] | |
# Create dataset | |
test_dataset = create_dataset(tokenizer, | |
[test_one_letter_sequence], | |
dummy_labels, | |
checkpoint) | |
# Select appropriate data collator | |
data_collator = (DataCollatorForTokenClassification(tokenizer) | |
if "esm" not in checkpoint and "ProstT5" not in checkpoint | |
else DataCollatorForTokenClassification(tokenizer)) | |
# Create data loader | |
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator) | |
# Predict | |
for batch in test_loader: | |
input_ids = batch['input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
logits = outputs.logits.detach().cpu().numpy() | |
# Process logits | |
logits = logits[:, :-1] # Remove last element for prot_t5 | |
logits = convert_predictions(logits) | |
# Normalize and format results | |
normalized_scores = normalize_scores(logits) | |
test_one_letter_sequence = test_one_letter_sequence.replace(" ", "") | |
result_str = "\n".join([f"{aa}: {score:.2f}" for aa, score in zip(test_one_letter_sequence, normalized_scores)]) | |
return result_str | |
def fetch_and_display_pdb(pdb_id): | |
try: | |
# Fetch the PDB structure from RCSB | |
pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb' | |
response = requests.get(pdb_url) | |
if response.status_code != 200: | |
return "Failed to load PDB structure. Please check the PDB ID." | |
pdb_structure = response.text | |
# Prepare the 3D molecular visualization | |
visualization = f""" | |
<div id="container" style="width: 100%; height: 400px; position: relative;"></div> | |
<script src="https://3dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
<script> | |
let viewer = $3Dmol.createViewer(document.getElementById("container")); | |
viewer.addModel(`{pdb_structure}`, "pdb"); | |
viewer.setStyle({{}}, {{"cartoon": {{"color": "spectrum"}}}}); | |
viewer.zoomTo(); | |
viewer.render(); | |
</script> | |
""" | |
return visualization | |
except Exception as e: | |
return f"Error visualizing PDB: {str(e)}" | |
def gradio_interface(sequence, pdb_id): | |
# Predict binding sites | |
binding_site_predictions = predict_protein_sequence(sequence) | |
# Fetch and visualize PDB structure | |
pdb_structure_html = fetch_and_display_pdb(pdb_id) | |
return binding_site_predictions, pdb_structure_html | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter protein sequence here...", label="Protein Sequence"), | |
gr.Textbox(lines=1, placeholder="Enter PDB ID here...", label="PDB ID for 3D Visualization") | |
], | |
outputs=[ | |
gr.Textbox(label="Binding Site Predictions"), | |
gr.HTML(label="3D Molecular Viewer") | |
], | |
title="Protein Binding Site Prediction and 3D Structure Viewer", | |
description="Input a protein sequence to predict binding sites and view the protein structure in 3D using its PDB ID.", | |
) | |
interface.launch() |