|
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
|
|
import timm
|
|
import torch.nn as nn
|
|
import torch
|
|
import numpy
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
|
|
class KEEPConfig(PretrainedConfig):
|
|
model_type = "keep"
|
|
|
|
def __init__(
|
|
self,
|
|
vision_config=None,
|
|
text_config=None,
|
|
projection_dim=768,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.vision_config = vision_config
|
|
self.text_config = text_config
|
|
self.projection_dim = projection_dim
|
|
|
|
|
|
class KEEPModel(PreTrainedModel):
|
|
config_class = KEEPConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
|
|
vision_config = config.vision_config
|
|
self.visual = timm.create_model(
|
|
"vit_large_patch16_224",
|
|
pretrained=False,
|
|
img_size=vision_config.get("img_size", 224),
|
|
patch_size=vision_config.get("patch_size", 16),
|
|
init_values=vision_config.get("init_values", 1e-5),
|
|
num_classes=vision_config.get("num_classes", 0),
|
|
dynamic_img_size=vision_config.get("dynamic_img_size", True),
|
|
)
|
|
|
|
|
|
self.visual_head = nn.Sequential(
|
|
nn.Linear(self.visual.num_features, config.projection_dim),
|
|
nn.GELU(),
|
|
nn.Linear(config.projection_dim, config.projection_dim)
|
|
)
|
|
|
|
|
|
text_config = BertConfig(**config.text_config)
|
|
self.text = BertModel(text_config)
|
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * numpy.log(1 / 0.04))
|
|
|
|
def encode_image(self, image_inputs):
|
|
vision_features = self.visual(image_inputs)
|
|
vision_features = torch.nn.functional.normalize(self.visual_head(vision_features), dim=-1)
|
|
|
|
return vision_features
|
|
|
|
def encode_text(self, text_inputs):
|
|
text_features = torch.nn.functional.normalize(self.text(**text_inputs).pooler_output, dim=-1)
|
|
return text_features
|
|
|
|
|
|
def forward(self, image_inputs, text_inputs):
|
|
vision_features = self.encode_image(image_inputs)
|
|
|
|
text_features = self.encode_text(text_inputs)
|
|
|
|
|
|
return {
|
|
"vision_features": vision_features,
|
|
"text_features": text_features
|
|
} |