aayush-shah commited on
Commit
0711b9c
1 Parent(s): 9b56526

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import transformers
4
+ from transformers import GenerationConfig, pipeline, AutoTokenizer, AutoModelForCausalLM, EsmForProteinFolding
5
+ import os
6
+ import tempfile
7
+ import subprocess
8
+ import pandas as pd
9
+ import numpy as np
10
+ import gradio as gr
11
+ from time import time
12
+
13
+ model_id = "Esperanto/Protein-Llama-3-8B"
14
+ #Loading the fine-tuned LLaMA 3 model
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
+ torch_dtype=torch.float16,
18
+ low_cpu_mem_usage=True
19
+ )
20
+
21
+ #loading the tokenizer
22
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
23
+
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+ tokenizer.padding_side = "left"
26
+
27
+ #Creating the pipeline for generation
28
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
29
+
30
+
31
+ # Loading the ESM Model
32
+ esm_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
33
+ esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
34
+
35
+ esm_model.to(device)
36
+
37
+ #Ensures that final output contains only valid amino acids
38
+ def clean_protein_sequence(protein_seq):
39
+ # Valid amino acid characters
40
+ valid_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
41
+
42
+ # Filter out any characters that are not valid amino acids
43
+ cleaned_seq = ''.join([char for char in protein_seq if char in valid_amino_acids])
44
+
45
+ return cleaned_seq
46
+
47
+ #convert pLDDT to percentage
48
+ def modify_b_factors(pdb_content, multiplier):
49
+ modified_pdb = []
50
+ for line in pdb_content.split('\n'):
51
+ if line.startswith("ATOM"):
52
+ b_factor = float(line[60:66].strip())
53
+ new_b_factor = b_factor * multiplier
54
+ new_line = f"{line[:60]}{new_b_factor:6.2f}{line[66:]}"
55
+ modified_pdb.append(new_line)
56
+ else:
57
+ modified_pdb.append(line)
58
+ return "\n".join(modified_pdb)
59
+
60
+ #saves the structure output from ESMFold as a PDB file in a temporary folder
61
+ def save_pdb(input_sequence):
62
+ inputs = esm_tokenizer([input_sequence], return_tensors="pt", add_special_tokens=False)
63
+ inputs = inputs.to(device)
64
+ with torch.no_grad():
65
+ outputs = esm_model(**inputs)
66
+ pdb_string_unscaled = esm_model.output_to_pdb(outputs)[0]
67
+ pdb_string = modify_b_factors(pdb_string_unscaled, 100)
68
+ plddt_values = outputs.plddt.tolist()[0][0]
69
+ plddt_values = [round(value * 100, 2) for value in plddt_values]
70
+ file_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb")
71
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
72
+ with open(file_path, "w") as f:
73
+ f.write(pdb_string)
74
+
75
+ return np.mean(plddt_values)
76
+
77
+ #reads the PDB file
78
+ def read_prot(molpath):
79
+ with open(molpath, "r") as fp:
80
+ lines = fp.readlines()
81
+ mol = ""
82
+ for l in lines:
83
+ mol += l
84
+ return mol
85
+
86
+
87
+ def protein_visual_html(input_pdb):
88
+
89
+ mol = read_prot(input_pdb)
90
+
91
+ x = (
92
+ """<!DOCTYPE html>
93
+ <html>
94
+ <head>
95
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
96
+ <style>
97
+ body{
98
+ font-family:sans-serif
99
+ }
100
+ .mol-container {
101
+ width: 100%;
102
+ height: 600px;
103
+ position: relative;
104
+ }
105
+ .mol-container select{
106
+ background-image:None;
107
+ }
108
+ </style>
109
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
110
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
111
+ </head>
112
+ <body>
113
+ <div id="container" class="mol-container"></div>
114
+
115
+ <script>
116
+ let pdb = `""" + mol + """`
117
+
118
+ $(document).ready(function () {
119
+ let element = $("#container");
120
+ let config = { backgroundColor: "white" };
121
+ let viewer = $3Dmol.createViewer(element, config);
122
+ viewer.addModel(pdb, "pdb");
123
+ viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } });
124
+ viewer.zoomTo();
125
+ viewer.render();
126
+ viewer.zoom(0.8, 2000);
127
+ })
128
+ </script>
129
+ </body></html>"""
130
+ )
131
+
132
+ return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
133
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
134
+ allow-scripts allow-same-origin allow-popups
135
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
136
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
137
+
138
+
139
+ def predict_structure(input_sequence):
140
+ #Hard coding the SARS-CoV 2 protein sequence and structure for instant demo purposes
141
+ if input_sequence == 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI':
142
+ return protein_visual_html('Protein-Llama-3-8B-Gradio/sars_cov_2_6vxx.pdb')
143
+ else:
144
+ plddt = save_pdb(input_sequence)
145
+ #Creating HTML visualization for the PDB file stores in temporary folder
146
+ pdb_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb")
147
+ return protein_visual_html(pdb_path)
148
+
149
+ def generate_protein_sequence(sequence, seq_length, property=''):
150
+ enzymes = ["Non-Hemolytic", "Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"]
151
+ start_time = time()
152
+
153
+ if property is None:
154
+ input_prompt = 'Seq=<' + sequence
155
+ elif property == 'SARS-CoV-2 Spike Protein (example)':
156
+ cleaned_seq = 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI'
157
+ end_time = time()
158
+ max_memory_used = 0
159
+ return cleaned_seq, end_time - start_time, max_memory_used, 0
160
+ elif property in enzymes:
161
+ input_prompt = '[Generate ' + property.lower() + ' protein] ' + 'Seq=<' + sequence
162
+ else:
163
+ input_prompt = '[Generate ' + property + ' protein] ' + 'Seq=<' + sequence
164
+
165
+
166
+
167
+ start_time = time()
168
+ protein_seq = generator(input_prompt, temperature=0.5,
169
+ top_k=40,
170
+ top_p=0.9,
171
+ do_sample=True,
172
+ repetition_penalty=1.2,
173
+ max_new_tokens=seq_length,
174
+ num_return_sequences=1)[0]["generated_text"]
175
+
176
+ end_time = time()
177
+
178
+ start_idx = protein_seq.find('Seq=<')
179
+ end_idx = protein_seq.find('>', start_idx)
180
+ protein_seq = protein_seq[start_idx:end_idx]
181
+ cleaned_seq = clean_protein_sequence(protein_seq)
182
+ tokens = tokenizer.encode(cleaned_seq, add_special_tokens=False)
183
+ tokens_per_second = len(tokens) / (end_time - start_time)
184
+
185
+ return cleaned_seq, end_time - start_time, tokens_per_second
186
+
187
+
188
+
189
+ # Create the Gradio interface
190
+
191
+ with gr.Blocks() as demo:
192
+ gr.Markdown("Interactive protein sequence generation and visualization")
193
+
194
+ with gr.Row():
195
+ input_text = gr.Textbox(label="Enter starting amino acids for protein sequence generation", placeholder="Example input: MK")
196
+
197
+ with gr.Row():
198
+ seq_length = gr.Slider(2, 200, value=30, step=1, label="Length", info="Choose the number of tokens to generate")
199
+ classes = ["SARS-CoV-2 Spike Protein (example)", 'Tetratricopeptide-like helical domain superfamily', 'CheY-like superfamily', 'S-adenosyl-L-methionine-dependent methyltransferase superfamily', 'Thioredoxin-like superfamily', "Non-Hemolytic" ,"Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"]
200
+ protein_property = gr.Dropdown(classes, label="Class")
201
+
202
+ with gr.Row():
203
+ btn = gr.Button("Submit")
204
+
205
+ with gr.Row():
206
+ output_text = gr.Textbox(label="Generated protein sequence will appear here")
207
+
208
+ with gr.Row():
209
+
210
+ infer_time = gr.Number(label="Inference Time (s)", precision=2)
211
+ tokens_per_sec = gr.Number(label="Tokens/sec", precision=2)
212
+
213
+ with gr.Row():
214
+ btn_vis = gr.Button("Visualize")
215
+
216
+ with gr.Row():
217
+ structure_visual = gr.HTML()
218
+
219
+ btn.click(generate_protein_sequence, inputs=[input_text, seq_length, protein_property], outputs=[output_text, infer_time, tokens_per_sec])
220
+
221
+ btn_vis.click(predict_structure, inputs=output_text, outputs=[structure_visual])
222
+
223
+
224
+ # Run the Gradio interface
225
+ demo.launch()