aayush-shah
commited on
Commit
•
0711b9c
1
Parent(s):
9b56526
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
from transformers import GenerationConfig, pipeline, AutoTokenizer, AutoModelForCausalLM, EsmForProteinFolding
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
import subprocess
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
import gradio as gr
|
11 |
+
from time import time
|
12 |
+
|
13 |
+
model_id = "Esperanto/Protein-Llama-3-8B"
|
14 |
+
#Loading the fine-tuned LLaMA 3 model
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
model_id,
|
17 |
+
torch_dtype=torch.float16,
|
18 |
+
low_cpu_mem_usage=True
|
19 |
+
)
|
20 |
+
|
21 |
+
#loading the tokenizer
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
23 |
+
|
24 |
+
tokenizer.pad_token = tokenizer.eos_token
|
25 |
+
tokenizer.padding_side = "left"
|
26 |
+
|
27 |
+
#Creating the pipeline for generation
|
28 |
+
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
29 |
+
|
30 |
+
|
31 |
+
# Loading the ESM Model
|
32 |
+
esm_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
|
33 |
+
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
|
34 |
+
|
35 |
+
esm_model.to(device)
|
36 |
+
|
37 |
+
#Ensures that final output contains only valid amino acids
|
38 |
+
def clean_protein_sequence(protein_seq):
|
39 |
+
# Valid amino acid characters
|
40 |
+
valid_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
41 |
+
|
42 |
+
# Filter out any characters that are not valid amino acids
|
43 |
+
cleaned_seq = ''.join([char for char in protein_seq if char in valid_amino_acids])
|
44 |
+
|
45 |
+
return cleaned_seq
|
46 |
+
|
47 |
+
#convert pLDDT to percentage
|
48 |
+
def modify_b_factors(pdb_content, multiplier):
|
49 |
+
modified_pdb = []
|
50 |
+
for line in pdb_content.split('\n'):
|
51 |
+
if line.startswith("ATOM"):
|
52 |
+
b_factor = float(line[60:66].strip())
|
53 |
+
new_b_factor = b_factor * multiplier
|
54 |
+
new_line = f"{line[:60]}{new_b_factor:6.2f}{line[66:]}"
|
55 |
+
modified_pdb.append(new_line)
|
56 |
+
else:
|
57 |
+
modified_pdb.append(line)
|
58 |
+
return "\n".join(modified_pdb)
|
59 |
+
|
60 |
+
#saves the structure output from ESMFold as a PDB file in a temporary folder
|
61 |
+
def save_pdb(input_sequence):
|
62 |
+
inputs = esm_tokenizer([input_sequence], return_tensors="pt", add_special_tokens=False)
|
63 |
+
inputs = inputs.to(device)
|
64 |
+
with torch.no_grad():
|
65 |
+
outputs = esm_model(**inputs)
|
66 |
+
pdb_string_unscaled = esm_model.output_to_pdb(outputs)[0]
|
67 |
+
pdb_string = modify_b_factors(pdb_string_unscaled, 100)
|
68 |
+
plddt_values = outputs.plddt.tolist()[0][0]
|
69 |
+
plddt_values = [round(value * 100, 2) for value in plddt_values]
|
70 |
+
file_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb")
|
71 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
72 |
+
with open(file_path, "w") as f:
|
73 |
+
f.write(pdb_string)
|
74 |
+
|
75 |
+
return np.mean(plddt_values)
|
76 |
+
|
77 |
+
#reads the PDB file
|
78 |
+
def read_prot(molpath):
|
79 |
+
with open(molpath, "r") as fp:
|
80 |
+
lines = fp.readlines()
|
81 |
+
mol = ""
|
82 |
+
for l in lines:
|
83 |
+
mol += l
|
84 |
+
return mol
|
85 |
+
|
86 |
+
|
87 |
+
def protein_visual_html(input_pdb):
|
88 |
+
|
89 |
+
mol = read_prot(input_pdb)
|
90 |
+
|
91 |
+
x = (
|
92 |
+
"""<!DOCTYPE html>
|
93 |
+
<html>
|
94 |
+
<head>
|
95 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
96 |
+
<style>
|
97 |
+
body{
|
98 |
+
font-family:sans-serif
|
99 |
+
}
|
100 |
+
.mol-container {
|
101 |
+
width: 100%;
|
102 |
+
height: 600px;
|
103 |
+
position: relative;
|
104 |
+
}
|
105 |
+
.mol-container select{
|
106 |
+
background-image:None;
|
107 |
+
}
|
108 |
+
</style>
|
109 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
110 |
+
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
|
111 |
+
</head>
|
112 |
+
<body>
|
113 |
+
<div id="container" class="mol-container"></div>
|
114 |
+
|
115 |
+
<script>
|
116 |
+
let pdb = `""" + mol + """`
|
117 |
+
|
118 |
+
$(document).ready(function () {
|
119 |
+
let element = $("#container");
|
120 |
+
let config = { backgroundColor: "white" };
|
121 |
+
let viewer = $3Dmol.createViewer(element, config);
|
122 |
+
viewer.addModel(pdb, "pdb");
|
123 |
+
viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } });
|
124 |
+
viewer.zoomTo();
|
125 |
+
viewer.render();
|
126 |
+
viewer.zoom(0.8, 2000);
|
127 |
+
})
|
128 |
+
</script>
|
129 |
+
</body></html>"""
|
130 |
+
)
|
131 |
+
|
132 |
+
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
|
133 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
134 |
+
allow-scripts allow-same-origin allow-popups
|
135 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
136 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
|
137 |
+
|
138 |
+
|
139 |
+
def predict_structure(input_sequence):
|
140 |
+
#Hard coding the SARS-CoV 2 protein sequence and structure for instant demo purposes
|
141 |
+
if input_sequence == 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI':
|
142 |
+
return protein_visual_html('Protein-Llama-3-8B-Gradio/sars_cov_2_6vxx.pdb')
|
143 |
+
else:
|
144 |
+
plddt = save_pdb(input_sequence)
|
145 |
+
#Creating HTML visualization for the PDB file stores in temporary folder
|
146 |
+
pdb_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb")
|
147 |
+
return protein_visual_html(pdb_path)
|
148 |
+
|
149 |
+
def generate_protein_sequence(sequence, seq_length, property=''):
|
150 |
+
enzymes = ["Non-Hemolytic", "Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"]
|
151 |
+
start_time = time()
|
152 |
+
|
153 |
+
if property is None:
|
154 |
+
input_prompt = 'Seq=<' + sequence
|
155 |
+
elif property == 'SARS-CoV-2 Spike Protein (example)':
|
156 |
+
cleaned_seq = 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI'
|
157 |
+
end_time = time()
|
158 |
+
max_memory_used = 0
|
159 |
+
return cleaned_seq, end_time - start_time, max_memory_used, 0
|
160 |
+
elif property in enzymes:
|
161 |
+
input_prompt = '[Generate ' + property.lower() + ' protein] ' + 'Seq=<' + sequence
|
162 |
+
else:
|
163 |
+
input_prompt = '[Generate ' + property + ' protein] ' + 'Seq=<' + sequence
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
start_time = time()
|
168 |
+
protein_seq = generator(input_prompt, temperature=0.5,
|
169 |
+
top_k=40,
|
170 |
+
top_p=0.9,
|
171 |
+
do_sample=True,
|
172 |
+
repetition_penalty=1.2,
|
173 |
+
max_new_tokens=seq_length,
|
174 |
+
num_return_sequences=1)[0]["generated_text"]
|
175 |
+
|
176 |
+
end_time = time()
|
177 |
+
|
178 |
+
start_idx = protein_seq.find('Seq=<')
|
179 |
+
end_idx = protein_seq.find('>', start_idx)
|
180 |
+
protein_seq = protein_seq[start_idx:end_idx]
|
181 |
+
cleaned_seq = clean_protein_sequence(protein_seq)
|
182 |
+
tokens = tokenizer.encode(cleaned_seq, add_special_tokens=False)
|
183 |
+
tokens_per_second = len(tokens) / (end_time - start_time)
|
184 |
+
|
185 |
+
return cleaned_seq, end_time - start_time, tokens_per_second
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
# Create the Gradio interface
|
190 |
+
|
191 |
+
with gr.Blocks() as demo:
|
192 |
+
gr.Markdown("Interactive protein sequence generation and visualization")
|
193 |
+
|
194 |
+
with gr.Row():
|
195 |
+
input_text = gr.Textbox(label="Enter starting amino acids for protein sequence generation", placeholder="Example input: MK")
|
196 |
+
|
197 |
+
with gr.Row():
|
198 |
+
seq_length = gr.Slider(2, 200, value=30, step=1, label="Length", info="Choose the number of tokens to generate")
|
199 |
+
classes = ["SARS-CoV-2 Spike Protein (example)", 'Tetratricopeptide-like helical domain superfamily', 'CheY-like superfamily', 'S-adenosyl-L-methionine-dependent methyltransferase superfamily', 'Thioredoxin-like superfamily', "Non-Hemolytic" ,"Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"]
|
200 |
+
protein_property = gr.Dropdown(classes, label="Class")
|
201 |
+
|
202 |
+
with gr.Row():
|
203 |
+
btn = gr.Button("Submit")
|
204 |
+
|
205 |
+
with gr.Row():
|
206 |
+
output_text = gr.Textbox(label="Generated protein sequence will appear here")
|
207 |
+
|
208 |
+
with gr.Row():
|
209 |
+
|
210 |
+
infer_time = gr.Number(label="Inference Time (s)", precision=2)
|
211 |
+
tokens_per_sec = gr.Number(label="Tokens/sec", precision=2)
|
212 |
+
|
213 |
+
with gr.Row():
|
214 |
+
btn_vis = gr.Button("Visualize")
|
215 |
+
|
216 |
+
with gr.Row():
|
217 |
+
structure_visual = gr.HTML()
|
218 |
+
|
219 |
+
btn.click(generate_protein_sequence, inputs=[input_text, seq_length, protein_property], outputs=[output_text, infer_time, tokens_per_sec])
|
220 |
+
|
221 |
+
btn_vis.click(predict_structure, inputs=output_text, outputs=[structure_visual])
|
222 |
+
|
223 |
+
|
224 |
+
# Run the Gradio interface
|
225 |
+
demo.launch()
|