Samuel Stevens commited on
Commit
852b07a
·
1 Parent(s): c4ee5c3

wip: adding mod preds

Browse files
Files changed (1) hide show
  1. app.py +30 -36
app.py CHANGED
@@ -46,7 +46,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
  CWD = pathlib.Path(".")
47
  """Current working directory."""
48
 
49
- N_SAE_LATENTS = 3
50
  """Number of SAE latents to show."""
51
 
52
  N_LATENT_EXAMPLES = 4
@@ -289,7 +289,7 @@ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
289
 
290
 
291
  @torch.inference_mode
292
- def get_preds(i: int) -> dict[str, object]:
293
  img = data.get_img(i)
294
  split_vit, vit_transform = modeling.load_vit(DEVICE)
295
 
@@ -331,16 +331,10 @@ def map_range(
331
  return c + (x - a) * (d - c) / (b - a)
332
 
333
 
 
334
  @torch.inference_mode
335
- def get_modified_labels(
336
- i: int,
337
- latent1: int,
338
- latent2: int,
339
- latent3: int,
340
- value1: float,
341
- value2: float,
342
- value3: float,
343
- ) -> list[Image.Image | list[int]]:
344
  sample = vit_dataset[i]
345
  x = sample["image"][None, ...].to(device)
346
  x_BPD = rest_of_vit.forward_start(x)
@@ -429,38 +423,38 @@ with gr.Blocks() as demo:
429
  api_name="get-sae-latents",
430
  )
431
 
432
- #############
433
- # get-preds #
434
- #############
435
 
436
  # Outputs
437
- get_preds_out = gr.JSON(label="get_preds_out", value=[])
438
 
439
  get_pred_labels_btn = gr.Button(value="Get Predictions")
440
  get_pred_labels_btn.click(
441
- get_preds, inputs=[img_number], outputs=[get_preds_out], api_name="get-preds"
 
 
 
442
  )
443
 
444
- # get_true_labels_btn = gr.Button(value="Get True Label")
445
- # get_true_labels_btn.click(
446
- # get_true_labels,
447
- # inputs=[img_number],
448
- # outputs=semseg_img,
449
- # api_name="get-true-labels",
450
- # )
451
-
452
- # latent_numbers = [gr.Number(label=f"Latent {i + 1}") for i in range(3)]
453
- # value_sliders = [
454
- # gr.Slider(label=f"Value {i + 1}", minimum=-10, maximum=10) for i in range(3)
455
- # ]
456
-
457
- # get_modified_labels_btn = gr.Button(value="Get Modified Label")
458
- # get_modified_labels_btn.click(
459
- # get_modified_labels,
460
- # inputs=[img_number] + latent_numbers + value_sliders,
461
- # outputs=[semseg_img, semseg_colors],
462
- # api_name="get-modified-labels",
463
- # )
464
 
465
  if __name__ == "__main__":
466
  demo.launch()
 
46
  CWD = pathlib.Path(".")
47
  """Current working directory."""
48
 
49
+ N_SAE_LATENTS = 2
50
  """Number of SAE latents to show."""
51
 
52
  N_LATENT_EXAMPLES = 4
 
289
 
290
 
291
  @torch.inference_mode
292
+ def get_orig_preds(i: int) -> dict[str, object]:
293
  img = data.get_img(i)
294
  split_vit, vit_transform = modeling.load_vit(DEVICE)
295
 
 
331
  return c + (x - a) * (d - c) / (b - a)
332
 
333
 
334
+ @beartype.beartype
335
  @torch.inference_mode
336
+ def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
337
+ breakpoint()
 
 
 
 
 
 
 
338
  sample = vit_dataset[i]
339
  x = sample["image"][None, ...].to(device)
340
  x_BPD = rest_of_vit.forward_start(x)
 
423
  api_name="get-sae-latents",
424
  )
425
 
426
+ ##################
427
+ # get-orig-preds #
428
+ ##################
429
 
430
  # Outputs
431
+ get_orig_preds_out = gr.JSON(label="get_orig_preds_out", value=[])
432
 
433
  get_pred_labels_btn = gr.Button(value="Get Predictions")
434
  get_pred_labels_btn.click(
435
+ get_orig_preds,
436
+ inputs=[img_number],
437
+ outputs=[get_orig_preds_out],
438
+ api_name="get-orig-preds",
439
  )
440
 
441
+ #################
442
+ # get-mod-preds #
443
+ #################
444
+
445
+ # Inputs
446
+ latents_json = gr.JSON(label="Modified Latents", value={})
447
+
448
+ # Outputs
449
+ get_mod_preds_out = gr.JSON(label="get_mod_preds_out", value=[])
450
+
451
+ get_pred_labels_btn = gr.Button(value="Get Predictions")
452
+ get_pred_labels_btn.click(
453
+ get_mod_preds,
454
+ inputs=[img_number, latents_json],
455
+ outputs=[get_mod_preds_out],
456
+ api_name="get-mod-preds",
457
+ )
 
 
 
458
 
459
  if __name__ == "__main__":
460
  demo.launch()