|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
import spaces |
|
import torch.nn.functional as F |
|
import requests |
|
import copy |
|
import torch |
|
from PIL import Image, ImageDraw, ImageFont |
|
import io |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
|
|
import random |
|
import numpy as np |
|
from esm import pretrained, FastaBatchedDataset |
|
|
|
|
|
models = { |
|
'facebook/esm2_t36_3B_UR50D': pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D').to("cuda").eval(), |
|
} |
|
|
|
processors = { |
|
'microsoft/Florence-2-large-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-large-ft', trust_remote_code=True), |
|
'microsoft/Florence-2-large': AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True), |
|
'microsoft/Florence-2-base-ft': AutoProcessor.from_pretrained('microsoft/Florence-2-base-ft', trust_remote_code=True), |
|
'microsoft/Florence-2-base': AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True), |
|
} |
|
|
|
|
|
DESCRIPTION = "Esm2 embedding" |
|
|
|
colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red', |
|
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue'] |
|
|
|
|
|
@spaces.GPU |
|
def run_example(protein, model_id='facebook/esm2_t36_3B_UR50D'): |
|
model_esm, alphabet = models[model_id] |
|
protein_name = 'protein_name' |
|
protein_seq = protein |
|
include = 'per_tok' |
|
repr_layers = [36] |
|
truncation_seq_length = 1024 |
|
toks_per_batch = 4096 |
|
print("start") |
|
dataset = FastaBatchedDataset([protein_name], [protein_seq]) |
|
print("dataset prepared") |
|
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) |
|
print("batches prepared") |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches |
|
) |
|
print(f"Read sequences") |
|
return_contacts = "contacts" in include |
|
|
|
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers) |
|
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers] |
|
|
|
with torch.no_grad(): |
|
for batch_idx, (labels, strs, toks) in enumerate(data_loader): |
|
print( |
|
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" |
|
) |
|
if torch.cuda.is_available(): |
|
toks = toks.to(device="cuda", non_blocking=True) |
|
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts) |
|
representations = { |
|
layer: t.to(device="cpu") for layer, t in out["representations"].items() |
|
} |
|
if return_contacts: |
|
contacts = out["contacts"].to(device="cpu") |
|
for i, label in enumerate(labels): |
|
result = {"label": label} |
|
truncate_len = min(truncation_seq_length, len(strs[i])) |
|
|
|
|
|
if "per_tok" in include: |
|
result["representations"] = { |
|
layer: t[i, 1: truncate_len + 1].clone() |
|
for layer, t in representations.items() |
|
} |
|
if "mean" in include: |
|
result["mean_representations"] = { |
|
layer: t[i, 1: truncate_len + 1].mean(0).clone() |
|
for layer, t in representations.items() |
|
} |
|
if "bos" in include: |
|
result["bos_representations"] = { |
|
layer: t[i, 0].clone() for layer, t in representations.items() |
|
} |
|
if return_contacts: |
|
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() |
|
esm_emb = result['representations'][36] |
|
''' |
|
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda') |
|
with torch.no_grad(): |
|
outputs = model_esm(**inputs) |
|
esm_emb = outputs.last_hidden_state.detach()[0] |
|
''' |
|
print("esm embedding generated") |
|
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t() |
|
torch.save(esm_emb, 'example.pt') |
|
return gr.File.update(value="example.pt", visible=True) |
|
|
|
css = """ |
|
#output { |
|
height: 500px; |
|
overflow: auto; |
|
border: 1px solid #ccc; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Tab(label="Esm2 embedding generation"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_protein = gr.Textbox(type="text", label="Upload sequence") |
|
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='microsoft/Florence-2-large') |
|
submit_btn = gr.Button(value="Submit") |
|
with gr.Column(): |
|
button = gr.Button("Export") |
|
pt = gr.File(interactive=False, visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
button.click(run_example, [input_protein, model_selector], pt) |
|
|
|
demo.launch(debug=True) |