Emaad's picture
Update app.py
72dc396
raw
history blame
4.22 kB
import gradio as gr
from prediction import run_image_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from PIL import Image
from matplotlib import pyplot as plt
def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
model = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
config = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if "Finetuned" in model_name:
dataset = "OpenCell"
else:
dataset = "HPA"
nucleus_image = process_image(nucleus_image, dataset, "nucleus")
if protein_image:
protein_image = process_image(protein_image, dataset, "protein")
protein_image = protein_image > torch.median(protein_image)
protein_image = protein_image[0, 0]
protein_image = protein_image * 1.0
else:
protein_image = torch.ones((256, 256))
threshold, heatmap = run_image_prediction(
sequence_input=sequence_input,
nucleus_image=nucleus_image,
model_ckpt_path=model,
model_config_path=config,
device=device,
)
# Plot the heatmap
plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
plt.axis("off")
# Save the plot to a temporary file
plt.savefig("temp.png", bbox_inches="tight", dpi=256)
# Open the temporary file as a PIL image
heatmap = Image.open("temp.png")
return (
T.ToPILImage()(nucleus_image[0, 0]),
T.ToPILImage()(protein_image),
T.ToPILImage()(threshold),
heatmap,
)
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown("Select the prediction model.")
gr.Markdown(
"CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF."
)
gr.Markdown(
"CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells."
)
with gr.Row():
model_name = gr.Dropdown(
["CELL-E_2-HPA_480", "CELL-E_2-HPA_Finetuned_480"],
value="CELL-E_2-HPA_480",
label="Model Name",
)
with gr.Row():
gr.Markdown(
"Input the desired amino acid sequence. GFP is shown below by default."
)
with gr.Row():
sequence_input = gr.Textbox(
value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
label="Sequence",
)
with gr.Row():
gr.Markdown(
"Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)"
)
gr.Markdown("The protein image is optional and is just used for display.")
with gr.Row().style(equal_height=True):
nucleus_image = gr.Image(
type="pil",
label="Nucleus Image",
image_mode="L",
)
protein_image = gr.Image(type="pil", label="Protein Image (Optional)")
with gr.Row():
gr.Markdown("Image predictions are show below.")
with gr.Row().style(equal_height=True):
nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L")
protein_threshold_image = gr.Image(
type="pil", label="Protein Threshold Image", image_mode="L"
)
predicted_threshold_image = gr.Image(
type="pil", label="Predicted Threshold image", image_mode="L"
)
predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap")
with gr.Row():
button = gr.Button("Run Model")
inputs = [model_name, sequence_input, nucleus_image, protein_image]
outputs = [
nucleus_image_crop,
protein_threshold_image,
predicted_threshold_image,
predicted_heatmap,
]
button.click(gradio_demo, inputs, outputs)
demo.launch(enable_queue=True)