SagiPolaczek's picture
Upload 3 files
ff01709 verified
import gradio as gr
import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal
from mammal.keys import *
model_path = "ibm/biomed.omics.bl.sm.ma-ted-458m"
# Load Model
model = Mammal.from_pretrained(model_path)
model.eval()
# Load Tokenizer
tokenizer_op = ModularTokenizerOp.from_pretrained(model_path)
# token for positive binding
positive_token_id = tokenizer_op.get_token_id("<1>")
# Default input proteins
protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"
protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ"
def format_prompt(prot1, prot2):
# Formatting prompt to match pre-training syntax
return f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
def run_prompt(prompt):
# Create and load sample
sample_dict = dict()
sample_dict[ENCODER_INPUTS_STR] = prompt
# Tokenize
sample_dict = tokenizer_op(
sample_dict=sample_dict,
key_in=ENCODER_INPUTS_STR,
key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
)
sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
sample_dict[ENCODER_INPUTS_TOKENS]
)
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
)
# Generate Prediction
batch_dict = model.generate(
[sample_dict],
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=5,
)
# Get output
generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
score = batch_dict["model.out.scores"][0][1][positive_token_id].item()
return generated_output, score
def create_and_run_prompt(prot1, prot2):
prompt = format_prompt(prot1, prot2)
res = prompt, *run_prompt(prompt=prompt)
return res
def create_application():
markup_text = f"""
# Mammal based Protein-Protein Interaction (PPI) demonstration
Given two protein sequences, estimate if the proteins interact or not.
### Using the model from
```{model_path} ```
"""
with gr.Blocks() as demo:
gr.Markdown(markup_text)
with gr.Row():
prot1 = gr.Textbox(
label="Protein 1 sequence",
# info="standard",
interactive=True,
lines=1,
value=protein_calmodulin,
)
prot2 = gr.Textbox(
label="Protein 2 sequence",
# info="standard",
interactive=True,
lines=1,
value=protein_calcineurin,
)
with gr.Row():
run_mammal = gr.Button(
"Run Mammal prompt for Protein-Protein Interaction", variant="primary"
)
with gr.Row():
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
with gr.Row():
decoded = gr.Textbox(label="Mammal output")
run_mammal.click(
fn=create_and_run_prompt,
inputs=[prot1, prot2],
outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
)
with gr.Row():
gr.Markdown(
"```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
)
return demo
def main():
demo = create_application()
demo.launch(show_error=True, share=True)
if __name__ == "__main__":
main()