|
import os |
|
import torch |
|
import transformers |
|
from transformers import GenerationConfig, pipeline, AutoTokenizer, AutoModelForCausalLM, EsmForProteinFolding |
|
import os |
|
import tempfile |
|
import subprocess |
|
import pandas as pd |
|
import numpy as np |
|
import gradio as gr |
|
from time import time |
|
|
|
model_id = "Esperanto/Protein-Llama-3-8B" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
|
|
|
|
generator = pipeline('text-generation', model=model, tokenizer=tokenizer) |
|
|
|
|
|
|
|
esm_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") |
|
esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") |
|
|
|
esm_model.to(device) |
|
|
|
|
|
def clean_protein_sequence(protein_seq): |
|
|
|
valid_amino_acids = "ACDEFGHIKLMNPQRSTVWY" |
|
|
|
|
|
cleaned_seq = ''.join([char for char in protein_seq if char in valid_amino_acids]) |
|
|
|
return cleaned_seq |
|
|
|
|
|
def modify_b_factors(pdb_content, multiplier): |
|
modified_pdb = [] |
|
for line in pdb_content.split('\n'): |
|
if line.startswith("ATOM"): |
|
b_factor = float(line[60:66].strip()) |
|
new_b_factor = b_factor * multiplier |
|
new_line = f"{line[:60]}{new_b_factor:6.2f}{line[66:]}" |
|
modified_pdb.append(new_line) |
|
else: |
|
modified_pdb.append(line) |
|
return "\n".join(modified_pdb) |
|
|
|
|
|
def save_pdb(input_sequence): |
|
inputs = esm_tokenizer([input_sequence], return_tensors="pt", add_special_tokens=False) |
|
inputs = inputs.to(device) |
|
with torch.no_grad(): |
|
outputs = esm_model(**inputs) |
|
pdb_string_unscaled = esm_model.output_to_pdb(outputs)[0] |
|
pdb_string = modify_b_factors(pdb_string_unscaled, 100) |
|
plddt_values = outputs.plddt.tolist()[0][0] |
|
plddt_values = [round(value * 100, 2) for value in plddt_values] |
|
file_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb") |
|
os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
with open(file_path, "w") as f: |
|
f.write(pdb_string) |
|
|
|
return np.mean(plddt_values) |
|
|
|
|
|
def read_prot(molpath): |
|
with open(molpath, "r") as fp: |
|
lines = fp.readlines() |
|
mol = "" |
|
for l in lines: |
|
mol += l |
|
return mol |
|
|
|
|
|
def protein_visual_html(input_pdb): |
|
|
|
mol = read_prot(input_pdb) |
|
|
|
x = ( |
|
"""<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> |
|
<style> |
|
body{ |
|
font-family:sans-serif |
|
} |
|
.mol-container { |
|
width: 100%; |
|
height: 600px; |
|
position: relative; |
|
} |
|
.mol-container select{ |
|
background-image:None; |
|
} |
|
</style> |
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> |
|
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> |
|
</head> |
|
<body> |
|
<div id="container" class="mol-container"></div> |
|
|
|
<script> |
|
let pdb = `""" + mol + """` |
|
|
|
$(document).ready(function () { |
|
let element = $("#container"); |
|
let config = { backgroundColor: "white" }; |
|
let viewer = $3Dmol.createViewer(element, config); |
|
viewer.addModel(pdb, "pdb"); |
|
viewer.getModel(0).setStyle({}, { cartoon: { color:"spectrum" } }); |
|
viewer.zoomTo(); |
|
viewer.render(); |
|
viewer.zoom(0.8, 2000); |
|
}) |
|
</script> |
|
</body></html>""" |
|
) |
|
|
|
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
|
def predict_structure(input_sequence): |
|
|
|
if input_sequence == 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI': |
|
return protein_visual_html('Protein-Llama-3-8B-Gradio/sars_cov_2_6vxx.pdb') |
|
else: |
|
plddt = save_pdb(input_sequence) |
|
|
|
pdb_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb") |
|
return protein_visual_html(pdb_path) |
|
|
|
def generate_protein_sequence(sequence, seq_length, property=''): |
|
enzymes = ["Non-Hemolytic", "Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"] |
|
start_time = time() |
|
|
|
if property is None: |
|
input_prompt = 'Seq=<' + sequence |
|
elif property == 'SARS-CoV-2 Spike Protein (example)': |
|
cleaned_seq = 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI' |
|
end_time = time() |
|
max_memory_used = 0 |
|
return cleaned_seq, end_time - start_time, max_memory_used, 0 |
|
elif property in enzymes: |
|
input_prompt = '[Generate ' + property.lower() + ' protein] ' + 'Seq=<' + sequence |
|
else: |
|
input_prompt = '[Generate ' + property + ' protein] ' + 'Seq=<' + sequence |
|
|
|
|
|
|
|
start_time = time() |
|
protein_seq = generator(input_prompt, temperature=0.5, |
|
top_k=40, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.2, |
|
max_new_tokens=seq_length, |
|
num_return_sequences=1)[0]["generated_text"] |
|
|
|
end_time = time() |
|
|
|
start_idx = protein_seq.find('Seq=<') |
|
end_idx = protein_seq.find('>', start_idx) |
|
protein_seq = protein_seq[start_idx:end_idx] |
|
cleaned_seq = clean_protein_sequence(protein_seq) |
|
tokens = tokenizer.encode(cleaned_seq, add_special_tokens=False) |
|
tokens_per_second = len(tokens) / (end_time - start_time) |
|
|
|
return cleaned_seq, end_time - start_time, tokens_per_second |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Interactive protein sequence generation and visualization") |
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox(label="Enter starting amino acids for protein sequence generation", placeholder="Example input: MK") |
|
|
|
with gr.Row(): |
|
seq_length = gr.Slider(2, 200, value=30, step=1, label="Length", info="Choose the number of tokens to generate") |
|
classes = ["SARS-CoV-2 Spike Protein (example)", 'Tetratricopeptide-like helical domain superfamily', 'CheY-like superfamily', 'S-adenosyl-L-methionine-dependent methyltransferase superfamily', 'Thioredoxin-like superfamily', "Non-Hemolytic" ,"Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"] |
|
protein_property = gr.Dropdown(classes, label="Class") |
|
|
|
with gr.Row(): |
|
btn = gr.Button("Submit") |
|
|
|
with gr.Row(): |
|
output_text = gr.Textbox(label="Generated protein sequence will appear here") |
|
|
|
with gr.Row(): |
|
|
|
infer_time = gr.Number(label="Inference Time (s)", precision=2) |
|
tokens_per_sec = gr.Number(label="Tokens/sec", precision=2) |
|
|
|
with gr.Row(): |
|
btn_vis = gr.Button("Visualize") |
|
|
|
with gr.Row(): |
|
structure_visual = gr.HTML() |
|
|
|
btn.click(generate_protein_sequence, inputs=[input_text, seq_length, protein_property], outputs=[output_text, infer_time, tokens_per_sec]) |
|
|
|
btn_vis.click(predict_structure, inputs=output_text, outputs=[structure_visual]) |
|
|
|
|
|
|
|
demo.launch() |