from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig import timm import torch.nn as nn import torch import numpy from torchvision import transforms from PIL import Image class RenameLayerScale(nn.Module): def __init__( self, dim: int, init_values: float = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.weight = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.weight) if self.inplace else x * self.weight timm.models.vision_transformer.LayerScale = RenameLayerScale class KEEPConfig(PretrainedConfig): model_type = "keep" # def __init__( self, vision_config=None, # Vision Encoder text_config=None, # Text Encoder projection_dim=768, **kwargs, ): super().__init__(**kwargs) self.vision_config = vision_config self.text_config = text_config self.projection_dim = projection_dim class KEEPModel(PreTrainedModel): config_class = KEEPConfig # def __init__(self, config): super().__init__(config) # Vision Encoder vision_config = config.vision_config self.visual = timm.create_model( "vit_large_patch16_224", pretrained=False, img_size=vision_config["img_size"], patch_size=vision_config["patch_size"], init_values=vision_config["init_values"], num_classes=vision_config["num_classes"], ) self.visual_head = nn.Sequential( nn.Linear(self.visual.num_features, config.projection_dim), nn.GELU(), nn.Linear(config.projection_dim, config.projection_dim) ) # Text Encoder text_config = BertConfig(**config.text_config) self.text = BertModel(text_config) self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04)) def encode_image(self, image_inputs): vision_features = self.visual(image_inputs) # [batch_size, vision_dim] vision_features = torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1) # [batch_size, projection_dim] return vision_features def encode_text(self, text_inputs): text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1) # [batch_size, text_dim] return text_features def forward(self, image_inputs, text_inputs): vision_features = self.encode_image(image_inputs) text_features = self.encode_text(text_inputs) return { "vision_features": vision_features, "text_features": text_features }