EmaadKhwaja commited on
Commit
5cbd5ac
1 Parent(s): 0677b2b

fix prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +2 -69
prediction.py CHANGED
@@ -136,8 +136,7 @@ def run_sequence_prediction(
136
 
137
  def run_image_prediction(
138
  sequence_input,
139
- nucleus_image_path,
140
- protein_image_path,
141
  model_ckpt_path,
142
  model_config_path,
143
  device
@@ -173,17 +172,6 @@ def run_image_prediction(
173
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
174
  sequence = dataset.tokenize_sequence(sequence_input)
175
 
176
- # Check if nucleus image path is provided and valid
177
- if not os.path.exists(nucleus_image_path):
178
- # Use default nucleus image from dataset and print warning
179
- nucleus_image = dataset[0]["nucleus"]
180
- print(
181
- "Warning: No nucleus image provided. Using default nucleus image from dataset."
182
- )
183
- else:
184
- # Load nucleus image from provided path
185
- nucleus_image = process_image(nucleus_image_path)
186
-
187
  # Load model config and set ckpt_path if not provided in config
188
  config = OmegaConf.load(model_config_path)
189
  if config["model"]["params"]["ckpt_path"] is None:
@@ -209,59 +197,4 @@ def run_image_prediction(
209
  predicted_threshold = predicted_threshold.cpu()[0, 0]
210
  predicted_heatmap = predicted_heatmap.cpu()[0, 0]
211
 
212
- # Create 3 or 4 panel plot depending on whether protein image path is provided
213
- fig, axs = plt.subplots(1, 3 if protein_image_path is None else 4)
214
- axs[0].imshow(nucleus_image)
215
- axs[0].set_title("Nucleus Input")
216
- axs[1].imshow(predicted_threshold)
217
- axs[1].set_title("Predicted Threshold")
218
- if protein_image_path is not None:
219
- protein_image = process_image(protein_image_path)
220
- axs[2].imshow(protein_image)
221
- axs[2].set_title("Protein Image")
222
- axs[-1].imshow(predicted_heatmap)
223
- axs[-1].set_title("Predicted Heatmap")
224
- plt.show()
225
-
226
-
227
- if __name__ == "__main__":
228
- # Parse command line arguments for input parameters
229
- parser = argparse.ArgumentParser(
230
- description="Run Celle model with provided inputs."
231
- )
232
- parser.add_argument("--mode", type=str, default="", help="Sequence or Image")
233
- parser.add_argument(
234
- "--sequence", type=str, default="", help="Path to sequence file"
235
- )
236
- parser.add_argument(
237
- "--nucleus_image_path",
238
- type=str,
239
- default="images/nucleus.jpg",
240
- help="Path to nucleus image",
241
- )
242
- parser.add_argument(
243
- "--protein_image_path",
244
- type=str,
245
- default=None,
246
- help="Path to protein image (optional)",
247
- )
248
- parser.add_argument(
249
- "--model_ckpt_path", type=str, required=True, help="Path to model checkpoint"
250
- )
251
- parser.add_argument(
252
- "--model_config_path", type=str, required=True, help="Path to model config"
253
- )
254
- parser.add_argument(
255
- "--device", type=str, default="cpu", required=True, help="device"
256
- )
257
- args = parser.parse_args()
258
-
259
- run_model(
260
- args.mode,
261
- args.sequence,
262
- args.nucleus_image_path,
263
- args.protein_image_path,
264
- args.model_ckpt_path,
265
- args.model_config_path,
266
- args.device
267
- )
 
136
 
137
  def run_image_prediction(
138
  sequence_input,
139
+ nucleus_image,
 
140
  model_ckpt_path,
141
  model_config_path,
142
  device
 
172
  # Convert SEQUENCE to sequence using dataset.tokenize_sequence()
173
  sequence = dataset.tokenize_sequence(sequence_input)
174
 
 
 
 
 
 
 
 
 
 
 
 
175
  # Load model config and set ckpt_path if not provided in config
176
  config = OmegaConf.load(model_config_path)
177
  if config["model"]["params"]["ckpt_path"] is None:
 
197
  predicted_threshold = predicted_threshold.cpu()[0, 0]
198
  predicted_heatmap = predicted_heatmap.cpu()[0, 0]
199
 
200
+ return predicted_threshold, predicted_heatmap