File size: 9,300 Bytes
6963cf4
c85a5b0
 
 
 
 
 
 
6963cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4499595
a6b7cf0
 
 
aae512c
 
 
6963cf4
c85a5b0
 
 
 
 
 
 
 
 
 
4499595
 
c85a5b0
4499595
 
 
 
 
c85a5b0
 
1f960e0
4499595
 
 
 
 
a28eeb5
 
c85a5b0
 
 
 
 
a28eeb5
f506fb3
8bef2d8
 
 
 
 
 
 
 
 
f506fb3
 
 
8bef2d8
f506fb3
fd6cc24
 
 
 
a28eeb5
8bef2d8
4499595
 
 
6963cf4
8bef2d8
1f960e0
 
c85a5b0
fd6cc24
 
 
ff689ae
 
 
 
1f960e0
c85a5b0
 
 
4499595
a6b7cf0
fd6cc24
c85a5b0
fd6cc24
c85a5b0
 
 
 
fd6cc24
 
 
 
 
c85a5b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bd6bbb
fd6cc24
 
c85a5b0
fd6cc24
c85a5b0
 
 
fd6cc24
c85a5b0
fd6cc24
c85a5b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd6cc24
c85a5b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bef2d8
4499595
01ff8b6
c85a5b0
 
 
 
 
 
8bd6bbb
01ff8b6
1f10fe3
c85a5b0
 
 
 
4499595
 
c85a5b0
 
 
 
 
1f960e0
 
 
fd6cc24
 
c85a5b0
1f960e0
c85a5b0
1f960e0
 
 
4499595
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import gradio as gr
import requests
from Bio.PDB import PDBParser
import numpy as np
import os
from gradio_molecule3d import Molecule3D


from model_loader import load_model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import re
import pandas as pd
import copy

import transformers, datasets
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification

from datasets import Dataset

from scipy.special import expit

# Load model and move to device
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

def normalize_scores(scores):
    min_score = np.min(scores)
    max_score = np.max(scores)
    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
    
def read_mol(pdb_path):
    """Read PDB file and return its content as a string"""
    with open(pdb_path, 'r') as f:
        return f.read()

def fetch_pdb(pdb_id):
    pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
    pdb_path = f'{pdb_id}.pdb'
    response = requests.get(pdb_url)
    if response.status_code == 200:
        with open(pdb_path, 'wb') as f:
            f.write(response.content)
        return pdb_path
    else:
        return None

def process_pdb(pdb_id, segment):
    pdb_path = fetch_pdb(pdb_id)
    if not pdb_path:
        return "Failed to fetch PDB file", None, None
    
    parser = PDBParser(QUIET=1)
    structure = parser.get_structure('protein', pdb_path)
    
    try:
        chain = structure[0][segment]
    except KeyError:
        return "Invalid Chain ID", None, None
    
    
    aa_dict = {
        'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
        'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
        'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
        'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
        'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'
    }
    
    # Exclude non-amino acid residues
    sequence = "".join(
        aa_dict[residue.get_resname().strip()] 
        for residue in chain 
        if residue.get_resname().strip() in aa_dict
    )
    sequence2 = [
        (res.id[1], res) for res in chain
        if res.get_resname().strip() in aa_dict
    ]
    
    # Prepare input for model prediction
    input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()

    # Calculate scores and normalize them
    scores = expit(outputs[:, 1] - outputs[:, 0])
    normalized_scores = normalize_scores(scores)

    # Zip residues with scores to track the residue ID and score
    residue_scores = [(resi, score) for (resi, _), score in zip(sequence2, normalized_scores)]
    
    result_str = "\n".join([
        f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}" 
        for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
    ])
    
    # Save the predictions to a file
    prediction_file = f"{pdb_id}_predictions.txt"
    with open(prediction_file, "w") as f:
        f.write(result_str)
    
    return result_str, molecule(pdb_path, residue_scores, segment), prediction_file

def molecule(input_pdb, residue_scores=None, segment='A'):
    mol = read_mol(input_pdb)  # Read PDB file content
    
    # Prepare high-scoring residues script if scores are provided
    high_score_script = ""
    if residue_scores is not None:
        # Sort residues based on their scores
        high_score_residues = [resi for resi, score in residue_scores if score > 0.75]
        mid_score_residues = [resi for resi, score in residue_scores if 0.5 < score <= 0.75]
        
        high_score_script = """
        // Reset all styles first
        viewer.getModel(0).setStyle({}, {});
        
        // Show only the selected chain
        viewer.getModel(0).setStyle(
            {"chain": "%s"}, 
            { cartoon: {colorscheme:"whiteCarbon"} }
        );
        
        // Highlight high-scoring residues only for the selected chain
        let highScoreResidues = [%s];
        viewer.getModel(0).setStyle(
            {"chain": "%s", "resi": highScoreResidues}, 
            {"stick": {"color": "red"}}
        );

        // Highlight medium-scoring residues only for the selected chain
        let midScoreResidues = [%s];
        viewer.getModel(0).setStyle(
            {"chain": "%s", "resi": midScoreResidues}, 
            {"stick": {"color": "orange"}}
        );
        """ % (segment, 
               ", ".join(str(resi) for resi in high_score_residues),
               segment,
               ", ".join(str(resi) for resi in mid_score_residues),
               segment)
    
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>    
        <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
        <style>
        .mol-container {{
            width: 100%;
            height: 700px;
            position: relative;
        }}
        </style>
        <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js"></script>
        <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
    </head>
    <body>
        <div id="container" class="mol-container"></div>
        <script>
            let pdb = `{mol}`; // Use template literal to properly escape PDB content
            $(document).ready(function () {{
                let element = $("#container");
                let config = {{ backgroundColor: "white" }};
                let viewer = $3Dmol.createViewer(element, config);
                viewer.addModel(pdb, "pdb");
                
                // Reset all styles and show only selected chain
                viewer.getModel(0).setStyle(
                    {{"chain": "{segment}"}}, 
                    {{ cartoon: {{ colorscheme:"whiteCarbon" }} }}
                );
                
                {high_score_script}
                
                // Add hover functionality
                viewer.setHoverable(
                    {{}}, 
                    true, 
                    function(atom, viewer, event, container) {{
                        if (!atom.label) {{
                            atom.label = viewer.addLabel(
                                atom.resn + ":" +atom.resi + ":" + atom.atom, 
                                {{
                                    position: atom, 
                                    backgroundColor: 'mintcream', 
                                    fontColor: 'black',
                                    fontSize: 12,
                                    padding: 2
                                }}
                            );
                        }}
                    }},
                    function(atom, viewer) {{
                        if (atom.label) {{
                            viewer.removeLabel(atom.label);
                            delete atom.label;
                        }}
                    }}
                );
                
                viewer.zoomTo();
                viewer.render();
                viewer.zoom(0.8, 2000);
            }});
        </script>
    </body>
    </html>
    """
    
    # Return the HTML content within an iframe safely encoded for special characters
    return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'

reps =    [
        {
          "model": 0,
          "style": "cartoon",
          "color": "whiteCarbon",
          "residue_range": "",
          "around": 0,
          "byres": False,
        }
    ]

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Protein Binding Site Prediction (Random Scores)")
    with gr.Row():
        pdb_input = gr.Textbox(value="2IWI", label="PDB ID", placeholder="Enter PDB ID here...")
        visualize_btn = gr.Button("Visualize Structure")

    molecule_output2 = Molecule3D(label="Protein Structure", reps=reps)

    with gr.Row():
        #pdb_input = gr.Textbox(value="2IWI", label="PDB ID", placeholder="Enter PDB ID here...")
        segment_input = gr.Textbox(value="A", label="Chain ID", placeholder="Enter Chain ID here...")
        prediction_btn = gr.Button("Predict Random Binding Site Scores")

    molecule_output = gr.HTML(label="Protein Structure")
    predictions_output = gr.Textbox(label="Binding Site Predictions")
    download_output = gr.File(label="Download Predictions")
    
    visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output2)
    
    prediction_btn.click(process_pdb, inputs=[pdb_input, segment_input], outputs=[predictions_output, molecule_output, download_output])
    
    gr.Markdown("## Examples")
    gr.Examples(
        examples=[
            ["7RPZ", "A"],
            ["2IWI", "B"],
            ["3TJN", "C"]
        ],
        inputs=[pdb_input, segment_input],
        outputs=[predictions_output, molecule_output, download_output]
    )

demo.launch(share=True)