File size: 3,853 Bytes
6963cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e29637
 
8bd6bbb
 
11bcc1a
a28eeb5
11bcc1a
e0408f3
6963cf4
8bd6bbb
 
 
4499595
a6b7cf0
 
 
aae512c
 
 
6963cf4
4499595
1705a41
4499595
 
 
 
 
 
 
 
 
 
 
 
1f960e0
 
 
 
 
 
4499595
 
 
 
 
 
a28eeb5
 
4499595
a28eeb5
4499595
a28eeb5
4499595
 
 
6963cf4
1f960e0
 
 
4499595
1f960e0
4499595
 
a6b7cf0
4499595
 
a6b7cf0
4499595
8bd6bbb
4499595
01ff8b6
 
8bd6bbb
01ff8b6
1f960e0
 
 
 
 
 
 
 
4499595
 
 
 
 
 
 
 
 
 
01ff8b6
 
1f960e0
 
 
 
 
 
 
 
 
 
 
4499595
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from model_loader import load_model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import re
import numpy as np
import os
import pandas as pd
import copy

import transformers, datasets
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification

from datasets import Dataset

from scipy.special import expit

import requests

from gradio_molecule3d import Molecule3D

# Biopython imports
from Bio.PDB import PDBParser, Select, PDBIO
from Bio.PDB.DSSP import DSSP
from Bio.PDB import PDBList

from matplotlib import cm  # For color mapping
from matplotlib.colors import Normalize

# Load model and move to device
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]

# Function to fetch a PDB file
def fetch_pdb(pdb_id):
    pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
    pdb_path = f'pdb_files/{pdb_id}.pdb'
    os.makedirs('pdb_files', exist_ok=True)
    response = requests.get(pdb_url)
    if response.status_code == 200:
        with open(pdb_path, 'wb') as f:
            f.write(response.content)
        return pdb_path
    return None


def normalize_scores(scores):
    min_score = np.min(scores)
    max_score = np.max(scores)
    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores

# Extract sequence and predict binding scores
def process_pdb(pdb_id, segment):
    pdb_path = fetch_pdb(pdb_id)
    if not pdb_path:
        return "Failed to fetch PDB file", None, None
    
    parser = PDBParser(QUIET=1)
    structure = parser.get_structure('protein', pdb_path)
    chain = structure[0][segment]
    
    sequence = "".join(residue.get_resname().strip() for residue in chain)
    
    input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()

    scores = expit(outputs[:, 1] - outputs[:, 0])
    normalized_scores = normalize_scores(scores)
    
    result_str = "\n".join([
        f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}" 
        for i, res in enumerate(chain)
    ])
    
    with open(f"{pdb_id}_predictions.txt", "w") as f:
        f.write(result_str)
    
    return result_str, pdb_path, f"{pdb_id}_predictions.txt"

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Protein Binding Site Prediction")

    with gr.Row():
        pdb_input = gr.Textbox(value="2IWI",
                label="PDB ID",
                placeholder="Enter PDB ID here...")
        segment_input = gr.Textbox(value="A",
                label="Chain ID (Segment)",
                placeholder="Enter Chain ID here...")
        visualize_btn = gr.Button("Visualize Sructure")
        prediction_btn = gr.Button("Predict Ligand Binding Site")

    molecule_output = Molecule3D(label="Protein Structure", reps=reps)
    predictions_output = gr.Textbox(label="Binding Site Predictions")
    download_output = gr.File(label="Download Predictions")

    visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
    prediction_btn.click(
        process_pdb, 
        inputs=[pdb_input, segment_input], 
        outputs=[predictions_output, molecule_output, download_output]
    )

    gr.Markdown("## Examples")
    gr.Examples(
        examples=[
            ["2IWI"],
            ["7RPZ"],
            ["3TJN"]
        ],
        inputs=[pdb_input, segment_input], 
        outputs=[predictions_output, molecule_output, download_output]
    )

demo.launch(share=True)