Samuel Stevens commited on
Commit
6c9f92c
·
1 Parent(s): 852b07a

add mod preds; todo: add legend

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -313,7 +313,7 @@ def get_orig_preds(i: int) -> dict[str, object]:
313
 
314
 
315
  @beartype.beartype
316
- def unscaled(x: float, max_obs: float) -> float:
317
  """Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs]."""
318
  return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs))
319
 
@@ -333,12 +333,17 @@ def map_range(
333
 
334
  @beartype.beartype
335
  @torch.inference_mode
336
- def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
337
- breakpoint()
338
- sample = vit_dataset[i]
339
- x = sample["image"][None, ...].to(device)
340
- x_BPD = rest_of_vit.forward_start(x)
341
 
 
 
 
 
 
 
 
342
  x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
343
 
344
  err_BPD = x_BPD - x_hat_BPD
@@ -346,18 +351,14 @@ def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
346
  values = torch.tensor(
347
  [
348
  unscaled(float(value), top_values[latent].max().item())
349
- for value, latent in [
350
- (value1, latent1),
351
- (value2, latent2),
352
- (value3, latent3),
353
- ]
354
  ],
355
- device=device,
356
  )
357
- f_x_BPS[..., torch.tensor([latent1, latent2, latent3], device=device)] = values
358
 
359
  # Reproduce the SAE forward pass after f_x
360
- modified_x_hat_BPD = (
361
  einops.einsum(
362
  f_x_BPS,
363
  sae.W_dec,
@@ -365,14 +366,19 @@ def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
365
  )
366
  + sae.b_dec
367
  )
368
- modified_BPD = err_BPD + modified_x_hat_BPD
369
 
370
- modified_BPD = rest_of_vit.forward_end(modified_BPD)
 
371
 
372
- logits_BPC = head(modified_BPD)
373
- pred_P = logits_BPC[0].argmax(axis=-1)
374
- pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
375
- return seg_to_img(upsample(pred_WH)), pred_P.tolist()
 
 
 
 
376
 
377
 
378
  @jaxtyped(typechecker=beartype.beartype)
 
313
 
314
 
315
  @beartype.beartype
316
+ def unscaled(x: float, max_obs: float | int) -> float:
317
  """Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs]."""
318
  return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs))
319
 
 
333
 
334
  @beartype.beartype
335
  @torch.inference_mode
336
+ def get_mod_preds(i: int, latents: dict[str, int | float]) -> dict[str, object]:
337
+ latents = {int(k): float(v) for k, v in latents.items()}
338
+ img = data.get_img(i)
 
 
339
 
340
+ split_vit, vit_transform = modeling.load_vit(DEVICE)
341
+ sae = load_sae(DEVICE)
342
+ _, top_values, _ = load_tensors()
343
+ clf = load_clf()
344
+
345
+ x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
346
+ x_BPD = split_vit.forward_start(x_BCWH)
347
  x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
348
 
349
  err_BPD = x_BPD - x_hat_BPD
 
351
  values = torch.tensor(
352
  [
353
  unscaled(float(value), top_values[latent].max().item())
354
+ for latent, value in latents.items()
 
 
 
 
355
  ],
356
+ device=DEVICE,
357
  )
358
+ f_x_BPS[..., torch.tensor(list(latents.keys()), device=DEVICE)] = values
359
 
360
  # Reproduce the SAE forward pass after f_x
361
+ mod_x_hat_BPD = (
362
  einops.einsum(
363
  f_x_BPS,
364
  sae.W_dec,
 
366
  )
367
  + sae.b_dec
368
  )
369
+ mod_BPD = err_BPD + mod_x_hat_BPD
370
 
371
+ mod_BPD = split_vit.forward_end(mod_BPD)
372
+ mod_WHD = einops.rearrange(mod_BPD, "() (w h) dim -> w h dim", w=16, h=16)
373
 
374
+ logits_WHC = clf(mod_WHD)
375
+ pred_WH = logits_WHC.argmax(axis=-1)
376
+ # pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
377
+ return {
378
+ "index": i,
379
+ "orig_url": data.img_to_base64(data.to_sized(img)),
380
+ "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
381
+ }
382
 
383
 
384
  @jaxtyped(typechecker=beartype.beartype)