Samuel Stevens
bug: SAE examples are not highlighted
699b9c3
raw
history blame
1.55 kB
import functools
import logging
import typing
import beartype
import torch
from jaxtyping import Float, jaxtyped
from torch import Tensor
from torchvision.transforms import v2
logger = logging.getLogger("modeling.py")
@jaxtyped(typechecker=beartype.beartype)
class SplitDinov2(torch.nn.Module):
def __init__(self, *, split_at: int):
super().__init__()
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval()
self.split_at = split_at
def forward_start(
self, x: Float[Tensor, "batch channels width height"]
) -> Float[Tensor, "batch patches dim"]:
x_BPD = self.vit.prepare_tokens_with_masks(x)
for blk in self.vit.blocks[: self.split_at]:
x_BPD = blk(x_BPD)
return x_BPD
def forward_end(
self, x_BPD: Float[Tensor, "batch n_patches dim"]
) -> Float[Tensor, "batch patches dim"]:
for blk in self.vit.blocks[-self.split_at :]:
x_BPD = blk(x_BPD)
x_BPD = self.vit.norm(x_BPD)
return x_BPD[:, self.vit.num_register_tokens + 1 :]
@functools.cache
def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]:
vit = SplitDinov2(split_at=11).to(device)
vit_transform = v2.Compose([
v2.Resize(size=(256, 256)),
v2.CenterCrop(size=(224, 224)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]),
])
logger.info("Loaded ViT.")
return vit, vit_transform