Image Feature Extraction
Transformers
Safetensors
dinov2
rad-dino / src /rad_dino /__init__.py
fepegar's picture
Refactor docs (#10)
f50d9fb verified
import torch
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor
from torch import nn
from transformers import AutoImageProcessor
from transformers import AutoModel
from transformers.feature_extraction_utils import BatchFeature
__version__ = "0.1.0"
TypeClsToken = Float[Tensor, "batch_size embed_dim"]
TypePatchTokensFlat = Float[Tensor, "batch_size (height width) embed_dim"]
TypePatchTokens = Float[Tensor, "batch_size embed_dim height width"]
TypeInputImages = Image.Image | list[Image.Image]
class RadDino(nn.Module):
_REPO = "microsoft/rad-dino"
def __init__(self):
super().__init__()
self.model = AutoModel.from_pretrained(self._REPO).eval()
self.processor = AutoImageProcessor.from_pretrained(self._REPO, use_fast=False)
@property
def device(self) -> torch.device:
return next(self.model.parameters()).device
def preprocess(self, image_or_images: TypeInputImages) -> BatchFeature:
return self.processor(image_or_images, return_tensors="pt")
def encode(self, inputs: BatchFeature) -> tuple[TypeClsToken, TypePatchTokensFlat]:
outputs = self.model(**inputs)
cls_token = outputs.last_hidden_state[:, 0]
patch_tokens = outputs.last_hidden_state[:, 1:]
return cls_token, patch_tokens
def reshape_patch_tokens(
self,
patch_tokens_flat: TypePatchTokensFlat,
) -> TypePatchTokens:
input_size = self.processor.crop_size["height"]
patch_size = self.model.config.patch_size
embeddings_size = input_size // patch_size
patches_grid = rearrange(
patch_tokens_flat,
"batch (height width) embed_dim -> batch embed_dim height width",
height=embeddings_size,
)
return patches_grid
@torch.inference_mode()
def extract_features(
self,
image_or_images: TypeInputImages,
) -> tuple[TypeClsToken, TypePatchTokens]:
inputs = self.preprocess(image_or_images).to(self.device)
cls_token, patch_tokens_flat = self.encode(inputs)
patch_tokens = self.reshape_patch_tokens(patch_tokens_flat)
return cls_token, patch_tokens
def extract_cls_token(self, image_or_images: TypeInputImages) -> TypeClsToken:
cls_token, _ = self.extract_features(image_or_images)
return cls_token
def extract_patch_tokens(self, image_or_images: TypeInputImages) -> TypePatchTokens:
_, patch_tokens = self.extract_features(image_or_images)
return patch_tokens
def forward(self, *args) -> tuple[TypeClsToken, TypePatchTokens]:
return self.extract_features(*args)