Upload 3 files
Browse filesadded inferencing files
- clip_inferencing.py +65 -0
- clip_model.py +53 -0
- configuration.py +33 -0
clip_inferencing.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from transformers import DistilBertTokenizer
|
4 |
+
from tqdm.autonotebook import tqdm
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
from clip_model import CLIPModel
|
8 |
+
from configuration import CFG
|
9 |
+
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
def load_model(model_path):
|
14 |
+
model = CLIPModel().to(CFG.device)
|
15 |
+
model.load_state_dict(torch.load(model_path, map_location=CFG.device))
|
16 |
+
model.eval()
|
17 |
+
return model
|
18 |
+
|
19 |
+
def load_df():
|
20 |
+
with open("pickles/valid_df.pkl", 'rb') as file:
|
21 |
+
valid_df = pickle.load(file)
|
22 |
+
return valid_df
|
23 |
+
|
24 |
+
def load_image_embeddings():
|
25 |
+
with open("pickles/image_embeddings.pkl", 'rb') as file:
|
26 |
+
image_embeddings = pickle.load(file)
|
27 |
+
return image_embeddings
|
28 |
+
|
29 |
+
def find_matches(model, image_embeddings, query, image_filenames, n=9):
|
30 |
+
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
|
31 |
+
encoded_query = tokenizer([query])
|
32 |
+
batch = {
|
33 |
+
key: torch.tensor(values).to(CFG.device)
|
34 |
+
for key, values in encoded_query.items()
|
35 |
+
}
|
36 |
+
with torch.no_grad():
|
37 |
+
text_features = model.text_encoder(
|
38 |
+
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
|
39 |
+
)
|
40 |
+
text_embeddings = model.text_projection(text_features)
|
41 |
+
|
42 |
+
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
|
43 |
+
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
|
44 |
+
dot_similarity = text_embeddings_n @ image_embeddings_n.T
|
45 |
+
|
46 |
+
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
|
47 |
+
matches = [image_filenames[idx] for idx in indices[::5]]
|
48 |
+
|
49 |
+
_, axes = plt.subplots(3, 3, figsize=(10, 10))
|
50 |
+
for match, ax in zip(matches, axes.flatten()):
|
51 |
+
image = cv2.imread(f"Images/{match}")
|
52 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
53 |
+
ax.imshow(image)
|
54 |
+
ax.axis("off")
|
55 |
+
|
56 |
+
plt.show()
|
57 |
+
|
58 |
+
def inference():
|
59 |
+
valid_df = load_df()
|
60 |
+
image_embeddings = load_image_embeddings()
|
61 |
+
find_matches(load_model(model_path="model/best.pt"),
|
62 |
+
image_embeddings,
|
63 |
+
query="dogs on the grass",
|
64 |
+
image_filenames=valid_df['image'].values, n=9)
|
65 |
+
|
clip_model.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from image_encoder import ImageEncoder
|
5 |
+
from text_encoder import TextEncoder
|
6 |
+
from projection_head import ProjectionHead
|
7 |
+
from configuration import CFG
|
8 |
+
|
9 |
+
|
10 |
+
class CLIPModel(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
temperature=CFG.temperature,
|
14 |
+
image_embedding=CFG.image_embedding,
|
15 |
+
text_embedding=CFG.text_embedding,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.image_encoder = ImageEncoder()
|
19 |
+
self.text_encoder = TextEncoder()
|
20 |
+
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
|
21 |
+
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
|
22 |
+
self.temperature = temperature
|
23 |
+
|
24 |
+
def forward(self, batch):
|
25 |
+
# Getting Image and Text Features
|
26 |
+
image_features = self.image_encoder(batch["image"])
|
27 |
+
text_features = self.text_encoder(
|
28 |
+
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
|
29 |
+
)
|
30 |
+
# Getting Image and Text Embeddings (with same dimension)
|
31 |
+
image_embeddings = self.image_projection(image_features)
|
32 |
+
text_embeddings = self.text_projection(text_features)
|
33 |
+
|
34 |
+
# Calculating the Loss
|
35 |
+
logits = (text_embeddings @ image_embeddings.T) / self.temperature
|
36 |
+
images_similarity = image_embeddings @ image_embeddings.T
|
37 |
+
texts_similarity = text_embeddings @ text_embeddings.T
|
38 |
+
targets = F.softmax(
|
39 |
+
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
|
40 |
+
)
|
41 |
+
texts_loss = cross_entropy(logits, targets, reduction='none')
|
42 |
+
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
|
43 |
+
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
|
44 |
+
return loss.mean()
|
45 |
+
|
46 |
+
|
47 |
+
def cross_entropy(preds, targets, reduction='none'):
|
48 |
+
log_softmax = nn.LogSoftmax(dim=-1)
|
49 |
+
loss = (-targets * log_softmax(preds)).sum(1)
|
50 |
+
if reduction == "none":
|
51 |
+
return loss
|
52 |
+
elif reduction == "mean":
|
53 |
+
return loss.mean()
|
configuration.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class CFG:
|
4 |
+
debug = False
|
5 |
+
batch_size = 32
|
6 |
+
num_workers = 2
|
7 |
+
head_lr = 1e-3
|
8 |
+
image_encoder_lr = 1e-4
|
9 |
+
text_encoder_lr = 1e-5
|
10 |
+
weight_decay = 1e-3
|
11 |
+
patience = 1
|
12 |
+
factor = 0.8
|
13 |
+
epochs = 1 #4
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
+
model_name = 'resnet50'
|
17 |
+
image_embedding = 2048
|
18 |
+
text_encoder_model = "distilbert-base-uncased"
|
19 |
+
text_embedding = 768
|
20 |
+
text_tokenizer = "distilbert-base-uncased"
|
21 |
+
max_length = 200
|
22 |
+
|
23 |
+
pretrained = True # for both image encoder and text encoder
|
24 |
+
trainable = True # for both image encoder and text encoder
|
25 |
+
temperature = 1.0
|
26 |
+
|
27 |
+
# image size
|
28 |
+
size = 224
|
29 |
+
|
30 |
+
# for projection head; used for both image and text encoders
|
31 |
+
num_projection_layers = 1
|
32 |
+
projection_dim = 256
|
33 |
+
dropout = 0.1
|