|
import gradio as gr |
|
|
|
import torch |
|
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp |
|
from mammal.model import Mammal |
|
from mammal.keys import * |
|
|
|
|
|
model_path = "ibm/biomed.omics.bl.sm.ma-ted-458m" |
|
|
|
model = Mammal.from_pretrained(model_path) |
|
model.eval() |
|
|
|
|
|
tokenizer_op = ModularTokenizerOp.from_pretrained(model_path) |
|
|
|
|
|
positive_token_id = tokenizer_op.get_token_id("<1>") |
|
|
|
|
|
protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK" |
|
protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ" |
|
|
|
|
|
def format_prompt(prot1, prot2): |
|
|
|
return f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>" |
|
|
|
|
|
def run_prompt(prompt): |
|
|
|
sample_dict = dict() |
|
sample_dict[ENCODER_INPUTS_STR] = prompt |
|
|
|
|
|
sample_dict = tokenizer_op( |
|
sample_dict=sample_dict, |
|
key_in=ENCODER_INPUTS_STR, |
|
key_out_tokens_ids=ENCODER_INPUTS_TOKENS, |
|
key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, |
|
) |
|
sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( |
|
sample_dict[ENCODER_INPUTS_TOKENS] |
|
) |
|
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( |
|
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] |
|
) |
|
|
|
|
|
batch_dict = model.generate( |
|
[sample_dict], |
|
output_scores=True, |
|
return_dict_in_generate=True, |
|
max_new_tokens=5, |
|
) |
|
|
|
|
|
generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]) |
|
score = batch_dict["model.out.scores"][0][1][positive_token_id].item() |
|
|
|
return generated_output, score |
|
|
|
|
|
def create_and_run_prompt(prot1, prot2): |
|
prompt = format_prompt(prot1, prot2) |
|
res = prompt, *run_prompt(prompt=prompt) |
|
return res |
|
|
|
|
|
def create_application(): |
|
markup_text = f""" |
|
# Mammal based Protein-Protein Interaction (PPI) demonstration |
|
|
|
Given two protein sequences, estimate if the proteins interact or not. |
|
|
|
### Using the model from |
|
|
|
```{model_path} ``` |
|
""" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(markup_text) |
|
with gr.Row(): |
|
prot1 = gr.Textbox( |
|
label="Protein 1 sequence", |
|
|
|
interactive=True, |
|
lines=1, |
|
value=protein_calmodulin, |
|
) |
|
prot2 = gr.Textbox( |
|
label="Protein 2 sequence", |
|
|
|
interactive=True, |
|
lines=1, |
|
value=protein_calcineurin, |
|
) |
|
with gr.Row(): |
|
run_mammal = gr.Button( |
|
"Run Mammal prompt for Protein-Protein Interaction", variant="primary" |
|
) |
|
with gr.Row(): |
|
prompt_box = gr.Textbox(label="Mammal prompt", lines=5) |
|
|
|
with gr.Row(): |
|
decoded = gr.Textbox(label="Mammal output") |
|
run_mammal.click( |
|
fn=create_and_run_prompt, |
|
inputs=[prot1, prot2], |
|
outputs=[prompt_box, decoded, gr.Number(label="PPI score")], |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting" |
|
) |
|
|
|
return demo |
|
|
|
|
|
def main(): |
|
demo = create_application() |
|
demo.launch(show_error=True, share=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|