import os import torch import transformers from transformers import GenerationConfig, pipeline, AutoTokenizer, AutoModelForCausalLM, EsmForProteinFolding import os import tempfile import subprocess import pandas as pd import numpy as np import gradio as gr from time import time model_id = "Esperanto/Protein-Llama-3-8B" #Loading the fine-tuned LLaMA 3 model model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True ) #loading the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" #Creating the pipeline for generation generator = pipeline('text-generation', model=model, tokenizer=tokenizer) # Loading the ESM Model esm_model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") esm_model.to(device) #Ensures that final output contains only valid amino acids def clean_protein_sequence(protein_seq): # Valid amino acid characters valid_amino_acids = "ACDEFGHIKLMNPQRSTVWY" # Filter out any characters that are not valid amino acids cleaned_seq = ''.join([char for char in protein_seq if char in valid_amino_acids]) return cleaned_seq #convert pLDDT to percentage def modify_b_factors(pdb_content, multiplier): modified_pdb = [] for line in pdb_content.split('\n'): if line.startswith("ATOM"): b_factor = float(line[60:66].strip()) new_b_factor = b_factor * multiplier new_line = f"{line[:60]}{new_b_factor:6.2f}{line[66:]}" modified_pdb.append(new_line) else: modified_pdb.append(line) return "\n".join(modified_pdb) #saves the structure output from ESMFold as a PDB file in a temporary folder def save_pdb(input_sequence): inputs = esm_tokenizer([input_sequence], return_tensors="pt", add_special_tokens=False) inputs = inputs.to(device) with torch.no_grad(): outputs = esm_model(**inputs) pdb_string_unscaled = esm_model.output_to_pdb(outputs)[0] pdb_string = modify_b_factors(pdb_string_unscaled, 100) plddt_values = outputs.plddt.tolist()[0][0] plddt_values = [round(value * 100, 2) for value in plddt_values] file_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb") os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w") as f: f.write(pdb_string) return np.mean(plddt_values) #reads the PDB file def read_prot(molpath): with open(molpath, "r") as fp: lines = fp.readlines() mol = "" for l in lines: mol += l return mol def protein_visual_html(input_pdb): mol = read_prot(input_pdb) x = ( """
""" ) return f"""""" def predict_structure(input_sequence): #Hard coding the SARS-CoV 2 protein sequence and structure for instant demo purposes if input_sequence == 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI': return protein_visual_html('Protein-Llama-3-8B-Gradio/sars_cov_2_6vxx.pdb') else: plddt = save_pdb(input_sequence) #Creating HTML visualization for the PDB file stores in temporary folder pdb_path = os.path.join('Protein-Llama-3-8B-Gradio/temporary_folder', f"protein.pdb") return protein_visual_html(pdb_path) def generate_protein_sequence(sequence, seq_length, property=''): enzymes = ["Non-Hemolytic", "Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"] start_time = time() if property is None: input_prompt = 'Seq=<' + sequence elif property == 'SARS-CoV-2 Spike Protein (example)': cleaned_seq = 'SNASADAQSFLNRVCGVSAARLTPCGTGTSTDVVYRAFDIYNDKVAGFAKFLKTNCCRFQEKDEDDNLIDSYFVVKRHTFSNYQHEETIYNLLKDCPAVAKHDFFKFRIDGDMVPHISRQRLTKYTMADLVYALRHFDEGNCDTLKEILVTYNCCDDDYFNKKDWYDFVENPDILRVYANLGERVRQALLKTVQFCDAMRNAGIVGVLTLDNQDLNGNWYDFGDFIQTTPGSGVPVVDSYYSLLMPILTLTRALTAESHVDTDLTKPYIKWDLLKYDFTEERLKLFDRYFKYWDQTYHPNCVNCLDDRCILHCANFNVLFSTVFPPTSFGPLVRKIFVDGVPFVVSTGYHFRELGVVHNQDVNLHSSRLSFKELLVYAADPAMHAASGNLLLDKRTTCFSVAALTNNVAFQTVKPGNFNKDFYDFAVSKGFFKEGSSVELKHFFFAQDGNAAISDYDYYRYNLPTMCDIRQLLFVVEVVDKYFDCYDGGCINANQVI' end_time = time() max_memory_used = 0 return cleaned_seq, end_time - start_time, max_memory_used, 0 elif property in enzymes: input_prompt = '[Generate ' + property.lower() + ' protein] ' + 'Seq=<' + sequence else: input_prompt = '[Generate ' + property + ' protein] ' + 'Seq=<' + sequence start_time = time() protein_seq = generator(input_prompt, temperature=0.5, top_k=40, top_p=0.9, do_sample=True, repetition_penalty=1.2, max_new_tokens=seq_length, num_return_sequences=1)[0]["generated_text"] end_time = time() start_idx = protein_seq.find('Seq=<') end_idx = protein_seq.find('>', start_idx) protein_seq = protein_seq[start_idx:end_idx] cleaned_seq = clean_protein_sequence(protein_seq) tokens = tokenizer.encode(cleaned_seq, add_special_tokens=False) tokens_per_second = len(tokens) / (end_time - start_time) return cleaned_seq, end_time - start_time, tokens_per_second # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("Interactive protein sequence generation and visualization") with gr.Row(): input_text = gr.Textbox(label="Enter starting amino acids for protein sequence generation", placeholder="Example input: MK") with gr.Row(): seq_length = gr.Slider(2, 200, value=30, step=1, label="Length", info="Choose the number of tokens to generate") classes = ["SARS-CoV-2 Spike Protein (example)", 'Tetratricopeptide-like helical domain superfamily', 'CheY-like superfamily', 'S-adenosyl-L-methionine-dependent methyltransferase superfamily', 'Thioredoxin-like superfamily', "Non-Hemolytic" ,"Soluble", "Oxidoreductase", "Transferase", "Hydrolase", "Lyase", "Isomerase", "Ligase", "Translocase"] protein_property = gr.Dropdown(classes, label="Class") with gr.Row(): btn = gr.Button("Submit") with gr.Row(): output_text = gr.Textbox(label="Generated protein sequence will appear here") with gr.Row(): infer_time = gr.Number(label="Inference Time (s)", precision=2) tokens_per_sec = gr.Number(label="Tokens/sec", precision=2) with gr.Row(): btn_vis = gr.Button("Visualize") with gr.Row(): structure_visual = gr.HTML() btn.click(generate_protein_sequence, inputs=[input_text, seq_length, protein_property], outputs=[output_text, infer_time, tokens_per_sec]) btn_vis.click(predict_structure, inputs=output_text, outputs=[structure_visual]) # Run the Gradio interface demo.launch()