File size: 4,569 Bytes
ce6b085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e29637
 
8bd6bbb
 
11bcc1a
a28eeb5
11bcc1a
e0408f3
ce6b085
8bd6bbb
 
 
4499595
a6b7cf0
 
 
aae512c
 
 
ce6b085
4499595
 
 
 
 
 
 
 
 
 
 
 
1f960e0
 
 
 
 
 
4499595
 
 
 
 
a28eeb5
 
4499595
a28eeb5
8bef2d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a28eeb5
8bef2d8
4499595
 
 
ce6b085
8bef2d8
1f960e0
 
 
8bef2d8
4499595
1f960e0
8bef2d8
4499595
a6b7cf0
8bef2d8
4499595
 
a6b7cf0
4499595
8bd6bbb
8bef2d8
 
4499595
01ff8b6
 
8bd6bbb
01ff8b6
1f960e0
 
 
 
 
 
 
 
4499595
 
 
 
 
 
 
 
 
 
01ff8b6
 
1f960e0
 
 
 
 
 
 
 
 
 
 
4499595
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from model_loader import load_model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import re
import numpy as np
import os
import pandas as pd
import copy

import transformers, datasets
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification

from datasets import Dataset

from scipy.special import expit

import requests

from gradio_molecule3d import Molecule3D

# Biopython imports
from Bio.PDB import PDBParser, Select, PDBIO
from Bio.PDB.DSSP import DSSP
from Bio.PDB import PDBList

from matplotlib import cm  # For color mapping
from matplotlib.colors import Normalize

# Load model and move to device
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

# Function to fetch a PDB file
def fetch_pdb(pdb_id):
    pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
    pdb_path = f'pdb_files/{pdb_id}.pdb'
    os.makedirs('pdb_files', exist_ok=True)
    response = requests.get(pdb_url)
    if response.status_code == 200:
        with open(pdb_path, 'wb') as f:
            f.write(response.content)
        return pdb_path
    return None


def normalize_scores(scores):
    min_score = np.min(scores)
    max_score = np.max(scores)
    return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores

def process_pdb(pdb_id, segment):
    pdb_path = fetch_pdb(pdb_id)
    if not pdb_path:
        return "Failed to fetch PDB file", None, None
    
    parser = PDBParser(QUIET=1)
    structure = parser.get_structure('protein', pdb_path)
    chain = structure[0][segment]
    
    # Comprehensive amino acid mapping
    aa_dict = {
        'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
        'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
        'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
        'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y',
        'MSE': 'M', 'SEP': 'S', 'TPO': 'T', 'CSO': 'C', 'PTR': 'Y', 'HYP': 'P'
    }
    
    # Exclude non-amino acid residues
    sequence = "".join(
        aa_dict[residue.get_resname().strip()] 
        for residue in chain 
        if residue.get_resname().strip() in aa_dict
    )
    
    # Prepare input for model prediction
    input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()

    # Calculate scores and normalize them
    scores = expit(outputs[:, 1] - outputs[:, 0])
    normalized_scores = normalize_scores(scores)
    
    # Prepare the result string, including only amino acid residues
    result_str = "\n".join([
        f"{res.get_resname()} {res.id[1]} {sequence[i]} {normalized_scores[i]:.2f}" 
        for i, res in enumerate(chain) if res.get_resname().strip() in aa_dict
    ])
    
    # Save predictions to file
    with open(f"{pdb_id}_predictions.txt", "w") as f:
        f.write(result_str)
    
    return result_str, pdb_path, f"{pdb_id}_predictions.txt"

reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Protein Binding Site Prediction")

    with gr.Row():
        pdb_input = gr.Textbox(value="2IWI",
                label="PDB ID",
                placeholder="Enter PDB ID here...")
        segment_input = gr.Textbox(value="A",
                label="Chain ID (Segment)",
                placeholder="Enter Chain ID here...")
        visualize_btn = gr.Button("Visualize Sructure")
        prediction_btn = gr.Button("Predict Ligand Binding Site")

    molecule_output = Molecule3D(label="Protein Structure", reps=reps)
    predictions_output = gr.Textbox(label="Binding Site Predictions")
    download_output = gr.File(label="Download Predictions")

    visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
    prediction_btn.click(
        process_pdb, 
        inputs=[pdb_input, segment_input], 
        outputs=[predictions_output, molecule_output, download_output]
    )

    gr.Markdown("## Examples")
    gr.Examples(
        examples=[
            ["2IWI"],
            ["7RPZ"],
            ["3TJN"]
        ],
        inputs=[pdb_input, segment_input], 
        outputs=[predictions_output, molecule_output, download_output]
    )

demo.launch(share=True)