Samuel Stevens
commited on
Commit
·
852b07a
1
Parent(s):
c4ee5c3
wip: adding mod preds
Browse files
app.py
CHANGED
@@ -46,7 +46,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
46 |
CWD = pathlib.Path(".")
|
47 |
"""Current working directory."""
|
48 |
|
49 |
-
N_SAE_LATENTS =
|
50 |
"""Number of SAE latents to show."""
|
51 |
|
52 |
N_LATENT_EXAMPLES = 4
|
@@ -289,7 +289,7 @@ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
|
|
289 |
|
290 |
|
291 |
@torch.inference_mode
|
292 |
-
def
|
293 |
img = data.get_img(i)
|
294 |
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
295 |
|
@@ -331,16 +331,10 @@ def map_range(
|
|
331 |
return c + (x - a) * (d - c) / (b - a)
|
332 |
|
333 |
|
|
|
334 |
@torch.inference_mode
|
335 |
-
def
|
336 |
-
|
337 |
-
latent1: int,
|
338 |
-
latent2: int,
|
339 |
-
latent3: int,
|
340 |
-
value1: float,
|
341 |
-
value2: float,
|
342 |
-
value3: float,
|
343 |
-
) -> list[Image.Image | list[int]]:
|
344 |
sample = vit_dataset[i]
|
345 |
x = sample["image"][None, ...].to(device)
|
346 |
x_BPD = rest_of_vit.forward_start(x)
|
@@ -429,38 +423,38 @@ with gr.Blocks() as demo:
|
|
429 |
api_name="get-sae-latents",
|
430 |
)
|
431 |
|
432 |
-
|
433 |
-
# get-preds #
|
434 |
-
|
435 |
|
436 |
# Outputs
|
437 |
-
|
438 |
|
439 |
get_pred_labels_btn = gr.Button(value="Get Predictions")
|
440 |
get_pred_labels_btn.click(
|
441 |
-
|
|
|
|
|
|
|
442 |
)
|
443 |
|
444 |
-
|
445 |
-
#
|
446 |
-
|
447 |
-
|
448 |
-
#
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
# outputs=[semseg_img, semseg_colors],
|
462 |
-
# api_name="get-modified-labels",
|
463 |
-
# )
|
464 |
|
465 |
if __name__ == "__main__":
|
466 |
demo.launch()
|
|
|
46 |
CWD = pathlib.Path(".")
|
47 |
"""Current working directory."""
|
48 |
|
49 |
+
N_SAE_LATENTS = 2
|
50 |
"""Number of SAE latents to show."""
|
51 |
|
52 |
N_LATENT_EXAMPLES = 4
|
|
|
289 |
|
290 |
|
291 |
@torch.inference_mode
|
292 |
+
def get_orig_preds(i: int) -> dict[str, object]:
|
293 |
img = data.get_img(i)
|
294 |
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
295 |
|
|
|
331 |
return c + (x - a) * (d - c) / (b - a)
|
332 |
|
333 |
|
334 |
+
@beartype.beartype
|
335 |
@torch.inference_mode
|
336 |
+
def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
|
337 |
+
breakpoint()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
sample = vit_dataset[i]
|
339 |
x = sample["image"][None, ...].to(device)
|
340 |
x_BPD = rest_of_vit.forward_start(x)
|
|
|
423 |
api_name="get-sae-latents",
|
424 |
)
|
425 |
|
426 |
+
##################
|
427 |
+
# get-orig-preds #
|
428 |
+
##################
|
429 |
|
430 |
# Outputs
|
431 |
+
get_orig_preds_out = gr.JSON(label="get_orig_preds_out", value=[])
|
432 |
|
433 |
get_pred_labels_btn = gr.Button(value="Get Predictions")
|
434 |
get_pred_labels_btn.click(
|
435 |
+
get_orig_preds,
|
436 |
+
inputs=[img_number],
|
437 |
+
outputs=[get_orig_preds_out],
|
438 |
+
api_name="get-orig-preds",
|
439 |
)
|
440 |
|
441 |
+
#################
|
442 |
+
# get-mod-preds #
|
443 |
+
#################
|
444 |
+
|
445 |
+
# Inputs
|
446 |
+
latents_json = gr.JSON(label="Modified Latents", value={})
|
447 |
+
|
448 |
+
# Outputs
|
449 |
+
get_mod_preds_out = gr.JSON(label="get_mod_preds_out", value=[])
|
450 |
+
|
451 |
+
get_pred_labels_btn = gr.Button(value="Get Predictions")
|
452 |
+
get_pred_labels_btn.click(
|
453 |
+
get_mod_preds,
|
454 |
+
inputs=[img_number, latents_json],
|
455 |
+
outputs=[get_mod_preds_out],
|
456 |
+
api_name="get-mod-preds",
|
457 |
+
)
|
|
|
|
|
|
|
458 |
|
459 |
if __name__ == "__main__":
|
460 |
demo.launch()
|