import gradio as gr from huggingface_hub import hf_hub_download from prediction import run_image_prediction import torch import torchvision.transforms as T from celle.utils import process_image from PIL import Image from matplotlib import pyplot as plt def gradio_demo(model_name, sequence_input, nucleus_image, protein_image): model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt") config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml") hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml") hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if 'Finetuned' in model_name: dataset = 'OpenCell' else: dataset = 'HPA' nucleus_image = process_image(nucleus_image,dataset,'nucleus') if protein_image: protein_image = process_image(protein_image,dataset,'protein') protein_image = protein_image > torch.median(protein_image) protein_image = protein_image[0,0] protein_image = protein_image*1.0 else: protein_image = torch.ones((256,256)) threshold, heatmap = run_image_prediction(sequence_input = sequence_input, nucleus_image = nucleus_image, model_ckpt_path=model, model_config_path=config, device=device) # Plot the heatmap plt.imshow(heatmap.cpu(), cmap='rainbow', interpolation = 'bicubic') plt.axis('off') # Save the plot to a temporary file plt.savefig('temp.png', bbox_inches='tight', dpi = 256) # Open the temporary file as a PIL image heatmap = Image.open('temp.png') return T.ToPILImage()(nucleus_image[0,0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap with gr.Blocks() as demo: gr.Markdown("Select the prediction model.") gr.Markdown("CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF.") gr.Markdown("CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells.") with gr.Row(): model_name = gr.Dropdown(['CELL-E_2_HPA_480','CELL-E_2_HPA_Finetuned_480'], value='CELL-E_2_HPA_480', label = 'Model Name') with gr.Row(): gr.Markdown("Input the desired amino acid sequence. GFP is shown below by default.") with gr.Row(): sequence_input = gr.Textbox(value='MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', label = 'Sequence') with gr.Row(): gr.Markdown("Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger.") gr.Markdown("The protein image is optional and is just used for display.") with gr.Row().style(equal_height=True): nucleus_image = gr.Image(value = 'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg', type='pil', label = 'Nucleus Image') protein_image = gr.Image(type='pil', label = 'Protein Image (Optional)') with gr.Row(): gr.Markdown("Image predictions are show below.") with gr.Row().style(equal_height=True): nucleus_image_crop = gr.Image(type='pil', label = 'Nucleus Image') protein_threshold_image = gr.Image(type='pil', label = 'Protein Threshold Image') predicted_threshold_image = gr.Image(type='pil', label = 'Predicted Threshold image') predicted_heatmap = gr.Image(type='pil', label = 'Predicted Heatmap') with gr.Row(): button = gr.Button("Run Model") inputs = [model_name, sequence_input, nucleus_image, protein_image] outputs = [nucleus_image_crop, protein_threshold_image, predicted_threshold_image, predicted_heatmap] button.click(gradio_demo, inputs, outputs) examples = [['CELL-E_2_HPA_Finetuned_480', 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', 'images/Proteasome activator complex subunit 3 nucleus.png', 'images/Proteasome activator complex subunit 3 protein.png'], ['CELL-E_2_HPA_480', 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', 'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg', 'images/Armadillo repeat-containing X-linked protein 5 protein.jpg']] # demo = gr.Interface(gradio_demo, inputs, outputs, examples, cache_examples=True, layout = layout) demo.launch(share=True)