File size: 5,530 Bytes
08b9eb6 72b0e49 9b993cf 1a0324b 3705c34 72b0e49 0137aa6 72b0e49 1a0324b 72b0e49 08b9eb6 72b0e49 08b9eb6 1a0324b 72b0e49 1a0324b 3660015 0b7981d 3660015 0b7981d 3660015 1a0324b 0b7981d d376f39 0b7981d c8e59d5 d376f39 c8e59d5 0b7981d d376f39 0b7981d 1a0324b 0b7981d 1a0324b 0b7981d 3705c34 77b966b 3705c34 77b966b 3705c34 cdf31f1 61cedea 72b0e49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
import torch.nn.functional as F
import requests
import copy
import torch
from PIL import Image, ImageDraw, ImageFont
import io
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
import numpy as np
from esm import pretrained, FastaBatchedDataset
models = {
'facebook/esm2_t36_3B_UR50D': pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D').to("cuda").eval(),
}
processors = {
'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True),
'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True),
'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True),
'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True),
}
DESCRIPTION = "Esm2 embedding"
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
@spaces.GPU
def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'):
model_esm, alphabet = models[model_id]
protein_name = 'protein_name'
protein_seq = protein
include = 'per_tok'
repr_layers = [36]
truncation_seq_length = 1024
toks_per_batch = 4096
print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1: truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t()
torch.save(esm_emb, 'example.pt')
return gr.File.update(value="example.pt", visible=True)
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Esm2 embedding generation"):
with gr.Row():
with gr.Column():
input_protein = gr.Textbox(type="text", label="Upload sequence")
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large')
submit_btn = gr.Button(value="Submit")
with gr.Column():
button = gr.Button("Export")
pt = gr.File(interactive=False, visible=False)
# gr.Examples(
# examples=[
# ["image1.jpg", 'Object Detection'],
# ],
# inputs=[input_img, task_prompt],
# outputs=[output_text, output_img],
# fn=process_image,
# cache_examples=True,
# label='Try examples'
# )
button.click(run_example, [input_protein, model_selector], pt)
demo.launch(debug=True) |