Samuel Stevens commited on
Commit
699b9c3
·
1 Parent(s): 0ab58fa

bug: SAE examples are not highlighted

Browse files
Files changed (2) hide show
  1. app.py +110 -157
  2. modeling.py +53 -0
app.py CHANGED
@@ -2,7 +2,7 @@ import functools
2
  import io
3
  import json
4
  import logging
5
- import os.path
6
  import pathlib
7
  import typing
8
 
@@ -10,17 +10,19 @@ import beartype
10
  import einops
11
  import einops.layers.torch
12
  import gradio as gr
 
13
  import saev.activations
14
  import saev.config
15
  import saev.nn
16
  import saev.visuals
17
  import torch
18
- from jaxtyping import Float, Int, UInt8, jaxtyped
19
- from PIL import Image
20
  from torch import Tensor
21
 
22
  import constants
23
  import data
 
24
 
25
  logger = logging.getLogger("app.py")
26
 
@@ -29,33 +31,26 @@ logger = logging.getLogger("app.py")
29
  ####################
30
 
31
 
32
- DEBUG = False
33
- """Whether we are debugging."""
34
-
35
- max_frequency = 1e-2
36
  """Maximum frequency. Any feature that fires more than this is ignored."""
37
 
38
- n_sae_latents = 3
39
- """Number of SAE latents to show."""
40
-
41
- n_sae_examples = 4
42
- """Number of SAE examples per latent to show."""
43
-
44
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
- """Hardware accelerator, if any."""
46
-
47
  RESIZE_SIZE = 512
48
  """Resize shorter size to this size in pixels."""
49
 
50
  CROP_SIZE = (448, 448)
51
  """Crop size in pixels."""
52
 
53
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  """Hardware accelerator, if any."""
55
 
56
  CWD = pathlib.Path(".")
57
  """Current working directory."""
58
 
 
 
 
 
 
59
 
60
  ##########
61
  # Models #
@@ -63,27 +58,7 @@ CWD = pathlib.Path(".")
63
 
64
 
65
  @functools.cache
66
- def load_vit() -> tuple[saev.activations.WrappedVisionTransformer, typing.Callable]:
67
- vit = (
68
- saev.activations.WrappedVisionTransformer(
69
- saev.config.Activations(
70
- model_family="dinov2",
71
- model_ckpt="dinov2_vitb14_reg",
72
- layers=[-2],
73
- n_patches_per_img=256,
74
- )
75
- )
76
- .to(DEVICE)
77
- .eval()
78
- )
79
- vit_transform = saev.activations.make_img_transform("dinov2", "dinov2_vitb14_reg")
80
- logger.info("Loaded ViT.")
81
-
82
- return vit, vit_transform
83
-
84
-
85
- @functools.cache
86
- def load_sae() -> saev.nn.SparseAutoencoder:
87
  """
88
  Loads a sparse autoencoder from disk.
89
  """
@@ -102,37 +77,12 @@ def load_clf() -> torch.nn.Module:
102
  buffer = io.BytesIO(fd.read())
103
 
104
  model = torch.nn.Linear(**kwargs)
105
- state_dict = torch.load(buffer, weights_only=True, map_location=device)
106
  model.load_state_dict(state_dict)
107
- model = model.to(device).eval()
108
  return model
109
 
110
 
111
- class RestOfDinoV2(torch.nn.Module):
112
- def __init__(self, *, n_end_layers: int):
113
- super().__init__()
114
- self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg")
115
- self.n_end_layers = n_end_layers
116
-
117
- def forward_start(self, x: Float[Tensor, "batch channels width height"]):
118
- x_BPD = self.vit.prepare_tokens_with_masks(x)
119
- for blk in self.vit.blocks[: -self.n_end_layers]:
120
- x_BPD = blk(x_BPD)
121
-
122
- return x_BPD
123
-
124
- def forward_end(self, x_BPD: Float[Tensor, "batch n_patches dim"]):
125
- for blk in self.vit.blocks[-self.n_end_layers :]:
126
- x_BPD = blk(x_BPD)
127
-
128
- x_BPD = self.vit.norm(x_BPD)
129
- return x_BPD[:, self.vit.num_register_tokens + 1 :]
130
-
131
-
132
- rest_of_vit = RestOfDinoV2(n_end_layers=1)
133
- rest_of_vit = rest_of_vit.to(device)
134
-
135
-
136
  ####################
137
  # Global Variables #
138
  ####################
@@ -143,13 +93,23 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
143
  return torch.load(path, weights_only=True, map_location="cpu")
144
 
145
 
146
- # top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
147
- # top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
148
- # sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
 
 
 
 
 
 
 
 
 
149
 
 
 
150
 
151
- # mask = torch.ones((sae.cfg.d_sae), dtype=bool)
152
- # mask = mask & (sparsity < max_frequency)
153
 
154
 
155
  ############
@@ -157,37 +117,42 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
157
  ############
158
 
159
 
160
- # in1k_dataset = saev.activations.get_dataset(
161
- # saev.config.ImagenetDataset(),
162
- # img_transform=v2.Compose([
163
- # v2.Resize(size=(512, 512)),
164
- # v2.CenterCrop(size=(448, 448)),
165
- # ]),
166
- # )
167
-
168
-
169
- # acts_dataset = saev.activations.Dataset(
170
- # saev.config.DataLoad(
171
- # shard_root="/local/scratch/stevens.994/cache/saev/a1f842330bb568b2fb05c15d4fa4252fb7f5204837335000d9fd420f120cd03e",
172
- # scale_mean=not DEBUG,
173
- # scale_norm=not DEBUG,
174
- # layer=-2,
175
- # )
176
- # )
177
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- # vit_dataset = saev.activations.Ade20k(
180
- # saev.config.Ade20kDataset(
181
- # root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/"
182
- # ),
183
- # img_transform=v2.Compose([
184
- # v2.Resize(size=(256, 256)),
185
- # v2.CenterCrop(size=(224, 224)),
186
- # v2.ToImage(),
187
- # v2.ToDtype(torch.float32, scale=True),
188
- # v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
189
- # ]),
190
- # )
191
 
192
 
193
  #######################
@@ -202,12 +167,14 @@ class Example(typing.TypedDict):
202
  Used to store examples of SAE latent activations for visualization.
203
  """
204
 
 
 
205
  orig_url: str
206
  """The URL or path to access the original example image."""
207
  highlighted_url: str
208
  """The URL or path to access the SAE-highlighted image."""
209
- index: int
210
- """Dataset index."""
211
 
212
 
213
  @beartype.beartype
@@ -249,64 +216,73 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
249
  if not patches:
250
  return []
251
 
252
- vit, vit_transform = load_vit()
253
- sae = load_sae()
254
 
255
  img = data.get_image(image_i)
256
 
257
- x = vit_transform(img)[None, ...].to(DEVICE)
258
 
259
- _, vit_acts_BLPD = vit(x)
260
- vit_acts_PD = (
261
- vit_acts_BLPD[0, 0, 1:].to(DEVICE).clamp(-1e-5, 1e5)
262
- - (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE)
263
  ) / constants.DINOV2_IMAGENET1K_SCALAR
264
 
265
- _, f_x_PS, _ = sae(vit_acts_PD)
266
- # Ignore [CLS] token and get just the requested latents.
267
- acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
268
- logger.info("Got SAE activations.")
269
-
270
- top_img_i, top_values = load_tensors(model_cfg)
271
- logger.info("Loaded top SAE activations for '%s'.", model_name)
272
 
273
- vit_acts_MD = torch.stack([
274
- acts_dataset[image_i * acts_dataset.metadata.n_patches_per_img + i]["act"]
275
- for i in patches
276
- ]).to(device)
277
 
278
- _, f_x_MS, _ = sae(vit_acts_MD)
279
- f_x_S = f_x_MS.sum(axis=0)
280
 
281
  latents = torch.argsort(f_x_S, descending=True).cpu()
282
- latents = latents[mask[latents]][:n_sae_latents].tolist()
283
 
284
- images = []
285
  for latent in latents:
286
- elems, seen_i_im = [], set()
287
  for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
288
  if i_im in seen_i_im:
289
  continue
290
 
291
- example = in1k_dataset[i_im]
292
- elems.append(
293
- saev.visuals.GridElement(example["image"], example["label"], values_p)
294
- )
295
  seen_i_im.add(i_im)
 
 
296
 
297
  # How to scale values.
298
  upper = None
299
  if top_values[latent].numel() > 0:
300
  upper = top_values[latent].max().item()
301
 
302
- latent_images = [make_img(elem, upper=upper) for elem in elems[:n_sae_examples]]
 
 
 
 
 
 
303
 
304
- while len(latent_images) < n_sae_examples:
305
- latent_images += [None]
 
 
 
 
 
 
 
 
306
 
307
- images.extend(latent_images)
 
 
 
308
 
309
- return images + latents
310
 
311
 
312
  @torch.inference_mode
@@ -416,29 +392,6 @@ def upsample(
416
  )
417
 
418
 
419
- @beartype.beartype
420
- def make_img(
421
- elem: saev.visuals.GridElement, *, upper: float | None = None
422
- ) -> Image.Image:
423
- # Resize to 256x256 and crop to 224x224
424
- resize_size_px = (512, 512)
425
- resize_w_px, resize_h_px = resize_size_px
426
- crop_size_px = (448, 448)
427
- crop_w_px, crop_h_px = crop_size_px
428
- crop_coords_px = (
429
- (resize_w_px - crop_w_px) // 2,
430
- (resize_h_px - crop_h_px) // 2,
431
- (resize_w_px + crop_w_px) // 2,
432
- (resize_h_px + crop_h_px) // 2,
433
- )
434
-
435
- img = elem.img.resize(resize_size_px).crop(crop_coords_px)
436
- img = saev.imaging.add_highlights(
437
- img, elem.patches.numpy(), upper=upper, opacity=0.5
438
- )
439
- return img
440
-
441
-
442
  with gr.Blocks() as demo:
443
  image_number = gr.Number(label="Validation Example")
444
 
 
2
  import io
3
  import json
4
  import logging
5
+ import math
6
  import pathlib
7
  import typing
8
 
 
10
  import einops
11
  import einops.layers.torch
12
  import gradio as gr
13
+ import numpy as np
14
  import saev.activations
15
  import saev.config
16
  import saev.nn
17
  import saev.visuals
18
  import torch
19
+ from jaxtyping import Bool, Float, Int, UInt8, jaxtyped
20
+ from PIL import Image, ImageDraw
21
  from torch import Tensor
22
 
23
  import constants
24
  import data
25
+ import modeling
26
 
27
  logger = logging.getLogger("app.py")
28
 
 
31
  ####################
32
 
33
 
34
+ MAX_FREQ = 1e-2
 
 
 
35
  """Maximum frequency. Any feature that fires more than this is ignored."""
36
 
 
 
 
 
 
 
 
 
 
37
  RESIZE_SIZE = 512
38
  """Resize shorter size to this size in pixels."""
39
 
40
  CROP_SIZE = (448, 448)
41
  """Crop size in pixels."""
42
 
43
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
44
  """Hardware accelerator, if any."""
45
 
46
  CWD = pathlib.Path(".")
47
  """Current working directory."""
48
 
49
+ N_SAE_LATENTS = 3
50
+ """Number of SAE latents to show."""
51
+
52
+ N_LATENT_EXAMPLES = 4
53
+ """Number of examples per SAE latent to show."""
54
 
55
  ##########
56
  # Models #
 
58
 
59
 
60
  @functools.cache
61
+ def load_sae(device: str) -> saev.nn.SparseAutoencoder:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  """
63
  Loads a sparse autoencoder from disk.
64
  """
 
77
  buffer = io.BytesIO(fd.read())
78
 
79
  model = torch.nn.Linear(**kwargs)
80
+ state_dict = torch.load(buffer, weights_only=True, map_location=DEVICE)
81
  model.load_state_dict(state_dict)
82
+ model = model.to(DEVICE).eval()
83
  return model
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ####################
87
  # Global Variables #
88
  ####################
 
93
  return torch.load(path, weights_only=True, map_location="cpu")
94
 
95
 
96
+ @functools.cache
97
+ def load_tensors() -> tuple[
98
+ Int[Tensor, "d_sae k"],
99
+ UInt8[Tensor, "d_sae k n_patches"],
100
+ Bool[Tensor, " d_sae"],
101
+ ]:
102
+ """
103
+ Loads the tensors for the SAE for ADE20K.
104
+ """
105
+ top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
106
+ top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
107
+ sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
108
 
109
+ mask = torch.ones(sparsity.shape, dtype=bool)
110
+ mask = mask & (sparsity < MAX_FREQ)
111
 
112
+ return top_img_i, top_values, mask
 
113
 
114
 
115
  ############
 
117
  ############
118
 
119
 
120
+ @jaxtyped(typechecker=beartype.beartype)
121
+ def add_highlights(
122
+ img: Image.Image,
123
+ patches: Float[np.ndarray, " n_patches"],
124
+ *,
125
+ upper: int | None = None,
126
+ opacity: float = 0.9,
127
+ ) -> Image.Image:
128
+ breakpoint()
129
+ if not len(patches):
130
+ return img
131
+
132
+ iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches)))
133
+ iw_px, ih_px = img.size
134
+ pw_px, ph_px = iw_px // iw_np, ih_px // ih_np
135
+ assert iw_np * ih_np == len(patches)
136
+
137
+ # Create a transparent overlay
138
+ overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
139
+ draw = ImageDraw.Draw(overlay)
140
+
141
+ # Using semi-transparent red (255, 0, 0, alpha)
142
+ for p, val in enumerate(patches):
143
+ assert upper is not None
144
+ val /= upper + 1e-9
145
+ x_np, y_np = p % iw_np, p // ih_np
146
+ draw.rectangle(
147
+ [
148
+ (x_np * pw_px, y_np * ph_px),
149
+ (x_np * pw_px + pw_px, y_np * ph_px + ph_px),
150
+ ],
151
+ fill=(int(val * 256), 0, 0, int(opacity * val * 256)),
152
+ )
153
 
154
+ # Composite the original image and the overlay
155
+ return Image.alpha_composite(img.convert("RGBA"), overlay)
 
 
 
 
 
 
 
 
 
 
156
 
157
 
158
  #######################
 
167
  Used to store examples of SAE latent activations for visualization.
168
  """
169
 
170
+ index: int
171
+ """Dataset index."""
172
  orig_url: str
173
  """The URL or path to access the original example image."""
174
  highlighted_url: str
175
  """The URL or path to access the SAE-highlighted image."""
176
+ seg_url: str
177
+ """Base64-encoded version of the colored segmentation map."""
178
 
179
 
180
  @beartype.beartype
 
216
  if not patches:
217
  return []
218
 
219
+ split_vit, vit_transform = modeling.load_vit(DEVICE)
220
+ sae = load_sae(DEVICE)
221
 
222
  img = data.get_image(image_i)
223
 
224
+ x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
225
 
226
+ x_BPD = split_vit.forward_start(x_BCWH)
227
+ x_BPD = (
228
+ x_BPD.clamp(-1e-5, 1e5) - (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE)
 
229
  ) / constants.DINOV2_IMAGENET1K_SCALAR
230
 
231
+ # Need to pick out the right patches
232
+ # + 1 + 4 for 1 [CLS] token and 4 register tokens
233
+ x_PD = x_BPD[0, [p + 1 + 4 for p in patches]]
234
+ _, f_x_PS, _ = sae(x_PD)
 
 
 
235
 
236
+ f_x_S = einops.reduce(f_x_PS, "patches n_latents -> n_latents", "sum")
237
+ logger.info("Got SAE activations.")
 
 
238
 
239
+ top_img_i, top_values, mask = load_tensors()
 
240
 
241
  latents = torch.argsort(f_x_S, descending=True).cpu()
242
+ latents = latents[mask[latents]][:N_SAE_LATENTS].tolist()
243
 
244
+ sae_activations = []
245
  for latent in latents:
246
+ pairs, seen_i_im = [], set()
247
  for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
248
  if i_im in seen_i_im:
249
  continue
250
 
251
+ pairs.append((i_im, values_p))
 
 
 
252
  seen_i_im.add(i_im)
253
+ if len(pairs) >= N_LATENT_EXAMPLES:
254
+ break
255
 
256
  # How to scale values.
257
  upper = None
258
  if top_values[latent].numel() > 0:
259
  upper = top_values[latent].max().item()
260
 
261
+ examples = []
262
+ for i_im, values_p in pairs:
263
+ seg_sized = data.to_sized(data.get_seg(i_im))
264
+ img_sized = data.to_sized(data.get_image(i_im))
265
+
266
+ seg_u8_sized = data.to_u8(seg_sized)
267
+ seg_img_sized = data.u8_to_img(seg_u8_sized)
268
 
269
+ highlighted_sized = add_highlights(
270
+ img_sized, values_p.float().numpy(), upper=upper
271
+ )
272
+
273
+ examples.append({
274
+ "index": i_im,
275
+ "orig_url": data.img_to_base64(img_sized),
276
+ "highlighted_url": data.img_to_base64(highlighted_sized),
277
+ "seg_url": data.img_to_base64(seg_img_sized),
278
+ })
279
 
280
+ sae_activations.append({
281
+ "latent": latent,
282
+ "examples": examples,
283
+ })
284
 
285
+ return sae_activations
286
 
287
 
288
  @torch.inference_mode
 
392
  )
393
 
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  with gr.Blocks() as demo:
396
  image_number = gr.Number(label="Validation Example")
397
 
modeling.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import typing
4
+
5
+ import beartype
6
+ import torch
7
+ from jaxtyping import Float, jaxtyped
8
+ from torch import Tensor
9
+ from torchvision.transforms import v2
10
+
11
+ logger = logging.getLogger("modeling.py")
12
+
13
+
14
+ @jaxtyped(typechecker=beartype.beartype)
15
+ class SplitDinov2(torch.nn.Module):
16
+ def __init__(self, *, split_at: int):
17
+ super().__init__()
18
+
19
+ self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval()
20
+ self.split_at = split_at
21
+
22
+ def forward_start(
23
+ self, x: Float[Tensor, "batch channels width height"]
24
+ ) -> Float[Tensor, "batch patches dim"]:
25
+ x_BPD = self.vit.prepare_tokens_with_masks(x)
26
+ for blk in self.vit.blocks[: self.split_at]:
27
+ x_BPD = blk(x_BPD)
28
+
29
+ return x_BPD
30
+
31
+ def forward_end(
32
+ self, x_BPD: Float[Tensor, "batch n_patches dim"]
33
+ ) -> Float[Tensor, "batch patches dim"]:
34
+ for blk in self.vit.blocks[-self.split_at :]:
35
+ x_BPD = blk(x_BPD)
36
+
37
+ x_BPD = self.vit.norm(x_BPD)
38
+ return x_BPD[:, self.vit.num_register_tokens + 1 :]
39
+
40
+
41
+ @functools.cache
42
+ def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]:
43
+ vit = SplitDinov2(split_at=11).to(device)
44
+ vit_transform = v2.Compose([
45
+ v2.Resize(size=(256, 256)),
46
+ v2.CenterCrop(size=(224, 224)),
47
+ v2.ToImage(),
48
+ v2.ToDtype(torch.float32, scale=True),
49
+ v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
50
+ ])
51
+ logger.info("Loaded ViT.")
52
+
53
+ return vit, vit_transform