Last commit not found
import os | |
import torch | |
import torch.nn as nn | |
import pandas as pd | |
import torch.nn.functional as F | |
from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral | |
from lavis.models.base_model import FAPMConfig | |
import spaces | |
import gradio as gr | |
# from esm_scripts.extract import run_demo | |
from esm import pretrained, FastaBatchedDataset | |
from data.evaluate_data.utils import Ontology | |
import difflib | |
import re | |
from transformers import MistralForCausalLM | |
# Load the trained model | |
def get_model(type='Molecule Function'): | |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b') | |
if type == 'Molecule Function': | |
model.load_checkpoint("model/checkpoint_mf2.pth") | |
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu')) | |
model.to('cuda') | |
elif type == 'Biological Process': | |
model.load_checkpoint("model/checkpoint_bp1.pth") | |
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu')) | |
model.to('cuda') | |
elif type == 'Cellar Component': | |
model.load_checkpoint("model/checkpoint_cc2.pth") | |
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu')) | |
model.to('cuda') | |
return model | |
models = { | |
'Molecule Function': get_model('Molecule Function'), | |
'Biological Process': get_model('Biological Process'), | |
'Cellular Component': get_model('Cellar Component'), | |
} | |
# Load the mistral model | |
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16) | |
mistral_model.to('cuda') | |
# Load ESM2 model | |
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D') | |
model_esm.to('cuda') | |
model_esm.eval() | |
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True) | |
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None) | |
go_des.columns = ['id', 'text'] | |
go_des = go_des.dropna() | |
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x)) | |
go_obo_set = set(go_des['id'].tolist()) | |
go_des['text'] = go_des['text'].apply(lambda x: x.lower()) | |
GO_dict = dict(zip(go_des['text'], go_des['id'])) | |
Func_dict = dict(zip(go_des['id'], go_des['text'])) | |
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl') | |
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))] | |
choices_mf = {x.lower(): x for x in choices_mf} | |
terms_bp = pd.read_pickle('data/terms/bp_terms.pkl') | |
choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))] | |
choices_bp = {x.lower(): x for x in choices_bp} | |
terms_cc = pd.read_pickle('data/terms/cc_terms.pkl') | |
choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))] | |
choices_cc = {x.lower(): x for x in choices_cc} | |
choices = { | |
'Molecule Function': choices_mf, | |
'Biological Process': choices_bp, | |
'Cellular Component': choices_cc, | |
} | |
def generate_caption(protein, prompt): | |
# Process the image and the prompt | |
# with open('/home/user/app/example.fasta', 'w') as f: | |
# f.write('>{}\n'.format("protein_name")) | |
# f.write('{}\n'.format(protein.strip())) | |
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok") | |
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein, | |
# model=model_esm, alphabet=alphabet, | |
# include='per_tok', repr_layers=[36], truncation_seq_length=1024) | |
protein_name = 'protein_name' | |
protein_seq = protein | |
include = 'per_tok' | |
repr_layers = [36] | |
truncation_seq_length = 1024 | |
toks_per_batch = 4096 | |
# print("start") | |
dataset = FastaBatchedDataset([protein_name], [protein_seq]) | |
# print("dataset prepared") | |
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1) | |
# print("batches prepared") | |
data_loader = torch.utils.data.DataLoader( | |
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches | |
) | |
# print(f"Read sequences") | |
return_contacts = "contacts" in include | |
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers) | |
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers] | |
with torch.no_grad(): | |
for batch_idx, (labels, strs, toks) in enumerate(data_loader): | |
print( | |
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" | |
) | |
if torch.cuda.is_available(): | |
toks = toks.to(device="cuda", non_blocking=True) | |
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts) | |
representations = { | |
layer: t.to(device="cpu") for layer, t in out["representations"].items() | |
} | |
if return_contacts: | |
contacts = out["contacts"].to(device="cpu") | |
for i, label in enumerate(labels): | |
result = {"label": label} | |
truncate_len = min(truncation_seq_length, len(strs[i])) | |
# Call clone on tensors to ensure tensors are not views into a larger representation | |
# See https://github.com/pytorch/pytorch/issues/1995 | |
if "per_tok" in include: | |
result["representations"] = { | |
layer: t[i, 1: truncate_len + 1].clone() | |
for layer, t in representations.items() | |
} | |
if "mean" in include: | |
result["mean_representations"] = { | |
layer: t[i, 1: truncate_len + 1].mean(0).clone() | |
for layer, t in representations.items() | |
} | |
if "bos" in include: | |
result["bos_representations"] = { | |
layer: t[i, 0].clone() for layer, t in representations.items() | |
} | |
if return_contacts: | |
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone() | |
esm_emb = result['representations'][36] | |
''' | |
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda') | |
with torch.no_grad(): | |
outputs = model_esm(**inputs) | |
esm_emb = outputs.last_hidden_state.detach()[0] | |
''' | |
# print("esm embedding generated") | |
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda') | |
if prompt is None: | |
prompt = 'none' | |
else: | |
prompt = prompt.lower() | |
samples = {'name': ['protein_name'], | |
'image': torch.unsqueeze(esm_emb, dim=0), | |
'text_input': ['none'], | |
'prompt': [prompt]} | |
union_pred_terms = [] | |
for model_id in models.keys(): | |
model = models[model_id] | |
# Generate the output | |
prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., | |
repetition_penalty=1.0) | |
x = prediction[0] | |
x = [eval(i) for i in x.split('; ')] | |
pred_terms = [] | |
temp = [] | |
for i in x: | |
txt = i[0] | |
prob = i[1] | |
sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9) | |
if len(sim_list) > 0: | |
t_standard = sim_list[0] | |
if t_standard not in temp: | |
pred_terms.append(t_standard+f'({prob})') | |
temp.append(t_standard) | |
union_pred_terms.append(pred_terms) | |
if prompt == 'none': | |
res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!" | |
else: | |
res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!" | |
if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0: | |
return res_str | |
res_str = '' | |
if len(union_pred_terms[0]) != 0: | |
temp = ['- '+i+'\n' for i in union_pred_terms[0]] | |
res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of \n{''.join(temp)} \n" | |
if len(union_pred_terms[1]) != 0: | |
temp = ['- ' + i + '\n' for i in union_pred_terms[1]] | |
res_str += f"It is likely involved in the following process: \n{''.join(temp)} \n" | |
if len(union_pred_terms[2]) != 0: | |
temp = ['- ' + i + '\n' for i in union_pred_terms[2]] | |
res_str += f"It's subcellular localization is within the: \n{''.join(temp)}" | |
return res_str | |
def save_feedback(inputs): | |
print(inputs) | |
with open('feedback.txt', 'a+') as f: | |
f.write(inputs+'\n') | |
return "Thanks your advice!" | |
feedback_data = [] | |
def chatbot_respond(message, history=[]): | |
response = "yes" | |
return response, history + [(message, response)] | |
# Functions to handle like/dislike | |
def upvote(vote_id): | |
feedback_data.append((vote_id, "upvote")) | |
print(f"Current feedback data: {feedback_data}") | |
return "You liked this prediction" | |
def downvote(vote_id): | |
feedback_data.append((vote_id, "downvote")) | |
print(f"Current feedback data: {feedback_data}") | |
return "You disliked this prediction" | |
# Define the FAPM interface | |
description = """Quick demonstration of the FAPM model for protein function prediction. Upload an protein sequence to generate a function description. Modify the Prompt to provide the taxonomy information. | |
Our paper is available at [BioRxiv](https://www.biorxiv.org/content/10.1101/2024.05.07.593067v1) | |
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main). | |
Thanks for the support from ProtonUnfold Tech. Co., Ltd (https://www.protonunfold.com/).""" | |
# iface = gr.Interface( | |
# fn=generate_caption, | |
# inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")], | |
# outputs=gr.Textbox(label="Generated description"), | |
# description=description | |
# ) | |
# # Launch the interface | |
# iface.launch() | |
css = """ | |
#output { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
/* Style for the upvote button */ | |
.upvote-button { | |
width: 500px; /* Set button width */ | |
height: 50px; /* Set button height */ | |
font-size: 20px; /* Set font size */ | |
background-color: #d4edda; /* Set background color */ | |
border-radius: 5px; /* Rounded corners */ | |
} | |
/* Style for the downvote button */ | |
.downvote-button { | |
width: 50px; /* Set button width */ | |
height: 50px; /* Set button height */ | |
font-size: 20px; /* Set font size */ | |
background-color: #f8d7da; /* Set background color */ | |
border-radius: 5px; /* Rounded corners */ | |
} | |
.feedback { | |
width: 40px; /* Set button width */ | |
height: 40px; /* Set button height */ | |
font-size: 16px; /* Set font size */ | |
background-color: #f8d7da; /* Set background color */ | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(description) | |
# vote_id = gr.State(0) | |
with gr.Tab(label="Protein caption"): | |
with gr.Row(): | |
with gr.Column(): | |
input_protein = gr.Textbox(type="text", label="Upload sequence") | |
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
# output_text = gr.Textbox(label="Output Text") | |
with gr.Accordion('Prediction:', open=True): | |
output_markdown = gr.Markdown(label="Output") | |
with gr.Row(): | |
with gr.Column(): | |
upvote_button = gr.Button("π") | |
with gr.Column(): | |
downvote_button = gr.Button("π") | |
with gr.Column(): | |
vote_markdown = gr.Markdown(label="Output") | |
with gr.Column(): | |
vote_temp = gr.Markdown() | |
with gr.Row(): | |
inputs = gr.Textbox(type="text", label="Your feedback") | |
feedback_markdown = gr.Markdown(label="Output") | |
with gr.Row(): | |
with gr.Column(): | |
feedback_btn = gr.Button(value="Feedback") | |
# feedback_temp1 = gr.Markdown(label="Output") | |
with gr.Column(): | |
feedback_temp2 = gr.Markdown(label="Output") | |
gr.Examples( | |
examples=[ | |
["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''], | |
["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''], | |
["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'], | |
['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'], | |
['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'], | |
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'], | |
], | |
inputs=[input_protein, prompt], | |
outputs=[output_markdown], | |
fn=generate_caption, | |
cache_examples=True, | |
label='Try examples' | |
) | |
submit_btn.click(generate_caption, [input_protein, prompt], [output_markdown]) | |
upvote_button.click(upvote, input_protein, vote_markdown) | |
downvote_button.click(downvote, input_protein, vote_markdown) | |
feedback_btn.click(save_feedback, [inputs], [feedback_markdown]) | |
demo.launch(debug=True) |