esmfold / app.py
wwydmanski's picture
Update app.py
df06262
raw
history blame
2.76 kB
import gradio as gr
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
import torch
from logging import getLogger
logger = getLogger(__name__)
def convert_outputs_to_pdb(outputs):
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = outputs["atom37_atom_exists"]
pdbs = []
for i in range(outputs["aatype"].shape[0]):
aa = outputs["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = outputs["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=outputs["plddt"][i],
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
)
pdbs.append(to_pdb(pred))
return pdbs
def fold_prot_locally(sequence):
logger.info("Folding: " + sequence)
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
with torch.no_grad():
output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
return pdb
sample_code = """
## Sample usage
from gradio_client import Client
client = Client("https://wwydmanski-esmfold.hf.space/")
def fold_huggingface(sequence, fname=None):
    result = client.predict(
sequence, # str in 'sequence' Textbox component
api_name="/predict")
result = eval(result)[0]
if fname is None:
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".pdb", prefix="esmfold_") as fp:
fp.write(result)
fp.flush()
return fp.name
else:
with open(fname, "w") as fp:
fp.write(result)
fp.flush()
return fname
pdb_fname = fold_huggingface("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN")
"""
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True).cuda()
model.esm = model.esm.half()
torch.backends.cuda.matmul.allow_tf32 = True
iface = gr.Interface(fn=fold_prot_locally, inputs="text", outputs="text", article=sample_code, title="ESMFold")
iface.launch()