File size: 3,135 Bytes
6963cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e29637
 
8bd6bbb
 
11bcc1a
a28eeb5
11bcc1a
e0408f3
6963cf4
8bd6bbb
 
 
4499595
a6b7cf0
 
 
aae512c
 
 
6963cf4
4499595
1705a41
4499595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a28eeb5
 
4499595
a28eeb5
4499595
a28eeb5
4499595
 
 
6963cf4
4499595
 
 
 
 
a6b7cf0
4499595
 
a6b7cf0
4499595
8bd6bbb
4499595
01ff8b6
 
8bd6bbb
01ff8b6
4499595
 
 
 
 
 
 
 
 
 
 
 
 
 
01ff8b6
 
4499595
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from model_loader import load_model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import re
import numpy as np
import os
import pandas as pd
import copy

import transformers, datasets
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification

from datasets import Dataset

from scipy.special import expit

import requests

from gradio_molecule3d import Molecule3D

# Biopython imports
from Bio.PDB import PDBParser, Select, PDBIO
from Bio.PDB.DSSP import DSSP
from Bio.PDB import PDBList

from matplotlib import cm  # For color mapping
from matplotlib.colors import Normalize

# Load model and move to device
checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
max_length = 1500
model, tokenizer = load_model(checkpoint, max_length)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

reps = [{"model": 0, "style": "cartoon", "color": "spectrum"}]

# Function to fetch a PDB file
def fetch_pdb(pdb_id):
    pdb_url = f'https://files.rcsb.org/download/{pdb_id}.pdb'
    pdb_path = f'pdb_files/{pdb_id}.pdb'
    os.makedirs('pdb_files', exist_ok=True)
    response = requests.get(pdb_url)
    if response.status_code == 200:
        with open(pdb_path, 'wb') as f:
            f.write(response.content)
        return pdb_path
    return None

# Extract sequence and predict binding scores
def process_pdb(pdb_id, segment):
    pdb_path = fetch_pdb(pdb_id)
    if not pdb_path:
        return "Failed to fetch PDB file", None, None
    
    parser = PDBParser(QUIET=1)
    structure = parser.get_structure('protein', pdb_path)
    chain = structure[0][segment]
    
    sequence = "".join(residue.get_resname().strip() for residue in chain)
    
    input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()

    scores = outputs[:, 1] - outputs[:, 0]
    result_str = "\n".join([
        f"{res.get_resname()} {res.id[1]} {sequence[i]} {scores[i]:.2f}" 
        for i, res in enumerate(chain)
    ])
    
    with open(f"{pdb_id}_predictions.txt", "w") as f:
        f.write(result_str)
    
    return result_str, pdb_path, f"{pdb_id}_predictions.txt"

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Protein Binding Site Prediction")

    with gr.Row():
        pdb_input = gr.Textbox(label="PDB ID")
        segment_input = gr.Textbox(label="Segment (Chain ID)")
        visualize_btn = gr.Button("Visualize")
        prediction_btn = gr.Button("Predict")

    molecule_output = Molecule3D(label="Protein Structure", reps=reps)
    predictions_output = gr.Textbox(label="Binding Site Predictions")
    download_output = gr.File(label="Download Predictions")

    visualize_btn.click(fetch_pdb, inputs=[pdb_input], outputs=molecule_output)
    prediction_btn.click(
        process_pdb, 
        inputs=[pdb_input, segment_input], 
        outputs=[predictions_output, molecule_output, download_output]
    )

demo.launch(share=True)