|
import torch.hub
|
|
|
|
from transformers import (
|
|
CLIPVisionModel,
|
|
CLIPVisionConfig,
|
|
CLIPModel,
|
|
CLIPProcessor,
|
|
AutoTokenizer,
|
|
CLIPTextModelWithProjection,
|
|
CLIPTextConfig,
|
|
CLIPVisionModelWithProjection,
|
|
ResNetModel,
|
|
ResNetConfig
|
|
)
|
|
from torch import nn
|
|
|
|
from PIL import Image
|
|
import requests
|
|
|
|
|
|
class CLIP(nn.Module):
|
|
def __init__(self, path):
|
|
"""Initializes the CLIP model."""
|
|
super().__init__()
|
|
if path == "":
|
|
config_vision = CLIPVisionConfig()
|
|
self.clip = CLIPVisionModel(config_vision)
|
|
else:
|
|
self.clip = CLIPVisionModel.from_pretrained(path)
|
|
|
|
def forward(self, x):
|
|
"""Predicts CLIP features from an image.
|
|
Args:
|
|
x (dict that contains "img": torch.Tensor): Input batch
|
|
"""
|
|
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
|
|
return features
|
|
|
|
|
|
class CLIPJZ(nn.Module):
|
|
def __init__(self, path):
|
|
"""Initializes the CLIP model."""
|
|
super().__init__()
|
|
if path == "":
|
|
config_vision = CLIPVisionConfig()
|
|
self.clip = CLIPVisionModel(config_vision)
|
|
else:
|
|
self.clip = CLIPVisionModel.from_pretrained(path)
|
|
|
|
def forward(self, x):
|
|
"""Predicts CLIP features from an image.
|
|
Args:
|
|
x (dict that contains "img": torch.Tensor): Input batch
|
|
"""
|
|
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
|
|
return features
|
|
|
|
|
|
class StreetCLIP(nn.Module):
|
|
def __init__(self, path):
|
|
"""Initializes the CLIP model."""
|
|
super().__init__()
|
|
self.clip = CLIPModel.from_pretrained(path)
|
|
self.transform = CLIPProcessor.from_pretrained(path)
|
|
|
|
def forward(self, x):
|
|
"""Predicts CLIP features from an image.
|
|
Args:
|
|
x (dict that contains "img": torch.Tensor): Input batch
|
|
"""
|
|
features = self.clip.get_image_features(
|
|
**self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device)
|
|
).unsqueeze(1)
|
|
return features
|
|
|
|
|
|
class CLIPText(nn.Module):
|
|
def __init__(self, path):
|
|
"""Initializes the CLIP model."""
|
|
super().__init__()
|
|
if path == "":
|
|
config_vision = CLIPVisionConfig()
|
|
self.clip = CLIPVisionModel(config_vision)
|
|
else:
|
|
self.clip = CLIPVisionModelWithProjection.from_pretrained(path)
|
|
|
|
def forward(self, x):
|
|
"""Predicts CLIP features from an image.
|
|
Args:
|
|
x (dict that contains "img": torch.Tensor): Input batch
|
|
"""
|
|
features = self.clip(pixel_values=x["img"])
|
|
return features.image_embeds, features.last_hidden_state
|
|
|
|
|
|
class TextEncoder(nn.Module):
|
|
def __init__(self, path):
|
|
"""Initializes the CLIP text model."""
|
|
super().__init__()
|
|
if path == "":
|
|
config_vision = CLIPTextConfig()
|
|
self.clip = CLIPTextModelWithProjection(config_vision)
|
|
self.transform = AutoTokenizer()
|
|
else:
|
|
self.clip = CLIPTextModelWithProjection.from_pretrained(path)
|
|
self.transform = AutoTokenizer.from_pretrained(path)
|
|
for p in self.clip.parameters():
|
|
p.requires_grad = False
|
|
self.clip.eval()
|
|
|
|
def forward(self, x):
|
|
"""Predicts CLIP features from text.
|
|
Args:
|
|
x (dict that contains "text": list): Input batch
|
|
"""
|
|
features = self.clip(
|
|
**self.transform(x["text"], padding=True, return_tensors="pt").to(
|
|
x["gps"].device
|
|
)
|
|
).text_embeds
|
|
return features
|
|
|
|
|
|
class DINOv2(nn.Module):
|
|
def __init__(self, tag) -> None:
|
|
"""Initializes the DINO model."""
|
|
super().__init__()
|
|
self.dino = torch.hub.load("facebookresearch/dinov2", tag)
|
|
self.stride = 14
|
|
|
|
def forward(self, x):
|
|
"""Predicts DINO features from an image."""
|
|
x = x["img"]
|
|
|
|
|
|
_, _, H, W = x.shape
|
|
H_new = H - H % self.stride
|
|
W_new = W - W % self.stride
|
|
x = x[:, :, :H_new, :W_new]
|
|
|
|
|
|
x = self.dino.forward_features(x)
|
|
x = x["x_prenorm"]
|
|
return x
|
|
|
|
class ResNet(nn.Module):
|
|
def __init__(self, path):
|
|
"""Initializes the ResNet model."""
|
|
super().__init__()
|
|
if path == "":
|
|
config_vision = ResNetConfig()
|
|
self.resnet = ResNetModel(config_vision)
|
|
else:
|
|
self.resnet = ResNetModel.from_pretrained(path)
|
|
|
|
def forward(self, x):
|
|
"""Predicts ResNet50 features from an image.
|
|
Args:
|
|
x (dict that contains "img": torch.Tensor): Input batch
|
|
"""
|
|
features = self.resnet(x["img"])["pooler_output"]
|
|
return features.squeeze()
|
|
|