File size: 5,054 Bytes
71026d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch.hub
from transformers import (
CLIPVisionModel,
CLIPVisionConfig,
CLIPModel,
CLIPProcessor,
AutoTokenizer,
CLIPTextModelWithProjection,
CLIPTextConfig,
CLIPVisionModelWithProjection,
ResNetModel,
ResNetConfig
)
from torch import nn
from PIL import Image
import requests
class CLIP(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
if path == "":
config_vision = CLIPVisionConfig()
self.clip = CLIPVisionModel(config_vision)
else:
self.clip = CLIPVisionModel.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
return features
class CLIPJZ(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
if path == "":
config_vision = CLIPVisionConfig()
self.clip = CLIPVisionModel(config_vision)
else:
self.clip = CLIPVisionModel.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip(pixel_values=x["img"])["last_hidden_state"]
return features
class StreetCLIP(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
self.clip = CLIPModel.from_pretrained(path)
self.transform = CLIPProcessor.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip.get_image_features(
**self.transform(images=x["img"], return_tensors="pt").to(x["gps"].device)
).unsqueeze(1)
return features
class CLIPText(nn.Module):
def __init__(self, path):
"""Initializes the CLIP model."""
super().__init__()
if path == "":
config_vision = CLIPVisionConfig()
self.clip = CLIPVisionModel(config_vision)
else:
self.clip = CLIPVisionModelWithProjection.from_pretrained(path)
def forward(self, x):
"""Predicts CLIP features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.clip(pixel_values=x["img"])
return features.image_embeds, features.last_hidden_state
class TextEncoder(nn.Module):
def __init__(self, path):
"""Initializes the CLIP text model."""
super().__init__()
if path == "":
config_vision = CLIPTextConfig()
self.clip = CLIPTextModelWithProjection(config_vision)
self.transform = AutoTokenizer()
else:
self.clip = CLIPTextModelWithProjection.from_pretrained(path)
self.transform = AutoTokenizer.from_pretrained(path)
for p in self.clip.parameters():
p.requires_grad = False
self.clip.eval()
def forward(self, x):
"""Predicts CLIP features from text.
Args:
x (dict that contains "text": list): Input batch
"""
features = self.clip(
**self.transform(x["text"], padding=True, return_tensors="pt").to(
x["gps"].device
)
).text_embeds
return features
class DINOv2(nn.Module):
def __init__(self, tag) -> None:
"""Initializes the DINO model."""
super().__init__()
self.dino = torch.hub.load("facebookresearch/dinov2", tag)
self.stride = 14 # ugly but dinov2 stride = 14
def forward(self, x):
"""Predicts DINO features from an image."""
x = x["img"]
# crop for stride
_, _, H, W = x.shape
H_new = H - H % self.stride
W_new = W - W % self.stride
x = x[:, :, :H_new, :W_new]
# forward features
x = self.dino.forward_features(x)
x = x["x_prenorm"]
return x
class ResNet(nn.Module):
def __init__(self, path):
"""Initializes the ResNet model."""
super().__init__()
if path == "":
config_vision = ResNetConfig()
self.resnet = ResNetModel(config_vision)
else:
self.resnet = ResNetModel.from_pretrained(path)
def forward(self, x):
"""Predicts ResNet50 features from an image.
Args:
x (dict that contains "img": torch.Tensor): Input batch
"""
features = self.resnet(x["img"])["pooler_output"]
return features.squeeze()
|