Image Feature Extraction
Transformers
Safetensors
dinov2
File size: 2,696 Bytes
f50d9fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor
from torch import nn
from transformers import AutoImageProcessor
from transformers import AutoModel
from transformers.feature_extraction_utils import BatchFeature


__version__ = "0.1.0"

TypeClsToken = Float[Tensor, "batch_size embed_dim"]
TypePatchTokensFlat = Float[Tensor, "batch_size (height width) embed_dim"]
TypePatchTokens = Float[Tensor, "batch_size embed_dim height width"]
TypeInputImages = Image.Image | list[Image.Image]


class RadDino(nn.Module):
    _REPO = "microsoft/rad-dino"

    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained(self._REPO).eval()
        self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)

    @property
    def device(self) -> torch.device:
        return next(self.model.parameters()).device

    def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
        return self.processor(image_or_images, return_tensors="pt")

    def encode(self, inputs: BatchFeature) -> tuple[TypeClsToken, TypePatchTokensFlat]:
        outputs = self.model(**inputs)
        cls_token = outputs.last_hidden_state[:, 0]
        patch_tokens = outputs.last_hidden_state[:, 1:]
        return cls_token, patch_tokens

    def reshape_patch_tokens(
        self,
        patch_tokens_flat: TypePatchTokensFlat,
    ) -> TypePatchTokens:
        input_size = self.processor.crop_size["height"]
        patch_size = self.model.config.patch_size
        embeddings_size = input_size // patch_size
        patches_grid = rearrange(
            patch_tokens_flat,
            "batch (height width) embed_dim -> batch embed_dim height width",
            height=embeddings_size,
        )
        return patches_grid

    @torch.inference_mode()
    def extract_features(
        self,
        image_or_images: TypeInputImages,
    ) -> tuple[TypeClsToken, TypePatchTokens]:
        inputs = self.preprocess(image_or_images).to(self.device)
        cls_token, patch_tokens_flat = self.encode(inputs)
        patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
        return cls_token, patch_tokens

    def extract_cls_token(self, image_or_images: TypeInputImages) -> TypeClsToken:
        cls_token, _ = self.extract_features(image_or_images)
        return cls_token

    def extract_patch_tokens(self, image_or_images: TypeInputImages) -> TypePatchTokens:
        _, patch_tokens = self.extract_features(image_or_images)
        return patch_tokens

    def forward(self, *args) -> tuple[TypeClsToken, TypePatchTokens]:
        return self.extract_features(*args)