import gradio as gr from transformers import AutoTokenizer, EsmForProteinFolding from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 import torch from logging import getLogger logger = getLogger(__name__) def convert_outputs_to_pdb(outputs): final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} final_atom_positions = final_atom_positions.cpu().numpy() final_atom_mask = outputs["atom37_atom_exists"] pdbs = [] for i in range(outputs["aatype"].shape[0]): aa = outputs["aatype"][i] pred_pos = final_atom_positions[i] mask = final_atom_mask[i] resid = outputs["residue_index"][i] + 1 pred = OFProtein( aatype=aa, atom_positions=pred_pos, atom_mask=mask, residue_index=resid, b_factors=outputs["plddt"][i], chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, ) pdbs.append(to_pdb(pred)) return pdbs def fold_prot_locally(sequence): logger.info("Folding: " + sequence) tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() with torch.no_grad(): output = model(tokenized_input) pdb = convert_outputs_to_pdb(output) return pdb sample_code = """ ## Sample usage from gradio_client import Client client = Client("https://wwydmanski-esmfold.hf.space/") def fold_huggingface(sequence, fname=None):     result = client.predict( sequence, # str in 'sequence' Textbox component api_name="/predict") result = eval(result)[0] if fname is None: with tempfile.NamedTemporaryFile("w", delete=False, suffix=".pdb", prefix="esmfold_") as fp: fp.write(result) fp.flush() return fp.name else: with open(fname, "w") as fp: fp.write(result) fp.flush() return fname pdb_fname = fold_huggingface("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN") """ tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).cuda() model.esm = model.esm.half() torch.backends.cuda.matmul.allow_tf32 = True iface = gr.Interface(fn=fold_prot_locally, inputs="text", outputs="text", article=sample_code, title="ESMFold") iface.launch()