import gradio as gr from transformers import AutoTokenizer, EsmForProteinFolding from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 import torch from logging import getLogger logger = getLogger(__name__) def convert_outputs_to_pdb(outputs): final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} final_atom_positions = final_atom_positions.cpu().numpy() final_atom_mask = outputs["atom37_atom_exists"] pdbs = [] for i in range(outputs["aatype"].shape[0]): aa = outputs["aatype"][i] pred_pos = final_atom_positions[i] mask = final_atom_mask[i] resid = outputs["residue_index"][i] + 1 pred = OFProtein( aatype=aa, atom_positions=pred_pos, atom_mask=mask, residue_index=resid, b_factors=outputs["plddt"][i], chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, ) pdbs.append(to_pdb(pred)) return pdbs[0] def fold_prot_locally(sequence): logger.info("Folding: " + sequence) tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() with torch.no_grad(): output = model(tokenized_input) pdb = convert_outputs_to_pdb(output) return pdb def suggest(option): if option == "Plastic degradation protein": suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ" elif option == "Antifreeze protein": suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH" elif option == "AI Generated protein": suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS" elif option == "7-bladed propeller fold": suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK" else: suggestion = "" return suggestion sample_code = """ from gradio_client import Client import json client = Client("https://wwydmanski-esmfold.hf.space/") def fold_huggingface(sequence, fname=None): result = client.predict( sequence, # str in 'sequence' Textbox component api_name="/predict") if fname is None: with tempfile.NamedTemporaryFile("w", delete=False, suffix=".pdb", prefix="esmfold_") as fp: fp.write(result) fp.flush() return fp.name else: with open(fname, "w") as fp: fp.write(result) fp.flush() return fname pdb_fname = fold_huggingface("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN") """ tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).cuda() model.esm = model.esm.half() torch.backends.cuda.matmul.allow_tf32 = True with gr.Blocks() as demo: gr.Markdown("# ESMFold") with gr.Row(): with gr.Column(): inp = gr.Textbox(lines=1, label="Sequence") name = gr.Dropdown(label="Choose a Sample Protein", value="Plastic degradation protein", choices=["Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"]) btn = gr.Button("🔬 Predict Structure ") with gr.Row(): with gr.Column(): gr.Markdown("## Sample code") gr.Code(sample_code, label="Sample usage", language="python", interactive=False) with gr.Row(): with gr.Column(): gr.Markdown("## Output") out = gr.Code(label="Output", interactive=False) name.change(fn=suggest, inputs=name, outputs=inp) btn.click(fold_prot_locally, inputs=[inp], outputs=[out]) demo.launch()