File size: 10,558 Bytes
2b26389 72b0e49 2b26389 892748f 9b993cf f3ed046 1a0324b 2b26389 892748f e95deab 892748f 08b9eb6 3daa625 f3ed046 1a0324b 72b0e49 ca55ed1 3daa625 c8e59d5 3daa625 3705c34 77b966b 3705c34 77b966b 3705c34 cdf31f1 61cedea 2b26389 2bc812b 2b26389 892748f 2b26389 f3ed046 1167137 f3ed046 1167137 43b95d1 1167137 e95deab 1167137 2b26389 c3846ee 892748f c3846ee 892748f c3846ee 892748f c3846ee ca55ed1 c3846ee ca55ed1 c3846ee ca55ed1 c3846ee 2b26389 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
from lavis.models.base_model import FAPMConfig
import spaces
import gradio as gr
# from esm_scripts.extract import run_demo
from esm import pretrained, FastaBatchedDataset
from data.evaluate_data.utils import Ontology
import difflib
import re
# Load the model
# model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
# model.load_checkpoint("model/checkpoint_mf2.pth")
# model.to('cuda')
def get_model(type='Molecule Function'):
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
if type == 'Molecule Function':
model.load_checkpoint("model/checkpoint_mf2.pth")
model.to('cuda')
elif type == 'Biological Process':
model.load_checkpoint("model/checkpoint_bp1.pth")
model.to('cuda')
elif type == 'Cellar Component':
model.load_checkpoint("model/checkpoint_cc2.pth")
model.to('cuda')
return model
models = {
'Molecule Function': get_model('Molecule Function'),
'Biological Process': get_model('Biological Process'),
'Cellar Component': get_model('Cellar Component'),
}
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model_esm.to('cuda')
model_esm.eval()
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
go_des.columns = ['id', 'text']
go_des = go_des.dropna()
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
go_obo_set = set(go_des['id'].tolist())
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
GO_dict = dict(zip(go_des['text'], go_des['id']))
Func_dict = dict(zip(go_des['id'], go_des['text']))
# terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
choices = {x.lower(): x for x in choices_mf}
@spaces.GPU
def generate_caption(model_id, protein, prompt):
# Process the image and the prompt
# with open('/home/user/app/example.fasta', 'w') as f:
# f.write('>{}\n'.format("protein_name"))
# f.write('{}\n'.format(protein.strip()))
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
# model=model_esm, alphabet=alphabet,
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
protein_name = 'protein_name'
protein_seq = protein
include = 'per_tok'
repr_layers = [36]
truncation_seq_length = 1024
toks_per_batch = 4096
print("start")
dataset = FastaBatchedDataset([protein_name], [protein_seq])
print("dataset prepared")
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
print("batches prepared")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
)
print(f"Read sequences")
return_contacts = "contacts" in include
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
print(
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
)
if torch.cuda.is_available():
toks = toks.to(device="cuda", non_blocking=True)
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
representations = {
layer: t.to(device="cpu") for layer, t in out["representations"].items()
}
if return_contacts:
contacts = out["contacts"].to(device="cpu")
for i, label in enumerate(labels):
result = {"label": label}
truncate_len = min(truncation_seq_length, len(strs[i]))
# Call clone on tensors to ensure tensors are not views into a larger representation
# See https://github.com/pytorch/pytorch/issues/1995
if "per_tok" in include:
result["representations"] = {
layer: t[i, 1: truncate_len + 1].clone()
for layer, t in representations.items()
}
if "mean" in include:
result["mean_representations"] = {
layer: t[i, 1: truncate_len + 1].mean(0).clone()
for layer, t in representations.items()
}
if "bos" in include:
result["bos_representations"] = {
layer: t[i, 0].clone() for layer, t in representations.items()
}
if return_contacts:
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
esm_emb = result['representations'][36]
'''
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
outputs = model_esm(**inputs)
esm_emb = outputs.last_hidden_state.detach()[0]
'''
print("esm embedding generated")
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
if prompt is None:
prompt = 'none'
else:
prompt = prompt.lower()
samples = {'name': ['protein_name'],
'image': torch.unsqueeze(esm_emb, dim=0),
'text_input': ['none'],
'prompt': [prompt]}
model = models[model_id]
# Generate the output
prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
repetition_penalty=1.0)
x = prediction[0]
x = [eval(i) for i in x.split('; ')]
pred_terms = []
temp = []
for i in x:
txt = i[0]
prob = i[1]
sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
if len(sim_list) > 0:
t_standard = sim_list[0]
if t_standard not in temp:
pred_terms.append(t_standard+f'({prob})')
temp.append(t_standard)
res_str = "No available predictions for this protein! You can try the other two types of model or remove prompt."
if len(pred_terms) == 0:
return res_str
if model_id == 'Molecule Function':
res_str = f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}"
elif model_id == 'Biological Process':
res_str = f"Based on the given amino acid sequence, it is likely involved in the {', '.join(pred_terms)}"
elif model_id == 'Cellar Component':
res_str = f"Based on the given amino acid sequence, it's subcellular localization is within the {', '.join(pred_terms)}"
return res_str
# return "test"
# Define the FAPM interface
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information.
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
# iface = gr.Interface(
# fn=generate_caption,
# inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
# outputs=gr.Textbox(label="Generated description"),
# description=description
# )
# # Launch the interface
# iface.launch()
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(description)
with gr.Tab(label="Protein caption"):
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value='Molecule Function')
input_protein = gr.Textbox(type="text", label="Upload sequence")
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
# O14813 train index 127, 266, 738, 1060 test index 4
gr.Examples(
examples=[
["Molecule Function", "MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
["Molecule Function", "MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
["Molecule Function", "MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
["Molecule Function", 'MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
["Molecule Function", 'MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
["Molecule Function", 'MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
],
inputs=[model_selector, input_protein, prompt],
outputs=[output_text],
fn=generate_caption,
cache_examples=True,
label='Try examples'
)
submit_btn.click(generate_caption, [model_selector, input_protein, prompt], [output_text])
demo.launch(debug=True)
|