Samuel Stevens
commited on
Commit
·
699b9c3
1
Parent(s):
0ab58fa
bug: SAE examples are not highlighted
Browse files- app.py +110 -157
- modeling.py +53 -0
app.py
CHANGED
@@ -2,7 +2,7 @@ import functools
|
|
2 |
import io
|
3 |
import json
|
4 |
import logging
|
5 |
-
import
|
6 |
import pathlib
|
7 |
import typing
|
8 |
|
@@ -10,17 +10,19 @@ import beartype
|
|
10 |
import einops
|
11 |
import einops.layers.torch
|
12 |
import gradio as gr
|
|
|
13 |
import saev.activations
|
14 |
import saev.config
|
15 |
import saev.nn
|
16 |
import saev.visuals
|
17 |
import torch
|
18 |
-
from jaxtyping import Float, Int, UInt8, jaxtyped
|
19 |
-
from PIL import Image
|
20 |
from torch import Tensor
|
21 |
|
22 |
import constants
|
23 |
import data
|
|
|
24 |
|
25 |
logger = logging.getLogger("app.py")
|
26 |
|
@@ -29,33 +31,26 @@ logger = logging.getLogger("app.py")
|
|
29 |
####################
|
30 |
|
31 |
|
32 |
-
|
33 |
-
"""Whether we are debugging."""
|
34 |
-
|
35 |
-
max_frequency = 1e-2
|
36 |
"""Maximum frequency. Any feature that fires more than this is ignored."""
|
37 |
|
38 |
-
n_sae_latents = 3
|
39 |
-
"""Number of SAE latents to show."""
|
40 |
-
|
41 |
-
n_sae_examples = 4
|
42 |
-
"""Number of SAE examples per latent to show."""
|
43 |
-
|
44 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
-
"""Hardware accelerator, if any."""
|
46 |
-
|
47 |
RESIZE_SIZE = 512
|
48 |
"""Resize shorter size to this size in pixels."""
|
49 |
|
50 |
CROP_SIZE = (448, 448)
|
51 |
"""Crop size in pixels."""
|
52 |
|
53 |
-
DEVICE =
|
54 |
"""Hardware accelerator, if any."""
|
55 |
|
56 |
CWD = pathlib.Path(".")
|
57 |
"""Current working directory."""
|
58 |
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
##########
|
61 |
# Models #
|
@@ -63,27 +58,7 @@ CWD = pathlib.Path(".")
|
|
63 |
|
64 |
|
65 |
@functools.cache
|
66 |
-
def
|
67 |
-
vit = (
|
68 |
-
saev.activations.WrappedVisionTransformer(
|
69 |
-
saev.config.Activations(
|
70 |
-
model_family="dinov2",
|
71 |
-
model_ckpt="dinov2_vitb14_reg",
|
72 |
-
layers=[-2],
|
73 |
-
n_patches_per_img=256,
|
74 |
-
)
|
75 |
-
)
|
76 |
-
.to(DEVICE)
|
77 |
-
.eval()
|
78 |
-
)
|
79 |
-
vit_transform = saev.activations.make_img_transform("dinov2", "dinov2_vitb14_reg")
|
80 |
-
logger.info("Loaded ViT.")
|
81 |
-
|
82 |
-
return vit, vit_transform
|
83 |
-
|
84 |
-
|
85 |
-
@functools.cache
|
86 |
-
def load_sae() -> saev.nn.SparseAutoencoder:
|
87 |
"""
|
88 |
Loads a sparse autoencoder from disk.
|
89 |
"""
|
@@ -102,37 +77,12 @@ def load_clf() -> torch.nn.Module:
|
|
102 |
buffer = io.BytesIO(fd.read())
|
103 |
|
104 |
model = torch.nn.Linear(**kwargs)
|
105 |
-
state_dict = torch.load(buffer, weights_only=True, map_location=
|
106 |
model.load_state_dict(state_dict)
|
107 |
-
model = model.to(
|
108 |
return model
|
109 |
|
110 |
|
111 |
-
class RestOfDinoV2(torch.nn.Module):
|
112 |
-
def __init__(self, *, n_end_layers: int):
|
113 |
-
super().__init__()
|
114 |
-
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg")
|
115 |
-
self.n_end_layers = n_end_layers
|
116 |
-
|
117 |
-
def forward_start(self, x: Float[Tensor, "batch channels width height"]):
|
118 |
-
x_BPD = self.vit.prepare_tokens_with_masks(x)
|
119 |
-
for blk in self.vit.blocks[: -self.n_end_layers]:
|
120 |
-
x_BPD = blk(x_BPD)
|
121 |
-
|
122 |
-
return x_BPD
|
123 |
-
|
124 |
-
def forward_end(self, x_BPD: Float[Tensor, "batch n_patches dim"]):
|
125 |
-
for blk in self.vit.blocks[-self.n_end_layers :]:
|
126 |
-
x_BPD = blk(x_BPD)
|
127 |
-
|
128 |
-
x_BPD = self.vit.norm(x_BPD)
|
129 |
-
return x_BPD[:, self.vit.num_register_tokens + 1 :]
|
130 |
-
|
131 |
-
|
132 |
-
rest_of_vit = RestOfDinoV2(n_end_layers=1)
|
133 |
-
rest_of_vit = rest_of_vit.to(device)
|
134 |
-
|
135 |
-
|
136 |
####################
|
137 |
# Global Variables #
|
138 |
####################
|
@@ -143,13 +93,23 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
|
|
143 |
return torch.load(path, weights_only=True, map_location="cpu")
|
144 |
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
# mask = mask & (sparsity < max_frequency)
|
153 |
|
154 |
|
155 |
############
|
@@ -157,37 +117,42 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
|
|
157 |
############
|
158 |
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
#
|
180 |
-
|
181 |
-
# root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/"
|
182 |
-
# ),
|
183 |
-
# img_transform=v2.Compose([
|
184 |
-
# v2.Resize(size=(256, 256)),
|
185 |
-
# v2.CenterCrop(size=(224, 224)),
|
186 |
-
# v2.ToImage(),
|
187 |
-
# v2.ToDtype(torch.float32, scale=True),
|
188 |
-
# v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
|
189 |
-
# ]),
|
190 |
-
# )
|
191 |
|
192 |
|
193 |
#######################
|
@@ -202,12 +167,14 @@ class Example(typing.TypedDict):
|
|
202 |
Used to store examples of SAE latent activations for visualization.
|
203 |
"""
|
204 |
|
|
|
|
|
205 |
orig_url: str
|
206 |
"""The URL or path to access the original example image."""
|
207 |
highlighted_url: str
|
208 |
"""The URL or path to access the SAE-highlighted image."""
|
209 |
-
|
210 |
-
"""
|
211 |
|
212 |
|
213 |
@beartype.beartype
|
@@ -249,64 +216,73 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
|
|
249 |
if not patches:
|
250 |
return []
|
251 |
|
252 |
-
|
253 |
-
sae = load_sae()
|
254 |
|
255 |
img = data.get_image(image_i)
|
256 |
|
257 |
-
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
- (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE)
|
263 |
) / constants.DINOV2_IMAGENET1K_SCALAR
|
264 |
|
265 |
-
|
266 |
-
#
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
top_img_i, top_values = load_tensors(model_cfg)
|
271 |
-
logger.info("Loaded top SAE activations for '%s'.", model_name)
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
for i in patches
|
276 |
-
]).to(device)
|
277 |
|
278 |
-
|
279 |
-
f_x_S = f_x_MS.sum(axis=0)
|
280 |
|
281 |
latents = torch.argsort(f_x_S, descending=True).cpu()
|
282 |
-
latents = latents[mask[latents]][:
|
283 |
|
284 |
-
|
285 |
for latent in latents:
|
286 |
-
|
287 |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
|
288 |
if i_im in seen_i_im:
|
289 |
continue
|
290 |
|
291 |
-
|
292 |
-
elems.append(
|
293 |
-
saev.visuals.GridElement(example["image"], example["label"], values_p)
|
294 |
-
)
|
295 |
seen_i_im.add(i_im)
|
|
|
|
|
296 |
|
297 |
# How to scale values.
|
298 |
upper = None
|
299 |
if top_values[latent].numel() > 0:
|
300 |
upper = top_values[latent].max().item()
|
301 |
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
-
|
|
|
|
|
|
|
308 |
|
309 |
-
return
|
310 |
|
311 |
|
312 |
@torch.inference_mode
|
@@ -416,29 +392,6 @@ def upsample(
|
|
416 |
)
|
417 |
|
418 |
|
419 |
-
@beartype.beartype
|
420 |
-
def make_img(
|
421 |
-
elem: saev.visuals.GridElement, *, upper: float | None = None
|
422 |
-
) -> Image.Image:
|
423 |
-
# Resize to 256x256 and crop to 224x224
|
424 |
-
resize_size_px = (512, 512)
|
425 |
-
resize_w_px, resize_h_px = resize_size_px
|
426 |
-
crop_size_px = (448, 448)
|
427 |
-
crop_w_px, crop_h_px = crop_size_px
|
428 |
-
crop_coords_px = (
|
429 |
-
(resize_w_px - crop_w_px) // 2,
|
430 |
-
(resize_h_px - crop_h_px) // 2,
|
431 |
-
(resize_w_px + crop_w_px) // 2,
|
432 |
-
(resize_h_px + crop_h_px) // 2,
|
433 |
-
)
|
434 |
-
|
435 |
-
img = elem.img.resize(resize_size_px).crop(crop_coords_px)
|
436 |
-
img = saev.imaging.add_highlights(
|
437 |
-
img, elem.patches.numpy(), upper=upper, opacity=0.5
|
438 |
-
)
|
439 |
-
return img
|
440 |
-
|
441 |
-
|
442 |
with gr.Blocks() as demo:
|
443 |
image_number = gr.Number(label="Validation Example")
|
444 |
|
|
|
2 |
import io
|
3 |
import json
|
4 |
import logging
|
5 |
+
import math
|
6 |
import pathlib
|
7 |
import typing
|
8 |
|
|
|
10 |
import einops
|
11 |
import einops.layers.torch
|
12 |
import gradio as gr
|
13 |
+
import numpy as np
|
14 |
import saev.activations
|
15 |
import saev.config
|
16 |
import saev.nn
|
17 |
import saev.visuals
|
18 |
import torch
|
19 |
+
from jaxtyping import Bool, Float, Int, UInt8, jaxtyped
|
20 |
+
from PIL import Image, ImageDraw
|
21 |
from torch import Tensor
|
22 |
|
23 |
import constants
|
24 |
import data
|
25 |
+
import modeling
|
26 |
|
27 |
logger = logging.getLogger("app.py")
|
28 |
|
|
|
31 |
####################
|
32 |
|
33 |
|
34 |
+
MAX_FREQ = 1e-2
|
|
|
|
|
|
|
35 |
"""Maximum frequency. Any feature that fires more than this is ignored."""
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
RESIZE_SIZE = 512
|
38 |
"""Resize shorter size to this size in pixels."""
|
39 |
|
40 |
CROP_SIZE = (448, 448)
|
41 |
"""Crop size in pixels."""
|
42 |
|
43 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
"""Hardware accelerator, if any."""
|
45 |
|
46 |
CWD = pathlib.Path(".")
|
47 |
"""Current working directory."""
|
48 |
|
49 |
+
N_SAE_LATENTS = 3
|
50 |
+
"""Number of SAE latents to show."""
|
51 |
+
|
52 |
+
N_LATENT_EXAMPLES = 4
|
53 |
+
"""Number of examples per SAE latent to show."""
|
54 |
|
55 |
##########
|
56 |
# Models #
|
|
|
58 |
|
59 |
|
60 |
@functools.cache
|
61 |
+
def load_sae(device: str) -> saev.nn.SparseAutoencoder:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
"""
|
63 |
Loads a sparse autoencoder from disk.
|
64 |
"""
|
|
|
77 |
buffer = io.BytesIO(fd.read())
|
78 |
|
79 |
model = torch.nn.Linear(**kwargs)
|
80 |
+
state_dict = torch.load(buffer, weights_only=True, map_location=DEVICE)
|
81 |
model.load_state_dict(state_dict)
|
82 |
+
model = model.to(DEVICE).eval()
|
83 |
return model
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
####################
|
87 |
# Global Variables #
|
88 |
####################
|
|
|
93 |
return torch.load(path, weights_only=True, map_location="cpu")
|
94 |
|
95 |
|
96 |
+
@functools.cache
|
97 |
+
def load_tensors() -> tuple[
|
98 |
+
Int[Tensor, "d_sae k"],
|
99 |
+
UInt8[Tensor, "d_sae k n_patches"],
|
100 |
+
Bool[Tensor, " d_sae"],
|
101 |
+
]:
|
102 |
+
"""
|
103 |
+
Loads the tensors for the SAE for ADE20K.
|
104 |
+
"""
|
105 |
+
top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
|
106 |
+
top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
|
107 |
+
sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
|
108 |
|
109 |
+
mask = torch.ones(sparsity.shape, dtype=bool)
|
110 |
+
mask = mask & (sparsity < MAX_FREQ)
|
111 |
|
112 |
+
return top_img_i, top_values, mask
|
|
|
113 |
|
114 |
|
115 |
############
|
|
|
117 |
############
|
118 |
|
119 |
|
120 |
+
@jaxtyped(typechecker=beartype.beartype)
|
121 |
+
def add_highlights(
|
122 |
+
img: Image.Image,
|
123 |
+
patches: Float[np.ndarray, " n_patches"],
|
124 |
+
*,
|
125 |
+
upper: int | None = None,
|
126 |
+
opacity: float = 0.9,
|
127 |
+
) -> Image.Image:
|
128 |
+
breakpoint()
|
129 |
+
if not len(patches):
|
130 |
+
return img
|
131 |
+
|
132 |
+
iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches)))
|
133 |
+
iw_px, ih_px = img.size
|
134 |
+
pw_px, ph_px = iw_px // iw_np, ih_px // ih_np
|
135 |
+
assert iw_np * ih_np == len(patches)
|
136 |
+
|
137 |
+
# Create a transparent overlay
|
138 |
+
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
|
139 |
+
draw = ImageDraw.Draw(overlay)
|
140 |
+
|
141 |
+
# Using semi-transparent red (255, 0, 0, alpha)
|
142 |
+
for p, val in enumerate(patches):
|
143 |
+
assert upper is not None
|
144 |
+
val /= upper + 1e-9
|
145 |
+
x_np, y_np = p % iw_np, p // ih_np
|
146 |
+
draw.rectangle(
|
147 |
+
[
|
148 |
+
(x_np * pw_px, y_np * ph_px),
|
149 |
+
(x_np * pw_px + pw_px, y_np * ph_px + ph_px),
|
150 |
+
],
|
151 |
+
fill=(int(val * 256), 0, 0, int(opacity * val * 256)),
|
152 |
+
)
|
153 |
|
154 |
+
# Composite the original image and the overlay
|
155 |
+
return Image.alpha_composite(img.convert("RGBA"), overlay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
|
158 |
#######################
|
|
|
167 |
Used to store examples of SAE latent activations for visualization.
|
168 |
"""
|
169 |
|
170 |
+
index: int
|
171 |
+
"""Dataset index."""
|
172 |
orig_url: str
|
173 |
"""The URL or path to access the original example image."""
|
174 |
highlighted_url: str
|
175 |
"""The URL or path to access the SAE-highlighted image."""
|
176 |
+
seg_url: str
|
177 |
+
"""Base64-encoded version of the colored segmentation map."""
|
178 |
|
179 |
|
180 |
@beartype.beartype
|
|
|
216 |
if not patches:
|
217 |
return []
|
218 |
|
219 |
+
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
220 |
+
sae = load_sae(DEVICE)
|
221 |
|
222 |
img = data.get_image(image_i)
|
223 |
|
224 |
+
x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
|
225 |
|
226 |
+
x_BPD = split_vit.forward_start(x_BCWH)
|
227 |
+
x_BPD = (
|
228 |
+
x_BPD.clamp(-1e-5, 1e5) - (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE)
|
|
|
229 |
) / constants.DINOV2_IMAGENET1K_SCALAR
|
230 |
|
231 |
+
# Need to pick out the right patches
|
232 |
+
# + 1 + 4 for 1 [CLS] token and 4 register tokens
|
233 |
+
x_PD = x_BPD[0, [p + 1 + 4 for p in patches]]
|
234 |
+
_, f_x_PS, _ = sae(x_PD)
|
|
|
|
|
|
|
235 |
|
236 |
+
f_x_S = einops.reduce(f_x_PS, "patches n_latents -> n_latents", "sum")
|
237 |
+
logger.info("Got SAE activations.")
|
|
|
|
|
238 |
|
239 |
+
top_img_i, top_values, mask = load_tensors()
|
|
|
240 |
|
241 |
latents = torch.argsort(f_x_S, descending=True).cpu()
|
242 |
+
latents = latents[mask[latents]][:N_SAE_LATENTS].tolist()
|
243 |
|
244 |
+
sae_activations = []
|
245 |
for latent in latents:
|
246 |
+
pairs, seen_i_im = [], set()
|
247 |
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]):
|
248 |
if i_im in seen_i_im:
|
249 |
continue
|
250 |
|
251 |
+
pairs.append((i_im, values_p))
|
|
|
|
|
|
|
252 |
seen_i_im.add(i_im)
|
253 |
+
if len(pairs) >= N_LATENT_EXAMPLES:
|
254 |
+
break
|
255 |
|
256 |
# How to scale values.
|
257 |
upper = None
|
258 |
if top_values[latent].numel() > 0:
|
259 |
upper = top_values[latent].max().item()
|
260 |
|
261 |
+
examples = []
|
262 |
+
for i_im, values_p in pairs:
|
263 |
+
seg_sized = data.to_sized(data.get_seg(i_im))
|
264 |
+
img_sized = data.to_sized(data.get_image(i_im))
|
265 |
+
|
266 |
+
seg_u8_sized = data.to_u8(seg_sized)
|
267 |
+
seg_img_sized = data.u8_to_img(seg_u8_sized)
|
268 |
|
269 |
+
highlighted_sized = add_highlights(
|
270 |
+
img_sized, values_p.float().numpy(), upper=upper
|
271 |
+
)
|
272 |
+
|
273 |
+
examples.append({
|
274 |
+
"index": i_im,
|
275 |
+
"orig_url": data.img_to_base64(img_sized),
|
276 |
+
"highlighted_url": data.img_to_base64(highlighted_sized),
|
277 |
+
"seg_url": data.img_to_base64(seg_img_sized),
|
278 |
+
})
|
279 |
|
280 |
+
sae_activations.append({
|
281 |
+
"latent": latent,
|
282 |
+
"examples": examples,
|
283 |
+
})
|
284 |
|
285 |
+
return sae_activations
|
286 |
|
287 |
|
288 |
@torch.inference_mode
|
|
|
392 |
)
|
393 |
|
394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
with gr.Blocks() as demo:
|
396 |
image_number = gr.Number(label="Validation Example")
|
397 |
|
modeling.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import logging
|
3 |
+
import typing
|
4 |
+
|
5 |
+
import beartype
|
6 |
+
import torch
|
7 |
+
from jaxtyping import Float, jaxtyped
|
8 |
+
from torch import Tensor
|
9 |
+
from torchvision.transforms import v2
|
10 |
+
|
11 |
+
logger = logging.getLogger("modeling.py")
|
12 |
+
|
13 |
+
|
14 |
+
@jaxtyped(typechecker=beartype.beartype)
|
15 |
+
class SplitDinov2(torch.nn.Module):
|
16 |
+
def __init__(self, *, split_at: int):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval()
|
20 |
+
self.split_at = split_at
|
21 |
+
|
22 |
+
def forward_start(
|
23 |
+
self, x: Float[Tensor, "batch channels width height"]
|
24 |
+
) -> Float[Tensor, "batch patches dim"]:
|
25 |
+
x_BPD = self.vit.prepare_tokens_with_masks(x)
|
26 |
+
for blk in self.vit.blocks[: self.split_at]:
|
27 |
+
x_BPD = blk(x_BPD)
|
28 |
+
|
29 |
+
return x_BPD
|
30 |
+
|
31 |
+
def forward_end(
|
32 |
+
self, x_BPD: Float[Tensor, "batch n_patches dim"]
|
33 |
+
) -> Float[Tensor, "batch patches dim"]:
|
34 |
+
for blk in self.vit.blocks[-self.split_at :]:
|
35 |
+
x_BPD = blk(x_BPD)
|
36 |
+
|
37 |
+
x_BPD = self.vit.norm(x_BPD)
|
38 |
+
return x_BPD[:, self.vit.num_register_tokens + 1 :]
|
39 |
+
|
40 |
+
|
41 |
+
@functools.cache
|
42 |
+
def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]:
|
43 |
+
vit = SplitDinov2(split_at=11).to(device)
|
44 |
+
vit_transform = v2.Compose([
|
45 |
+
v2.Resize(size=(256, 256)),
|
46 |
+
v2.CenterCrop(size=(224, 224)),
|
47 |
+
v2.ToImage(),
|
48 |
+
v2.ToDtype(torch.float32, scale=True),
|
49 |
+
v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
|
50 |
+
])
|
51 |
+
logger.info("Loaded ViT.")
|
52 |
+
|
53 |
+
return vit, vit_transform
|