|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from transformers import ( |
|
CLIPVisionModelWithProjection, |
|
CLIPVisionConfig, |
|
Dinov2Model, |
|
Dinov2Config, |
|
) |
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
version=None, |
|
config=None, |
|
use_cls_token=True, |
|
image_size=224, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
if config is None: |
|
self.model = self.MODEL_CLASS.from_pretrained(version) |
|
else: |
|
self.model = self.MODEL_CLASS(self.MODEL_CONFIG_CLASS.from_dict(config)) |
|
self.model.eval() |
|
self.model.requires_grad_(False) |
|
self.use_cls_token = use_cls_token |
|
self.size = image_size // 14 |
|
self.num_patches = (image_size // 14) ** 2 |
|
if self.use_cls_token: |
|
self.num_patches += 1 |
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize(image_size, transforms.InterpolationMode.BILINEAR, antialias=True), |
|
transforms.CenterCrop(image_size), |
|
transforms.Normalize( |
|
mean=self.mean, |
|
std=self.std, |
|
), |
|
] |
|
) |
|
|
|
def forward(self, image, mask=None, value_range=(-1, 1)): |
|
if value_range is not None: |
|
low, high = value_range |
|
image = (image - low) / (high - low) |
|
|
|
image = image.to(self.model.device, dtype=self.model.dtype) |
|
|
|
if mask is not None: |
|
mask = mask.to(image) |
|
image = image * mask |
|
supported_sizes = [518, 530] |
|
if image.shape[2] not in supported_sizes or image.shape[3] not in supported_sizes: |
|
print(f'Image shape {image.shape} not supported. Resizing to 518x518') |
|
inputs = self.transform(image) |
|
else: |
|
inputs = image |
|
outputs = self.model(inputs) |
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
if not self.use_cls_token: |
|
last_hidden_state = last_hidden_state[:, 1:, :] |
|
|
|
return last_hidden_state |
|
|
|
def unconditional_embedding(self, batch_size): |
|
device = next(self.model.parameters()).device |
|
dtype = next(self.model.parameters()).dtype |
|
zero = torch.zeros( |
|
batch_size, |
|
self.num_patches, |
|
self.model.config.hidden_size, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
return zero |
|
|
|
|
|
class CLIPImageEncoder(ImageEncoder): |
|
MODEL_CLASS = CLIPVisionModelWithProjection |
|
MODEL_CONFIG_CLASS = CLIPVisionConfig |
|
mean = [0.48145466, 0.4578275, 0.40821073] |
|
std = [0.26862954, 0.26130258, 0.27577711] |
|
|
|
|
|
class DinoImageEncoder(ImageEncoder): |
|
MODEL_CLASS = Dinov2Model |
|
MODEL_CONFIG_CLASS = Dinov2Config |
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
|
|
|
|
def build_image_encoder(config): |
|
if config['type'] == 'CLIPImageEncoder': |
|
return CLIPImageEncoder(**config['kwargs']) |
|
elif config['type'] == 'DinoImageEncoder': |
|
return DinoImageEncoder(**config['kwargs']) |
|
else: |
|
raise ValueError(f'Unknown image encoder type: {config["type"]}') |
|
|
|
|
|
class DualImageEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
main_image_encoder, |
|
additional_image_encoder, |
|
): |
|
super().__init__() |
|
self.main_image_encoder = build_image_encoder(main_image_encoder) |
|
self.additional_image_encoder = build_image_encoder(additional_image_encoder) |
|
|
|
def forward(self, image, mask=None): |
|
outputs = { |
|
'main': self.main_image_encoder(image, mask=mask), |
|
'additional': self.additional_image_encoder(image, mask=mask), |
|
} |
|
return outputs |
|
|
|
def unconditional_embedding(self, batch_size): |
|
outputs = { |
|
'main': self.main_image_encoder.unconditional_embedding(batch_size), |
|
'additional': self.additional_image_encoder.unconditional_embedding(batch_size), |
|
} |
|
return outputs |
|
|
|
|
|
class SingleImageEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
main_image_encoder, |
|
): |
|
super().__init__() |
|
self.main_image_encoder = build_image_encoder(main_image_encoder) |
|
|
|
def forward(self, image, mask=None): |
|
outputs = { |
|
'main': self.main_image_encoder(image, mask=mask), |
|
} |
|
return outputs |
|
|
|
def unconditional_embedding(self, batch_size): |
|
outputs = { |
|
'main': self.main_image_encoder.unconditional_embedding(batch_size), |
|
} |
|
return outputs |
|
|