Samuel Stevens
commited on
Commit
·
6c9f92c
1
Parent(s):
852b07a
add mod preds; todo: add legend
Browse files
app.py
CHANGED
@@ -313,7 +313,7 @@ def get_orig_preds(i: int) -> dict[str, object]:
|
|
313 |
|
314 |
|
315 |
@beartype.beartype
|
316 |
-
def unscaled(x: float, max_obs: float) -> float:
|
317 |
"""Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs]."""
|
318 |
return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs))
|
319 |
|
@@ -333,12 +333,17 @@ def map_range(
|
|
333 |
|
334 |
@beartype.beartype
|
335 |
@torch.inference_mode
|
336 |
-
def get_mod_preds(i: int, latents: dict[
|
337 |
-
|
338 |
-
|
339 |
-
x = sample["image"][None, ...].to(device)
|
340 |
-
x_BPD = rest_of_vit.forward_start(x)
|
341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
|
343 |
|
344 |
err_BPD = x_BPD - x_hat_BPD
|
@@ -346,18 +351,14 @@ def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
|
|
346 |
values = torch.tensor(
|
347 |
[
|
348 |
unscaled(float(value), top_values[latent].max().item())
|
349 |
-
for
|
350 |
-
(value1, latent1),
|
351 |
-
(value2, latent2),
|
352 |
-
(value3, latent3),
|
353 |
-
]
|
354 |
],
|
355 |
-
device=
|
356 |
)
|
357 |
-
f_x_BPS[..., torch.tensor(
|
358 |
|
359 |
# Reproduce the SAE forward pass after f_x
|
360 |
-
|
361 |
einops.einsum(
|
362 |
f_x_BPS,
|
363 |
sae.W_dec,
|
@@ -365,14 +366,19 @@ def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
|
|
365 |
)
|
366 |
+ sae.b_dec
|
367 |
)
|
368 |
-
|
369 |
|
370 |
-
|
|
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
|
375 |
-
return
|
|
|
|
|
|
|
|
|
376 |
|
377 |
|
378 |
@jaxtyped(typechecker=beartype.beartype)
|
|
|
313 |
|
314 |
|
315 |
@beartype.beartype
|
316 |
+
def unscaled(x: float, max_obs: float | int) -> float:
|
317 |
"""Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs]."""
|
318 |
return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs))
|
319 |
|
|
|
333 |
|
334 |
@beartype.beartype
|
335 |
@torch.inference_mode
|
336 |
+
def get_mod_preds(i: int, latents: dict[str, int | float]) -> dict[str, object]:
|
337 |
+
latents = {int(k): float(v) for k, v in latents.items()}
|
338 |
+
img = data.get_img(i)
|
|
|
|
|
339 |
|
340 |
+
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
341 |
+
sae = load_sae(DEVICE)
|
342 |
+
_, top_values, _ = load_tensors()
|
343 |
+
clf = load_clf()
|
344 |
+
|
345 |
+
x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
|
346 |
+
x_BPD = split_vit.forward_start(x_BCWH)
|
347 |
x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
|
348 |
|
349 |
err_BPD = x_BPD - x_hat_BPD
|
|
|
351 |
values = torch.tensor(
|
352 |
[
|
353 |
unscaled(float(value), top_values[latent].max().item())
|
354 |
+
for latent, value in latents.items()
|
|
|
|
|
|
|
|
|
355 |
],
|
356 |
+
device=DEVICE,
|
357 |
)
|
358 |
+
f_x_BPS[..., torch.tensor(list(latents.keys()), device=DEVICE)] = values
|
359 |
|
360 |
# Reproduce the SAE forward pass after f_x
|
361 |
+
mod_x_hat_BPD = (
|
362 |
einops.einsum(
|
363 |
f_x_BPS,
|
364 |
sae.W_dec,
|
|
|
366 |
)
|
367 |
+ sae.b_dec
|
368 |
)
|
369 |
+
mod_BPD = err_BPD + mod_x_hat_BPD
|
370 |
|
371 |
+
mod_BPD = split_vit.forward_end(mod_BPD)
|
372 |
+
mod_WHD = einops.rearrange(mod_BPD, "() (w h) dim -> w h dim", w=16, h=16)
|
373 |
|
374 |
+
logits_WHC = clf(mod_WHD)
|
375 |
+
pred_WH = logits_WHC.argmax(axis=-1)
|
376 |
+
# pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
|
377 |
+
return {
|
378 |
+
"index": i,
|
379 |
+
"orig_url": data.img_to_base64(data.to_sized(img)),
|
380 |
+
"seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
|
381 |
+
}
|
382 |
|
383 |
|
384 |
@jaxtyped(typechecker=beartype.beartype)
|