Emaad commited on
Commit
a888fd4
·
1 Parent(s): c3c07bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -68
app.py CHANGED
@@ -1,119 +1,156 @@
 
1
  import gradio as gr
2
- from prediction import run_sequence_prediction
3
  import torch
4
  import torchvision.transforms as T
5
  from celle.utils import process_image
6
  from PIL import Image
7
  from matplotlib import pyplot as plt
 
 
8
 
9
 
10
- def gradio_demo(model_name, sequence_input, image):
11
- model = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
12
- config = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
 
13
 
14
- if "Finetuned" in model_name:
15
- dataset = "OpenCell"
16
-
17
- else:
18
- dataset = "HPA"
19
-
20
 
21
- nucleus_image = image['image'].convert('L')
22
- protein_image = image['mask'].convert('L')
23
-
24
- to_tensor = T.ToTensor()
25
- nucleus_tensor = to_tensor(nucleus_image)
26
- protein_tensor = to_tensor(protein_image)
27
- stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0)
28
- processed_images = process_image(stacked_images, dataset)
29
-
30
- nucleus_image = processed_images[0].unsqueeze(0)
31
- protein_image = processed_images[1].unsqueeze(0)
32
- protein_image = protein_image > 0
33
- protein_image = 1.0 * protein_image
34
-
35
- print(f'{protein_image.sum()}')
36
-
37
-
38
- formatted_predicted_sequence = run_sequence_prediction(
39
- sequence_input=sequence_input,
40
- nucleus_image=nucleus_image,
41
- protein_image=protein_image,
42
- model_ckpt_path=model,
43
- model_config_path=config,
44
- device=device,
45
- )
46
-
47
- return T.ToPILImage()(protein_image), T.ToPILImage()(nucleus_image), formatted_predicted_sequence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
49
 
50
  with gr.Blocks(theme='gradio/soft') as demo:
51
  gr.Markdown("Select the prediction model.")
52
  gr.Markdown(
53
- "- CELL-E_2_HPA_2560 is a good general purpose model for various cell types using ICC-IF."
54
  )
55
  gr.Markdown(
56
- "- CELL-E_2_OpenCell_2560 is trained on OpenCell and is good more live-cell predictions on HEK cells."
57
  )
58
  with gr.Row():
59
  model_name = gr.Dropdown(
60
- ["CELL-E_2_HPA_2560", "CELL-E_2_OpenCell_2560"],
61
- value="CELL-E_2_HPA_2560",
62
  label="Model Name",
63
  )
64
  with gr.Row():
65
  gr.Markdown(
66
- "Input the desired amino acid sequence. GFP is shown below by default. The sequence must include ```<mask>``` for a prediction to be run."
67
  )
68
 
69
  with gr.Row():
70
  sequence_input = gr.Textbox(
71
- value="M<mask><mask><mask><mask><mask>SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
72
  label="Sequence",
73
  )
74
  with gr.Row():
75
  gr.Markdown(
76
- "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images). Draw the desired localization on top of the nucelus image."
77
  )
 
78
 
79
  with gr.Row().style(equal_height=True):
80
  nucleus_image = gr.Image(
81
- source="upload",
82
- tool="sketch",
83
- invert_colors=True,
84
- label="Nucleus Image",
85
- interactive=True,
86
  image_mode="L",
87
- type="pil"
88
  )
89
 
 
90
 
91
- with gr.Row().style(equal_height=True):
92
- nucleus_crop = gr.Image(
93
- label="Nucleus Image (Crop)",
94
- image_mode="L",
95
- type="pil"
96
- )
97
-
98
- mask = gr.Image(
99
- label="Threshold Image",
100
- image_mode="L",
101
- type="pil"
102
- )
103
  with gr.Row():
104
- gr.Markdown("Sequence predictions are show below.")
105
 
106
  with gr.Row().style(equal_height=True):
107
- predicted_sequence = gr.Textbox(label='Predicted Sequence')
 
 
 
 
108
 
 
 
 
109
 
 
110
  with gr.Row():
111
  button = gr.Button("Run Model")
112
 
113
- inputs = [model_name, sequence_input, nucleus_image]
114
 
115
- outputs = [mask, nucleus_crop, predicted_sequence]
 
 
 
 
 
116
 
117
- button.click(gradio_demo, inputs, outputs)
118
 
119
- demo.launch(enable_queue=True)
 
1
+ import os
2
  import gradio as gr
3
+ from prediction import run_image_prediction
4
  import torch
5
  import torchvision.transforms as T
6
  from celle.utils import process_image
7
  from PIL import Image
8
  from matplotlib import pyplot as plt
9
+ from celle_main import instantiate_from_config
10
+ from omegaconf import OmegaConf
11
 
12
 
13
+ class model:
14
+ def __init__(self):
15
+ self.model = None
16
+ self.model_name = None
17
 
18
+ def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
20
 
21
+ if self.model_name != model_name:
22
+ self.model_name = model_name
23
+ model_ckpt_path = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
24
+ model_config_path = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
25
+
26
+
27
+ # Load model config and set ckpt_path if not provided in config
28
+ config = OmegaConf.load(model_config_path)
29
+ if config["model"]["params"]["ckpt_path"] is None:
30
+ config["model"]["params"]["ckpt_path"] = model_ckpt_path
31
+
32
+ # Set condition_model_path and vqgan_model_path to None
33
+ config["model"]["params"]["condition_model_path"] = None
34
+ config["model"]["params"]["vqgan_model_path"] = None
35
+
36
+ base_path = os.getcwd()
37
+
38
+ os.chdir(os.path.dirname(model_ckpt_path))
39
+
40
+ # Instantiate model from config and move to device
41
+ self.model = instantiate_from_config(config.model).to(device)
42
+ self.model = torch.compile(self.model,mode='reduce-overhead')
43
+
44
+ os.chdir(base_path)
45
+
46
+
47
+ if "Finetuned" in model_name:
48
+ dataset = "OpenCell"
49
+
50
+ else:
51
+ dataset = "HPA"
52
+
53
+ nucleus_image = process_image(nucleus_image, dataset, "nucleus")
54
+ if protein_image:
55
+ protein_image = process_image(protein_image, dataset, "protein")
56
+ protein_image = protein_image > torch.median(protein_image)
57
+ protein_image = protein_image[0, 0]
58
+ protein_image = protein_image * 1.0
59
+ else:
60
+ protein_image = torch.ones((256, 256))
61
+
62
+ threshold, heatmap = run_image_prediction(
63
+ sequence_input=sequence_input,
64
+ nucleus_image=nucleus_image,
65
+ model=self.model,
66
+ device=device,
67
+ )
68
+
69
+ # Plot the heatmap
70
+ plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
71
+ plt.axis("off")
72
+
73
+ # Save the plot to a temporary file
74
+ plt.savefig("temp.png", bbox_inches="tight", dpi=256)
75
+
76
+ # Open the temporary file as a PIL image
77
+ heatmap = Image.open("temp.png")
78
+
79
+ return (
80
+ T.ToPILImage()(nucleus_image[0, 0]),
81
+ T.ToPILImage()(protein_image),
82
+ T.ToPILImage()(threshold),
83
+ heatmap,
84
+ )
85
 
86
+ base_class = model()
87
 
88
  with gr.Blocks(theme='gradio/soft') as demo:
89
  gr.Markdown("Select the prediction model.")
90
  gr.Markdown(
91
+ "CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF."
92
  )
93
  gr.Markdown(
94
+ "CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells."
95
  )
96
  with gr.Row():
97
  model_name = gr.Dropdown(
98
+ ["CELL-E_2_HPA_480", "CELL-E_2_HPA_Finetuned_480"],
99
+ value="CELL-E_2_HPA_480",
100
  label="Model Name",
101
  )
102
  with gr.Row():
103
  gr.Markdown(
104
+ "Input the desired amino acid sequence. GFP is shown below by default."
105
  )
106
 
107
  with gr.Row():
108
  sequence_input = gr.Textbox(
109
+ value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
110
  label="Sequence",
111
  )
112
  with gr.Row():
113
  gr.Markdown(
114
+ "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)"
115
  )
116
+ gr.Markdown("The protein image is optional and is just used for display.")
117
 
118
  with gr.Row().style(equal_height=True):
119
  nucleus_image = gr.Image(
120
+ type="pil",
121
+ label="Nucleus Image",
 
 
 
122
  image_mode="L",
 
123
  )
124
 
125
+ protein_image = gr.Image(type="pil", label="Protein Image (Optional)")
126
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  with gr.Row():
128
+ gr.Markdown("Image predictions are show below.")
129
 
130
  with gr.Row().style(equal_height=True):
131
+ nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L")
132
+
133
+ protein_threshold_image = gr.Image(
134
+ type="pil", label="Protein Threshold Image", image_mode="L"
135
+ )
136
 
137
+ predicted_threshold_image = gr.Image(
138
+ type="pil", label="Predicted Threshold image", image_mode="L"
139
+ )
140
 
141
+ predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap")
142
  with gr.Row():
143
  button = gr.Button("Run Model")
144
 
145
+ inputs = [model_name, sequence_input, nucleus_image, protein_image]
146
 
147
+ outputs = [
148
+ nucleus_image_crop,
149
+ protein_threshold_image,
150
+ predicted_threshold_image,
151
+ predicted_heatmap,
152
+ ]
153
 
154
+ button.click(base_class.gradio_demo, inputs, outputs)
155
 
156
+ demo.launch(enable_queue=True)