Samuel Stevens commited on
Commit
af47b42
·
1 Parent(s): 6c9f92c

Add legend; add image uploader

Browse files
Files changed (2) hide show
  1. app.py +77 -74
  2. data.py +9 -2
app.py CHANGED
@@ -52,6 +52,44 @@ N_SAE_LATENTS = 2
52
  N_LATENT_EXAMPLES = 4
53
  """Number of examples per SAE latent to show."""
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ##########
56
  # Models #
57
  ##########
@@ -112,9 +150,9 @@ def load_tensors() -> tuple[
112
  return top_img_i, top_values, mask
113
 
114
 
115
- ############
116
- # Datasets #
117
- ############
118
 
119
 
120
  @jaxtyped(typechecker=beartype.beartype)
@@ -154,65 +192,43 @@ def add_highlights(
154
  return Image.alpha_composite(img.convert("RGBA"), overlay)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  #######################
158
  # Inference Functions #
159
  #######################
160
 
161
 
162
  @beartype.beartype
163
- class Example(typing.TypedDict):
164
- """Represents an example image and its associated label.
165
-
166
- Used to store examples of SAE latent activations for visualization.
167
- """
168
-
169
- index: int
170
- """Dataset index."""
171
- orig_url: str
172
- """The URL or path to access the original example image."""
173
- highlighted_url: str
174
- """The URL or path to access the SAE-highlighted image."""
175
- seg_url: str
176
- """Base64-encoded version of the colored segmentation map."""
177
-
178
-
179
- @beartype.beartype
180
- class SaeActivation(typing.TypedDict):
181
- """Represents the activation pattern of a single SAE latent across patches.
182
-
183
- This captures how strongly a particular SAE latent fires on different patches of an input image.
184
- """
185
-
186
- latent: int
187
- """The index of the SAE latent being measured."""
188
-
189
- highlighted_url: str
190
- """The image with the colormaps applied."""
191
-
192
- activations: list[float]
193
- """The activation values of this latent across different patches. Each value represents how strongly this latent fired on a particular patch."""
194
-
195
- examples: list[Example]
196
- """Top examples for this latent."""
197
-
198
-
199
- @beartype.beartype
200
- def get_img(i: int) -> dict[str, object]:
201
  img_sized = data.to_sized(data.get_img(i))
202
  seg_sized = data.to_sized(data.get_seg(i))
203
  seg_u8_sized = data.to_u8(seg_sized)
204
  seg_img_sized = data.u8_to_img(seg_u8_sized)
205
 
206
  return {
207
- "index": i,
208
  "orig_url": data.img_to_base64(img_sized),
209
  "seg_url": data.img_to_base64(seg_img_sized),
 
210
  }
211
 
212
 
213
  @beartype.beartype
214
  @torch.inference_mode
215
- def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
216
  """
217
  Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
218
  """
@@ -222,9 +238,7 @@ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
222
  split_vit, vit_transform = modeling.load_vit(DEVICE)
223
  sae = load_sae(DEVICE)
224
 
225
- img = data.get_img(img_i)
226
-
227
- x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
228
 
229
  x_BPD = split_vit.forward_start(x_BCWH)
230
  x_BPD = (
@@ -274,10 +288,10 @@ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
274
  )
275
 
276
  examples.append({
277
- "index": i_im,
278
  "orig_url": data.img_to_base64(img_sized),
279
  "highlighted_url": data.img_to_base64(highlighted_sized),
280
  "seg_url": data.img_to_base64(seg_img_sized),
 
281
  })
282
 
283
  sae_activations.append({
@@ -288,12 +302,12 @@ def get_sae_latents(img_i: int, patches: list[int]) -> list[SaeActivation]:
288
  return sae_activations
289
 
290
 
 
291
  @torch.inference_mode
292
- def get_orig_preds(i: int) -> dict[str, object]:
293
- img = data.get_img(i)
294
  split_vit, vit_transform = modeling.load_vit(DEVICE)
295
 
296
- x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
297
 
298
  x_BPD = split_vit.forward_start(x_BCWH)
299
  x_BPD = split_vit.forward_end(x_BPD)
@@ -304,11 +318,10 @@ def get_orig_preds(i: int) -> dict[str, object]:
304
  logits_WHC = clf(x_WHD)
305
 
306
  pred_WH = logits_WHC.argmax(axis=-1)
307
- # preds = einops.rearrange(pred_WH, "w h -> (w h)").tolist()
308
  return {
309
- "index": i,
310
  "orig_url": data.img_to_base64(data.to_sized(img)),
311
  "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
 
312
  }
313
 
314
 
@@ -333,16 +346,15 @@ def map_range(
333
 
334
  @beartype.beartype
335
  @torch.inference_mode
336
- def get_mod_preds(i: int, latents: dict[str, int | float]) -> dict[str, object]:
337
  latents = {int(k): float(v) for k, v in latents.items()}
338
- img = data.get_img(i)
339
 
340
  split_vit, vit_transform = modeling.load_vit(DEVICE)
341
  sae = load_sae(DEVICE)
342
  _, top_values, _ = load_tensors()
343
  clf = load_clf()
344
 
345
- x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
346
  x_BPD = split_vit.forward_start(x_BCWH)
347
  x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
348
 
@@ -375,27 +387,12 @@ def get_mod_preds(i: int, latents: dict[str, int | float]) -> dict[str, object]:
375
  pred_WH = logits_WHC.argmax(axis=-1)
376
  # pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
377
  return {
378
- "index": i,
379
  "orig_url": data.img_to_base64(data.to_sized(img)),
380
  "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
 
381
  }
382
 
383
 
384
- @jaxtyped(typechecker=beartype.beartype)
385
- @torch.inference_mode
386
- def upsample(
387
- x_WH: Int[Tensor, "width_ps height_ps"],
388
- ) -> UInt8[Tensor, "width_px height_px"]:
389
- return (
390
- torch.nn.functional.interpolate(
391
- x_WH.view((1, 1, 16, 16)).float(),
392
- scale_factor=28,
393
- )
394
- .view((448, 448))
395
- .type(torch.uint8)
396
- )
397
-
398
-
399
  with gr.Blocks() as demo:
400
  ###########
401
  # get-img #
@@ -418,13 +415,19 @@ with gr.Blocks() as demo:
418
 
419
  # Inputs
420
  patches_json = gr.JSON(label="Patches", value=[])
 
 
 
 
 
 
421
  # Outputs
422
  get_sae_latents_out = gr.JSON(label="get_sae_latents_out", value=[])
423
 
424
  get_sae_latents_btn = gr.Button(value="Get SAE Latents")
425
  get_sae_latents_btn.click(
426
  get_sae_latents,
427
- inputs=[img_number, patches_json],
428
  outputs=[get_sae_latents_out],
429
  api_name="get-sae-latents",
430
  )
@@ -439,7 +442,7 @@ with gr.Blocks() as demo:
439
  get_pred_labels_btn = gr.Button(value="Get Predictions")
440
  get_pred_labels_btn.click(
441
  get_orig_preds,
442
- inputs=[img_number],
443
  outputs=[get_orig_preds_out],
444
  api_name="get-orig-preds",
445
  )
@@ -457,7 +460,7 @@ with gr.Blocks() as demo:
457
  get_pred_labels_btn = gr.Button(value="Get Predictions")
458
  get_pred_labels_btn.click(
459
  get_mod_preds,
460
- inputs=[img_number, latents_json],
461
  outputs=[get_mod_preds_out],
462
  api_name="get-mod-preds",
463
  )
 
52
  N_LATENT_EXAMPLES = 4
53
  """Number of examples per SAE latent to show."""
54
 
55
+
56
+ @beartype.beartype
57
+ class Example(typing.TypedDict):
58
+ """Represents an example image and its associated label.
59
+
60
+ Used to store examples of SAE latent activations for visualization.
61
+ """
62
+
63
+ orig_url: str
64
+ """The URL or path to access the original example image."""
65
+ highlighted_url: typing.NotRequired[str]
66
+ """The URL or path to access the SAE-highlighted image."""
67
+ seg_url: str
68
+ """Base64-encoded version of the colored segmentation map."""
69
+ classes: list[int]
70
+ """Unique list of all classes in the seg_url."""
71
+
72
+
73
+ @beartype.beartype
74
+ class SaeActivation(typing.TypedDict):
75
+ """Represents the activation pattern of a single SAE latent across patches.
76
+
77
+ This captures how strongly a particular SAE latent fires on different patches of an input image.
78
+ """
79
+
80
+ latent: int
81
+ """The index of the SAE latent being measured."""
82
+
83
+ highlighted_url: str
84
+ """The image with the colormaps applied."""
85
+
86
+ activations: list[float]
87
+ """The activation values of this latent across different patches. Each value represents how strongly this latent fired on a particular patch."""
88
+
89
+ examples: list[Example]
90
+ """Top examples for this latent."""
91
+
92
+
93
  ##########
94
  # Models #
95
  ##########
 
150
  return top_img_i, top_values, mask
151
 
152
 
153
+ ###########
154
+ # Imaging #
155
+ ###########
156
 
157
 
158
  @jaxtyped(typechecker=beartype.beartype)
 
192
  return Image.alpha_composite(img.convert("RGBA"), overlay)
193
 
194
 
195
+ @jaxtyped(typechecker=beartype.beartype)
196
+ @torch.inference_mode
197
+ def upsample(
198
+ x_WH: Int[Tensor, "width_ps height_ps"],
199
+ ) -> UInt8[Tensor, "width_px height_px"]:
200
+ return (
201
+ torch.nn.functional.interpolate(
202
+ x_WH.view((1, 1, 16, 16)).float(),
203
+ scale_factor=28,
204
+ )
205
+ .view((448, 448))
206
+ .type(torch.uint8)
207
+ )
208
+
209
+
210
  #######################
211
  # Inference Functions #
212
  #######################
213
 
214
 
215
  @beartype.beartype
216
+ def get_img(i: int) -> Example:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  img_sized = data.to_sized(data.get_img(i))
218
  seg_sized = data.to_sized(data.get_seg(i))
219
  seg_u8_sized = data.to_u8(seg_sized)
220
  seg_img_sized = data.u8_to_img(seg_u8_sized)
221
 
222
  return {
 
223
  "orig_url": data.img_to_base64(img_sized),
224
  "seg_url": data.img_to_base64(seg_img_sized),
225
+ "classes": data.to_classes(seg_u8_sized),
226
  }
227
 
228
 
229
  @beartype.beartype
230
  @torch.inference_mode
231
+ def get_sae_latents(img: Image.Image, patches: list[int]) -> list[SaeActivation]:
232
  """
233
  Given a particular cell, returns some highlighted images showing what feature fires most on this cell.
234
  """
 
238
  split_vit, vit_transform = modeling.load_vit(DEVICE)
239
  sae = load_sae(DEVICE)
240
 
241
+ x_BCWH = vit_transform(img.convert("RGB"))[None, ...].to(DEVICE)
 
 
242
 
243
  x_BPD = split_vit.forward_start(x_BCWH)
244
  x_BPD = (
 
288
  )
289
 
290
  examples.append({
 
291
  "orig_url": data.img_to_base64(img_sized),
292
  "highlighted_url": data.img_to_base64(highlighted_sized),
293
  "seg_url": data.img_to_base64(seg_img_sized),
294
+ "classes": data.to_classes(seg_u8_sized),
295
  })
296
 
297
  sae_activations.append({
 
302
  return sae_activations
303
 
304
 
305
+ @beartype.beartype
306
  @torch.inference_mode
307
+ def get_orig_preds(img: Image.Image) -> Example:
 
308
  split_vit, vit_transform = modeling.load_vit(DEVICE)
309
 
310
+ x_BCWH = vit_transform(img.convert("RGB"))[None, ...].to(DEVICE)
311
 
312
  x_BPD = split_vit.forward_start(x_BCWH)
313
  x_BPD = split_vit.forward_end(x_BPD)
 
318
  logits_WHC = clf(x_WHD)
319
 
320
  pred_WH = logits_WHC.argmax(axis=-1)
 
321
  return {
 
322
  "orig_url": data.img_to_base64(data.to_sized(img)),
323
  "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
324
+ "classes": data.to_classes(pred_WH),
325
  }
326
 
327
 
 
346
 
347
  @beartype.beartype
348
  @torch.inference_mode
349
+ def get_mod_preds(img: Image.Image, latents: dict[str, int | float]) -> Example:
350
  latents = {int(k): float(v) for k, v in latents.items()}
 
351
 
352
  split_vit, vit_transform = modeling.load_vit(DEVICE)
353
  sae = load_sae(DEVICE)
354
  _, top_values, _ = load_tensors()
355
  clf = load_clf()
356
 
357
+ x_BCWH = vit_transform(img.convert("RGB"))[None, ...].to(DEVICE)
358
  x_BPD = split_vit.forward_start(x_BCWH)
359
  x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
360
 
 
387
  pred_WH = logits_WHC.argmax(axis=-1)
388
  # pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
389
  return {
 
390
  "orig_url": data.img_to_base64(data.to_sized(img)),
391
  "seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
392
+ "classes": data.to_classes(pred_WH),
393
  }
394
 
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  with gr.Blocks() as demo:
397
  ###########
398
  # get-img #
 
415
 
416
  # Inputs
417
  patches_json = gr.JSON(label="Patches", value=[])
418
+ input_img = gr.Image(
419
+ label="Input Image",
420
+ sources=["upload", "clipboard"],
421
+ type="pil",
422
+ interactive=True,
423
+ )
424
  # Outputs
425
  get_sae_latents_out = gr.JSON(label="get_sae_latents_out", value=[])
426
 
427
  get_sae_latents_btn = gr.Button(value="Get SAE Latents")
428
  get_sae_latents_btn.click(
429
  get_sae_latents,
430
+ inputs=[input_img, patches_json],
431
  outputs=[get_sae_latents_out],
432
  api_name="get-sae-latents",
433
  )
 
442
  get_pred_labels_btn = gr.Button(value="Get Predictions")
443
  get_pred_labels_btn.click(
444
  get_orig_preds,
445
+ inputs=[input_img],
446
  outputs=[get_orig_preds_out],
447
  api_name="get-orig-preds",
448
  )
 
460
  get_pred_labels_btn = gr.Button(value="Get Predictions")
461
  get_pred_labels_btn.click(
462
  get_mod_preds,
463
+ inputs=[input_img, latents_json],
464
  outputs=[get_mod_preds_out],
465
  api_name="get-mod-preds",
466
  )
data.py CHANGED
@@ -8,7 +8,7 @@ import beartype
8
  import einops.layers.torch
9
  import numpy as np
10
  import requests
11
- from jaxtyping import UInt8, jaxtyped
12
  from PIL import Image
13
  from torch import Tensor
14
  from torchvision.transforms import v2
@@ -48,12 +48,13 @@ def make_colors() -> UInt8[np.ndarray, "n 3"]:
48
  random.Random(42).shuffle(colors)
49
  colors = np.array(colors, dtype=np.uint8)
50
 
51
- # Fixed colors for example 3122
52
  colors[2] = np.array([201, 249, 255], dtype=np.uint8)
53
  colors[4] = np.array([151, 204, 4], dtype=np.uint8)
54
  colors[13] = np.array([104, 139, 88], dtype=np.uint8)
55
  colors[16] = np.array([54, 48, 32], dtype=np.uint8)
56
  colors[26] = np.array([45, 125, 210], dtype=np.uint8)
 
57
  colors[46] = np.array([238, 185, 2], dtype=np.uint8)
58
  colors[52] = np.array([88, 91, 86], dtype=np.uint8)
59
  colors[72] = np.array([76, 46, 5], dtype=np.uint8)
@@ -97,6 +98,12 @@ def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image:
97
  return Image.fromarray(colored)
98
 
99
 
 
 
 
 
 
 
100
  @beartype.beartype
101
  def img_to_base64(img: Image.Image) -> str:
102
  buf = io.BytesIO()
 
8
  import einops.layers.torch
9
  import numpy as np
10
  import requests
11
+ from jaxtyping import Integer, UInt8, jaxtyped
12
  from PIL import Image
13
  from torch import Tensor
14
  from torchvision.transforms import v2
 
48
  random.Random(42).shuffle(colors)
49
  colors = np.array(colors, dtype=np.uint8)
50
 
51
+ # Fixed colors. Must be synced with Segmentation.elm.
52
  colors[2] = np.array([201, 249, 255], dtype=np.uint8)
53
  colors[4] = np.array([151, 204, 4], dtype=np.uint8)
54
  colors[13] = np.array([104, 139, 88], dtype=np.uint8)
55
  colors[16] = np.array([54, 48, 32], dtype=np.uint8)
56
  colors[26] = np.array([45, 125, 210], dtype=np.uint8)
57
+ colors[29] = np.array([116, 142, 84], dtype=np.uint8)
58
  colors[46] = np.array([238, 185, 2], dtype=np.uint8)
59
  colors[52] = np.array([88, 91, 86], dtype=np.uint8)
60
  colors[72] = np.array([76, 46, 5], dtype=np.uint8)
 
98
  return Image.fromarray(colored)
99
 
100
 
101
+ @jaxtyped(typechecker=beartype.beartype)
102
+ def to_classes(map: Integer[Tensor, "width height"]) -> list[int]:
103
+ # Integer is any signed or unsigned int: https://docs.kidger.site/jaxtyping/api/array/#dtype
104
+ return list(set(map.view(-1).tolist()))
105
+
106
+
107
  @beartype.beartype
108
  def img_to_base64(img: Image.Image) -> str:
109
  buf = io.BytesIO()