File size: 6,264 Bytes
db6ee6a 0a8703d db6ee6a 0a8703d db6ee6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder
from .modules import MLP, MultiTaskModel
from .types import ImageModelOutput
class BaseImageModel(nn.Module, ABC):
"""Abstract class for image models."""
@abstractmethod
def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput:
raise NotImplementedError
@abstractmethod
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
raise NotImplementedError
def get_module_device(module: torch.nn.Module) -> torch.device:
"""
Returns the device of the module
"""
device = next(module.parameters()).device # type: ignore
assert isinstance(device, torch.device)
return device
class ImageModel(BaseImageModel):
"""Image encoder module"""
def __init__(self,
img_encoder_type: str,
joint_feature_size: int,
freeze_encoder: bool = False,
pretrained_model_path: Optional[Union[str, Path]] = None,
**downstream_classifier_kwargs: Any):
super().__init__()
# Initiate encoder, projector, and classifier
self.encoder = get_encoder_from_type(img_encoder_type)
self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
hidden_dim=joint_feature_size, use_1x1_convs=True)
self.downstream_classifier_kwargs = downstream_classifier_kwargs
self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None
# Initialise the mode of modules
self.freeze_encoder = freeze_encoder
self.train()
self.image_processor = None #TODO
if pretrained_model_path is not None:
if not isinstance(pretrained_model_path, (str, Path)):
raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}")
state_dict = torch.load(pretrained_model_path, map_location="cpu")
# drop projector
# for k in list(state_dict.keys()):
# if k.startswith("projector"):
# state_dict.pop(k)
self.load_state_dict(state_dict, strict=False)
def train(self, mode: bool = True) -> Any:
"""Switch the model between training and evaluation modes."""
super().train(mode=mode)
if self.freeze_encoder:
self.encoder.train(mode=False)
self.projector.train(mode=False)
return self
def forward(self, x: torch.Tensor) -> ImageModelOutput: # type: ignore[override]
with torch.set_grad_enabled(not self.freeze_encoder):
patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
return self.forward_post_encoder(patch_x, pooled_x)
def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput:
with torch.set_grad_enabled(not self.freeze_encoder):
projected_patch_embeddings = self.projector(patch_x)
projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))
logits = self.classifier(pooled_x) if self.classifier else None
return ImageModelOutput(img_embedding=pooled_x,
patch_embeddings=patch_x,
class_logits=logits,
projected_patch_embeddings=projected_patch_embeddings,
projected_global_embedding=projected_global_embedding)
def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
"""Create the classification module for the downstream task."""
downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)
@torch.no_grad()
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
"""Get patch-wise projected embeddings from the CNN model.
:param input_img: input tensor image [B, C, H, W].
:param normalize: If ``True``, the embeddings are L2-normalized.
:returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
"""
assert not self.training, "This function is only implemented for evaluation mode"
outputs = self.forward(input_img)
projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore
if normalize:
projected_embeddings = F.normalize(projected_embeddings, dim=1)
projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features)
return projected_embeddings
class MultiImageModel(ImageModel):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"
def forward(self, # type: ignore[override]
current_image: torch.Tensor,
previous_image: Optional[torch.Tensor] = None) -> ImageModelOutput:
with torch.set_grad_enabled(not self.freeze_encoder):
patch_x, pooled_x = self.encoder(current_image=current_image,
previous_image=previous_image,
return_patch_embeddings=True)
return self.forward_post_encoder(patch_x, pooled_x)
|